From 5dc2bfcb01b899d2d1e5d92e172fa3ab75bdaf74 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Mon, 29 Nov 2021 23:53:06 +0000 Subject: [PATCH] actix-tls release candidate prep (#422) --- actix-tls/CHANGES.md | 30 +- actix-tls/Cargo.toml | 9 +- actix-tls/src/accept/mod.rs | 13 +- actix-tls/src/accept/native_tls.rs | 77 +++-- actix-tls/src/accept/openssl.rs | 89 +++--- actix-tls/src/accept/rustls.rs | 88 +++--- actix-tls/src/connect/connect.rs | 361 ------------------------ actix-tls/src/connect/connect_addrs.rs | 82 ++++++ actix-tls/src/connect/connection.rs | 54 ++++ actix-tls/src/connect/connector.rs | 235 ++++++--------- actix-tls/src/connect/error.rs | 19 +- actix-tls/src/connect/host.rs | 71 +++++ actix-tls/src/connect/info.rs | 249 ++++++++++++++++ actix-tls/src/connect/mod.rs | 51 ++-- actix-tls/src/connect/native_tls.rs | 90 ++++++ actix-tls/src/connect/openssl.rs | 146 ++++++++++ actix-tls/src/connect/resolve.rs | 207 +------------- actix-tls/src/connect/resolver.rs | 201 +++++++++++++ actix-tls/src/connect/rustls.rs | 148 ++++++++++ actix-tls/src/connect/service.rs | 129 --------- actix-tls/src/connect/tcp.rs | 225 ++++++++++++--- actix-tls/src/connect/tls/mod.rs | 10 - actix-tls/src/connect/tls/native_tls.rs | 88 ------ actix-tls/src/connect/tls/openssl.rs | 130 --------- actix-tls/src/connect/tls/rustls.rs | 146 ---------- actix-tls/src/connect/uri.rs | 15 +- actix-tls/src/lib.rs | 7 +- actix-tls/tests/accept-rustls.rs | 10 +- actix-tls/tests/test_connect.rs | 63 +++-- actix-tls/tests/test_resolvers.rs | 21 +- 30 files changed, 1608 insertions(+), 1456 deletions(-) delete mode 100755 actix-tls/src/connect/connect.rs create mode 100644 actix-tls/src/connect/connect_addrs.rs create mode 100644 actix-tls/src/connect/connection.rs mode change 100755 => 100644 actix-tls/src/connect/connector.rs create mode 100644 actix-tls/src/connect/host.rs create mode 100644 actix-tls/src/connect/info.rs create mode 100644 actix-tls/src/connect/native_tls.rs create mode 100644 actix-tls/src/connect/openssl.rs mode change 100755 => 100644 actix-tls/src/connect/resolve.rs create mode 100644 actix-tls/src/connect/resolver.rs create mode 100644 actix-tls/src/connect/rustls.rs delete mode 100755 actix-tls/src/connect/service.rs delete mode 100644 actix-tls/src/connect/tls/mod.rs delete mode 100644 actix-tls/src/connect/tls/native_tls.rs delete mode 100755 actix-tls/src/connect/tls/openssl.rs delete mode 100755 actix-tls/src/connect/tls/rustls.rs mode change 100755 => 100644 actix-tls/tests/test_connect.rs diff --git a/actix-tls/CHANGES.md b/actix-tls/CHANGES.md index 7f2a0c5e..78ba3665 100644 --- a/actix-tls/CHANGES.md +++ b/actix-tls/CHANGES.md @@ -1,6 +1,33 @@ # Changes ## Unreleased - 2021-xx-xx +### Added +* Derive `Debug` for `connect::Connection`. [#422] +* Implement `Display` for `accept::TlsError`. [#422] +* Implement `Error` for `accept::TlsError` where both types also implement `Error`. [#422] +* Implement `Default` for `connect::Resolver`. [#422] +* Implement `Error` for `connect::ConnectError`. [#422] + +### Changed +* There are now no default features. [#422] +* Useful re-exports from underlying TLS crates are exposed in a `reexports` modules in all acceptors and connectors. +* Convert `connect::ResolverService` from enum to struct. [#422] +* Make `ConnectAddrsIter` private. [#422] +* Rename `accept::native_tls::{NativeTlsAcceptorService => AcceptorService}`. [#422] +* Rename `connect::{Address => Host}` trait. [#422] +* Rename method `connect::Connection::{host => hostname}`. [#422] +* Rename struct `connect::{Connect => ConnectInfo}`. [#422] +* Rename struct `connect::{ConnectService => ConnectorService}`. [#422] +* Rename struct `connect::{ConnectServiceFactory => Connector}`. [#422] +* Rename TLS acceptor service future types and hide from docs. [#422] +* Unbox some service futures types. [#422] + +### Removed +* Remove `connect::{new_connector, new_connector_factory, default_connector, default_connector_factory}` methods. [#422] +* Remove `connect::native_tls::Connector::service` method. [#422] +* Remove redundant `connect::Connection::from_parts` method. [#422] + +[#422]: https://github.com/actix/actix-net/pull/422 ## 3.0.0-beta.9 - 2021-11-22 @@ -41,8 +68,7 @@ * Remove `connect::ssl::openssl::OpensslConnectService`. [#297] * Add `connect::ssl::native_tls` module for native tls support. [#295] * Rename `accept::{nativetls => native_tls}`. [#295] -* Remove `connect::TcpConnectService` type. service caller expect a `TcpStream` should use - `connect::ConnectService` instead and call `Connection::into_parts`. [#299] +* Remove `connect::TcpConnectService` type. Service caller expecting a `TcpStream` should use `connect::ConnectService` instead and call `Connection::into_parts`. [#299] [#295]: https://github.com/actix/actix-net/pull/295 [#296]: https://github.com/actix/actix-net/pull/296 diff --git a/actix-tls/Cargo.toml b/actix-tls/Cargo.toml index 5349f330..29feff73 100755 --- a/actix-tls/Cargo.toml +++ b/actix-tls/Cargo.toml @@ -13,14 +13,15 @@ license = "MIT OR Apache-2.0" edition = "2018" [package.metadata.docs.rs] -features = ["openssl", "rustls", "native-tls", "accept", "connect", "uri"] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] [lib] name = "actix_tls" path = "src/lib.rs" [features] -default = ["accept", "connect", "uri"] +default = [] # enable acceptor services accept = [] @@ -48,11 +49,13 @@ actix-utils = "3.0.0" derive_more = "0.99.5" futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } -http = { version = "0.2.3", optional = true } log = "0.4" pin-project-lite = "0.2.7" tokio-util = { version = "0.6.3", default-features = false } +# uri +http = { version = "0.2.3", optional = true } + # openssl tls-openssl = { package = "openssl", version = "0.10.9", optional = true } tokio-openssl = { version = "0.6", optional = true } diff --git a/actix-tls/src/accept/mod.rs b/actix-tls/src/accept/mod.rs index 300e1767..de220ac5 100644 --- a/actix-tls/src/accept/mod.rs +++ b/actix-tls/src/accept/mod.rs @@ -1,4 +1,4 @@ -//! TLS acceptor services. +//! TLS connection acceptor services. use std::{ convert::Infallible, @@ -6,14 +6,18 @@ use std::{ }; use actix_utils::counter::Counter; +use derive_more::{Display, Error}; #[cfg(feature = "openssl")] +#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] pub mod openssl; #[cfg(feature = "rustls")] +#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] pub mod rustls; #[cfg(feature = "native-tls")] +#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] pub mod native_tls; pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256); @@ -41,15 +45,18 @@ pub fn max_concurrent_tls_connect(num: usize) { /// All TLS acceptors from this crate will return the `SvcErr` type parameter as [`Infallible`], /// which can be cast to your own service type, inferred or otherwise, /// using [`into_service_error`](Self::into_service_error). -#[derive(Debug)] +#[derive(Debug, Display, Error)] pub enum TlsError { /// TLS handshake has timed-out. + #[display(fmt = "TLS handshake has timed-out")] Timeout, /// Wraps TLS service errors. + #[display(fmt = "TLS handshake error")] Tls(TlsErr), - /// Wraps inner service errors. + /// Wraps service errors. + #[display(fmt = "Service error")] Service(SvcErr), } diff --git a/actix-tls/src/accept/native_tls.rs b/actix-tls/src/accept/native_tls.rs index e61300e6..534dc58d 100644 --- a/actix-tls/src/accept/native_tls.rs +++ b/actix-tls/src/accept/native_tls.rs @@ -1,7 +1,10 @@ +//! `native-tls` based TLS connection acceptor service. +//! +//! See [`Acceptor`] for main service factory docs. + use std::{ convert::Infallible, io::{self, IoSlice}, - ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -13,37 +16,27 @@ use actix_rt::{ time::timeout, }; use actix_service::{Service, ServiceFactory}; -use actix_utils::counter::Counter; +use actix_utils::{ + counter::Counter, + future::{ready, Ready as FutReady}, +}; +use derive_more::{Deref, DerefMut, From}; use futures_core::future::LocalBoxFuture; - -pub use tokio_native_tls::{native_tls::Error, TlsAcceptor}; +use tokio_native_tls::{native_tls::Error, TlsAcceptor}; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; -/// Wrapper type for `tokio_native_tls::TlsStream` in order to impl `ActixStream` trait. -pub struct TlsStream(tokio_native_tls::TlsStream); +pub mod reexports { + //! Re-exports from `native-tls` that are useful for acceptors. -impl From> for TlsStream { - fn from(stream: tokio_native_tls::TlsStream) -> Self { - Self(stream) - } + pub use tokio_native_tls::{native_tls::Error, TlsAcceptor}; } -impl Deref for TlsStream { - type Target = tokio_native_tls::TlsStream; +/// Wraps a `native-tls` based async TLS stream in order to implement [`ActixStream`]. +#[derive(Deref, DerefMut, From)] +pub struct TlsStream(tokio_native_tls::TlsStream); - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TlsStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl AsyncRead for TlsStream { +impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -53,7 +46,7 @@ impl AsyncRead for TlsStream { } } -impl AsyncWrite for TlsStream { +impl AsyncWrite for TlsStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -83,27 +76,24 @@ impl AsyncWrite for TlsStream { } } -impl ActixStream for TlsStream { +impl ActixStream for TlsStream { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { - T::poll_read_ready((&**self).get_ref().get_ref().get_ref(), cx) + IO::poll_read_ready((&**self).get_ref().get_ref().get_ref(), cx) } fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { - T::poll_write_ready((&**self).get_ref().get_ref().get_ref(), cx) + IO::poll_write_ready((&**self).get_ref().get_ref().get_ref(), cx) } } -/// Accept TLS connections via `native-tls` package. -/// -/// `native-tls` feature enables this `Acceptor` type. +/// Accept TLS connections via the `native-tls` crate. pub struct Acceptor { acceptor: TlsAcceptor, handshake_timeout: Duration, } impl Acceptor { - /// Create `native-tls` based `Acceptor` service factory. - #[inline] + /// Constructs `native-tls` based `Acceptor` service factory. pub fn new(acceptor: TlsAcceptor) -> Self { Acceptor { acceptor, @@ -130,35 +120,36 @@ impl Clone for Acceptor { } } -impl ServiceFactory for Acceptor { - type Response = TlsStream; +impl ServiceFactory for Acceptor { + type Response = TlsStream; type Error = TlsError; type Config = (); - type Service = NativeTlsAcceptorService; + type Service = AcceptorService; type InitError = (); - type Future = LocalBoxFuture<'static, Result>; + type Future = FutReady>; fn new_service(&self, _: ()) -> Self::Future { let res = MAX_CONN_COUNTER.with(|conns| { - Ok(NativeTlsAcceptorService { + Ok(AcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), handshake_timeout: self.handshake_timeout, }) }); - Box::pin(async { res }) + ready(res) } } -pub struct NativeTlsAcceptorService { +/// Native-TLS based acceptor service. +pub struct AcceptorService { acceptor: TlsAcceptor, conns: Counter, handshake_timeout: Duration, } -impl Service for NativeTlsAcceptorService { - type Response = TlsStream; +impl Service for AcceptorService { + type Response = TlsStream; type Error = TlsError; type Future = LocalBoxFuture<'static, Result>; @@ -170,7 +161,7 @@ impl Service for NativeTlsAcceptorService { } } - fn call(&self, io: T) -> Self::Future { + fn call(&self, io: IO) -> Self::Future { let guard = self.conns.get(); let acceptor = self.acceptor.clone(); diff --git a/actix-tls/src/accept/openssl.rs b/actix-tls/src/accept/openssl.rs index cb1887ea..a91000cc 100644 --- a/actix-tls/src/accept/openssl.rs +++ b/actix-tls/src/accept/openssl.rs @@ -1,8 +1,11 @@ +//! `openssl` based TLS acceptor service. +//! +//! See [`Acceptor`] for main service factory docs. + use std::{ convert::Infallible, future::Future, io::{self, IoSlice}, - ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -14,40 +17,29 @@ use actix_rt::{ time::{sleep, Sleep}, }; use actix_service::{Service, ServiceFactory}; -use actix_utils::counter::{Counter, CounterGuard}; -use futures_core::future::LocalBoxFuture; - -pub use openssl::ssl::{ - AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder, +use actix_utils::{ + counter::{Counter, CounterGuard}, + future::{ready, Ready as FutReady}, }; +use derive_more::{Deref, DerefMut, From}; +use openssl::ssl::{Error, Ssl, SslAcceptor}; use pin_project_lite::pin_project; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; -/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait. -pub struct TlsStream(tokio_openssl::SslStream); +pub mod reexports { + //! Re-exports from `openssl` that are useful for acceptors. -impl From> for TlsStream { - fn from(stream: tokio_openssl::SslStream) -> Self { - Self(stream) - } + pub use openssl::ssl::{ + AlpnError, Error, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder, + }; } -impl Deref for TlsStream { - type Target = tokio_openssl::SslStream; +/// Wraps an `openssl` based async TLS stream in order to implement [`ActixStream`]. +#[derive(Deref, DerefMut, From)] +pub struct TlsStream(tokio_openssl::SslStream); - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TlsStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl AsyncRead for TlsStream { +impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -57,7 +49,7 @@ impl AsyncRead for TlsStream { } } -impl AsyncWrite for TlsStream { +impl AsyncWrite for TlsStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -87,19 +79,17 @@ impl AsyncWrite for TlsStream { } } -impl ActixStream for TlsStream { +impl ActixStream for TlsStream { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { - T::poll_read_ready((&**self).get_ref(), cx) + IO::poll_read_ready((&**self).get_ref(), cx) } fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { - T::poll_write_ready((&**self).get_ref(), cx) + IO::poll_write_ready((&**self).get_ref(), cx) } } -/// Accept TLS connections via `openssl` package. -/// -/// `openssl` feature enables this `Acceptor` type. +/// Accept TLS connections via the `openssl` crate. pub struct Acceptor { acceptor: SslAcceptor, handshake_timeout: Duration, @@ -134,13 +124,13 @@ impl Clone for Acceptor { } } -impl ServiceFactory for Acceptor { - type Response = TlsStream; - type Error = TlsError; +impl ServiceFactory for Acceptor { + type Response = TlsStream; + type Error = TlsError; type Config = (); type Service = AcceptorService; type InitError = (); - type Future = LocalBoxFuture<'static, Result>; + type Future = FutReady>; fn new_service(&self, _: ()) -> Self::Future { let res = MAX_CONN_COUNTER.with(|conns| { @@ -151,20 +141,21 @@ impl ServiceFactory for Acceptor { }) }); - Box::pin(async { res }) + ready(res) } } +/// OpenSSL based acceptor service. pub struct AcceptorService { acceptor: SslAcceptor, conns: Counter, handshake_timeout: Duration, } -impl Service for AcceptorService { - type Response = TlsStream; - type Error = TlsError; - type Future = AcceptorServiceResponse; +impl Service for AcceptorService { + type Response = TlsStream; + type Error = TlsError; + type Future = AcceptFut; fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll> { if self.conns.available(ctx) { @@ -174,11 +165,11 @@ impl Service for AcceptorService { } } - fn call(&self, io: T) -> Self::Future { + fn call(&self, io: IO) -> Self::Future { let ssl_ctx = self.acceptor.context(); let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid."); - AcceptorServiceResponse { + AcceptFut { _guard: self.conns.get(), timeout: sleep(self.handshake_timeout), stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()), @@ -187,16 +178,18 @@ impl Service for AcceptorService { } pin_project! { - pub struct AcceptorServiceResponse { - stream: Option>, + /// Accept future for OpenSSL service. + #[doc(hidden)] + pub struct AcceptFut { + stream: Option>, #[pin] timeout: Sleep, _guard: CounterGuard, } } -impl Future for AcceptorServiceResponse { - type Output = Result, TlsError>; +impl Future for AcceptFut { + type Output = Result, TlsError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); diff --git a/actix-tls/src/accept/rustls.rs b/actix-tls/src/accept/rustls.rs index b0f31365..b6f3a8fe 100644 --- a/actix-tls/src/accept/rustls.rs +++ b/actix-tls/src/accept/rustls.rs @@ -1,8 +1,11 @@ +//! `rustls` based TLS connection acceptor service. +//! +//! See [`Acceptor`] for main service factory docs. + use std::{ convert::Infallible, future::Future, io::{self, IoSlice}, - ops::{Deref, DerefMut}, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -15,39 +18,28 @@ use actix_rt::{ time::{sleep, Sleep}, }; use actix_service::{Service, ServiceFactory}; -use actix_utils::counter::{Counter, CounterGuard}; -use futures_core::future::LocalBoxFuture; +use actix_utils::{ + counter::{Counter, CounterGuard}, + future::{ready, Ready as FutReady}, +}; +use derive_more::{Deref, DerefMut, From}; use pin_project_lite::pin_project; +use tokio_rustls::rustls::ServerConfig; use tokio_rustls::{Accept, TlsAcceptor}; -pub use tokio_rustls::rustls::ServerConfig; - use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; -/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait. -pub struct TlsStream(tokio_rustls::server::TlsStream); +pub mod reexports { + //! Re-exports from `rustls` that are useful for acceptors. -impl From> for TlsStream { - fn from(stream: tokio_rustls::server::TlsStream) -> Self { - Self(stream) - } + pub use tokio_rustls::rustls::ServerConfig; } -impl Deref for TlsStream { - type Target = tokio_rustls::server::TlsStream; +/// Wraps a `rustls` based async TLS stream in order to implement [`ActixStream`]. +#[derive(Deref, DerefMut, From)] +pub struct TlsStream(tokio_rustls::server::TlsStream); - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TlsStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl AsyncRead for TlsStream { +impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -57,7 +49,7 @@ impl AsyncRead for TlsStream { } } -impl AsyncWrite for TlsStream { +impl AsyncWrite for TlsStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -87,27 +79,24 @@ impl AsyncWrite for TlsStream { } } -impl ActixStream for TlsStream { +impl ActixStream for TlsStream { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { - T::poll_read_ready((&**self).get_ref().0, cx) + IO::poll_read_ready((&**self).get_ref().0, cx) } fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { - T::poll_write_ready((&**self).get_ref().0, cx) + IO::poll_write_ready((&**self).get_ref().0, cx) } } -/// Accept TLS connections via `rustls` package. -/// -/// `rustls` feature enables this `Acceptor` type. +/// Accept TLS connections via the `rustls` crate. pub struct Acceptor { config: Arc, handshake_timeout: Duration, } impl Acceptor { - /// Create Rustls based `Acceptor` service factory. - #[inline] + /// Constructs Rustls based acceptor service factory. pub fn new(config: ServerConfig) -> Self { Acceptor { config: Arc::new(config), @@ -125,7 +114,6 @@ impl Acceptor { } impl Clone for Acceptor { - #[inline] fn clone(&self) -> Self { Self { config: self.config.clone(), @@ -134,13 +122,13 @@ impl Clone for Acceptor { } } -impl ServiceFactory for Acceptor { - type Response = TlsStream; +impl ServiceFactory for Acceptor { + type Response = TlsStream; type Error = TlsError; type Config = (); type Service = AcceptorService; type InitError = (); - type Future = LocalBoxFuture<'static, Result>; + type Future = FutReady>; fn new_service(&self, _: ()) -> Self::Future { let res = MAX_CONN_COUNTER.with(|conns| { @@ -151,21 +139,21 @@ impl ServiceFactory for Acceptor { }) }); - Box::pin(async { res }) + ready(res) } } -/// Rustls based `Acceptor` service +/// Rustls based acceptor service. pub struct AcceptorService { acceptor: TlsAcceptor, conns: Counter, handshake_timeout: Duration, } -impl Service for AcceptorService { - type Response = TlsStream; +impl Service for AcceptorService { + type Response = TlsStream; type Error = TlsError; - type Future = AcceptorServiceFut; + type Future = AcceptFut; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { if self.conns.available(cx) { @@ -175,8 +163,8 @@ impl Service for AcceptorService { } } - fn call(&self, req: T) -> Self::Future { - AcceptorServiceFut { + fn call(&self, req: IO) -> Self::Future { + AcceptFut { fut: self.acceptor.accept(req), timeout: sleep(self.handshake_timeout), _guard: self.conns.get(), @@ -185,16 +173,18 @@ impl Service for AcceptorService { } pin_project! { - pub struct AcceptorServiceFut { - fut: Accept, + /// Accept future for Rustls service. + #[doc(hidden)] + pub struct AcceptFut { + fut: Accept, #[pin] timeout: Sleep, _guard: CounterGuard, } } -impl Future for AcceptorServiceFut { - type Output = Result, TlsError>; +impl Future for AcceptFut { + type Output = Result, TlsError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); diff --git a/actix-tls/src/connect/connect.rs b/actix-tls/src/connect/connect.rs deleted file mode 100755 index 65d9e05e..00000000 --- a/actix-tls/src/connect/connect.rs +++ /dev/null @@ -1,361 +0,0 @@ -use std::{ - collections::{vec_deque, VecDeque}, - fmt, - iter::{self, FromIterator as _}, - mem, - net::{IpAddr, SocketAddr}, -}; - -/// Parse a host into parts (hostname and port). -pub trait Address: Unpin + 'static { - /// Get hostname part. - fn hostname(&self) -> &str; - - /// Get optional port part. - fn port(&self) -> Option { - None - } -} - -impl Address for String { - fn hostname(&self) -> &str { - self - } -} - -impl Address for &'static str { - fn hostname(&self) -> &str { - self - } -} - -#[derive(Debug, Eq, PartialEq, Hash)] -pub(crate) enum ConnectAddrs { - None, - One(SocketAddr), - Multi(VecDeque), -} - -impl ConnectAddrs { - pub(crate) fn is_none(&self) -> bool { - matches!(self, Self::None) - } - - pub(crate) fn is_some(&self) -> bool { - !self.is_none() - } -} - -impl Default for ConnectAddrs { - fn default() -> Self { - Self::None - } -} - -impl From> for ConnectAddrs { - fn from(addr: Option) -> Self { - match addr { - Some(addr) => ConnectAddrs::One(addr), - None => ConnectAddrs::None, - } - } -} - -/// Connection info. -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct Connect { - pub(crate) req: R, - pub(crate) port: u16, - pub(crate) addr: ConnectAddrs, - pub(crate) local_addr: Option, -} - -impl Connect { - /// Create `Connect` instance by splitting the string by ':' and convert the second part to u16 - pub fn new(req: R) -> Connect { - let (_, port) = parse_host(req.hostname()); - - Connect { - req, - port: port.unwrap_or(0), - addr: ConnectAddrs::None, - local_addr: None, - } - } - - /// Create new `Connect` instance from host and address. Connector skips name resolution stage - /// for such connect messages. - pub fn with_addr(req: R, addr: SocketAddr) -> Connect { - Connect { - req, - port: 0, - addr: ConnectAddrs::One(addr), - local_addr: None, - } - } - - /// Use port if address does not provide one. - /// - /// Default value is 0. - pub fn set_port(mut self, port: u16) -> Self { - self.port = port; - self - } - - /// Set address. - pub fn set_addr(mut self, addr: Option) -> Self { - self.addr = ConnectAddrs::from(addr); - self - } - - /// Set list of addresses. - pub fn set_addrs(mut self, addrs: I) -> Self - where - I: IntoIterator, - { - let mut addrs = VecDeque::from_iter(addrs); - self.addr = if addrs.len() < 2 { - ConnectAddrs::from(addrs.pop_front()) - } else { - ConnectAddrs::Multi(addrs) - }; - self - } - - /// Set local_addr of connect. - pub fn set_local_addr(mut self, addr: impl Into) -> Self { - self.local_addr = Some(addr.into()); - self - } - - /// Get hostname. - pub fn hostname(&self) -> &str { - self.req.hostname() - } - - /// Get request port. - pub fn port(&self) -> u16 { - self.req.port().unwrap_or(self.port) - } - - /// Get resolved request addresses. - pub fn addrs(&self) -> ConnectAddrsIter<'_> { - match self.addr { - ConnectAddrs::None => ConnectAddrsIter::None, - ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), - ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), - } - } - - /// Take resolved request addresses. - pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> { - match mem::take(&mut self.addr) { - ConnectAddrs::None => ConnectAddrsIter::None, - ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), - ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), - } - } - - /// Returns a reference to the connection request. - pub fn request(&self) -> &R { - &self.req - } -} - -impl From for Connect { - fn from(addr: R) -> Self { - Connect::new(addr) - } -} - -impl fmt::Display for Connect { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.hostname(), self.port()) - } -} - -/// Iterator over addresses in a [`Connect`] request. -#[derive(Clone)] -pub enum ConnectAddrsIter<'a> { - None, - One(SocketAddr), - Multi(vec_deque::Iter<'a, SocketAddr>), - MultiOwned(vec_deque::IntoIter), -} - -impl Iterator for ConnectAddrsIter<'_> { - type Item = SocketAddr; - - fn next(&mut self) -> Option { - match *self { - Self::None => None, - Self::One(addr) => { - *self = Self::None; - Some(addr) - } - Self::Multi(ref mut iter) => iter.next().copied(), - Self::MultiOwned(ref mut iter) => iter.next(), - } - } - - fn size_hint(&self) -> (usize, Option) { - match *self { - Self::None => (0, Some(0)), - Self::One(_) => (1, Some(1)), - Self::Multi(ref iter) => iter.size_hint(), - Self::MultiOwned(ref iter) => iter.size_hint(), - } - } -} - -impl fmt::Debug for ConnectAddrsIter<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.clone()).finish() - } -} - -impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {} - -impl iter::FusedIterator for ConnectAddrsIter<'_> {} - -pub struct Connection { - io: U, - req: T, -} - -impl Connection { - pub fn new(io: U, req: T) -> Self { - Self { io, req } - } -} - -impl Connection { - /// Reconstruct from a parts. - pub fn from_parts(io: U, req: T) -> Self { - Self { io, req } - } - - /// Deconstruct into a parts. - pub fn into_parts(self) -> (U, T) { - (self.io, self.req) - } - - /// Replace inclosed object, return new Stream and old object - pub fn replace_io(self, io: Y) -> (U, Connection) { - (self.io, Connection { io, req: self.req }) - } - - /// Returns a shared reference to the underlying stream. - pub fn io_ref(&self) -> &U { - &self.io - } - - /// Returns a mutable reference to the underlying stream. - pub fn io_mut(&mut self) -> &mut U { - &mut self.io - } -} - -impl Connection { - /// Get hostname. - pub fn host(&self) -> &str { - self.req.hostname() - } -} - -impl std::ops::Deref for Connection { - type Target = U; - - fn deref(&self) -> &U { - &self.io - } -} - -impl std::ops::DerefMut for Connection { - fn deref_mut(&mut self) -> &mut U { - &mut self.io - } -} - -impl fmt::Debug for Connection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Stream {{{:?}}}", self.io) - } -} - -fn parse_host(host: &str) -> (&str, Option) { - let mut parts_iter = host.splitn(2, ':'); - - match parts_iter.next() { - Some(hostname) => { - let port_str = parts_iter.next().unwrap_or(""); - let port = port_str.parse::().ok(); - (hostname, port) - } - - None => (host, None), - } -} - -#[cfg(test)] -mod tests { - use std::net::Ipv4Addr; - - use super::*; - - #[test] - fn test_host_parser() { - assert_eq!(parse_host("example.com"), ("example.com", None)); - assert_eq!(parse_host("example.com:8080"), ("example.com", Some(8080))); - assert_eq!(parse_host("example:8080"), ("example", Some(8080))); - assert_eq!(parse_host("example.com:false"), ("example.com", None)); - assert_eq!(parse_host("example.com:false:false"), ("example.com", None)); - } - - #[test] - fn test_addr_iter_multi() { - let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); - let unspecified = SocketAddr::from((IpAddr::from(Ipv4Addr::UNSPECIFIED), 8080)); - - let mut addrs = VecDeque::new(); - addrs.push_back(localhost); - addrs.push_back(unspecified); - - let mut iter = ConnectAddrsIter::Multi(addrs.iter()); - assert_eq!(iter.next(), Some(localhost)); - assert_eq!(iter.next(), Some(unspecified)); - assert_eq!(iter.next(), None); - - let mut iter = ConnectAddrsIter::MultiOwned(addrs.into_iter()); - assert_eq!(iter.next(), Some(localhost)); - assert_eq!(iter.next(), Some(unspecified)); - assert_eq!(iter.next(), None); - } - - #[test] - fn test_addr_iter_single() { - let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); - - let mut iter = ConnectAddrsIter::One(localhost); - assert_eq!(iter.next(), Some(localhost)); - assert_eq!(iter.next(), None); - - let mut iter = ConnectAddrsIter::None; - assert_eq!(iter.next(), None); - } - - #[test] - fn test_local_addr() { - let conn = Connect::new("hello").set_local_addr([127, 0, 0, 1]); - assert_eq!( - conn.local_addr.unwrap(), - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) - ) - } - - #[test] - fn request_ref() { - let conn = Connect::new("hello"); - assert_eq!(conn.request(), &"hello") - } -} diff --git a/actix-tls/src/connect/connect_addrs.rs b/actix-tls/src/connect/connect_addrs.rs new file mode 100644 index 00000000..13e4c4fa --- /dev/null +++ b/actix-tls/src/connect/connect_addrs.rs @@ -0,0 +1,82 @@ +use std::{ + collections::{vec_deque, VecDeque}, + fmt, iter, + net::SocketAddr, +}; + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub(crate) enum ConnectAddrs { + None, + One(SocketAddr), + // TODO: consider using smallvec + Multi(VecDeque), +} + +impl ConnectAddrs { + pub(crate) fn is_unresolved(&self) -> bool { + matches!(self, Self::None) + } + + pub(crate) fn is_resolved(&self) -> bool { + !self.is_unresolved() + } +} + +impl Default for ConnectAddrs { + fn default() -> Self { + Self::None + } +} + +impl From> for ConnectAddrs { + fn from(addr: Option) -> Self { + match addr { + Some(addr) => ConnectAddrs::One(addr), + None => ConnectAddrs::None, + } + } +} + +/// Iterator over addresses in a [`Connect`] request. +#[derive(Clone)] +pub(crate) enum ConnectAddrsIter<'a> { + None, + One(SocketAddr), + Multi(vec_deque::Iter<'a, SocketAddr>), + MultiOwned(vec_deque::IntoIter), +} + +impl Iterator for ConnectAddrsIter<'_> { + type Item = SocketAddr; + + fn next(&mut self) -> Option { + match *self { + Self::None => None, + Self::One(addr) => { + *self = Self::None; + Some(addr) + } + Self::Multi(ref mut iter) => iter.next().copied(), + Self::MultiOwned(ref mut iter) => iter.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match *self { + Self::None => (0, Some(0)), + Self::One(_) => (1, Some(1)), + Self::Multi(ref iter) => iter.size_hint(), + Self::MultiOwned(ref iter) => iter.size_hint(), + } + } +} + +impl fmt::Debug for ConnectAddrsIter<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.clone()).finish() + } +} + +impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {} + +impl iter::FusedIterator for ConnectAddrsIter<'_> {} diff --git a/actix-tls/src/connect/connection.rs b/actix-tls/src/connect/connection.rs new file mode 100644 index 00000000..68972a2a --- /dev/null +++ b/actix-tls/src/connect/connection.rs @@ -0,0 +1,54 @@ +use derive_more::{Deref, DerefMut}; + +use super::Host; + +/// Wraps underlying I/O and the connection request that initiated it. +#[derive(Debug, Deref, DerefMut)] +pub struct Connection { + pub(crate) req: R, + + #[deref] + #[deref_mut] + pub(crate) io: IO, +} + +impl Connection { + /// Construct new `Connection` from request and IO parts. + pub(crate) fn new(req: R, io: IO) -> Self { + Self { req, io } + } +} + +impl Connection { + /// Deconstructs into IO and request parts. + pub fn into_parts(self) -> (IO, R) { + (self.io, self.req) + } + + /// Replaces underlying IO, returning old IO and new `Connection`. + pub fn replace_io(self, io: IO2) -> (IO, Connection) { + (self.io, Connection { io, req: self.req }) + } + + /// Returns a shared reference to the underlying IO. + pub fn io_ref(&self) -> &IO { + &self.io + } + + /// Returns a mutable reference to the underlying IO. + pub fn io_mut(&mut self) -> &mut IO { + &mut self.io + } + + /// Returns a reference to the connection request. + pub fn request(&self) -> &R { + &self.req + } +} + +impl Connection { + /// Returns hostname. + pub fn hostname(&self) -> &str { + self.req.hostname() + } +} diff --git a/actix-tls/src/connect/connector.rs b/actix-tls/src/connect/connector.rs old mode 100755 new mode 100644 index 9438404e..f5717661 --- a/actix-tls/src/connect/connector.rs +++ b/actix-tls/src/connect/connector.rs @@ -1,194 +1,127 @@ use std::{ - collections::VecDeque, future::Future, - io, - net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, pin::Pin, task::{Context, Poll}, }; -use actix_rt::net::{TcpSocket, TcpStream}; +use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; -use log::{error, trace}; -use tokio_util::sync::ReusableBoxFuture; +use actix_utils::future::{ok, Ready}; +use futures_core::ready; -use super::connect::{Address, Connect, ConnectAddrs, Connection}; -use super::error::ConnectError; +use super::{ + error::ConnectError, + resolver::{Resolver, ResolverService}, + tcp::{TcpConnector, TcpConnectorService}, + ConnectInfo, Connection, Host, +}; -/// TCP connector service factory -#[derive(Debug, Copy, Clone)] -pub struct TcpConnectorFactory; +/// Combined resolver and TCP connector service factory. +/// +/// Used to create [`ConnectorService`]s which receive connection information, resolve DNS if +/// required, and return a TCP stream. +#[derive(Clone, Default)] +pub struct Connector { + resolver: Resolver, +} -impl TcpConnectorFactory { - /// Create TCP connector service - pub fn service(&self) -> TcpConnector { - TcpConnector +impl Connector { + /// Constructs new connector factory with the given resolver. + pub fn new(resolver: Resolver) -> Self { + Connector { resolver } + } + + /// Build connector service. + pub fn service(&self) -> ConnectorService { + ConnectorService { + tcp: TcpConnector.service(), + resolver: self.resolver.service(), + } } } -impl ServiceFactory> for TcpConnectorFactory { - type Response = Connection; +impl ServiceFactory> for Connector { + type Response = Connection; type Error = ConnectError; type Config = (); - type Service = TcpConnector; + type Service = ConnectorService; type InitError = (); - type Future = LocalBoxFuture<'static, Result>; + type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { - let service = self.service(); - Box::pin(async move { Ok(service) }) + ok(self.service()) } } -/// TCP connector service -#[derive(Debug, Copy, Clone)] -pub struct TcpConnector; +/// Combined resolver and TCP connector service. +/// +/// Service implementation receives connection information, resolves DNS if required, and returns +/// a TCP stream. +#[derive(Clone)] +pub struct ConnectorService { + tcp: TcpConnectorService, + resolver: ResolverService, +} -impl Service> for TcpConnector { - type Response = Connection; +impl Service> for ConnectorService { + type Response = Connection; type Error = ConnectError; - type Future = TcpConnectorResponse; + type Future = ConnectServiceResponse; actix_service::always_ready!(); - fn call(&self, req: Connect) -> Self::Future { - let port = req.port(); - let Connect { - req, - addr, - local_addr, - .. - } = req; - - TcpConnectorResponse::new(req, port, local_addr, addr) + fn call(&self, req: ConnectInfo) -> Self::Future { + ConnectServiceResponse { + fut: ConnectFut::Resolve(self.resolver.call(req)), + tcp: self.tcp, + } } } -/// TCP stream connector response future -pub enum TcpConnectorResponse { - Response { - req: Option, - port: u16, - local_addr: Option, - addrs: Option>, - stream: ReusableBoxFuture>, - }, - Error(Option), +/// Helper enum to generic over futures of resolve and connect steps. +pub(crate) enum ConnectFut { + Resolve(>>::Future), + Connect(>>::Future), } -impl TcpConnectorResponse { - pub(crate) fn new( - req: T, - port: u16, - local_addr: Option, - addr: ConnectAddrs, - ) -> TcpConnectorResponse { - if addr.is_none() { - error!("TCP connector: unresolved connection address"); - return TcpConnectorResponse::Error(Some(ConnectError::Unresolved)); - } +/// Helper enum to contain the future output of `ConnectFuture`. +pub(crate) enum ConnectOutput { + Resolved(ConnectInfo), + Connected(Connection), +} - trace!( - "TCP connector: connecting to {} on port {}", - req.hostname(), - port - ); - - match addr { - ConnectAddrs::None => unreachable!("none variant already checked"), - - ConnectAddrs::One(addr) => TcpConnectorResponse::Response { - req: Some(req), - port, - local_addr, - addrs: None, - stream: ReusableBoxFuture::new(connect(addr, local_addr)), - }, - - // when resolver returns multiple socket addr for request they would be popped from - // front end of queue and returns with the first successful tcp connection. - ConnectAddrs::Multi(mut addrs) => { - let addr = addrs.pop_front().unwrap(); - - TcpConnectorResponse::Response { - req: Some(req), - port, - local_addr, - addrs: Some(addrs), - stream: ReusableBoxFuture::new(connect(addr, local_addr)), - } +impl ConnectFut { + fn poll_connect( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, ConnectError>> { + match self { + ConnectFut::Resolve(ref mut fut) => { + Pin::new(fut).poll(cx).map_ok(ConnectOutput::Resolved) + } + ConnectFut::Connect(ref mut fut) => { + Pin::new(fut).poll(cx).map_ok(ConnectOutput::Connected) } } } } -impl Future for TcpConnectorResponse { - type Output = Result, ConnectError>; +pub struct ConnectServiceResponse { + fut: ConnectFut, + tcp: TcpConnectorService, +} - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.get_mut() { - TcpConnectorResponse::Error(err) => Poll::Ready(Err(err.take().unwrap())), +impl Future for ConnectServiceResponse { + type Output = Result, ConnectError>; - TcpConnectorResponse::Response { - req, - port, - local_addr, - addrs, - stream, - } => loop { - match ready!(stream.poll(cx)) { - Ok(sock) => { - let req = req.take().unwrap(); - trace!( - "TCP connector: successfully connected to {:?} - {:?}", - req.hostname(), - sock.peer_addr() - ); - return Poll::Ready(Ok(Connection::new(sock, req))); - } - - Err(err) => { - trace!( - "TCP connector: failed to connect to {:?} port: {}", - req.as_ref().unwrap().hostname(), - port, - ); - - if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) { - stream.set(connect(addr, *local_addr)); - } else { - return Poll::Ready(Err(ConnectError::Io(err))); - } - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match ready!(self.fut.poll_connect(cx))? { + ConnectOutput::Resolved(res) => { + self.fut = ConnectFut::Connect(self.tcp.call(res)); } - }, + ConnectOutput::Connected(res) => return Poll::Ready(Ok(res)), + } } } } - -async fn connect(addr: SocketAddr, local_addr: Option) -> io::Result { - // use local addr if connect asks for it. - match local_addr { - Some(ip_addr) => { - let socket = match ip_addr { - IpAddr::V4(ip_addr) => { - let socket = TcpSocket::new_v4()?; - let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0)); - socket.bind(addr)?; - socket - } - IpAddr::V6(ip_addr) => { - let socket = TcpSocket::new_v6()?; - let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0)); - socket.bind(addr)?; - socket - } - }; - - socket.connect(addr).await - } - - None => TcpStream::connect(addr).await, - } -} diff --git a/actix-tls/src/connect/error.rs b/actix-tls/src/connect/error.rs index 5d8cb9db..46944988 100644 --- a/actix-tls/src/connect/error.rs +++ b/actix-tls/src/connect/error.rs @@ -1,15 +1,16 @@ -use std::io; +use std::{error::Error, io}; use derive_more::Display; +/// Errors that can result from using a connector service. #[derive(Debug, Display)] pub enum ConnectError { /// Failed to resolve the hostname - #[display(fmt = "Failed resolving hostname: {}", _0)] + #[display(fmt = "Failed resolving hostname")] Resolver(Box), - /// No dns records - #[display(fmt = "No dns records found for the input")] + /// No DNS records + #[display(fmt = "No DNS records found for the input")] NoRecords, /// Invalid input @@ -23,3 +24,13 @@ pub enum ConnectError { #[display(fmt = "{}", _0)] Io(io::Error), } + +impl Error for ConnectError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::Resolver(err) => Some(&**err), + Self::Io(err) => Some(err), + Self::NoRecords | Self::InvalidInput | Self::Unresolved => None, + } + } +} diff --git a/actix-tls/src/connect/host.rs b/actix-tls/src/connect/host.rs new file mode 100644 index 00000000..c4ff9a01 --- /dev/null +++ b/actix-tls/src/connect/host.rs @@ -0,0 +1,71 @@ +//! The [`Host`] trait. + +/// An interface for types where host parts (hostname and port) can be derived. +/// +/// The [WHATWG URL Standard] defines the terminology used for this trait and its methods. +/// +/// ```plain +/// +------------------------+ +/// | host | +/// +-----------------+------+ +/// | hostname | port | +/// | | | +/// | sub.example.com : 8080 | +/// +-----------------+------+ +/// ``` +/// +/// [WHATWG URL Standard]: https://url.spec.whatwg.org/ +pub trait Host: Unpin + 'static { + /// Extract hostname. + fn hostname(&self) -> &str; + + /// Extract optional port. + fn port(&self) -> Option { + None + } +} + +impl Host for String { + fn hostname(&self) -> &str { + self.split_once(':') + .map(|(hostname, _)| hostname) + .unwrap_or(self) + } + + fn port(&self) -> Option { + self.split_once(':').and_then(|(_, port)| port.parse().ok()) + } +} + +impl Host for &'static str { + fn hostname(&self) -> &str { + self.split_once(':') + .map(|(hostname, _)| hostname) + .unwrap_or(self) + } + + fn port(&self) -> Option { + self.split_once(':').and_then(|(_, port)| port.parse().ok()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_connection_info_eq { + ($req:expr, $hostname:expr, $port:expr) => {{ + assert_eq!($req.hostname(), $hostname); + assert_eq!($req.port(), $port); + }}; + } + + #[test] + fn host_parsing() { + assert_connection_info_eq!("example.com", "example.com", None); + assert_connection_info_eq!("example.com:8080", "example.com", Some(8080)); + assert_connection_info_eq!("example:8080", "example", Some(8080)); + assert_connection_info_eq!("example.com:false", "example.com", None); + assert_connection_info_eq!("example.com:false:false", "example.com", None); + } +} diff --git a/actix-tls/src/connect/info.rs b/actix-tls/src/connect/info.rs new file mode 100644 index 00000000..7bd1e5f3 --- /dev/null +++ b/actix-tls/src/connect/info.rs @@ -0,0 +1,249 @@ +//! Connection info struct. + +use std::{ + collections::VecDeque, + fmt, + iter::{self, FromIterator as _}, + mem, + net::{IpAddr, SocketAddr}, +}; + +use super::{ + connect_addrs::{ConnectAddrs, ConnectAddrsIter}, + Host, +}; + +/// Connection request information. +/// +/// May contain known/pre-resolved socket address(es) or a host that needs resolving with DNS. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ConnectInfo { + pub(crate) request: R, + pub(crate) port: u16, + pub(crate) addr: ConnectAddrs, + pub(crate) local_addr: Option, +} + +impl ConnectInfo { + /// Constructs new connection info using a request. + pub fn new(request: R) -> ConnectInfo { + let port = request.port(); + + ConnectInfo { + request, + port: port.unwrap_or(0), + addr: ConnectAddrs::None, + local_addr: None, + } + } + + /// Constructs new connection info from request and known socket address. + /// + /// Since socket address is known, [`Connector`](super::Connector) will skip the DNS + /// resolution step. + pub fn with_addr(request: R, addr: SocketAddr) -> ConnectInfo { + ConnectInfo { + request, + port: 0, + addr: ConnectAddrs::One(addr), + local_addr: None, + } + } + + /// Set connection port. + /// + /// If request provided a port, this will override it. + pub fn set_port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Set connection socket address. + pub fn set_addr(mut self, addr: impl Into>) -> Self { + self.addr = ConnectAddrs::from(addr.into()); + self + } + + /// Set list of addresses. + pub fn set_addrs(mut self, addrs: I) -> Self + where + I: IntoIterator, + { + let mut addrs = VecDeque::from_iter(addrs); + self.addr = if addrs.len() < 2 { + ConnectAddrs::from(addrs.pop_front()) + } else { + ConnectAddrs::Multi(addrs) + }; + self + } + + /// Set local address to connection with. + /// + /// Useful in situations where the IP address bound to a particular network interface is known. + /// This would make sure the socket is opened through that interface. + pub fn set_local_addr(mut self, addr: impl Into) -> Self { + self.local_addr = Some(addr.into()); + self + } + + /// Returns a reference to the connection request. + pub fn request(&self) -> &R { + &self.request + } + + /// Returns request hostname. + pub fn hostname(&self) -> &str { + self.request.hostname() + } + + /// Returns request port. + pub fn port(&self) -> u16 { + self.request.port().unwrap_or(self.port) + } + + /// Get borrowed iterator of resolved request addresses. + /// + /// # Examples + /// ``` + /// # use std::net::SocketAddr; + /// # use actix_tls::connect::ConnectInfo; + /// let addr = SocketAddr::from(([127, 0, 0, 1], 4242)); + /// + /// let conn = ConnectInfo::new("localhost"); + /// let mut addrs = conn.addrs(); + /// assert!(addrs.next().is_none()); + /// + /// let conn = ConnectInfo::with_addr("localhost", addr); + /// let mut addrs = conn.addrs(); + /// assert_eq!(addrs.next().unwrap(), addr); + /// ``` + pub fn addrs( + &self, + ) -> impl Iterator + + ExactSizeIterator + + iter::FusedIterator + + Clone + + fmt::Debug + + '_ { + match self.addr { + ConnectAddrs::None => ConnectAddrsIter::None, + ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), + ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), + } + } + + /// Take owned iterator resolved request addresses. + /// + /// # Examples + /// ``` + /// # use std::net::SocketAddr; + /// # use actix_tls::connect::ConnectInfo; + /// let addr = SocketAddr::from(([127, 0, 0, 1], 4242)); + /// + /// let mut conn = ConnectInfo::new("localhost"); + /// let mut addrs = conn.take_addrs(); + /// assert!(addrs.next().is_none()); + /// + /// let mut conn = ConnectInfo::with_addr("localhost", addr); + /// let mut addrs = conn.take_addrs(); + /// assert_eq!(addrs.next().unwrap(), addr); + /// ``` + pub fn take_addrs( + &mut self, + ) -> impl Iterator + + ExactSizeIterator + + iter::FusedIterator + + Clone + + fmt::Debug + + 'static { + match mem::take(&mut self.addr) { + ConnectAddrs::None => ConnectAddrsIter::None, + ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), + ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), + } + } +} + +impl From for ConnectInfo { + fn from(addr: R) -> Self { + ConnectInfo::new(addr) + } +} + +impl fmt::Display for ConnectInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{}", self.hostname(), self.port()) + } +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use super::*; + + #[test] + fn test_addr_iter_multi() { + let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); + let unspecified = SocketAddr::from((IpAddr::from(Ipv4Addr::UNSPECIFIED), 8080)); + + let mut addrs = VecDeque::new(); + addrs.push_back(localhost); + addrs.push_back(unspecified); + + let mut iter = ConnectAddrsIter::Multi(addrs.iter()); + assert_eq!(iter.next(), Some(localhost)); + assert_eq!(iter.next(), Some(unspecified)); + assert_eq!(iter.next(), None); + + let mut iter = ConnectAddrsIter::MultiOwned(addrs.into_iter()); + assert_eq!(iter.next(), Some(localhost)); + assert_eq!(iter.next(), Some(unspecified)); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_addr_iter_single() { + let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); + + let mut iter = ConnectAddrsIter::One(localhost); + assert_eq!(iter.next(), Some(localhost)); + assert_eq!(iter.next(), None); + + let mut iter = ConnectAddrsIter::None; + assert_eq!(iter.next(), None); + } + + #[test] + fn test_local_addr() { + let conn = ConnectInfo::new("hello").set_local_addr([127, 0, 0, 1]); + assert_eq!( + conn.local_addr.unwrap(), + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) + ) + } + + #[test] + fn request_ref() { + let conn = ConnectInfo::new("hello"); + assert_eq!(conn.request(), &"hello") + } + + #[test] + fn set_connect_addr_into_option() { + let addr = SocketAddr::from(([127, 0, 0, 1], 4242)); + + let conn = ConnectInfo::new("hello").set_addr(None); + let mut addrs = conn.addrs(); + assert!(addrs.next().is_none()); + + let conn = ConnectInfo::new("hello").set_addr(addr); + let mut addrs = conn.addrs(); + assert_eq!(addrs.next().unwrap(), addr); + + let conn = ConnectInfo::new("hello").set_addr(Some(addr)); + let mut addrs = conn.addrs(); + assert_eq!(addrs.next().unwrap(), addr); + } +} diff --git a/actix-tls/src/connect/mod.rs b/actix-tls/src/connect/mod.rs index c2aee9c7..3511dd58 100644 --- a/actix-tls/src/connect/mod.rs +++ b/actix-tls/src/connect/mod.rs @@ -1,35 +1,46 @@ //! TCP and TLS connector services. //! //! # Stages of the TCP connector service: -//! - Resolve [`Address`] with given [`Resolver`] and collect list of socket addresses. -//! - Establish TCP connection and return [`TcpStream`]. +//! 1. Resolve [`Host`] (if needed) with given [`Resolver`] and collect list of socket addresses. +//! 1. Establish TCP connection and return [`TcpStream`]. //! //! # Stages of TLS connector services: -//! - Establish [`TcpStream`] with connector service. -//! - Wrap the stream and perform connect handshake with remote peer. -//! - Return certain stream type that impls `AsyncRead` and `AsyncWrite`. +//! 1. Resolve DNS and establish a [`TcpStream`] with the TCP connector service. +//! 1. Wrap the stream and perform connect handshake with remote peer. +//! 1. Return wrapped stream type that implements `AsyncRead` and `AsyncWrite`. //! //! [`TcpStream`]: actix_rt::net::TcpStream -#[allow(clippy::module_inception)] -mod connect; +mod connect_addrs; +mod connection; mod connector; mod error; +mod host; +mod info; mod resolve; -mod service; -pub mod tls; -// TODO: remove `ssl` mod re-export in next break change -#[doc(hidden)] -pub use tls as ssl; -mod tcp; +mod resolver; +pub mod tcp; + #[cfg(feature = "uri")] +#[cfg_attr(docsrs, doc(cfg(feature = "uri")))] mod uri; -pub use self::connect::{Address, Connect, Connection}; -pub use self::connector::{TcpConnector, TcpConnectorFactory}; +#[cfg(feature = "openssl")] +#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] +pub mod openssl; + +#[cfg(feature = "rustls")] +#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] +pub mod rustls; + +#[cfg(feature = "native-tls")] +#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] +pub mod native_tls; + +pub use self::connection::Connection; +pub use self::connector::{Connector, ConnectorService}; pub use self::error::ConnectError; -pub use self::resolve::{Resolve, Resolver, ResolverFactory}; -pub use self::service::{ConnectService, ConnectServiceFactory}; -pub use self::tcp::{ - default_connector, default_connector_factory, new_connector, new_connector_factory, -}; +pub use self::host::Host; +pub use self::info::ConnectInfo; +pub use self::resolve::Resolve; +pub use self::resolver::{Resolver, ResolverService}; diff --git a/actix-tls/src/connect/native_tls.rs b/actix-tls/src/connect/native_tls.rs new file mode 100644 index 00000000..eba89cb2 --- /dev/null +++ b/actix-tls/src/connect/native_tls.rs @@ -0,0 +1,90 @@ +//! Native-TLS based connector service. +//! +//! See [`TlsConnector`] for main connector service factory docs. + +use std::io; + +use actix_rt::net::ActixStream; +use actix_service::{Service, ServiceFactory}; +use actix_utils::future::{ok, Ready}; +use futures_core::future::LocalBoxFuture; +use log::trace; +use tokio_native_tls::{ + native_tls::TlsConnector as NativeTlsConnector, TlsConnector as TokioNativeTlsConnector, + TlsStream, +}; + +use crate::connect::{Connection, Host}; + +pub mod reexports { + //! Re-exports from `native-tls` that are useful for connectors. + + pub use tokio_native_tls::native_tls::TlsConnector; +} + +/// Connector service and factory using `native-tls`. +#[derive(Clone)] +pub struct TlsConnector { + connector: TokioNativeTlsConnector, +} + +impl TlsConnector { + /// Constructs new connector service from a `native-tls` connector. + /// + /// This type is it's own service factory, so it can be used in that setting, too. + pub fn new(connector: NativeTlsConnector) -> Self { + Self { + connector: TokioNativeTlsConnector::from(connector), + } + } +} + +impl ServiceFactory> for TlsConnector +where + IO: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Config = (); + type Service = Self; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(self.clone()) + } +} + +/// The `native-tls` connector is both it's ServiceFactory and Service impl type. +/// As the factory and service share the same type and state. +impl Service> for TlsConnector +where + R: Host, + IO: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::always_ready!(); + + fn call(&self, stream: Connection) -> Self::Future { + let (io, stream) = stream.replace_io(()); + let connector = self.connector.clone(); + + Box::pin(async move { + trace!("SSL Handshake start for: {:?}", stream.hostname()); + connector + .connect(stream.hostname(), io) + .await + .map(|res| { + trace!("SSL Handshake success: {:?}", stream.hostname()); + stream.replace_io(res).1 + }) + .map_err(|e| { + trace!("SSL Handshake error: {:?}", e); + io::Error::new(io::ErrorKind::Other, format!("{}", e)) + }) + }) + } +} diff --git a/actix-tls/src/connect/openssl.rs b/actix-tls/src/connect/openssl.rs new file mode 100644 index 00000000..3db37284 --- /dev/null +++ b/actix-tls/src/connect/openssl.rs @@ -0,0 +1,146 @@ +//! OpenSSL based connector service. +//! +//! See [`TlsConnector`] for main connector service factory docs. + +use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_rt::net::ActixStream; +use actix_service::{Service, ServiceFactory}; +use actix_utils::future::{ok, Ready}; +use futures_core::ready; +use log::trace; +use openssl::ssl::SslConnector; +use tokio_openssl::SslStream; + +use crate::connect::{Connection, Host}; + +pub mod reexports { + //! Re-exports from `openssl` that are useful for connectors. + + pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; +} + +/// Connector service factory using `openssl`. +pub struct TlsConnector { + connector: SslConnector, +} + +impl TlsConnector { + /// Constructs new connector service factory from an `openssl` connector. + pub fn new(connector: SslConnector) -> Self { + TlsConnector { connector } + } + + /// Constructs new connector service from an `openssl` connector. + pub fn service(connector: SslConnector) -> TlsConnectorService { + TlsConnectorService { connector } + } +} + +impl Clone for TlsConnector { + fn clone(&self) -> Self { + Self { + connector: self.connector.clone(), + } + } +} + +impl ServiceFactory> for TlsConnector +where + R: Host, + IO: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Config = (); + type Service = TlsConnectorService; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(TlsConnectorService { + connector: self.connector.clone(), + }) + } +} + +/// Connector service using `openssl`. +pub struct TlsConnectorService { + connector: SslConnector, +} + +impl Clone for TlsConnectorService { + fn clone(&self) -> Self { + Self { + connector: self.connector.clone(), + } + } +} + +impl Service> for TlsConnectorService +where + R: Host, + IO: ActixStream, +{ + type Response = Connection>; + type Error = io::Error; + type Future = ConnectFut; + + actix_service::always_ready!(); + + fn call(&self, stream: Connection) -> Self::Future { + trace!("SSL Handshake start for: {:?}", stream.hostname()); + let (io, stream) = stream.replace_io(()); + let host = stream.hostname(); + + let config = self + .connector + .configure() + .expect("SSL connect configuration was invalid."); + + let ssl = config + .into_ssl(host) + .expect("SSL connect configuration was invalid."); + + ConnectFut { + io: Some(SslStream::new(ssl, io).unwrap()), + stream: Some(stream), + } + } +} + +/// Connect future for OpenSSL service. +#[doc(hidden)] +pub struct ConnectFut { + io: Option>, + stream: Option>, +} + +impl Future for ConnectFut +where + R: Host, + IO: ActixStream, +{ + type Output = Result>, io::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match ready!(Pin::new(this.io.as_mut().unwrap()).poll_connect(cx)) { + Ok(_) => { + let stream = this.stream.take().unwrap(); + trace!("SSL Handshake success: {:?}", stream.hostname()); + Poll::Ready(Ok(stream.replace_io(this.io.take().unwrap()).1)) + } + Err(e) => { + trace!("SSL Handshake error: {:?}", e); + Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) + } + } + } +} diff --git a/actix-tls/src/connect/resolve.rs b/actix-tls/src/connect/resolve.rs old mode 100755 new mode 100644 index 335e69d8..33e2c676 --- a/actix-tls/src/connect/resolve.rs +++ b/actix-tls/src/connect/resolve.rs @@ -1,61 +1,12 @@ -use std::{ - future::Future, - io, - net::SocketAddr, - pin::Pin, - rc::Rc, - task::{Context, Poll}, - vec::IntoIter, -}; +//! The [`Resolve`] trait. -use actix_rt::task::{spawn_blocking, JoinHandle}; -use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; -use log::trace; +use std::{error::Error as StdError, net::SocketAddr}; -use super::connect::{Address, Connect}; -use super::error::ConnectError; +use futures_core::future::LocalBoxFuture; -/// DNS Resolver Service Factory -#[derive(Clone)] -pub struct ResolverFactory { - resolver: Resolver, -} - -impl ResolverFactory { - pub fn new(resolver: Resolver) -> Self { - Self { resolver } - } - - pub fn service(&self) -> Resolver { - self.resolver.clone() - } -} - -impl ServiceFactory> for ResolverFactory { - type Response = Connect; - type Error = ConnectError; - type Config = (); - type Service = Resolver; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let service = self.resolver.clone(); - Box::pin(async { Ok(service) }) - } -} - -/// DNS Resolver Service -#[derive(Clone)] -pub enum Resolver { - Default, - Custom(Rc), -} - -/// An interface for custom async DNS resolvers. +/// Custom async DNS resolvers. /// -/// # Usage +/// # Examples /// ``` /// use std::net::SocketAddr; /// @@ -89,155 +40,23 @@ pub enum Resolver { /// } /// } /// -/// let resolver = MyResolver { +/// let my_resolver = MyResolver { /// trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), /// }; /// -/// // construct custom resolver -/// let resolver = Resolver::new_custom(resolver); -/// -/// // pass custom resolver to connector builder. -/// // connector would then be usable as a service or awc's connector. -/// let connector = actix_tls::connect::new_connector::<&str>(resolver.clone()); +/// // wrap custom resolver +/// let resolver = Resolver::custom(my_resolver); /// /// // resolver can be passed to connector factory where returned service factory -/// // can be used to construct new connector services. -/// let factory = actix_tls::connect::new_connector_factory::<&str>(resolver); +/// // can be used to construct new connector services for use in clients +/// let factory = actix_tls::connect::Connector::new(resolver); +/// let connector = factory.service(); /// ``` pub trait Resolve { + /// Given DNS lookup information, returns a future that completes with socket information. fn lookup<'a>( &'a self, host: &'a str, port: u16, - ) -> LocalBoxFuture<'a, Result, Box>>; -} - -impl Resolver { - /// Constructor for custom Resolve trait object and use it as resolver. - pub fn new_custom(resolver: impl Resolve + 'static) -> Self { - Self::Custom(Rc::new(resolver)) - } - - // look up with default resolver variant. - fn look_up(req: &Connect) -> JoinHandle>> { - let host = req.hostname(); - // TODO: Connect should always return host with port if possible. - let host = if req - .hostname() - .splitn(2, ':') - .last() - .and_then(|p| p.parse::().ok()) - .map(|p| p == req.port()) - .unwrap_or(false) - { - host.to_string() - } else { - format!("{}:{}", host, req.port()) - }; - - // run blocking DNS lookup in thread pool - spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host)) - } -} - -impl Service> for Resolver { - type Response = Connect; - type Error = ConnectError; - type Future = ResolverFuture; - - actix_service::always_ready!(); - - fn call(&self, req: Connect) -> Self::Future { - if req.addr.is_some() { - ResolverFuture::Connected(Some(req)) - } else if let Ok(ip) = req.hostname().parse() { - let addr = SocketAddr::new(ip, req.port()); - let req = req.set_addr(Some(addr)); - ResolverFuture::Connected(Some(req)) - } else { - trace!("DNS resolver: resolving host {:?}", req.hostname()); - - match self { - Self::Default => { - let fut = Self::look_up(&req); - ResolverFuture::LookUp(fut, Some(req)) - } - - Self::Custom(resolver) => { - let resolver = Rc::clone(resolver); - ResolverFuture::LookupCustom(Box::pin(async move { - let addrs = resolver - .lookup(req.hostname(), req.port()) - .await - .map_err(ConnectError::Resolver)?; - - let req = req.set_addrs(addrs); - - if req.addr.is_none() { - Err(ConnectError::NoRecords) - } else { - Ok(req) - } - })) - } - } - } - } -} - -pub enum ResolverFuture { - Connected(Option>), - LookUp( - JoinHandle>>, - Option>, - ), - LookupCustom(LocalBoxFuture<'static, Result, ConnectError>>), -} - -impl Future for ResolverFuture { - type Output = Result, ConnectError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.get_mut() { - Self::Connected(conn) => Poll::Ready(Ok(conn - .take() - .expect("ResolverFuture polled after finished"))), - - Self::LookUp(fut, req) => { - let res = match ready!(Pin::new(fut).poll(cx)) { - Ok(Ok(res)) => Ok(res), - Ok(Err(e)) => Err(ConnectError::Resolver(Box::new(e))), - Err(e) => Err(ConnectError::Io(e.into())), - }; - - let req = req.take().unwrap(); - - let addrs = res.map_err(|err| { - trace!( - "DNS resolver: failed to resolve host {:?} err: {:?}", - req.hostname(), - err - ); - - err - })?; - - let req = req.set_addrs(addrs); - - trace!( - "DNS resolver: host {:?} resolved to {:?}", - req.hostname(), - req.addrs() - ); - - if req.addr.is_none() { - Poll::Ready(Err(ConnectError::NoRecords)) - } else { - Poll::Ready(Ok(req)) - } - } - - Self::LookupCustom(fut) => fut.as_mut().poll(cx), - } - } + ) -> LocalBoxFuture<'a, Result, Box>>; } diff --git a/actix-tls/src/connect/resolver.rs b/actix-tls/src/connect/resolver.rs new file mode 100644 index 00000000..8e700deb --- /dev/null +++ b/actix-tls/src/connect/resolver.rs @@ -0,0 +1,201 @@ +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, + vec::IntoIter, +}; + +use actix_rt::task::{spawn_blocking, JoinHandle}; +use actix_service::{Service, ServiceFactory}; +use actix_utils::future::{ok, Ready}; +use futures_core::{future::LocalBoxFuture, ready}; +use log::trace; + +use super::{ConnectError, ConnectInfo, Host, Resolve}; + +/// DNS resolver service factory. +#[derive(Clone, Default)] +pub struct Resolver { + resolver: ResolverService, +} + +impl Resolver { + /// Constructs a new resolver factory with a custom resolver. + pub fn custom(resolver: impl Resolve + 'static) -> Self { + Self { + resolver: ResolverService::custom(resolver), + } + } + + /// Returns a new resolver service. + pub fn service(&self) -> ResolverService { + self.resolver.clone() + } +} + +impl ServiceFactory> for Resolver { + type Response = ConnectInfo; + type Error = ConnectError; + type Config = (); + type Service = ResolverService; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(self.resolver.clone()) + } +} + +#[derive(Clone)] +enum ResolverKind { + /// Built-in DNS resolver. + /// + /// See [`std::net::ToSocketAddrs`] trait. + Default, + + /// Custom, user-provided DNS resolver. + Custom(Rc), +} + +impl Default for ResolverKind { + fn default() -> Self { + Self::Default + } +} + +/// DNS resolver service. +#[derive(Clone, Default)] +pub struct ResolverService { + kind: ResolverKind, +} + +impl ResolverService { + /// Constructor for custom Resolve trait object and use it as resolver. + pub fn custom(resolver: impl Resolve + 'static) -> Self { + Self { + kind: ResolverKind::Custom(Rc::new(resolver)), + } + } + + /// Resolve DNS with default resolver. + fn default_lookup( + req: &ConnectInfo, + ) -> JoinHandle>> { + // reconstruct host; concatenate hostname and port together + let host = format!("{}:{}", req.hostname(), req.port()); + + // run blocking DNS lookup in thread pool since DNS lookups can take upwards of seconds on + // some platforms if conditions are poor and OS-level cache is not populated + spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host)) + } +} + +impl Service> for ResolverService { + type Response = ConnectInfo; + type Error = ConnectError; + type Future = ResolverFut; + + actix_service::always_ready!(); + + fn call(&self, req: ConnectInfo) -> Self::Future { + if req.addr.is_resolved() { + // socket address(es) already resolved; return existing connection request + ResolverFut::Resolved(Some(req)) + } else if let Ok(ip) = req.hostname().parse() { + // request hostname is valid ip address; add address to request and return + let addr = SocketAddr::new(ip, req.port()); + let req = req.set_addr(Some(addr)); + ResolverFut::Resolved(Some(req)) + } else { + trace!("DNS resolver: resolving host {:?}", req.hostname()); + + match &self.kind { + ResolverKind::Default => { + let fut = Self::default_lookup(&req); + ResolverFut::LookUp(fut, Some(req)) + } + + ResolverKind::Custom(resolver) => { + let resolver = Rc::clone(resolver); + + ResolverFut::LookupCustom(Box::pin(async move { + let addrs = resolver + .lookup(req.hostname(), req.port()) + .await + .map_err(ConnectError::Resolver)?; + + let req = req.set_addrs(addrs); + + if req.addr.is_unresolved() { + Err(ConnectError::NoRecords) + } else { + Ok(req) + } + })) + } + } + } + } +} + +/// Future for resolver service. +#[doc(hidden)] +pub enum ResolverFut { + Resolved(Option>), + LookUp( + JoinHandle>>, + Option>, + ), + LookupCustom(LocalBoxFuture<'static, Result, ConnectError>>), +} + +impl Future for ResolverFut { + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + Self::Resolved(conn) => Poll::Ready(Ok(conn + .take() + .expect("ResolverFuture polled after finished"))), + + Self::LookUp(fut, req) => { + let res = match ready!(Pin::new(fut).poll(cx)) { + Ok(Ok(res)) => Ok(res), + Ok(Err(e)) => Err(ConnectError::Resolver(Box::new(e))), + Err(e) => Err(ConnectError::Io(e.into())), + }; + + let req = req.take().unwrap(); + + let addrs = res.map_err(|err| { + trace!( + "DNS resolver: failed to resolve host {:?} err: {:?}", + req.hostname(), + err + ); + + err + })?; + + let req = req.set_addrs(addrs); + + trace!( + "DNS resolver: host {:?} resolved to {:?}", + req.hostname(), + req.addrs() + ); + + if req.addr.is_unresolved() { + Poll::Ready(Err(ConnectError::NoRecords)) + } else { + Poll::Ready(Ok(req)) + } + } + + Self::LookupCustom(fut) => fut.as_mut().poll(cx), + } + } +} diff --git a/actix-tls/src/connect/rustls.rs b/actix-tls/src/connect/rustls.rs new file mode 100644 index 00000000..a98ae04e --- /dev/null +++ b/actix-tls/src/connect/rustls.rs @@ -0,0 +1,148 @@ +//! Rustls based connector service. +//! +//! See [`TlsConnector`] for main connector service factory docs. + +use std::{ + convert::TryFrom, + future::Future, + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use actix_rt::net::ActixStream; +use actix_service::{Service, ServiceFactory}; +use actix_utils::future::{ok, Ready}; +use futures_core::ready; +use log::trace; +use tokio_rustls::rustls::{client::ServerName, OwnedTrustAnchor, RootCertStore}; +use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; +use tokio_rustls::{Connect as RustlsConnect, TlsConnector as RustlsTlsConnector}; +use webpki_roots::TLS_SERVER_ROOTS; + +use crate::connect::{Connection, Host}; + +pub mod reexports { + //! Re-exports from `rustls` and `webpki_roots` that are useful for connectors. + + pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; + + pub use webpki_roots::TLS_SERVER_ROOTS; +} + +/// Returns standard root certificates from `webpki-roots` crate as a rustls certificate store. +pub fn webpki_roots_cert_store() -> RootCertStore { + let mut root_certs = RootCertStore::empty(); + for cert in TLS_SERVER_ROOTS.0 { + let cert = OwnedTrustAnchor::from_subject_spki_name_constraints( + cert.subject, + cert.spki, + cert.name_constraints, + ); + let certs = vec![cert].into_iter(); + root_certs.add_server_trust_anchors(certs); + } + root_certs +} + +/// Connector service factory using `rustls`. +#[derive(Clone)] +pub struct TlsConnector { + connector: Arc, +} + +impl TlsConnector { + /// Constructs new connector service factory from a `rustls` client configuration. + pub fn new(connector: Arc) -> Self { + TlsConnector { connector } + } + + /// Constructs new connector service from a `rustls` client configuration. + pub fn service(connector: Arc) -> TlsConnectorService { + TlsConnectorService { connector } + } +} + +impl ServiceFactory> for TlsConnector +where + R: Host, + IO: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Config = (); + type Service = TlsConnectorService; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(TlsConnectorService { + connector: self.connector.clone(), + }) + } +} + +/// Connector service using `rustls`. +#[derive(Clone)] +pub struct TlsConnectorService { + connector: Arc, +} + +impl Service> for TlsConnectorService +where + R: Host, + IO: ActixStream, +{ + type Response = Connection>; + type Error = io::Error; + type Future = ConnectFut; + + actix_service::always_ready!(); + + fn call(&self, connection: Connection) -> Self::Future { + trace!("SSL Handshake start for: {:?}", connection.hostname()); + let (stream, connection) = connection.replace_io(()); + + match ServerName::try_from(connection.hostname()) { + Ok(host) => ConnectFut::Future { + connect: RustlsTlsConnector::from(self.connector.clone()).connect(host, stream), + connection: Some(connection), + }, + Err(_) => ConnectFut::InvalidDns, + } + } +} + +/// Connect future for Rustls service. +#[doc(hidden)] +pub enum ConnectFut { + /// See issue + InvalidDns, + Future { + connect: RustlsConnect, + connection: Option>, + }, +} + +impl Future for ConnectFut +where + R: Host, + IO: ActixStream, +{ + type Output = Result>, io::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + Self::InvalidDns => Poll::Ready(Err( + io::Error::new(io::ErrorKind::Other, "rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54") + )), + Self::Future { connect, connection } => { + let stream = ready!(Pin::new(connect).poll(cx))?; + let connection = connection.take().unwrap(); + trace!("SSL Handshake success: {:?}", connection.hostname()); + Poll::Ready(Ok(connection.replace_io(stream).1)) + } + } + } +} diff --git a/actix-tls/src/connect/service.rs b/actix-tls/src/connect/service.rs deleted file mode 100755 index 9961498e..00000000 --- a/actix-tls/src/connect/service.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use actix_rt::net::TcpStream; -use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; - -use super::connect::{Address, Connect, Connection}; -use super::connector::{TcpConnector, TcpConnectorFactory}; -use super::error::ConnectError; -use super::resolve::{Resolver, ResolverFactory}; - -pub struct ConnectServiceFactory { - tcp: TcpConnectorFactory, - resolver: ResolverFactory, -} - -impl ConnectServiceFactory { - /// Construct new ConnectService factory - pub fn new(resolver: Resolver) -> Self { - ConnectServiceFactory { - tcp: TcpConnectorFactory, - resolver: ResolverFactory::new(resolver), - } - } - - /// Construct new service - pub fn service(&self) -> ConnectService { - ConnectService { - tcp: self.tcp.service(), - resolver: self.resolver.service(), - } - } -} - -impl Clone for ConnectServiceFactory { - fn clone(&self) -> Self { - ConnectServiceFactory { - tcp: self.tcp, - resolver: self.resolver.clone(), - } - } -} - -impl ServiceFactory> for ConnectServiceFactory { - type Response = Connection; - type Error = ConnectError; - type Config = (); - type Service = ConnectService; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let service = self.service(); - Box::pin(async { Ok(service) }) - } -} - -#[derive(Clone)] -pub struct ConnectService { - tcp: TcpConnector, - resolver: Resolver, -} - -impl Service> for ConnectService { - type Response = Connection; - type Error = ConnectError; - type Future = ConnectServiceResponse; - - actix_service::always_ready!(); - - fn call(&self, req: Connect) -> Self::Future { - ConnectServiceResponse { - fut: ConnectFuture::Resolve(self.resolver.call(req)), - tcp: self.tcp, - } - } -} - -// helper enum to generic over futures of resolve and connect phase. -pub(crate) enum ConnectFuture { - Resolve(>>::Future), - Connect(>>::Future), -} - -// helper enum to contain the future output of ConnectFuture -pub(crate) enum ConnectOutput { - Resolved(Connect), - Connected(Connection), -} - -impl ConnectFuture { - fn poll_connect( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, ConnectError>> { - match self { - ConnectFuture::Resolve(ref mut fut) => { - Pin::new(fut).poll(cx).map_ok(ConnectOutput::Resolved) - } - ConnectFuture::Connect(ref mut fut) => { - Pin::new(fut).poll(cx).map_ok(ConnectOutput::Connected) - } - } - } -} - -pub struct ConnectServiceResponse { - fut: ConnectFuture, - tcp: TcpConnector, -} - -impl Future for ConnectServiceResponse { - type Output = Result, ConnectError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - match ready!(self.fut.poll_connect(cx))? { - ConnectOutput::Resolved(res) => { - self.fut = ConnectFuture::Connect(self.tcp.call(res)); - } - ConnectOutput::Connected(res) => return Poll::Ready(Ok(res)), - } - } - } -} diff --git a/actix-tls/src/connect/tcp.rs b/actix-tls/src/connect/tcp.rs index 57059c99..8f566da7 100644 --- a/actix-tls/src/connect/tcp.rs +++ b/actix-tls/src/connect/tcp.rs @@ -1,43 +1,202 @@ -use actix_rt::net::TcpStream; +//! TCP connector service. +//! +//! See [`TcpConnector`] for main connector service factory docs. + +use std::{ + collections::VecDeque, + future::Future, + io, + net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_rt::net::{TcpSocket, TcpStream}; use actix_service::{Service, ServiceFactory}; +use actix_utils::future::{ok, Ready}; +use futures_core::ready; +use log::{error, trace}; +use tokio_util::sync::ReusableBoxFuture; -use super::{Address, Connect, ConnectError, ConnectServiceFactory, Connection, Resolver}; +use super::{connect_addrs::ConnectAddrs, error::ConnectError, ConnectInfo, Connection, Host}; -/// Create TCP connector service. -pub fn new_connector( - resolver: Resolver, -) -> impl Service, Response = Connection, Error = ConnectError> + Clone -{ - ConnectServiceFactory::new(resolver).service() +/// TCP connector service factory. +#[derive(Debug, Copy, Clone)] +pub struct TcpConnector; + +impl TcpConnector { + /// Returns a new TCP connector service. + pub fn service(&self) -> TcpConnectorService { + TcpConnectorService + } } -/// Create TCP connector service factory. -pub fn new_connector_factory( - resolver: Resolver, -) -> impl ServiceFactory< - Connect, - Config = (), - Response = Connection, - Error = ConnectError, - InitError = (), -> + Clone { - ConnectServiceFactory::new(resolver) +impl ServiceFactory> for TcpConnector { + type Response = Connection; + type Error = ConnectError; + type Config = (); + type Service = TcpConnectorService; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(self.service()) + } } -/// Create TCP connector service with default parameters. -pub fn default_connector( -) -> impl Service, Response = Connection, Error = ConnectError> + Clone -{ - new_connector(Resolver::Default) +/// TCP connector service. +#[derive(Debug, Copy, Clone)] +pub struct TcpConnectorService; + +impl Service> for TcpConnectorService { + type Response = Connection; + type Error = ConnectError; + type Future = TcpConnectorFut; + + actix_service::always_ready!(); + + fn call(&self, req: ConnectInfo) -> Self::Future { + let port = req.port(); + + let ConnectInfo { + request: req, + addr, + local_addr, + .. + } = req; + + TcpConnectorFut::new(req, port, local_addr, addr) + } } -/// Create TCP connector service factory with default parameters. -pub fn default_connector_factory() -> impl ServiceFactory< - Connect, - Config = (), - Response = Connection, - Error = ConnectError, - InitError = (), -> + Clone { - new_connector_factory(Resolver::Default) +/// Connect future for TCP service. +#[doc(hidden)] +pub enum TcpConnectorFut { + Response { + req: Option, + port: u16, + local_addr: Option, + addrs: Option>, + stream: ReusableBoxFuture>, + }, + + Error(Option), +} + +impl TcpConnectorFut { + pub(crate) fn new( + req: R, + port: u16, + local_addr: Option, + addr: ConnectAddrs, + ) -> TcpConnectorFut { + if addr.is_unresolved() { + error!("TCP connector: unresolved connection address"); + return TcpConnectorFut::Error(Some(ConnectError::Unresolved)); + } + + trace!( + "TCP connector: connecting to {} on port {}", + req.hostname(), + port + ); + + match addr { + ConnectAddrs::None => unreachable!("none variant already checked"), + + ConnectAddrs::One(addr) => TcpConnectorFut::Response { + req: Some(req), + port, + local_addr, + addrs: None, + stream: ReusableBoxFuture::new(connect(addr, local_addr)), + }, + + // when resolver returns multiple socket addr for request they would be popped from + // front end of queue and returns with the first successful tcp connection. + ConnectAddrs::Multi(mut addrs) => { + let addr = addrs.pop_front().unwrap(); + + TcpConnectorFut::Response { + req: Some(req), + port, + local_addr, + addrs: Some(addrs), + stream: ReusableBoxFuture::new(connect(addr, local_addr)), + } + } + } + } +} + +impl Future for TcpConnectorFut { + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + TcpConnectorFut::Error(err) => Poll::Ready(Err(err.take().unwrap())), + + TcpConnectorFut::Response { + req, + port, + local_addr, + addrs, + stream, + } => loop { + match ready!(stream.poll(cx)) { + Ok(sock) => { + let req = req.take().unwrap(); + + trace!( + "TCP connector: successfully connected to {:?} - {:?}", + req.hostname(), + sock.peer_addr() + ); + + return Poll::Ready(Ok(Connection::new(req, sock))); + } + + Err(err) => { + trace!( + "TCP connector: failed to connect to {:?} port: {}", + req.as_ref().unwrap().hostname(), + port, + ); + + if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) { + stream.set(connect(addr, *local_addr)); + } else { + return Poll::Ready(Err(ConnectError::Io(err))); + } + } + } + }, + } + } +} + +async fn connect(addr: SocketAddr, local_addr: Option) -> io::Result { + // use local addr if connect asks for it + match local_addr { + Some(ip_addr) => { + let socket = match ip_addr { + IpAddr::V4(ip_addr) => { + let socket = TcpSocket::new_v4()?; + let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0)); + socket.bind(addr)?; + socket + } + IpAddr::V6(ip_addr) => { + let socket = TcpSocket::new_v6()?; + let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0)); + socket.bind(addr)?; + socket + } + }; + + socket.connect(addr).await + } + + None => TcpStream::connect(addr).await, + } } diff --git a/actix-tls/src/connect/tls/mod.rs b/actix-tls/src/connect/tls/mod.rs deleted file mode 100644 index 7f48d06c..00000000 --- a/actix-tls/src/connect/tls/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! TLS Services - -#[cfg(feature = "openssl")] -pub mod openssl; - -#[cfg(feature = "rustls")] -pub mod rustls; - -#[cfg(feature = "native-tls")] -pub mod native_tls; diff --git a/actix-tls/src/connect/tls/native_tls.rs b/actix-tls/src/connect/tls/native_tls.rs deleted file mode 100644 index de08ea2a..00000000 --- a/actix-tls/src/connect/tls/native_tls.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::io; - -use actix_rt::net::ActixStream; -use actix_service::{Service, ServiceFactory}; -use futures_core::future::LocalBoxFuture; -use log::trace; -use tokio_native_tls::{TlsConnector as TokioNativetlsConnector, TlsStream}; - -pub use tokio_native_tls::native_tls::TlsConnector; - -use crate::connect::{Address, Connection}; - -/// Native-tls connector factory and service -pub struct NativetlsConnector { - connector: TokioNativetlsConnector, -} - -impl NativetlsConnector { - pub fn new(connector: TlsConnector) -> Self { - Self { - connector: TokioNativetlsConnector::from(connector), - } - } -} - -impl NativetlsConnector { - pub fn service(connector: TlsConnector) -> Self { - Self::new(connector) - } -} - -impl Clone for NativetlsConnector { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl ServiceFactory> for NativetlsConnector -where - U: ActixStream + 'static, -{ - type Response = Connection>; - type Error = io::Error; - type Config = (); - type Service = Self; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let connector = self.clone(); - Box::pin(async { Ok(connector) }) - } -} - -// NativetlsConnector is both it's ServiceFactory and Service impl type. -// As the factory and service share the same type and state. -impl Service> for NativetlsConnector -where - T: Address, - U: ActixStream + 'static, -{ - type Response = Connection>; - type Error = io::Error; - type Future = LocalBoxFuture<'static, Result>; - - actix_service::always_ready!(); - - fn call(&self, stream: Connection) -> Self::Future { - let (io, stream) = stream.replace_io(()); - let connector = self.connector.clone(); - Box::pin(async move { - trace!("SSL Handshake start for: {:?}", stream.host()); - connector - .connect(stream.host(), io) - .await - .map(|res| { - trace!("SSL Handshake success: {:?}", stream.host()); - stream.replace_io(res).1 - }) - .map_err(|e| { - trace!("SSL Handshake error: {:?}", e); - io::Error::new(io::ErrorKind::Other, format!("{}", e)) - }) - }) - } -} diff --git a/actix-tls/src/connect/tls/openssl.rs b/actix-tls/src/connect/tls/openssl.rs deleted file mode 100755 index b4298fed..00000000 --- a/actix-tls/src/connect/tls/openssl.rs +++ /dev/null @@ -1,130 +0,0 @@ -use std::{ - future::Future, - io, - pin::Pin, - task::{Context, Poll}, -}; - -use actix_rt::net::ActixStream; -use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; -use log::trace; - -pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; -pub use tokio_openssl::SslStream; - -use crate::connect::{Address, Connection}; - -/// OpenSSL connector factory -pub struct OpensslConnector { - connector: SslConnector, -} - -impl OpensslConnector { - pub fn new(connector: SslConnector) -> Self { - OpensslConnector { connector } - } - - pub fn service(connector: SslConnector) -> OpensslConnectorService { - OpensslConnectorService { connector } - } -} - -impl Clone for OpensslConnector { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl ServiceFactory> for OpensslConnector -where - T: Address, - U: ActixStream + 'static, -{ - type Response = Connection>; - type Error = io::Error; - type Config = (); - type Service = OpensslConnectorService; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let connector = self.connector.clone(); - Box::pin(async { Ok(OpensslConnectorService { connector }) }) - } -} - -pub struct OpensslConnectorService { - connector: SslConnector, -} - -impl Clone for OpensslConnectorService { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl Service> for OpensslConnectorService -where - T: Address, - U: ActixStream, -{ - type Response = Connection>; - type Error = io::Error; - type Future = ConnectAsyncExt; - - actix_service::always_ready!(); - - fn call(&self, stream: Connection) -> Self::Future { - trace!("SSL Handshake start for: {:?}", stream.host()); - let (io, stream) = stream.replace_io(()); - let host = stream.host(); - - let config = self - .connector - .configure() - .expect("SSL connect configuration was invalid."); - - let ssl = config - .into_ssl(host) - .expect("SSL connect configuration was invalid."); - - ConnectAsyncExt { - io: Some(SslStream::new(ssl, io).unwrap()), - stream: Some(stream), - } - } -} - -pub struct ConnectAsyncExt { - io: Option>, - stream: Option>, -} - -impl Future for ConnectAsyncExt -where - T: Address, - U: ActixStream, -{ - type Output = Result>, io::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - match ready!(Pin::new(this.io.as_mut().unwrap()).poll_connect(cx)) { - Ok(_) => { - let stream = this.stream.take().unwrap(); - trace!("SSL Handshake success: {:?}", stream.host()); - Poll::Ready(Ok(stream.replace_io(this.io.take().unwrap()).1)) - } - Err(e) => { - trace!("SSL Handshake error: {:?}", e); - Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) - } - } - } -} diff --git a/actix-tls/src/connect/tls/rustls.rs b/actix-tls/src/connect/tls/rustls.rs deleted file mode 100755 index 139aadbe..00000000 --- a/actix-tls/src/connect/tls/rustls.rs +++ /dev/null @@ -1,146 +0,0 @@ -use std::{ - convert::TryFrom, - future::Future, - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; -pub use webpki_roots::TLS_SERVER_ROOTS; - -use actix_rt::net::ActixStream; -use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; -use log::trace; -use tokio_rustls::rustls::{client::ServerName, OwnedTrustAnchor, RootCertStore}; -use tokio_rustls::{Connect, TlsConnector}; - -use crate::connect::{Address, Connection}; - -/// Returns standard root certificates from `webpki-roots` crate as a rustls certificate store. -pub fn webpki_roots_cert_store() -> RootCertStore { - let mut root_certs = RootCertStore::empty(); - for cert in TLS_SERVER_ROOTS.0 { - let cert = OwnedTrustAnchor::from_subject_spki_name_constraints( - cert.subject, - cert.spki, - cert.name_constraints, - ); - let certs = vec![cert].into_iter(); - root_certs.add_server_trust_anchors(certs); - } - root_certs -} - -/// Rustls connector factory -pub struct RustlsConnector { - connector: Arc, -} - -impl RustlsConnector { - pub fn new(connector: Arc) -> Self { - RustlsConnector { connector } - } -} - -impl RustlsConnector { - pub fn service(connector: Arc) -> RustlsConnectorService { - RustlsConnectorService { connector } - } -} - -impl Clone for RustlsConnector { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl ServiceFactory> for RustlsConnector -where - T: Address, - U: ActixStream + 'static, -{ - type Response = Connection>; - type Error = io::Error; - type Config = (); - type Service = RustlsConnectorService; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let connector = self.connector.clone(); - Box::pin(async { Ok(RustlsConnectorService { connector }) }) - } -} - -pub struct RustlsConnectorService { - connector: Arc, -} - -impl Clone for RustlsConnectorService { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl Service> for RustlsConnectorService -where - T: Address, - U: ActixStream, -{ - type Response = Connection>; - type Error = io::Error; - type Future = RustlsConnectorServiceFuture; - - actix_service::always_ready!(); - - fn call(&self, connection: Connection) -> Self::Future { - trace!("SSL Handshake start for: {:?}", connection.host()); - let (stream, connection) = connection.replace_io(()); - - match ServerName::try_from(connection.host()) { - Ok(host) => RustlsConnectorServiceFuture::Future { - connect: TlsConnector::from(self.connector.clone()).connect(host, stream), - connection: Some(connection), - }, - Err(_) => RustlsConnectorServiceFuture::InvalidDns, - } - } -} - -pub enum RustlsConnectorServiceFuture { - /// See issue - InvalidDns, - Future { - connect: Connect, - connection: Option>, - }, -} - -impl Future for RustlsConnectorServiceFuture -where - T: Address, - U: ActixStream, -{ - type Output = Result>, io::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.get_mut() { - Self::InvalidDns => Poll::Ready(Err( - io::Error::new(io::ErrorKind::Other, "rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54") - )), - Self::Future { connect, connection } => { - let stream = ready!(Pin::new(connect).poll(cx))?; - let connection = connection.take().unwrap(); - trace!("SSL Handshake success: {:?}", connection.host()); - Poll::Ready(Ok(connection.replace_io(stream).1)) - } - } - } -} diff --git a/actix-tls/src/connect/uri.rs b/actix-tls/src/connect/uri.rs index 2d54b618..b1c7f0fe 100644 --- a/actix-tls/src/connect/uri.rs +++ b/actix-tls/src/connect/uri.rs @@ -1,8 +1,8 @@ use http::Uri; -use super::Address; +use super::Host; -impl Address for Uri { +impl Host for Uri { fn hostname(&self) -> &str { self.host().unwrap_or("") } @@ -35,9 +35,18 @@ fn scheme_to_port(scheme: Option<&str>) -> Option { Some("mqtts") => Some(8883), // File Transfer Protocol (FTP) - Some("ftp") => Some(1883), + Some("ftp") => Some(21), Some("ftps") => Some(990), + // Redis + Some("redis") => Some(6379), + + // MySQL + Some("mysql") => Some(3306), + + // PostgreSQL + Some("postgres") => Some(5432), + _ => None, } } diff --git a/actix-tls/src/lib.rs b/actix-tls/src/lib.rs index dbda8834..68ca0e35 100644 --- a/actix-tls/src/lib.rs +++ b/actix-tls/src/lib.rs @@ -1,14 +1,19 @@ -//! TLS acceptor and connector services for Actix ecosystem +//! TLS acceptor and connector services for the Actix ecosystem. #![deny(rust_2018_idioms, nonstandard_style)] +#![warn(missing_docs)] #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] +#![cfg_attr(docsrs, feature(doc_cfg))] #[cfg(feature = "openssl")] #[allow(unused_extern_crates)] extern crate tls_openssl as openssl; #[cfg(feature = "accept")] +#[cfg_attr(docsrs, doc(cfg(feature = "accept")))] pub mod accept; + #[cfg(feature = "connect")] +#[cfg_attr(docsrs, doc(cfg(feature = "connect")))] pub mod connect; diff --git a/actix-tls/tests/accept-rustls.rs b/actix-tls/tests/accept-rustls.rs index a083ebba..2c922a68 100644 --- a/actix-tls/tests/accept-rustls.rs +++ b/actix-tls/tests/accept-rustls.rs @@ -7,13 +7,15 @@ feature = "openssl" ))] +extern crate tls_openssl as openssl; + use std::io::{BufReader, Write}; use actix_rt::net::TcpStream; use actix_server::TestServer; use actix_service::ServiceFactoryExt as _; use actix_tls::accept::rustls::{Acceptor, TlsStream}; -use actix_tls::connect::tls::openssl::SslConnector; +use actix_tls::connect::openssl::reexports::SslConnector; use actix_utils::future::ok; use rustls_pemfile::{certs, pkcs8_private_keys}; use tls_openssl::ssl::SslVerifyMode; @@ -53,13 +55,13 @@ fn rustls_server_config(cert: String, key: String) -> rustls::ServerConfig { } fn openssl_connector(cert: String, key: String) -> SslConnector { - use actix_tls::connect::tls::openssl::{SslConnector as OpensslConnector, SslMethod}; - use tls_openssl::{pkey::PKey, x509::X509}; + use actix_tls::connect::openssl::reexports::SslMethod; + use openssl::{pkey::PKey, x509::X509}; let cert = X509::from_pem(cert.as_bytes()).unwrap(); let key = PKey::private_key_from_pem(key.as_bytes()).unwrap(); - let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); + let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap(); ssl.set_verify(SslVerifyMode::NONE); ssl.set_certificate(&cert).unwrap(); ssl.set_private_key(&key).unwrap(); diff --git a/actix-tls/tests/test_connect.rs b/actix-tls/tests/test_connect.rs old mode 100755 new mode 100644 index 564151ce..d3373c90 --- a/actix-tls/tests/test_connect.rs +++ b/actix-tls/tests/test_connect.rs @@ -12,7 +12,7 @@ use actix_service::{fn_service, Service, ServiceFactory}; use bytes::Bytes; use futures_util::sink::SinkExt; -use actix_tls::connect::{self as actix_connect, Connect}; +use actix_tls::connect::{ConnectError, ConnectInfo, Connection, Connector, Host}; #[cfg(feature = "openssl")] #[actix_rt::test] @@ -25,9 +25,9 @@ async fn test_string() { }) }); - let conn = actix_connect::default_connector(); + let connector = Connector::default().service(); let addr = format!("localhost:{}", srv.port()); - let con = conn.call(addr.into()).await.unwrap(); + let con = connector.call(addr.into()).await.unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); } @@ -42,7 +42,7 @@ async fn test_rustls_string() { }) }); - let conn = actix_connect::default_connector(); + let conn = Connector::default().service(); let addr = format!("localhost:{}", srv.port()); let con = conn.call(addr.into()).await.unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); @@ -58,23 +58,29 @@ async fn test_static_str() { }) }); - let conn = actix_connect::default_connector(); + let info = ConnectInfo::with_addr("10", srv.addr()); + let connector = Connector::default().service(); + let conn = connector.call(info).await.unwrap(); + assert_eq!(conn.peer_addr().unwrap(), srv.addr()); - let con = conn - .call(Connect::with_addr("10", srv.addr())) - .await - .unwrap(); - assert_eq!(con.peer_addr().unwrap(), srv.addr()); - - let connect = Connect::new(srv.host().to_owned()); - - let conn = actix_connect::default_connector(); - let con = conn.call(connect).await; - assert!(con.is_err()); + let info = ConnectInfo::new(srv.host().to_owned()); + let connector = Connector::default().service(); + let conn = connector.call(info).await; + assert!(conn.is_err()); } #[actix_rt::test] -async fn test_new_service() { +async fn service_factory() { + pub fn default_connector_factory() -> impl ServiceFactory< + ConnectInfo, + Config = (), + Response = Connection, + Error = ConnectError, + InitError = (), + > { + Connector::default() + } + let srv = TestServer::with(|| { fn_service(|io: TcpStream| async { let mut framed = Framed::new(io, BytesCodec); @@ -83,14 +89,11 @@ async fn test_new_service() { }) }); - let factory = actix_connect::default_connector_factory(); - - let conn = factory.new_service(()).await.unwrap(); - let con = conn - .call(Connect::with_addr("10", srv.addr())) - .await - .unwrap(); - assert_eq!(con.peer_addr().unwrap(), srv.addr()); + let info = ConnectInfo::with_addr("10", srv.addr()); + let factory = default_connector_factory(); + let connector = factory.new_service(()).await.unwrap(); + let con = connector.call(info).await; + assert_eq!(con.unwrap().peer_addr().unwrap(), srv.addr()); } #[cfg(all(feature = "openssl", feature = "uri"))] @@ -106,9 +109,9 @@ async fn test_openssl_uri() { }) }); - let conn = actix_connect::default_connector(); + let connector = Connector::default().service(); let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap(); - let con = conn.call(addr.into()).await.unwrap(); + let con = connector.call(addr.into()).await.unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); } @@ -125,7 +128,7 @@ async fn test_rustls_uri() { }) }); - let conn = actix_connect::default_connector(); + let conn = Connector::default().service(); let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap(); let con = conn.call(addr.into()).await.unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); @@ -141,11 +144,11 @@ async fn test_local_addr() { }) }); - let conn = actix_connect::default_connector(); + let conn = Connector::default().service(); let local = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3)); let (con, _) = conn - .call(Connect::with_addr("10", srv.addr()).set_local_addr(local)) + .call(ConnectInfo::with_addr("10", srv.addr()).set_local_addr(local)) .await .unwrap() .into_parts(); diff --git a/actix-tls/tests/test_resolvers.rs b/actix-tls/tests/test_resolvers.rs index 40ee21fa..987b229c 100644 --- a/actix-tls/tests/test_resolvers.rs +++ b/actix-tls/tests/test_resolvers.rs @@ -10,7 +10,9 @@ use actix_server::TestServer; use actix_service::{fn_service, Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; -use actix_tls::connect::{new_connector_factory, Connect, Resolve, Resolver}; +use actix_tls::connect::{ + ConnectError, ConnectInfo, Connection, Connector, Host, Resolve, Resolver, +}; #[actix_rt::test] async fn custom_resolver() { @@ -36,6 +38,18 @@ async fn custom_resolver() { #[actix_rt::test] async fn custom_resolver_connect() { + pub fn connector_factory( + resolver: Resolver, + ) -> impl ServiceFactory< + ConnectInfo, + Config = (), + Response = Connection, + Error = ConnectError, + InitError = (), + > { + Connector::new(resolver) + } + use trust_dns_resolver::TokioAsyncResolver; let srv = @@ -68,12 +82,11 @@ async fn custom_resolver_connect() { trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), }; - let resolver = Resolver::new_custom(resolver); - let factory = new_connector_factory(resolver); + let factory = connector_factory(Resolver::custom(resolver)); let conn = factory.new_service(()).await.unwrap(); let con = conn - .call(Connect::with_addr("example.com", srv.addr())) + .call(ConnectInfo::with_addr("example.com", srv.addr())) .await .unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr());