1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-27 17:52:56 +01:00

Add ability to pass a custom TlsConnector (#491)

This commit is contained in:
Markus Unterwaditzer 2018-08-29 20:53:31 +02:00 committed by Armin Ronacher
parent 5906971b6d
commit 4bab50c861

View File

@ -17,14 +17,16 @@ use tokio_io::{AsyncRead, AsyncWrite};
use tokio_timer::Delay; use tokio_timer::Delay;
#[cfg(feature = "alpn")] #[cfg(feature = "alpn")]
use openssl::ssl::{Error as OpensslError, SslConnector, SslMethod}; use {
#[cfg(feature = "alpn")] openssl::ssl::{Error as SslError, SslConnector, SslMethod},
use tokio_openssl::SslConnectorExt; tokio_openssl::SslConnectorExt
};
#[cfg(all(feature = "tls", not(feature = "alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
use native_tls::{Error as TlsError, TlsConnector as NativeTlsConnector}; use {
#[cfg(all(feature = "tls", not(feature = "alpn")))] native_tls::{Error as SslError, TlsConnector as NativeTlsConnector},
use tokio_tls::{TlsConnector}; tokio_tls::TlsConnector as SslConnector
};
#[cfg( #[cfg(
all( all(
@ -32,42 +34,25 @@ use tokio_tls::{TlsConnector};
not(any(feature = "alpn", feature = "tls")) not(any(feature = "alpn", feature = "tls"))
) )
)] )]
use rustls::ClientConfig; use {
rustls::ClientConfig,
std::io::Error as SslError,
std::sync::Arc,
tokio_rustls::ClientConfigExt,
webpki::DNSNameRef,
webpki_roots,
};
#[cfg( #[cfg(
all( all(
feature = "rust-tls", feature = "rust-tls",
not(any(feature = "alpn", feature = "tls")) not(any(feature = "alpn", feature = "tls"))
) )
)] )]
use std::io::Error as TLSError; type SslConnector = Arc<ClientConfig>;
#[cfg(
all( #[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))]
feature = "rust-tls", type SslConnector = ();
not(any(feature = "alpn", feature = "tls"))
)
)]
use std::sync::Arc;
#[cfg(
all(
feature = "rust-tls",
not(any(feature = "alpn", feature = "tls"))
)
)]
use tokio_rustls::ClientConfigExt;
#[cfg(
all(
feature = "rust-tls",
not(any(feature = "alpn", feature = "tls"))
)
)]
use webpki::DNSNameRef;
#[cfg(
all(
feature = "rust-tls",
not(any(feature = "alpn", feature = "tls"))
)
)]
use webpki_roots;
use server::IoStream; use server::IoStream;
use {HAS_OPENSSL, HAS_RUSTLS, HAS_TLS}; use {HAS_OPENSSL, HAS_RUSTLS, HAS_TLS};
@ -173,24 +158,9 @@ pub enum ClientConnectorError {
SslIsNotSupported, SslIsNotSupported,
/// SSL error /// SSL error
#[cfg(feature = "alpn")] #[cfg(any(feature = "tls", feature = "alpn", feature = "rust-tls"))]
#[fail(display = "{}", _0)] #[fail(display = "{}", _0)]
SslError(#[cause] OpensslError), SslError(#[cause] SslError),
/// SSL error
#[cfg(all(feature = "tls", not(feature = "alpn")))]
#[fail(display = "{}", _0)]
SslError(#[cause] TlsError),
/// SSL error
#[cfg(
all(
feature = "rust-tls",
not(any(feature = "alpn", feature = "tls"))
)
)]
#[fail(display = "{}", _0)]
SslError(#[cause] TLSError),
/// Resolver error /// Resolver error
#[fail(display = "{}", _0)] #[fail(display = "{}", _0)]
@ -242,17 +212,7 @@ impl Paused {
/// `ClientConnector` type is responsible for transport layer of a /// `ClientConnector` type is responsible for transport layer of a
/// client connection. /// client connection.
pub struct ClientConnector { pub struct ClientConnector {
#[cfg(all(feature = "alpn"))]
connector: SslConnector, connector: SslConnector,
#[cfg(all(feature = "tls", not(feature = "alpn")))]
connector: TlsConnector,
#[cfg(
all(
feature = "rust-tls",
not(any(feature = "alpn", feature = "tls"))
)
)]
connector: Arc<ClientConfig>,
stats: ClientConnectorStats, stats: ClientConnectorStats,
subscriber: Option<Recipient<ClientConnectorStats>>, subscriber: Option<Recipient<ClientConnectorStats>>,
@ -293,71 +253,32 @@ impl SystemService for ClientConnector {}
impl Default for ClientConnector { impl Default for ClientConnector {
fn default() -> ClientConnector { fn default() -> ClientConnector {
#[cfg(all(feature = "alpn"))] let connector = {
{ #[cfg(all(feature = "alpn"))]
let builder = SslConnector::builder(SslMethod::tls()).unwrap(); { SslConnector::builder(SslMethod::tls()).unwrap().build() }
ClientConnector::with_connector(builder.build())
}
#[cfg(all(feature = "tls", not(feature = "alpn")))]
{
let (tx, rx) = mpsc::unbounded();
let builder = NativeTlsConnector::builder();
ClientConnector {
stats: ClientConnectorStats::default(),
subscriber: None,
acq_tx: tx,
acq_rx: Some(rx),
resolver: None,
connector: builder.build().unwrap().into(),
conn_lifetime: Duration::from_secs(75),
conn_keep_alive: Duration::from_secs(15),
limit: 100,
limit_per_host: 0,
acquired: 0,
acquired_per_host: HashMap::new(),
available: HashMap::new(),
to_close: Vec::new(),
waiters: Some(HashMap::new()),
wait_timeout: None,
paused: Paused::No,
}
}
#[cfg(
all(
feature = "rust-tls",
not(any(feature = "alpn", feature = "tls"))
)
)]
{
let mut config = ClientConfig::new();
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
ClientConnector::with_connector(config)
}
#[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
{ { NativeTlsConnector::builder().build().unwrap().into() }
let (tx, rx) = mpsc::unbounded();
ClientConnector { #[cfg(
stats: ClientConnectorStats::default(), all(
subscriber: None, feature = "rust-tls",
acq_tx: tx, not(any(feature = "alpn", feature = "tls"))
acq_rx: Some(rx), )
resolver: None, )]
conn_lifetime: Duration::from_secs(75), {
conn_keep_alive: Duration::from_secs(15), let mut config = ClientConfig::new();
limit: 100, config
limit_per_host: 0, .root_store
acquired: 0, .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
acquired_per_host: HashMap::new(), Arc::new(config)
available: HashMap::new(),
to_close: Vec::new(),
waiters: Some(HashMap::new()),
wait_timeout: None,
paused: Paused::No,
} }
}
#[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))]
{ () }
};
ClientConnector::with_connector_impl(connector)
} }
} }
@ -402,27 +323,8 @@ impl ClientConnector {
/// } /// }
/// ``` /// ```
pub fn with_connector(connector: SslConnector) -> ClientConnector { pub fn with_connector(connector: SslConnector) -> ClientConnector {
let (tx, rx) = mpsc::unbounded(); // keep level of indirection for docstrings matching featureflags
Self::with_connector_impl(connector)
ClientConnector {
connector,
stats: ClientConnectorStats::default(),
subscriber: None,
acq_tx: tx,
acq_rx: Some(rx),
resolver: None,
conn_lifetime: Duration::from_secs(75),
conn_keep_alive: Duration::from_secs(15),
limit: 100,
limit_per_host: 0,
acquired: 0,
acquired_per_host: HashMap::new(),
available: HashMap::new(),
to_close: Vec::new(),
waiters: Some(HashMap::new()),
wait_timeout: None,
paused: Paused::No,
}
} }
#[cfg( #[cfg(
@ -476,10 +378,27 @@ impl ClientConnector {
/// } /// }
/// ``` /// ```
pub fn with_connector(connector: ClientConfig) -> ClientConnector { pub fn with_connector(connector: ClientConfig) -> ClientConnector {
// keep level of indirection for docstrings matching featureflags
Self::with_connector_impl(Arc::new(connector))
}
#[cfg(
all(
feature = "tls",
not(any(feature = "alpn", feature = "rust-tls"))
)
)]
pub fn with_connector(connector: SslConnector) -> ClientConnector {
// keep level of indirection for docstrings matching featureflags
Self::with_connector_impl(connector)
}
#[inline]
fn with_connector_impl(connector: SslConnector) -> ClientConnector {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
ClientConnector { ClientConnector {
connector: Arc::new(connector), connector,
stats: ClientConnectorStats::default(), stats: ClientConnectorStats::default(),
subscriber: None, subscriber: None,
acq_tx: tx, acq_tx: tx,
@ -1364,4 +1283,4 @@ impl<Io: IoStream> IoStream for TlsStream<Io> {
fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()> { fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()> {
self.get_mut().get_mut().set_linger(dur) self.get_mut().get_mut().set_linger(dur)
} }
} }