1
0
mirror of https://github.com/fafhrd91/actix-net synced 2024-11-27 19:12:56 +01:00

actix-tls release candidate prep (#422)

This commit is contained in:
Rob Ede 2021-11-29 23:53:06 +00:00 committed by GitHub
parent 5556afd524
commit 5dc2bfcb01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1608 additions and 1456 deletions

View File

@ -1,6 +1,33 @@
# Changes # Changes
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
### Added
* Derive `Debug` for `connect::Connection`. [#422]
* Implement `Display` for `accept::TlsError`. [#422]
* Implement `Error` for `accept::TlsError` where both types also implement `Error`. [#422]
* Implement `Default` for `connect::Resolver`. [#422]
* Implement `Error` for `connect::ConnectError`. [#422]
### Changed
* There are now no default features. [#422]
* Useful re-exports from underlying TLS crates are exposed in a `reexports` modules in all acceptors and connectors.
* Convert `connect::ResolverService` from enum to struct. [#422]
* Make `ConnectAddrsIter` private. [#422]
* Rename `accept::native_tls::{NativeTlsAcceptorService => AcceptorService}`. [#422]
* Rename `connect::{Address => Host}` trait. [#422]
* Rename method `connect::Connection::{host => hostname}`. [#422]
* Rename struct `connect::{Connect => ConnectInfo}`. [#422]
* Rename struct `connect::{ConnectService => ConnectorService}`. [#422]
* Rename struct `connect::{ConnectServiceFactory => Connector}`. [#422]
* Rename TLS acceptor service future types and hide from docs. [#422]
* Unbox some service futures types. [#422]
### Removed
* Remove `connect::{new_connector, new_connector_factory, default_connector, default_connector_factory}` methods. [#422]
* Remove `connect::native_tls::Connector::service` method. [#422]
* Remove redundant `connect::Connection::from_parts` method. [#422]
[#422]: https://github.com/actix/actix-net/pull/422
## 3.0.0-beta.9 - 2021-11-22 ## 3.0.0-beta.9 - 2021-11-22
@ -41,8 +68,7 @@
* Remove `connect::ssl::openssl::OpensslConnectService`. [#297] * Remove `connect::ssl::openssl::OpensslConnectService`. [#297]
* Add `connect::ssl::native_tls` module for native tls support. [#295] * Add `connect::ssl::native_tls` module for native tls support. [#295]
* Rename `accept::{nativetls => native_tls}`. [#295] * Rename `accept::{nativetls => native_tls}`. [#295]
* Remove `connect::TcpConnectService` type. service caller expect a `TcpStream` should use * Remove `connect::TcpConnectService` type. Service caller expecting a `TcpStream` should use `connect::ConnectService` instead and call `Connection<T, TcpStream>::into_parts`. [#299]
`connect::ConnectService` instead and call `Connection<T, TcpStream>::into_parts`. [#299]
[#295]: https://github.com/actix/actix-net/pull/295 [#295]: https://github.com/actix/actix-net/pull/295
[#296]: https://github.com/actix/actix-net/pull/296 [#296]: https://github.com/actix/actix-net/pull/296

View File

@ -13,14 +13,15 @@ license = "MIT OR Apache-2.0"
edition = "2018" edition = "2018"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["openssl", "rustls", "native-tls", "accept", "connect", "uri"] all-features = true
rustdoc-args = ["--cfg", "docsrs"]
[lib] [lib]
name = "actix_tls" name = "actix_tls"
path = "src/lib.rs" path = "src/lib.rs"
[features] [features]
default = ["accept", "connect", "uri"] default = []
# enable acceptor services # enable acceptor services
accept = [] accept = []
@ -48,11 +49,13 @@ actix-utils = "3.0.0"
derive_more = "0.99.5" derive_more = "0.99.5"
futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] }
http = { version = "0.2.3", optional = true }
log = "0.4" log = "0.4"
pin-project-lite = "0.2.7" pin-project-lite = "0.2.7"
tokio-util = { version = "0.6.3", default-features = false } tokio-util = { version = "0.6.3", default-features = false }
# uri
http = { version = "0.2.3", optional = true }
# openssl # openssl
tls-openssl = { package = "openssl", version = "0.10.9", optional = true } tls-openssl = { package = "openssl", version = "0.10.9", optional = true }
tokio-openssl = { version = "0.6", optional = true } tokio-openssl = { version = "0.6", optional = true }

View File

@ -1,4 +1,4 @@
//! TLS acceptor services. //! TLS connection acceptor services.
use std::{ use std::{
convert::Infallible, convert::Infallible,
@ -6,14 +6,18 @@ use std::{
}; };
use actix_utils::counter::Counter; use actix_utils::counter::Counter;
use derive_more::{Display, Error};
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
pub mod openssl; pub mod openssl;
#[cfg(feature = "rustls")] #[cfg(feature = "rustls")]
#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
pub mod rustls; pub mod rustls;
#[cfg(feature = "native-tls")] #[cfg(feature = "native-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))]
pub mod native_tls; pub mod native_tls;
pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256); pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256);
@ -41,15 +45,18 @@ pub fn max_concurrent_tls_connect(num: usize) {
/// All TLS acceptors from this crate will return the `SvcErr` type parameter as [`Infallible`], /// All TLS acceptors from this crate will return the `SvcErr` type parameter as [`Infallible`],
/// which can be cast to your own service type, inferred or otherwise, /// which can be cast to your own service type, inferred or otherwise,
/// using [`into_service_error`](Self::into_service_error). /// using [`into_service_error`](Self::into_service_error).
#[derive(Debug)] #[derive(Debug, Display, Error)]
pub enum TlsError<TlsErr, SvcErr> { pub enum TlsError<TlsErr, SvcErr> {
/// TLS handshake has timed-out. /// TLS handshake has timed-out.
#[display(fmt = "TLS handshake has timed-out")]
Timeout, Timeout,
/// Wraps TLS service errors. /// Wraps TLS service errors.
#[display(fmt = "TLS handshake error")]
Tls(TlsErr), Tls(TlsErr),
/// Wraps inner service errors. /// Wraps service errors.
#[display(fmt = "Service error")]
Service(SvcErr), Service(SvcErr),
} }

View File

@ -1,7 +1,10 @@
//! `native-tls` based TLS connection acceptor service.
//!
//! See [`Acceptor`] for main service factory docs.
use std::{ use std::{
convert::Infallible, convert::Infallible,
io::{self, IoSlice}, io::{self, IoSlice},
ops::{Deref, DerefMut},
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
@ -13,37 +16,27 @@ use actix_rt::{
time::timeout, time::timeout,
}; };
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::counter::Counter; use actix_utils::{
counter::Counter,
future::{ready, Ready as FutReady},
};
use derive_more::{Deref, DerefMut, From};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use tokio_native_tls::{native_tls::Error, TlsAcceptor};
pub use tokio_native_tls::{native_tls::Error, TlsAcceptor};
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
/// Wrapper type for `tokio_native_tls::TlsStream` in order to impl `ActixStream` trait. pub mod reexports {
pub struct TlsStream<T>(tokio_native_tls::TlsStream<T>); //! Re-exports from `native-tls` that are useful for acceptors.
impl<T> From<tokio_native_tls::TlsStream<T>> for TlsStream<T> { pub use tokio_native_tls::{native_tls::Error, TlsAcceptor};
fn from(stream: tokio_native_tls::TlsStream<T>) -> Self {
Self(stream)
}
} }
impl<T: ActixStream> Deref for TlsStream<T> { /// Wraps a `native-tls` based async TLS stream in order to implement [`ActixStream`].
type Target = tokio_native_tls::TlsStream<T>; #[derive(Deref, DerefMut, From)]
pub struct TlsStream<IO>(tokio_native_tls::TlsStream<IO>);
fn deref(&self) -> &Self::Target { impl<IO: ActixStream> AsyncRead for TlsStream<IO> {
&self.0
}
}
impl<T: ActixStream> DerefMut for TlsStream<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T: ActixStream> AsyncRead for TlsStream<T> {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@ -53,7 +46,7 @@ impl<T: ActixStream> AsyncRead for TlsStream<T> {
} }
} }
impl<T: ActixStream> AsyncWrite for TlsStream<T> { impl<IO: ActixStream> AsyncWrite for TlsStream<IO> {
fn poll_write( fn poll_write(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@ -83,27 +76,24 @@ impl<T: ActixStream> AsyncWrite for TlsStream<T> {
} }
} }
impl<T: ActixStream> ActixStream for TlsStream<T> { impl<IO: ActixStream> ActixStream for TlsStream<IO> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
T::poll_read_ready((&**self).get_ref().get_ref().get_ref(), cx) IO::poll_read_ready((&**self).get_ref().get_ref().get_ref(), cx)
} }
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
T::poll_write_ready((&**self).get_ref().get_ref().get_ref(), cx) IO::poll_write_ready((&**self).get_ref().get_ref().get_ref(), cx)
} }
} }
/// Accept TLS connections via `native-tls` package. /// Accept TLS connections via the `native-tls` crate.
///
/// `native-tls` feature enables this `Acceptor` type.
pub struct Acceptor { pub struct Acceptor {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
handshake_timeout: Duration, handshake_timeout: Duration,
} }
impl Acceptor { impl Acceptor {
/// Create `native-tls` based `Acceptor` service factory. /// Constructs `native-tls` based `Acceptor` service factory.
#[inline]
pub fn new(acceptor: TlsAcceptor) -> Self { pub fn new(acceptor: TlsAcceptor) -> Self {
Acceptor { Acceptor {
acceptor, acceptor,
@ -130,35 +120,36 @@ impl Clone for Acceptor {
} }
} }
impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor { impl<IO: ActixStream + 'static> ServiceFactory<IO> for Acceptor {
type Response = TlsStream<T>; type Response = TlsStream<IO>;
type Error = TlsError<Error, Infallible>; type Error = TlsError<Error, Infallible>;
type Config = (); type Config = ();
type Service = NativeTlsAcceptorService; type Service = AcceptorService;
type InitError = (); type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>; type Future = FutReady<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
let res = MAX_CONN_COUNTER.with(|conns| { let res = MAX_CONN_COUNTER.with(|conns| {
Ok(NativeTlsAcceptorService { Ok(AcceptorService {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
conns: conns.clone(), conns: conns.clone(),
handshake_timeout: self.handshake_timeout, handshake_timeout: self.handshake_timeout,
}) })
}); });
Box::pin(async { res }) ready(res)
} }
} }
pub struct NativeTlsAcceptorService { /// Native-TLS based acceptor service.
pub struct AcceptorService {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
conns: Counter, conns: Counter,
handshake_timeout: Duration, handshake_timeout: Duration,
} }
impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService { impl<IO: ActixStream + 'static> Service<IO> for AcceptorService {
type Response = TlsStream<T>; type Response = TlsStream<IO>;
type Error = TlsError<Error, Infallible>; type Error = TlsError<Error, Infallible>;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
@ -170,7 +161,7 @@ impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService {
} }
} }
fn call(&self, io: T) -> Self::Future { fn call(&self, io: IO) -> Self::Future {
let guard = self.conns.get(); let guard = self.conns.get();
let acceptor = self.acceptor.clone(); let acceptor = self.acceptor.clone();

View File

@ -1,8 +1,11 @@
//! `openssl` based TLS acceptor service.
//!
//! See [`Acceptor`] for main service factory docs.
use std::{ use std::{
convert::Infallible, convert::Infallible,
future::Future, future::Future,
io::{self, IoSlice}, io::{self, IoSlice},
ops::{Deref, DerefMut},
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
@ -14,40 +17,29 @@ use actix_rt::{
time::{sleep, Sleep}, time::{sleep, Sleep},
}; };
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::counter::{Counter, CounterGuard}; use actix_utils::{
use futures_core::future::LocalBoxFuture; counter::{Counter, CounterGuard},
future::{ready, Ready as FutReady},
pub use openssl::ssl::{
AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder,
}; };
use derive_more::{Deref, DerefMut, From};
use openssl::ssl::{Error, Ssl, SslAcceptor};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait. pub mod reexports {
pub struct TlsStream<T>(tokio_openssl::SslStream<T>); //! Re-exports from `openssl` that are useful for acceptors.
impl<T> From<tokio_openssl::SslStream<T>> for TlsStream<T> { pub use openssl::ssl::{
fn from(stream: tokio_openssl::SslStream<T>) -> Self { AlpnError, Error, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder,
Self(stream) };
}
} }
impl<T> Deref for TlsStream<T> { /// Wraps an `openssl` based async TLS stream in order to implement [`ActixStream`].
type Target = tokio_openssl::SslStream<T>; #[derive(Deref, DerefMut, From)]
pub struct TlsStream<IO>(tokio_openssl::SslStream<IO>);
fn deref(&self) -> &Self::Target { impl<IO: ActixStream> AsyncRead for TlsStream<IO> {
&self.0
}
}
impl<T> DerefMut for TlsStream<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T: ActixStream> AsyncRead for TlsStream<T> {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@ -57,7 +49,7 @@ impl<T: ActixStream> AsyncRead for TlsStream<T> {
} }
} }
impl<T: ActixStream> AsyncWrite for TlsStream<T> { impl<IO: ActixStream> AsyncWrite for TlsStream<IO> {
fn poll_write( fn poll_write(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@ -87,19 +79,17 @@ impl<T: ActixStream> AsyncWrite for TlsStream<T> {
} }
} }
impl<T: ActixStream> ActixStream for TlsStream<T> { impl<IO: ActixStream> ActixStream for TlsStream<IO> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
T::poll_read_ready((&**self).get_ref(), cx) IO::poll_read_ready((&**self).get_ref(), cx)
} }
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
T::poll_write_ready((&**self).get_ref(), cx) IO::poll_write_ready((&**self).get_ref(), cx)
} }
} }
/// Accept TLS connections via `openssl` package. /// Accept TLS connections via the `openssl` crate.
///
/// `openssl` feature enables this `Acceptor` type.
pub struct Acceptor { pub struct Acceptor {
acceptor: SslAcceptor, acceptor: SslAcceptor,
handshake_timeout: Duration, handshake_timeout: Duration,
@ -134,13 +124,13 @@ impl Clone for Acceptor {
} }
} }
impl<T: ActixStream> ServiceFactory<T> for Acceptor { impl<IO: ActixStream> ServiceFactory<IO> for Acceptor {
type Response = TlsStream<T>; type Response = TlsStream<IO>;
type Error = TlsError<SslError, Infallible>; type Error = TlsError<Error, Infallible>;
type Config = (); type Config = ();
type Service = AcceptorService; type Service = AcceptorService;
type InitError = (); type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>; type Future = FutReady<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
let res = MAX_CONN_COUNTER.with(|conns| { let res = MAX_CONN_COUNTER.with(|conns| {
@ -151,20 +141,21 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
}) })
}); });
Box::pin(async { res }) ready(res)
} }
} }
/// OpenSSL based acceptor service.
pub struct AcceptorService { pub struct AcceptorService {
acceptor: SslAcceptor, acceptor: SslAcceptor,
conns: Counter, conns: Counter,
handshake_timeout: Duration, handshake_timeout: Duration,
} }
impl<T: ActixStream> Service<T> for AcceptorService { impl<IO: ActixStream> Service<IO> for AcceptorService {
type Response = TlsStream<T>; type Response = TlsStream<IO>;
type Error = TlsError<SslError, Infallible>; type Error = TlsError<Error, Infallible>;
type Future = AcceptorServiceResponse<T>; type Future = AcceptFut<IO>;
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.conns.available(ctx) { if self.conns.available(ctx) {
@ -174,11 +165,11 @@ impl<T: ActixStream> Service<T> for AcceptorService {
} }
} }
fn call(&self, io: T) -> Self::Future { fn call(&self, io: IO) -> Self::Future {
let ssl_ctx = self.acceptor.context(); let ssl_ctx = self.acceptor.context();
let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid."); let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid.");
AcceptorServiceResponse { AcceptFut {
_guard: self.conns.get(), _guard: self.conns.get(),
timeout: sleep(self.handshake_timeout), timeout: sleep(self.handshake_timeout),
stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()), stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()),
@ -187,16 +178,18 @@ impl<T: ActixStream> Service<T> for AcceptorService {
} }
pin_project! { pin_project! {
pub struct AcceptorServiceResponse<T: ActixStream> { /// Accept future for OpenSSL service.
stream: Option<tokio_openssl::SslStream<T>>, #[doc(hidden)]
pub struct AcceptFut<IO: ActixStream> {
stream: Option<tokio_openssl::SslStream<IO>>,
#[pin] #[pin]
timeout: Sleep, timeout: Sleep,
_guard: CounterGuard, _guard: CounterGuard,
} }
} }
impl<T: ActixStream> Future for AcceptorServiceResponse<T> { impl<IO: ActixStream> Future for AcceptFut<IO> {
type Output = Result<TlsStream<T>, TlsError<SslError, Infallible>>; type Output = Result<TlsStream<IO>, TlsError<Error, Infallible>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();

View File

@ -1,8 +1,11 @@
//! `rustls` based TLS connection acceptor service.
//!
//! See [`Acceptor`] for main service factory docs.
use std::{ use std::{
convert::Infallible, convert::Infallible,
future::Future, future::Future,
io::{self, IoSlice}, io::{self, IoSlice},
ops::{Deref, DerefMut},
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
@ -15,39 +18,28 @@ use actix_rt::{
time::{sleep, Sleep}, time::{sleep, Sleep},
}; };
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::counter::{Counter, CounterGuard}; use actix_utils::{
use futures_core::future::LocalBoxFuture; counter::{Counter, CounterGuard},
future::{ready, Ready as FutReady},
};
use derive_more::{Deref, DerefMut, From};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::{Accept, TlsAcceptor}; use tokio_rustls::{Accept, TlsAcceptor};
pub use tokio_rustls::rustls::ServerConfig;
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait. pub mod reexports {
pub struct TlsStream<T>(tokio_rustls::server::TlsStream<T>); //! Re-exports from `rustls` that are useful for acceptors.
impl<T> From<tokio_rustls::server::TlsStream<T>> for TlsStream<T> { pub use tokio_rustls::rustls::ServerConfig;
fn from(stream: tokio_rustls::server::TlsStream<T>) -> Self {
Self(stream)
}
} }
impl<T> Deref for TlsStream<T> { /// Wraps a `rustls` based async TLS stream in order to implement [`ActixStream`].
type Target = tokio_rustls::server::TlsStream<T>; #[derive(Deref, DerefMut, From)]
pub struct TlsStream<IO>(tokio_rustls::server::TlsStream<IO>);
fn deref(&self) -> &Self::Target { impl<IO: ActixStream> AsyncRead for TlsStream<IO> {
&self.0
}
}
impl<T> DerefMut for TlsStream<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T: ActixStream> AsyncRead for TlsStream<T> {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@ -57,7 +49,7 @@ impl<T: ActixStream> AsyncRead for TlsStream<T> {
} }
} }
impl<T: ActixStream> AsyncWrite for TlsStream<T> { impl<IO: ActixStream> AsyncWrite for TlsStream<IO> {
fn poll_write( fn poll_write(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@ -87,27 +79,24 @@ impl<T: ActixStream> AsyncWrite for TlsStream<T> {
} }
} }
impl<T: ActixStream> ActixStream for TlsStream<T> { impl<IO: ActixStream> ActixStream for TlsStream<IO> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
T::poll_read_ready((&**self).get_ref().0, cx) IO::poll_read_ready((&**self).get_ref().0, cx)
} }
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
T::poll_write_ready((&**self).get_ref().0, cx) IO::poll_write_ready((&**self).get_ref().0, cx)
} }
} }
/// Accept TLS connections via `rustls` package. /// Accept TLS connections via the `rustls` crate.
///
/// `rustls` feature enables this `Acceptor` type.
pub struct Acceptor { pub struct Acceptor {
config: Arc<ServerConfig>, config: Arc<ServerConfig>,
handshake_timeout: Duration, handshake_timeout: Duration,
} }
impl Acceptor { impl Acceptor {
/// Create Rustls based `Acceptor` service factory. /// Constructs Rustls based acceptor service factory.
#[inline]
pub fn new(config: ServerConfig) -> Self { pub fn new(config: ServerConfig) -> Self {
Acceptor { Acceptor {
config: Arc::new(config), config: Arc::new(config),
@ -125,7 +114,6 @@ impl Acceptor {
} }
impl Clone for Acceptor { impl Clone for Acceptor {
#[inline]
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
config: self.config.clone(), config: self.config.clone(),
@ -134,13 +122,13 @@ impl Clone for Acceptor {
} }
} }
impl<T: ActixStream> ServiceFactory<T> for Acceptor { impl<IO: ActixStream> ServiceFactory<IO> for Acceptor {
type Response = TlsStream<T>; type Response = TlsStream<IO>;
type Error = TlsError<io::Error, Infallible>; type Error = TlsError<io::Error, Infallible>;
type Config = (); type Config = ();
type Service = AcceptorService; type Service = AcceptorService;
type InitError = (); type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>; type Future = FutReady<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
let res = MAX_CONN_COUNTER.with(|conns| { let res = MAX_CONN_COUNTER.with(|conns| {
@ -151,21 +139,21 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
}) })
}); });
Box::pin(async { res }) ready(res)
} }
} }
/// Rustls based `Acceptor` service /// Rustls based acceptor service.
pub struct AcceptorService { pub struct AcceptorService {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
conns: Counter, conns: Counter,
handshake_timeout: Duration, handshake_timeout: Duration,
} }
impl<T: ActixStream> Service<T> for AcceptorService { impl<IO: ActixStream> Service<IO> for AcceptorService {
type Response = TlsStream<T>; type Response = TlsStream<IO>;
type Error = TlsError<io::Error, Infallible>; type Error = TlsError<io::Error, Infallible>;
type Future = AcceptorServiceFut<T>; type Future = AcceptFut<IO>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.conns.available(cx) { if self.conns.available(cx) {
@ -175,8 +163,8 @@ impl<T: ActixStream> Service<T> for AcceptorService {
} }
} }
fn call(&self, req: T) -> Self::Future { fn call(&self, req: IO) -> Self::Future {
AcceptorServiceFut { AcceptFut {
fut: self.acceptor.accept(req), fut: self.acceptor.accept(req),
timeout: sleep(self.handshake_timeout), timeout: sleep(self.handshake_timeout),
_guard: self.conns.get(), _guard: self.conns.get(),
@ -185,16 +173,18 @@ impl<T: ActixStream> Service<T> for AcceptorService {
} }
pin_project! { pin_project! {
pub struct AcceptorServiceFut<T: ActixStream> { /// Accept future for Rustls service.
fut: Accept<T>, #[doc(hidden)]
pub struct AcceptFut<IO: ActixStream> {
fut: Accept<IO>,
#[pin] #[pin]
timeout: Sleep, timeout: Sleep,
_guard: CounterGuard, _guard: CounterGuard,
} }
} }
impl<T: ActixStream> Future for AcceptorServiceFut<T> { impl<IO: ActixStream> Future for AcceptFut<IO> {
type Output = Result<TlsStream<T>, TlsError<io::Error, Infallible>>; type Output = Result<TlsStream<IO>, TlsError<io::Error, Infallible>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project(); let mut this = self.project();

View File

@ -1,361 +0,0 @@
use std::{
collections::{vec_deque, VecDeque},
fmt,
iter::{self, FromIterator as _},
mem,
net::{IpAddr, SocketAddr},
};
/// Parse a host into parts (hostname and port).
pub trait Address: Unpin + 'static {
/// Get hostname part.
fn hostname(&self) -> &str;
/// Get optional port part.
fn port(&self) -> Option<u16> {
None
}
}
impl Address for String {
fn hostname(&self) -> &str {
self
}
}
impl Address for &'static str {
fn hostname(&self) -> &str {
self
}
}
#[derive(Debug, Eq, PartialEq, Hash)]
pub(crate) enum ConnectAddrs {
None,
One(SocketAddr),
Multi(VecDeque<SocketAddr>),
}
impl ConnectAddrs {
pub(crate) fn is_none(&self) -> bool {
matches!(self, Self::None)
}
pub(crate) fn is_some(&self) -> bool {
!self.is_none()
}
}
impl Default for ConnectAddrs {
fn default() -> Self {
Self::None
}
}
impl From<Option<SocketAddr>> for ConnectAddrs {
fn from(addr: Option<SocketAddr>) -> Self {
match addr {
Some(addr) => ConnectAddrs::One(addr),
None => ConnectAddrs::None,
}
}
}
/// Connection info.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Connect<R> {
pub(crate) req: R,
pub(crate) port: u16,
pub(crate) addr: ConnectAddrs,
pub(crate) local_addr: Option<IpAddr>,
}
impl<R: Address> Connect<R> {
/// Create `Connect` instance by splitting the string by ':' and convert the second part to u16
pub fn new(req: R) -> Connect<R> {
let (_, port) = parse_host(req.hostname());
Connect {
req,
port: port.unwrap_or(0),
addr: ConnectAddrs::None,
local_addr: None,
}
}
/// Create new `Connect` instance from host and address. Connector skips name resolution stage
/// for such connect messages.
pub fn with_addr(req: R, addr: SocketAddr) -> Connect<R> {
Connect {
req,
port: 0,
addr: ConnectAddrs::One(addr),
local_addr: None,
}
}
/// Use port if address does not provide one.
///
/// Default value is 0.
pub fn set_port(mut self, port: u16) -> Self {
self.port = port;
self
}
/// Set address.
pub fn set_addr(mut self, addr: Option<SocketAddr>) -> Self {
self.addr = ConnectAddrs::from(addr);
self
}
/// Set list of addresses.
pub fn set_addrs<I>(mut self, addrs: I) -> Self
where
I: IntoIterator<Item = SocketAddr>,
{
let mut addrs = VecDeque::from_iter(addrs);
self.addr = if addrs.len() < 2 {
ConnectAddrs::from(addrs.pop_front())
} else {
ConnectAddrs::Multi(addrs)
};
self
}
/// Set local_addr of connect.
pub fn set_local_addr(mut self, addr: impl Into<IpAddr>) -> Self {
self.local_addr = Some(addr.into());
self
}
/// Get hostname.
pub fn hostname(&self) -> &str {
self.req.hostname()
}
/// Get request port.
pub fn port(&self) -> u16 {
self.req.port().unwrap_or(self.port)
}
/// Get resolved request addresses.
pub fn addrs(&self) -> ConnectAddrsIter<'_> {
match self.addr {
ConnectAddrs::None => ConnectAddrsIter::None,
ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr),
ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()),
}
}
/// Take resolved request addresses.
pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> {
match mem::take(&mut self.addr) {
ConnectAddrs::None => ConnectAddrsIter::None,
ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr),
ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()),
}
}
/// Returns a reference to the connection request.
pub fn request(&self) -> &R {
&self.req
}
}
impl<R: Address> From<R> for Connect<R> {
fn from(addr: R) -> Self {
Connect::new(addr)
}
}
impl<R: Address> fmt::Display for Connect<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.hostname(), self.port())
}
}
/// Iterator over addresses in a [`Connect`] request.
#[derive(Clone)]
pub enum ConnectAddrsIter<'a> {
None,
One(SocketAddr),
Multi(vec_deque::Iter<'a, SocketAddr>),
MultiOwned(vec_deque::IntoIter<SocketAddr>),
}
impl Iterator for ConnectAddrsIter<'_> {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
match *self {
Self::None => None,
Self::One(addr) => {
*self = Self::None;
Some(addr)
}
Self::Multi(ref mut iter) => iter.next().copied(),
Self::MultiOwned(ref mut iter) => iter.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match *self {
Self::None => (0, Some(0)),
Self::One(_) => (1, Some(1)),
Self::Multi(ref iter) => iter.size_hint(),
Self::MultiOwned(ref iter) => iter.size_hint(),
}
}
}
impl fmt::Debug for ConnectAddrsIter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.clone()).finish()
}
}
impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {}
impl iter::FusedIterator for ConnectAddrsIter<'_> {}
pub struct Connection<T, U> {
io: U,
req: T,
}
impl<T, U> Connection<T, U> {
pub fn new(io: U, req: T) -> Self {
Self { io, req }
}
}
impl<T, U> Connection<T, U> {
/// Reconstruct from a parts.
pub fn from_parts(io: U, req: T) -> Self {
Self { io, req }
}
/// Deconstruct into a parts.
pub fn into_parts(self) -> (U, T) {
(self.io, self.req)
}
/// Replace inclosed object, return new Stream and old object
pub fn replace_io<Y>(self, io: Y) -> (U, Connection<T, Y>) {
(self.io, Connection { io, req: self.req })
}
/// Returns a shared reference to the underlying stream.
pub fn io_ref(&self) -> &U {
&self.io
}
/// Returns a mutable reference to the underlying stream.
pub fn io_mut(&mut self) -> &mut U {
&mut self.io
}
}
impl<T: Address, U> Connection<T, U> {
/// Get hostname.
pub fn host(&self) -> &str {
self.req.hostname()
}
}
impl<T, U> std::ops::Deref for Connection<T, U> {
type Target = U;
fn deref(&self) -> &U {
&self.io
}
}
impl<T, U> std::ops::DerefMut for Connection<T, U> {
fn deref_mut(&mut self) -> &mut U {
&mut self.io
}
}
impl<T, U: fmt::Debug> fmt::Debug for Connection<T, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Stream {{{:?}}}", self.io)
}
}
fn parse_host(host: &str) -> (&str, Option<u16>) {
let mut parts_iter = host.splitn(2, ':');
match parts_iter.next() {
Some(hostname) => {
let port_str = parts_iter.next().unwrap_or("");
let port = port_str.parse::<u16>().ok();
(hostname, port)
}
None => (host, None),
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use super::*;
#[test]
fn test_host_parser() {
assert_eq!(parse_host("example.com"), ("example.com", None));
assert_eq!(parse_host("example.com:8080"), ("example.com", Some(8080)));
assert_eq!(parse_host("example:8080"), ("example", Some(8080)));
assert_eq!(parse_host("example.com:false"), ("example.com", None));
assert_eq!(parse_host("example.com:false:false"), ("example.com", None));
}
#[test]
fn test_addr_iter_multi() {
let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080));
let unspecified = SocketAddr::from((IpAddr::from(Ipv4Addr::UNSPECIFIED), 8080));
let mut addrs = VecDeque::new();
addrs.push_back(localhost);
addrs.push_back(unspecified);
let mut iter = ConnectAddrsIter::Multi(addrs.iter());
assert_eq!(iter.next(), Some(localhost));
assert_eq!(iter.next(), Some(unspecified));
assert_eq!(iter.next(), None);
let mut iter = ConnectAddrsIter::MultiOwned(addrs.into_iter());
assert_eq!(iter.next(), Some(localhost));
assert_eq!(iter.next(), Some(unspecified));
assert_eq!(iter.next(), None);
}
#[test]
fn test_addr_iter_single() {
let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080));
let mut iter = ConnectAddrsIter::One(localhost);
assert_eq!(iter.next(), Some(localhost));
assert_eq!(iter.next(), None);
let mut iter = ConnectAddrsIter::None;
assert_eq!(iter.next(), None);
}
#[test]
fn test_local_addr() {
let conn = Connect::new("hello").set_local_addr([127, 0, 0, 1]);
assert_eq!(
conn.local_addr.unwrap(),
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
)
}
#[test]
fn request_ref() {
let conn = Connect::new("hello");
assert_eq!(conn.request(), &"hello")
}
}

View File

@ -0,0 +1,82 @@
use std::{
collections::{vec_deque, VecDeque},
fmt, iter,
net::SocketAddr,
};
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub(crate) enum ConnectAddrs {
None,
One(SocketAddr),
// TODO: consider using smallvec
Multi(VecDeque<SocketAddr>),
}
impl ConnectAddrs {
pub(crate) fn is_unresolved(&self) -> bool {
matches!(self, Self::None)
}
pub(crate) fn is_resolved(&self) -> bool {
!self.is_unresolved()
}
}
impl Default for ConnectAddrs {
fn default() -> Self {
Self::None
}
}
impl From<Option<SocketAddr>> for ConnectAddrs {
fn from(addr: Option<SocketAddr>) -> Self {
match addr {
Some(addr) => ConnectAddrs::One(addr),
None => ConnectAddrs::None,
}
}
}
/// Iterator over addresses in a [`Connect`] request.
#[derive(Clone)]
pub(crate) enum ConnectAddrsIter<'a> {
None,
One(SocketAddr),
Multi(vec_deque::Iter<'a, SocketAddr>),
MultiOwned(vec_deque::IntoIter<SocketAddr>),
}
impl Iterator for ConnectAddrsIter<'_> {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
match *self {
Self::None => None,
Self::One(addr) => {
*self = Self::None;
Some(addr)
}
Self::Multi(ref mut iter) => iter.next().copied(),
Self::MultiOwned(ref mut iter) => iter.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match *self {
Self::None => (0, Some(0)),
Self::One(_) => (1, Some(1)),
Self::Multi(ref iter) => iter.size_hint(),
Self::MultiOwned(ref iter) => iter.size_hint(),
}
}
}
impl fmt::Debug for ConnectAddrsIter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.clone()).finish()
}
}
impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {}
impl iter::FusedIterator for ConnectAddrsIter<'_> {}

View File

@ -0,0 +1,54 @@
use derive_more::{Deref, DerefMut};
use super::Host;
/// Wraps underlying I/O and the connection request that initiated it.
#[derive(Debug, Deref, DerefMut)]
pub struct Connection<R, IO> {
pub(crate) req: R,
#[deref]
#[deref_mut]
pub(crate) io: IO,
}
impl<R, IO> Connection<R, IO> {
/// Construct new `Connection` from request and IO parts.
pub(crate) fn new(req: R, io: IO) -> Self {
Self { req, io }
}
}
impl<R, IO> Connection<R, IO> {
/// Deconstructs into IO and request parts.
pub fn into_parts(self) -> (IO, R) {
(self.io, self.req)
}
/// Replaces underlying IO, returning old IO and new `Connection`.
pub fn replace_io<IO2>(self, io: IO2) -> (IO, Connection<R, IO2>) {
(self.io, Connection { io, req: self.req })
}
/// Returns a shared reference to the underlying IO.
pub fn io_ref(&self) -> &IO {
&self.io
}
/// Returns a mutable reference to the underlying IO.
pub fn io_mut(&mut self) -> &mut IO {
&mut self.io
}
/// Returns a reference to the connection request.
pub fn request(&self) -> &R {
&self.req
}
}
impl<R: Host, IO> Connection<R, IO> {
/// Returns hostname.
pub fn hostname(&self) -> &str {
self.req.hostname()
}
}

235
actix-tls/src/connect/connector.rs Executable file → Normal file
View File

@ -1,194 +1,127 @@
use std::{ use std::{
collections::VecDeque,
future::Future, future::Future,
io,
net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6},
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use actix_rt::net::{TcpSocket, TcpStream}; use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready}; use actix_utils::future::{ok, Ready};
use log::{error, trace}; use futures_core::ready;
use tokio_util::sync::ReusableBoxFuture;
use super::connect::{Address, Connect, ConnectAddrs, Connection}; use super::{
use super::error::ConnectError; error::ConnectError,
resolver::{Resolver, ResolverService},
tcp::{TcpConnector, TcpConnectorService},
ConnectInfo, Connection, Host,
};
/// TCP connector service factory /// Combined resolver and TCP connector service factory.
#[derive(Debug, Copy, Clone)] ///
pub struct TcpConnectorFactory; /// Used to create [`ConnectorService`]s which receive connection information, resolve DNS if
/// required, and return a TCP stream.
#[derive(Clone, Default)]
pub struct Connector {
resolver: Resolver,
}
impl TcpConnectorFactory { impl Connector {
/// Create TCP connector service /// Constructs new connector factory with the given resolver.
pub fn service(&self) -> TcpConnector { pub fn new(resolver: Resolver) -> Self {
TcpConnector Connector { resolver }
}
/// Build connector service.
pub fn service(&self) -> ConnectorService {
ConnectorService {
tcp: TcpConnector.service(),
resolver: self.resolver.service(),
}
} }
} }
impl<T: Address> ServiceFactory<Connect<T>> for TcpConnectorFactory { impl<R: Host> ServiceFactory<ConnectInfo<R>> for Connector {
type Response = Connection<T, TcpStream>; type Response = Connection<R, TcpStream>;
type Error = ConnectError; type Error = ConnectError;
type Config = (); type Config = ();
type Service = TcpConnector; type Service = ConnectorService;
type InitError = (); type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
let service = self.service(); ok(self.service())
Box::pin(async move { Ok(service) })
} }
} }
/// TCP connector service /// Combined resolver and TCP connector service.
#[derive(Debug, Copy, Clone)] ///
pub struct TcpConnector; /// Service implementation receives connection information, resolves DNS if required, and returns
/// a TCP stream.
#[derive(Clone)]
pub struct ConnectorService {
tcp: TcpConnectorService,
resolver: ResolverService,
}
impl<T: Address> Service<Connect<T>> for TcpConnector { impl<R: Host> Service<ConnectInfo<R>> for ConnectorService {
type Response = Connection<T, TcpStream>; type Response = Connection<R, TcpStream>;
type Error = ConnectError; type Error = ConnectError;
type Future = TcpConnectorResponse<T>; type Future = ConnectServiceResponse<R>;
actix_service::always_ready!(); actix_service::always_ready!();
fn call(&self, req: Connect<T>) -> Self::Future { fn call(&self, req: ConnectInfo<R>) -> Self::Future {
let port = req.port(); ConnectServiceResponse {
let Connect { fut: ConnectFut::Resolve(self.resolver.call(req)),
req, tcp: self.tcp,
addr, }
local_addr,
..
} = req;
TcpConnectorResponse::new(req, port, local_addr, addr)
} }
} }
/// TCP stream connector response future /// Helper enum to generic over futures of resolve and connect steps.
pub enum TcpConnectorResponse<T> { pub(crate) enum ConnectFut<R: Host> {
Response { Resolve(<ResolverService as Service<ConnectInfo<R>>>::Future),
req: Option<T>, Connect(<TcpConnectorService as Service<ConnectInfo<R>>>::Future),
port: u16,
local_addr: Option<IpAddr>,
addrs: Option<VecDeque<SocketAddr>>,
stream: ReusableBoxFuture<Result<TcpStream, io::Error>>,
},
Error(Option<ConnectError>),
} }
impl<T: Address> TcpConnectorResponse<T> { /// Helper enum to contain the future output of `ConnectFuture`.
pub(crate) fn new( pub(crate) enum ConnectOutput<R: Host> {
req: T, Resolved(ConnectInfo<R>),
port: u16, Connected(Connection<R, TcpStream>),
local_addr: Option<IpAddr>, }
addr: ConnectAddrs,
) -> TcpConnectorResponse<T> {
if addr.is_none() {
error!("TCP connector: unresolved connection address");
return TcpConnectorResponse::Error(Some(ConnectError::Unresolved));
}
trace!( impl<R: Host> ConnectFut<R> {
"TCP connector: connecting to {} on port {}", fn poll_connect(
req.hostname(), &mut self,
port cx: &mut Context<'_>,
); ) -> Poll<Result<ConnectOutput<R>, ConnectError>> {
match self {
match addr { ConnectFut::Resolve(ref mut fut) => {
ConnectAddrs::None => unreachable!("none variant already checked"), Pin::new(fut).poll(cx).map_ok(ConnectOutput::Resolved)
}
ConnectAddrs::One(addr) => TcpConnectorResponse::Response { ConnectFut::Connect(ref mut fut) => {
req: Some(req), Pin::new(fut).poll(cx).map_ok(ConnectOutput::Connected)
port,
local_addr,
addrs: None,
stream: ReusableBoxFuture::new(connect(addr, local_addr)),
},
// when resolver returns multiple socket addr for request they would be popped from
// front end of queue and returns with the first successful tcp connection.
ConnectAddrs::Multi(mut addrs) => {
let addr = addrs.pop_front().unwrap();
TcpConnectorResponse::Response {
req: Some(req),
port,
local_addr,
addrs: Some(addrs),
stream: ReusableBoxFuture::new(connect(addr, local_addr)),
}
} }
} }
} }
} }
impl<T: Address> Future for TcpConnectorResponse<T> { pub struct ConnectServiceResponse<R: Host> {
type Output = Result<Connection<T, TcpStream>, ConnectError>; fut: ConnectFut<R>,
tcp: TcpConnectorService,
}
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { impl<R: Host> Future for ConnectServiceResponse<R> {
match self.get_mut() { type Output = Result<Connection<R, TcpStream>, ConnectError>;
TcpConnectorResponse::Error(err) => Poll::Ready(Err(err.take().unwrap())),
TcpConnectorResponse::Response { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
req, loop {
port, match ready!(self.fut.poll_connect(cx))? {
local_addr, ConnectOutput::Resolved(res) => {
addrs, self.fut = ConnectFut::Connect(self.tcp.call(res));
stream,
} => loop {
match ready!(stream.poll(cx)) {
Ok(sock) => {
let req = req.take().unwrap();
trace!(
"TCP connector: successfully connected to {:?} - {:?}",
req.hostname(),
sock.peer_addr()
);
return Poll::Ready(Ok(Connection::new(sock, req)));
}
Err(err) => {
trace!(
"TCP connector: failed to connect to {:?} port: {}",
req.as_ref().unwrap().hostname(),
port,
);
if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) {
stream.set(connect(addr, *local_addr));
} else {
return Poll::Ready(Err(ConnectError::Io(err)));
}
}
} }
}, ConnectOutput::Connected(res) => return Poll::Ready(Ok(res)),
}
} }
} }
} }
async fn connect(addr: SocketAddr, local_addr: Option<IpAddr>) -> io::Result<TcpStream> {
// use local addr if connect asks for it.
match local_addr {
Some(ip_addr) => {
let socket = match ip_addr {
IpAddr::V4(ip_addr) => {
let socket = TcpSocket::new_v4()?;
let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0));
socket.bind(addr)?;
socket
}
IpAddr::V6(ip_addr) => {
let socket = TcpSocket::new_v6()?;
let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0));
socket.bind(addr)?;
socket
}
};
socket.connect(addr).await
}
None => TcpStream::connect(addr).await,
}
}

View File

@ -1,15 +1,16 @@
use std::io; use std::{error::Error, io};
use derive_more::Display; use derive_more::Display;
/// Errors that can result from using a connector service.
#[derive(Debug, Display)] #[derive(Debug, Display)]
pub enum ConnectError { pub enum ConnectError {
/// Failed to resolve the hostname /// Failed to resolve the hostname
#[display(fmt = "Failed resolving hostname: {}", _0)] #[display(fmt = "Failed resolving hostname")]
Resolver(Box<dyn std::error::Error>), Resolver(Box<dyn std::error::Error>),
/// No dns records /// No DNS records
#[display(fmt = "No dns records found for the input")] #[display(fmt = "No DNS records found for the input")]
NoRecords, NoRecords,
/// Invalid input /// Invalid input
@ -23,3 +24,13 @@ pub enum ConnectError {
#[display(fmt = "{}", _0)] #[display(fmt = "{}", _0)]
Io(io::Error), Io(io::Error),
} }
impl Error for ConnectError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Resolver(err) => Some(&**err),
Self::Io(err) => Some(err),
Self::NoRecords | Self::InvalidInput | Self::Unresolved => None,
}
}
}

View File

@ -0,0 +1,71 @@
//! The [`Host`] trait.
/// An interface for types where host parts (hostname and port) can be derived.
///
/// The [WHATWG URL Standard] defines the terminology used for this trait and its methods.
///
/// ```plain
/// +------------------------+
/// | host |
/// +-----------------+------+
/// | hostname | port |
/// | | |
/// | sub.example.com : 8080 |
/// +-----------------+------+
/// ```
///
/// [WHATWG URL Standard]: https://url.spec.whatwg.org/
pub trait Host: Unpin + 'static {
/// Extract hostname.
fn hostname(&self) -> &str;
/// Extract optional port.
fn port(&self) -> Option<u16> {
None
}
}
impl Host for String {
fn hostname(&self) -> &str {
self.split_once(':')
.map(|(hostname, _)| hostname)
.unwrap_or(self)
}
fn port(&self) -> Option<u16> {
self.split_once(':').and_then(|(_, port)| port.parse().ok())
}
}
impl Host for &'static str {
fn hostname(&self) -> &str {
self.split_once(':')
.map(|(hostname, _)| hostname)
.unwrap_or(self)
}
fn port(&self) -> Option<u16> {
self.split_once(':').and_then(|(_, port)| port.parse().ok())
}
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! assert_connection_info_eq {
($req:expr, $hostname:expr, $port:expr) => {{
assert_eq!($req.hostname(), $hostname);
assert_eq!($req.port(), $port);
}};
}
#[test]
fn host_parsing() {
assert_connection_info_eq!("example.com", "example.com", None);
assert_connection_info_eq!("example.com:8080", "example.com", Some(8080));
assert_connection_info_eq!("example:8080", "example", Some(8080));
assert_connection_info_eq!("example.com:false", "example.com", None);
assert_connection_info_eq!("example.com:false:false", "example.com", None);
}
}

View File

@ -0,0 +1,249 @@
//! Connection info struct.
use std::{
collections::VecDeque,
fmt,
iter::{self, FromIterator as _},
mem,
net::{IpAddr, SocketAddr},
};
use super::{
connect_addrs::{ConnectAddrs, ConnectAddrsIter},
Host,
};
/// Connection request information.
///
/// May contain known/pre-resolved socket address(es) or a host that needs resolving with DNS.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ConnectInfo<R> {
pub(crate) request: R,
pub(crate) port: u16,
pub(crate) addr: ConnectAddrs,
pub(crate) local_addr: Option<IpAddr>,
}
impl<R: Host> ConnectInfo<R> {
/// Constructs new connection info using a request.
pub fn new(request: R) -> ConnectInfo<R> {
let port = request.port();
ConnectInfo {
request,
port: port.unwrap_or(0),
addr: ConnectAddrs::None,
local_addr: None,
}
}
/// Constructs new connection info from request and known socket address.
///
/// Since socket address is known, [`Connector`](super::Connector) will skip the DNS
/// resolution step.
pub fn with_addr(request: R, addr: SocketAddr) -> ConnectInfo<R> {
ConnectInfo {
request,
port: 0,
addr: ConnectAddrs::One(addr),
local_addr: None,
}
}
/// Set connection port.
///
/// If request provided a port, this will override it.
pub fn set_port(mut self, port: u16) -> Self {
self.port = port;
self
}
/// Set connection socket address.
pub fn set_addr(mut self, addr: impl Into<Option<SocketAddr>>) -> Self {
self.addr = ConnectAddrs::from(addr.into());
self
}
/// Set list of addresses.
pub fn set_addrs<I>(mut self, addrs: I) -> Self
where
I: IntoIterator<Item = SocketAddr>,
{
let mut addrs = VecDeque::from_iter(addrs);
self.addr = if addrs.len() < 2 {
ConnectAddrs::from(addrs.pop_front())
} else {
ConnectAddrs::Multi(addrs)
};
self
}
/// Set local address to connection with.
///
/// Useful in situations where the IP address bound to a particular network interface is known.
/// This would make sure the socket is opened through that interface.
pub fn set_local_addr(mut self, addr: impl Into<IpAddr>) -> Self {
self.local_addr = Some(addr.into());
self
}
/// Returns a reference to the connection request.
pub fn request(&self) -> &R {
&self.request
}
/// Returns request hostname.
pub fn hostname(&self) -> &str {
self.request.hostname()
}
/// Returns request port.
pub fn port(&self) -> u16 {
self.request.port().unwrap_or(self.port)
}
/// Get borrowed iterator of resolved request addresses.
///
/// # Examples
/// ```
/// # use std::net::SocketAddr;
/// # use actix_tls::connect::ConnectInfo;
/// let addr = SocketAddr::from(([127, 0, 0, 1], 4242));
///
/// let conn = ConnectInfo::new("localhost");
/// let mut addrs = conn.addrs();
/// assert!(addrs.next().is_none());
///
/// let conn = ConnectInfo::with_addr("localhost", addr);
/// let mut addrs = conn.addrs();
/// assert_eq!(addrs.next().unwrap(), addr);
/// ```
pub fn addrs(
&self,
) -> impl Iterator<Item = SocketAddr>
+ ExactSizeIterator
+ iter::FusedIterator
+ Clone
+ fmt::Debug
+ '_ {
match self.addr {
ConnectAddrs::None => ConnectAddrsIter::None,
ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr),
ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()),
}
}
/// Take owned iterator resolved request addresses.
///
/// # Examples
/// ```
/// # use std::net::SocketAddr;
/// # use actix_tls::connect::ConnectInfo;
/// let addr = SocketAddr::from(([127, 0, 0, 1], 4242));
///
/// let mut conn = ConnectInfo::new("localhost");
/// let mut addrs = conn.take_addrs();
/// assert!(addrs.next().is_none());
///
/// let mut conn = ConnectInfo::with_addr("localhost", addr);
/// let mut addrs = conn.take_addrs();
/// assert_eq!(addrs.next().unwrap(), addr);
/// ```
pub fn take_addrs(
&mut self,
) -> impl Iterator<Item = SocketAddr>
+ ExactSizeIterator
+ iter::FusedIterator
+ Clone
+ fmt::Debug
+ 'static {
match mem::take(&mut self.addr) {
ConnectAddrs::None => ConnectAddrsIter::None,
ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr),
ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()),
}
}
}
impl<R: Host> From<R> for ConnectInfo<R> {
fn from(addr: R) -> Self {
ConnectInfo::new(addr)
}
}
impl<R: Host> fmt::Display for ConnectInfo<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.hostname(), self.port())
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use super::*;
#[test]
fn test_addr_iter_multi() {
let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080));
let unspecified = SocketAddr::from((IpAddr::from(Ipv4Addr::UNSPECIFIED), 8080));
let mut addrs = VecDeque::new();
addrs.push_back(localhost);
addrs.push_back(unspecified);
let mut iter = ConnectAddrsIter::Multi(addrs.iter());
assert_eq!(iter.next(), Some(localhost));
assert_eq!(iter.next(), Some(unspecified));
assert_eq!(iter.next(), None);
let mut iter = ConnectAddrsIter::MultiOwned(addrs.into_iter());
assert_eq!(iter.next(), Some(localhost));
assert_eq!(iter.next(), Some(unspecified));
assert_eq!(iter.next(), None);
}
#[test]
fn test_addr_iter_single() {
let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080));
let mut iter = ConnectAddrsIter::One(localhost);
assert_eq!(iter.next(), Some(localhost));
assert_eq!(iter.next(), None);
let mut iter = ConnectAddrsIter::None;
assert_eq!(iter.next(), None);
}
#[test]
fn test_local_addr() {
let conn = ConnectInfo::new("hello").set_local_addr([127, 0, 0, 1]);
assert_eq!(
conn.local_addr.unwrap(),
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
)
}
#[test]
fn request_ref() {
let conn = ConnectInfo::new("hello");
assert_eq!(conn.request(), &"hello")
}
#[test]
fn set_connect_addr_into_option() {
let addr = SocketAddr::from(([127, 0, 0, 1], 4242));
let conn = ConnectInfo::new("hello").set_addr(None);
let mut addrs = conn.addrs();
assert!(addrs.next().is_none());
let conn = ConnectInfo::new("hello").set_addr(addr);
let mut addrs = conn.addrs();
assert_eq!(addrs.next().unwrap(), addr);
let conn = ConnectInfo::new("hello").set_addr(Some(addr));
let mut addrs = conn.addrs();
assert_eq!(addrs.next().unwrap(), addr);
}
}

View File

@ -1,35 +1,46 @@
//! TCP and TLS connector services. //! TCP and TLS connector services.
//! //!
//! # Stages of the TCP connector service: //! # Stages of the TCP connector service:
//! - Resolve [`Address`] with given [`Resolver`] and collect list of socket addresses. //! 1. Resolve [`Host`] (if needed) with given [`Resolver`] and collect list of socket addresses.
//! - Establish TCP connection and return [`TcpStream`]. //! 1. Establish TCP connection and return [`TcpStream`].
//! //!
//! # Stages of TLS connector services: //! # Stages of TLS connector services:
//! - Establish [`TcpStream`] with connector service. //! 1. Resolve DNS and establish a [`TcpStream`] with the TCP connector service.
//! - Wrap the stream and perform connect handshake with remote peer. //! 1. Wrap the stream and perform connect handshake with remote peer.
//! - Return certain stream type that impls `AsyncRead` and `AsyncWrite`. //! 1. Return wrapped stream type that implements `AsyncRead` and `AsyncWrite`.
//! //!
//! [`TcpStream`]: actix_rt::net::TcpStream //! [`TcpStream`]: actix_rt::net::TcpStream
#[allow(clippy::module_inception)] mod connect_addrs;
mod connect; mod connection;
mod connector; mod connector;
mod error; mod error;
mod host;
mod info;
mod resolve; mod resolve;
mod service; mod resolver;
pub mod tls; pub mod tcp;
// TODO: remove `ssl` mod re-export in next break change
#[doc(hidden)]
pub use tls as ssl;
mod tcp;
#[cfg(feature = "uri")] #[cfg(feature = "uri")]
#[cfg_attr(docsrs, doc(cfg(feature = "uri")))]
mod uri; mod uri;
pub use self::connect::{Address, Connect, Connection}; #[cfg(feature = "openssl")]
pub use self::connector::{TcpConnector, TcpConnectorFactory}; #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
pub mod openssl;
#[cfg(feature = "rustls")]
#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
pub mod rustls;
#[cfg(feature = "native-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))]
pub mod native_tls;
pub use self::connection::Connection;
pub use self::connector::{Connector, ConnectorService};
pub use self::error::ConnectError; pub use self::error::ConnectError;
pub use self::resolve::{Resolve, Resolver, ResolverFactory}; pub use self::host::Host;
pub use self::service::{ConnectService, ConnectServiceFactory}; pub use self::info::ConnectInfo;
pub use self::tcp::{ pub use self::resolve::Resolve;
default_connector, default_connector_factory, new_connector, new_connector_factory, pub use self::resolver::{Resolver, ResolverService};
};

View File

@ -0,0 +1,90 @@
//! Native-TLS based connector service.
//!
//! See [`TlsConnector`] for main connector service factory docs.
use std::io;
use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory};
use actix_utils::future::{ok, Ready};
use futures_core::future::LocalBoxFuture;
use log::trace;
use tokio_native_tls::{
native_tls::TlsConnector as NativeTlsConnector, TlsConnector as TokioNativeTlsConnector,
TlsStream,
};
use crate::connect::{Connection, Host};
pub mod reexports {
//! Re-exports from `native-tls` that are useful for connectors.
pub use tokio_native_tls::native_tls::TlsConnector;
}
/// Connector service and factory using `native-tls`.
#[derive(Clone)]
pub struct TlsConnector {
connector: TokioNativeTlsConnector,
}
impl TlsConnector {
/// Constructs new connector service from a `native-tls` connector.
///
/// This type is it's own service factory, so it can be used in that setting, too.
pub fn new(connector: NativeTlsConnector) -> Self {
Self {
connector: TokioNativeTlsConnector::from(connector),
}
}
}
impl<R: Host, IO> ServiceFactory<Connection<R, IO>> for TlsConnector
where
IO: ActixStream + 'static,
{
type Response = Connection<R, TlsStream<IO>>;
type Error = io::Error;
type Config = ();
type Service = Self;
type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(self.clone())
}
}
/// The `native-tls` connector is both it's ServiceFactory and Service impl type.
/// As the factory and service share the same type and state.
impl<R, IO> Service<Connection<R, IO>> for TlsConnector
where
R: Host,
IO: ActixStream + 'static,
{
type Response = Connection<R, TlsStream<IO>>;
type Error = io::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
actix_service::always_ready!();
fn call(&self, stream: Connection<R, IO>) -> Self::Future {
let (io, stream) = stream.replace_io(());
let connector = self.connector.clone();
Box::pin(async move {
trace!("SSL Handshake start for: {:?}", stream.hostname());
connector
.connect(stream.hostname(), io)
.await
.map(|res| {
trace!("SSL Handshake success: {:?}", stream.hostname());
stream.replace_io(res).1
})
.map_err(|e| {
trace!("SSL Handshake error: {:?}", e);
io::Error::new(io::ErrorKind::Other, format!("{}", e))
})
})
}
}

View File

@ -0,0 +1,146 @@
//! OpenSSL based connector service.
//!
//! See [`TlsConnector`] for main connector service factory docs.
use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory};
use actix_utils::future::{ok, Ready};
use futures_core::ready;
use log::trace;
use openssl::ssl::SslConnector;
use tokio_openssl::SslStream;
use crate::connect::{Connection, Host};
pub mod reexports {
//! Re-exports from `openssl` that are useful for connectors.
pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
}
/// Connector service factory using `openssl`.
pub struct TlsConnector {
connector: SslConnector,
}
impl TlsConnector {
/// Constructs new connector service factory from an `openssl` connector.
pub fn new(connector: SslConnector) -> Self {
TlsConnector { connector }
}
/// Constructs new connector service from an `openssl` connector.
pub fn service(connector: SslConnector) -> TlsConnectorService {
TlsConnectorService { connector }
}
}
impl Clone for TlsConnector {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<R, IO> ServiceFactory<Connection<R, IO>> for TlsConnector
where
R: Host,
IO: ActixStream + 'static,
{
type Response = Connection<R, SslStream<IO>>;
type Error = io::Error;
type Config = ();
type Service = TlsConnectorService;
type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(TlsConnectorService {
connector: self.connector.clone(),
})
}
}
/// Connector service using `openssl`.
pub struct TlsConnectorService {
connector: SslConnector,
}
impl Clone for TlsConnectorService {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<R, IO> Service<Connection<R, IO>> for TlsConnectorService
where
R: Host,
IO: ActixStream,
{
type Response = Connection<R, SslStream<IO>>;
type Error = io::Error;
type Future = ConnectFut<R, IO>;
actix_service::always_ready!();
fn call(&self, stream: Connection<R, IO>) -> Self::Future {
trace!("SSL Handshake start for: {:?}", stream.hostname());
let (io, stream) = stream.replace_io(());
let host = stream.hostname();
let config = self
.connector
.configure()
.expect("SSL connect configuration was invalid.");
let ssl = config
.into_ssl(host)
.expect("SSL connect configuration was invalid.");
ConnectFut {
io: Some(SslStream::new(ssl, io).unwrap()),
stream: Some(stream),
}
}
}
/// Connect future for OpenSSL service.
#[doc(hidden)]
pub struct ConnectFut<R, IO> {
io: Option<SslStream<IO>>,
stream: Option<Connection<R, ()>>,
}
impl<R: Host, IO> Future for ConnectFut<R, IO>
where
R: Host,
IO: ActixStream,
{
type Output = Result<Connection<R, SslStream<IO>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match ready!(Pin::new(this.io.as_mut().unwrap()).poll_connect(cx)) {
Ok(_) => {
let stream = this.stream.take().unwrap();
trace!("SSL Handshake success: {:?}", stream.hostname());
Poll::Ready(Ok(stream.replace_io(this.io.take().unwrap()).1))
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
}
}
}

207
actix-tls/src/connect/resolve.rs Executable file → Normal file
View File

@ -1,61 +1,12 @@
use std::{ //! The [`Resolve`] trait.
future::Future,
io,
net::SocketAddr,
pin::Pin,
rc::Rc,
task::{Context, Poll},
vec::IntoIter,
};
use actix_rt::task::{spawn_blocking, JoinHandle}; use std::{error::Error as StdError, net::SocketAddr};
use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready};
use log::trace;
use super::connect::{Address, Connect}; use futures_core::future::LocalBoxFuture;
use super::error::ConnectError;
/// DNS Resolver Service Factory /// Custom async DNS resolvers.
#[derive(Clone)]
pub struct ResolverFactory {
resolver: Resolver,
}
impl ResolverFactory {
pub fn new(resolver: Resolver) -> Self {
Self { resolver }
}
pub fn service(&self) -> Resolver {
self.resolver.clone()
}
}
impl<T: Address> ServiceFactory<Connect<T>> for ResolverFactory {
type Response = Connect<T>;
type Error = ConnectError;
type Config = ();
type Service = Resolver;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
let service = self.resolver.clone();
Box::pin(async { Ok(service) })
}
}
/// DNS Resolver Service
#[derive(Clone)]
pub enum Resolver {
Default,
Custom(Rc<dyn Resolve>),
}
/// An interface for custom async DNS resolvers.
/// ///
/// # Usage /// # Examples
/// ``` /// ```
/// use std::net::SocketAddr; /// use std::net::SocketAddr;
/// ///
@ -89,155 +40,23 @@ pub enum Resolver {
/// } /// }
/// } /// }
/// ///
/// let resolver = MyResolver { /// let my_resolver = MyResolver {
/// trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), /// trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(),
/// }; /// };
/// ///
/// // construct custom resolver /// // wrap custom resolver
/// let resolver = Resolver::new_custom(resolver); /// let resolver = Resolver::custom(my_resolver);
///
/// // pass custom resolver to connector builder.
/// // connector would then be usable as a service or awc's connector.
/// let connector = actix_tls::connect::new_connector::<&str>(resolver.clone());
/// ///
/// // resolver can be passed to connector factory where returned service factory /// // resolver can be passed to connector factory where returned service factory
/// // can be used to construct new connector services. /// // can be used to construct new connector services for use in clients
/// let factory = actix_tls::connect::new_connector_factory::<&str>(resolver); /// let factory = actix_tls::connect::Connector::new(resolver);
/// let connector = factory.service();
/// ``` /// ```
pub trait Resolve { pub trait Resolve {
/// Given DNS lookup information, returns a future that completes with socket information.
fn lookup<'a>( fn lookup<'a>(
&'a self, &'a self,
host: &'a str, host: &'a str,
port: u16, port: u16,
) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>>; ) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn StdError>>>;
}
impl Resolver {
/// Constructor for custom Resolve trait object and use it as resolver.
pub fn new_custom(resolver: impl Resolve + 'static) -> Self {
Self::Custom(Rc::new(resolver))
}
// look up with default resolver variant.
fn look_up<T: Address>(req: &Connect<T>) -> JoinHandle<io::Result<IntoIter<SocketAddr>>> {
let host = req.hostname();
// TODO: Connect should always return host with port if possible.
let host = if req
.hostname()
.splitn(2, ':')
.last()
.and_then(|p| p.parse::<u16>().ok())
.map(|p| p == req.port())
.unwrap_or(false)
{
host.to_string()
} else {
format!("{}:{}", host, req.port())
};
// run blocking DNS lookup in thread pool
spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host))
}
}
impl<T: Address> Service<Connect<T>> for Resolver {
type Response = Connect<T>;
type Error = ConnectError;
type Future = ResolverFuture<T>;
actix_service::always_ready!();
fn call(&self, req: Connect<T>) -> Self::Future {
if req.addr.is_some() {
ResolverFuture::Connected(Some(req))
} else if let Ok(ip) = req.hostname().parse() {
let addr = SocketAddr::new(ip, req.port());
let req = req.set_addr(Some(addr));
ResolverFuture::Connected(Some(req))
} else {
trace!("DNS resolver: resolving host {:?}", req.hostname());
match self {
Self::Default => {
let fut = Self::look_up(&req);
ResolverFuture::LookUp(fut, Some(req))
}
Self::Custom(resolver) => {
let resolver = Rc::clone(resolver);
ResolverFuture::LookupCustom(Box::pin(async move {
let addrs = resolver
.lookup(req.hostname(), req.port())
.await
.map_err(ConnectError::Resolver)?;
let req = req.set_addrs(addrs);
if req.addr.is_none() {
Err(ConnectError::NoRecords)
} else {
Ok(req)
}
}))
}
}
}
}
}
pub enum ResolverFuture<T: Address> {
Connected(Option<Connect<T>>),
LookUp(
JoinHandle<io::Result<IntoIter<SocketAddr>>>,
Option<Connect<T>>,
),
LookupCustom(LocalBoxFuture<'static, Result<Connect<T>, ConnectError>>),
}
impl<T: Address> Future for ResolverFuture<T> {
type Output = Result<Connect<T>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::Connected(conn) => Poll::Ready(Ok(conn
.take()
.expect("ResolverFuture polled after finished"))),
Self::LookUp(fut, req) => {
let res = match ready!(Pin::new(fut).poll(cx)) {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => Err(ConnectError::Resolver(Box::new(e))),
Err(e) => Err(ConnectError::Io(e.into())),
};
let req = req.take().unwrap();
let addrs = res.map_err(|err| {
trace!(
"DNS resolver: failed to resolve host {:?} err: {:?}",
req.hostname(),
err
);
err
})?;
let req = req.set_addrs(addrs);
trace!(
"DNS resolver: host {:?} resolved to {:?}",
req.hostname(),
req.addrs()
);
if req.addr.is_none() {
Poll::Ready(Err(ConnectError::NoRecords))
} else {
Poll::Ready(Ok(req))
}
}
Self::LookupCustom(fut) => fut.as_mut().poll(cx),
}
}
} }

View File

@ -0,0 +1,201 @@
use std::{
future::Future,
io,
net::SocketAddr,
pin::Pin,
rc::Rc,
task::{Context, Poll},
vec::IntoIter,
};
use actix_rt::task::{spawn_blocking, JoinHandle};
use actix_service::{Service, ServiceFactory};
use actix_utils::future::{ok, Ready};
use futures_core::{future::LocalBoxFuture, ready};
use log::trace;
use super::{ConnectError, ConnectInfo, Host, Resolve};
/// DNS resolver service factory.
#[derive(Clone, Default)]
pub struct Resolver {
resolver: ResolverService,
}
impl Resolver {
/// Constructs a new resolver factory with a custom resolver.
pub fn custom(resolver: impl Resolve + 'static) -> Self {
Self {
resolver: ResolverService::custom(resolver),
}
}
/// Returns a new resolver service.
pub fn service(&self) -> ResolverService {
self.resolver.clone()
}
}
impl<R: Host> ServiceFactory<ConnectInfo<R>> for Resolver {
type Response = ConnectInfo<R>;
type Error = ConnectError;
type Config = ();
type Service = ResolverService;
type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(self.resolver.clone())
}
}
#[derive(Clone)]
enum ResolverKind {
/// Built-in DNS resolver.
///
/// See [`std::net::ToSocketAddrs`] trait.
Default,
/// Custom, user-provided DNS resolver.
Custom(Rc<dyn Resolve>),
}
impl Default for ResolverKind {
fn default() -> Self {
Self::Default
}
}
/// DNS resolver service.
#[derive(Clone, Default)]
pub struct ResolverService {
kind: ResolverKind,
}
impl ResolverService {
/// Constructor for custom Resolve trait object and use it as resolver.
pub fn custom(resolver: impl Resolve + 'static) -> Self {
Self {
kind: ResolverKind::Custom(Rc::new(resolver)),
}
}
/// Resolve DNS with default resolver.
fn default_lookup<R: Host>(
req: &ConnectInfo<R>,
) -> JoinHandle<io::Result<IntoIter<SocketAddr>>> {
// reconstruct host; concatenate hostname and port together
let host = format!("{}:{}", req.hostname(), req.port());
// run blocking DNS lookup in thread pool since DNS lookups can take upwards of seconds on
// some platforms if conditions are poor and OS-level cache is not populated
spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host))
}
}
impl<R: Host> Service<ConnectInfo<R>> for ResolverService {
type Response = ConnectInfo<R>;
type Error = ConnectError;
type Future = ResolverFut<R>;
actix_service::always_ready!();
fn call(&self, req: ConnectInfo<R>) -> Self::Future {
if req.addr.is_resolved() {
// socket address(es) already resolved; return existing connection request
ResolverFut::Resolved(Some(req))
} else if let Ok(ip) = req.hostname().parse() {
// request hostname is valid ip address; add address to request and return
let addr = SocketAddr::new(ip, req.port());
let req = req.set_addr(Some(addr));
ResolverFut::Resolved(Some(req))
} else {
trace!("DNS resolver: resolving host {:?}", req.hostname());
match &self.kind {
ResolverKind::Default => {
let fut = Self::default_lookup(&req);
ResolverFut::LookUp(fut, Some(req))
}
ResolverKind::Custom(resolver) => {
let resolver = Rc::clone(resolver);
ResolverFut::LookupCustom(Box::pin(async move {
let addrs = resolver
.lookup(req.hostname(), req.port())
.await
.map_err(ConnectError::Resolver)?;
let req = req.set_addrs(addrs);
if req.addr.is_unresolved() {
Err(ConnectError::NoRecords)
} else {
Ok(req)
}
}))
}
}
}
}
}
/// Future for resolver service.
#[doc(hidden)]
pub enum ResolverFut<R: Host> {
Resolved(Option<ConnectInfo<R>>),
LookUp(
JoinHandle<io::Result<IntoIter<SocketAddr>>>,
Option<ConnectInfo<R>>,
),
LookupCustom(LocalBoxFuture<'static, Result<ConnectInfo<R>, ConnectError>>),
}
impl<R: Host> Future for ResolverFut<R> {
type Output = Result<ConnectInfo<R>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::Resolved(conn) => Poll::Ready(Ok(conn
.take()
.expect("ResolverFuture polled after finished"))),
Self::LookUp(fut, req) => {
let res = match ready!(Pin::new(fut).poll(cx)) {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => Err(ConnectError::Resolver(Box::new(e))),
Err(e) => Err(ConnectError::Io(e.into())),
};
let req = req.take().unwrap();
let addrs = res.map_err(|err| {
trace!(
"DNS resolver: failed to resolve host {:?} err: {:?}",
req.hostname(),
err
);
err
})?;
let req = req.set_addrs(addrs);
trace!(
"DNS resolver: host {:?} resolved to {:?}",
req.hostname(),
req.addrs()
);
if req.addr.is_unresolved() {
Poll::Ready(Err(ConnectError::NoRecords))
} else {
Poll::Ready(Ok(req))
}
}
Self::LookupCustom(fut) => fut.as_mut().poll(cx),
}
}
}

View File

@ -0,0 +1,148 @@
//! Rustls based connector service.
//!
//! See [`TlsConnector`] for main connector service factory docs.
use std::{
convert::TryFrom,
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory};
use actix_utils::future::{ok, Ready};
use futures_core::ready;
use log::trace;
use tokio_rustls::rustls::{client::ServerName, OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
use tokio_rustls::{Connect as RustlsConnect, TlsConnector as RustlsTlsConnector};
use webpki_roots::TLS_SERVER_ROOTS;
use crate::connect::{Connection, Host};
pub mod reexports {
//! Re-exports from `rustls` and `webpki_roots` that are useful for connectors.
pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
pub use webpki_roots::TLS_SERVER_ROOTS;
}
/// Returns standard root certificates from `webpki-roots` crate as a rustls certificate store.
pub fn webpki_roots_cert_store() -> RootCertStore {
let mut root_certs = RootCertStore::empty();
for cert in TLS_SERVER_ROOTS.0 {
let cert = OwnedTrustAnchor::from_subject_spki_name_constraints(
cert.subject,
cert.spki,
cert.name_constraints,
);
let certs = vec![cert].into_iter();
root_certs.add_server_trust_anchors(certs);
}
root_certs
}
/// Connector service factory using `rustls`.
#[derive(Clone)]
pub struct TlsConnector {
connector: Arc<ClientConfig>,
}
impl TlsConnector {
/// Constructs new connector service factory from a `rustls` client configuration.
pub fn new(connector: Arc<ClientConfig>) -> Self {
TlsConnector { connector }
}
/// Constructs new connector service from a `rustls` client configuration.
pub fn service(connector: Arc<ClientConfig>) -> TlsConnectorService {
TlsConnectorService { connector }
}
}
impl<R, IO> ServiceFactory<Connection<R, IO>> for TlsConnector
where
R: Host,
IO: ActixStream + 'static,
{
type Response = Connection<R, TlsStream<IO>>;
type Error = io::Error;
type Config = ();
type Service = TlsConnectorService;
type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(TlsConnectorService {
connector: self.connector.clone(),
})
}
}
/// Connector service using `rustls`.
#[derive(Clone)]
pub struct TlsConnectorService {
connector: Arc<ClientConfig>,
}
impl<R, IO> Service<Connection<R, IO>> for TlsConnectorService
where
R: Host,
IO: ActixStream,
{
type Response = Connection<R, TlsStream<IO>>;
type Error = io::Error;
type Future = ConnectFut<R, IO>;
actix_service::always_ready!();
fn call(&self, connection: Connection<R, IO>) -> Self::Future {
trace!("SSL Handshake start for: {:?}", connection.hostname());
let (stream, connection) = connection.replace_io(());
match ServerName::try_from(connection.hostname()) {
Ok(host) => ConnectFut::Future {
connect: RustlsTlsConnector::from(self.connector.clone()).connect(host, stream),
connection: Some(connection),
},
Err(_) => ConnectFut::InvalidDns,
}
}
}
/// Connect future for Rustls service.
#[doc(hidden)]
pub enum ConnectFut<R, IO> {
/// See issue <https://github.com/briansmith/webpki/issues/54>
InvalidDns,
Future {
connect: RustlsConnect<IO>,
connection: Option<Connection<R, ()>>,
},
}
impl<R, IO> Future for ConnectFut<R, IO>
where
R: Host,
IO: ActixStream,
{
type Output = Result<Connection<R, TlsStream<IO>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::InvalidDns => Poll::Ready(Err(
io::Error::new(io::ErrorKind::Other, "rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54")
)),
Self::Future { connect, connection } => {
let stream = ready!(Pin::new(connect).poll(cx))?;
let connection = connection.take().unwrap();
trace!("SSL Handshake success: {:?}", connection.hostname());
Poll::Ready(Ok(connection.replace_io(stream).1))
}
}
}
}

View File

@ -1,129 +0,0 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready};
use super::connect::{Address, Connect, Connection};
use super::connector::{TcpConnector, TcpConnectorFactory};
use super::error::ConnectError;
use super::resolve::{Resolver, ResolverFactory};
pub struct ConnectServiceFactory {
tcp: TcpConnectorFactory,
resolver: ResolverFactory,
}
impl ConnectServiceFactory {
/// Construct new ConnectService factory
pub fn new(resolver: Resolver) -> Self {
ConnectServiceFactory {
tcp: TcpConnectorFactory,
resolver: ResolverFactory::new(resolver),
}
}
/// Construct new service
pub fn service(&self) -> ConnectService {
ConnectService {
tcp: self.tcp.service(),
resolver: self.resolver.service(),
}
}
}
impl Clone for ConnectServiceFactory {
fn clone(&self) -> Self {
ConnectServiceFactory {
tcp: self.tcp,
resolver: self.resolver.clone(),
}
}
}
impl<T: Address> ServiceFactory<Connect<T>> for ConnectServiceFactory {
type Response = Connection<T, TcpStream>;
type Error = ConnectError;
type Config = ();
type Service = ConnectService;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
let service = self.service();
Box::pin(async { Ok(service) })
}
}
#[derive(Clone)]
pub struct ConnectService {
tcp: TcpConnector,
resolver: Resolver,
}
impl<T: Address> Service<Connect<T>> for ConnectService {
type Response = Connection<T, TcpStream>;
type Error = ConnectError;
type Future = ConnectServiceResponse<T>;
actix_service::always_ready!();
fn call(&self, req: Connect<T>) -> Self::Future {
ConnectServiceResponse {
fut: ConnectFuture::Resolve(self.resolver.call(req)),
tcp: self.tcp,
}
}
}
// helper enum to generic over futures of resolve and connect phase.
pub(crate) enum ConnectFuture<T: Address> {
Resolve(<Resolver as Service<Connect<T>>>::Future),
Connect(<TcpConnector as Service<Connect<T>>>::Future),
}
// helper enum to contain the future output of ConnectFuture
pub(crate) enum ConnectOutput<T: Address> {
Resolved(Connect<T>),
Connected(Connection<T, TcpStream>),
}
impl<T: Address> ConnectFuture<T> {
fn poll_connect(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<ConnectOutput<T>, ConnectError>> {
match self {
ConnectFuture::Resolve(ref mut fut) => {
Pin::new(fut).poll(cx).map_ok(ConnectOutput::Resolved)
}
ConnectFuture::Connect(ref mut fut) => {
Pin::new(fut).poll(cx).map_ok(ConnectOutput::Connected)
}
}
}
}
pub struct ConnectServiceResponse<T: Address> {
fut: ConnectFuture<T>,
tcp: TcpConnector,
}
impl<T: Address> Future for ConnectServiceResponse<T> {
type Output = Result<Connection<T, TcpStream>, ConnectError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match ready!(self.fut.poll_connect(cx))? {
ConnectOutput::Resolved(res) => {
self.fut = ConnectFuture::Connect(self.tcp.call(res));
}
ConnectOutput::Connected(res) => return Poll::Ready(Ok(res)),
}
}
}
}

View File

@ -1,43 +1,202 @@
use actix_rt::net::TcpStream; //! TCP connector service.
//!
//! See [`TcpConnector`] for main connector service factory docs.
use std::{
collections::VecDeque,
future::Future,
io,
net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6},
pin::Pin,
task::{Context, Poll},
};
use actix_rt::net::{TcpSocket, TcpStream};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::future::{ok, Ready};
use futures_core::ready;
use log::{error, trace};
use tokio_util::sync::ReusableBoxFuture;
use super::{Address, Connect, ConnectError, ConnectServiceFactory, Connection, Resolver}; use super::{connect_addrs::ConnectAddrs, error::ConnectError, ConnectInfo, Connection, Host};
/// Create TCP connector service. /// TCP connector service factory.
pub fn new_connector<T: Address + 'static>( #[derive(Debug, Copy, Clone)]
resolver: Resolver, pub struct TcpConnector;
) -> impl Service<Connect<T>, Response = Connection<T, TcpStream>, Error = ConnectError> + Clone
{ impl TcpConnector {
ConnectServiceFactory::new(resolver).service() /// Returns a new TCP connector service.
pub fn service(&self) -> TcpConnectorService {
TcpConnectorService
}
} }
/// Create TCP connector service factory. impl<R: Host> ServiceFactory<ConnectInfo<R>> for TcpConnector {
pub fn new_connector_factory<T: Address + 'static>( type Response = Connection<R, TcpStream>;
resolver: Resolver, type Error = ConnectError;
) -> impl ServiceFactory< type Config = ();
Connect<T>, type Service = TcpConnectorService;
Config = (), type InitError = ();
Response = Connection<T, TcpStream>, type Future = Ready<Result<Self::Service, Self::InitError>>;
Error = ConnectError,
InitError = (), fn new_service(&self, _: ()) -> Self::Future {
> + Clone { ok(self.service())
ConnectServiceFactory::new(resolver) }
} }
/// Create TCP connector service with default parameters. /// TCP connector service.
pub fn default_connector<T: Address + 'static>( #[derive(Debug, Copy, Clone)]
) -> impl Service<Connect<T>, Response = Connection<T, TcpStream>, Error = ConnectError> + Clone pub struct TcpConnectorService;
{
new_connector(Resolver::Default) impl<R: Host> Service<ConnectInfo<R>> for TcpConnectorService {
type Response = Connection<R, TcpStream>;
type Error = ConnectError;
type Future = TcpConnectorFut<R>;
actix_service::always_ready!();
fn call(&self, req: ConnectInfo<R>) -> Self::Future {
let port = req.port();
let ConnectInfo {
request: req,
addr,
local_addr,
..
} = req;
TcpConnectorFut::new(req, port, local_addr, addr)
}
} }
/// Create TCP connector service factory with default parameters. /// Connect future for TCP service.
pub fn default_connector_factory<T: Address + 'static>() -> impl ServiceFactory< #[doc(hidden)]
Connect<T>, pub enum TcpConnectorFut<R> {
Config = (), Response {
Response = Connection<T, TcpStream>, req: Option<R>,
Error = ConnectError, port: u16,
InitError = (), local_addr: Option<IpAddr>,
> + Clone { addrs: Option<VecDeque<SocketAddr>>,
new_connector_factory(Resolver::Default) stream: ReusableBoxFuture<Result<TcpStream, io::Error>>,
},
Error(Option<ConnectError>),
}
impl<R: Host> TcpConnectorFut<R> {
pub(crate) fn new(
req: R,
port: u16,
local_addr: Option<IpAddr>,
addr: ConnectAddrs,
) -> TcpConnectorFut<R> {
if addr.is_unresolved() {
error!("TCP connector: unresolved connection address");
return TcpConnectorFut::Error(Some(ConnectError::Unresolved));
}
trace!(
"TCP connector: connecting to {} on port {}",
req.hostname(),
port
);
match addr {
ConnectAddrs::None => unreachable!("none variant already checked"),
ConnectAddrs::One(addr) => TcpConnectorFut::Response {
req: Some(req),
port,
local_addr,
addrs: None,
stream: ReusableBoxFuture::new(connect(addr, local_addr)),
},
// when resolver returns multiple socket addr for request they would be popped from
// front end of queue and returns with the first successful tcp connection.
ConnectAddrs::Multi(mut addrs) => {
let addr = addrs.pop_front().unwrap();
TcpConnectorFut::Response {
req: Some(req),
port,
local_addr,
addrs: Some(addrs),
stream: ReusableBoxFuture::new(connect(addr, local_addr)),
}
}
}
}
}
impl<R: Host> Future for TcpConnectorFut<R> {
type Output = Result<Connection<R, TcpStream>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
TcpConnectorFut::Error(err) => Poll::Ready(Err(err.take().unwrap())),
TcpConnectorFut::Response {
req,
port,
local_addr,
addrs,
stream,
} => loop {
match ready!(stream.poll(cx)) {
Ok(sock) => {
let req = req.take().unwrap();
trace!(
"TCP connector: successfully connected to {:?} - {:?}",
req.hostname(),
sock.peer_addr()
);
return Poll::Ready(Ok(Connection::new(req, sock)));
}
Err(err) => {
trace!(
"TCP connector: failed to connect to {:?} port: {}",
req.as_ref().unwrap().hostname(),
port,
);
if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) {
stream.set(connect(addr, *local_addr));
} else {
return Poll::Ready(Err(ConnectError::Io(err)));
}
}
}
},
}
}
}
async fn connect(addr: SocketAddr, local_addr: Option<IpAddr>) -> io::Result<TcpStream> {
// use local addr if connect asks for it
match local_addr {
Some(ip_addr) => {
let socket = match ip_addr {
IpAddr::V4(ip_addr) => {
let socket = TcpSocket::new_v4()?;
let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0));
socket.bind(addr)?;
socket
}
IpAddr::V6(ip_addr) => {
let socket = TcpSocket::new_v6()?;
let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0));
socket.bind(addr)?;
socket
}
};
socket.connect(addr).await
}
None => TcpStream::connect(addr).await,
}
} }

View File

@ -1,10 +0,0 @@
//! TLS Services
#[cfg(feature = "openssl")]
pub mod openssl;
#[cfg(feature = "rustls")]
pub mod rustls;
#[cfg(feature = "native-tls")]
pub mod native_tls;

View File

@ -1,88 +0,0 @@
use std::io;
use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory};
use futures_core::future::LocalBoxFuture;
use log::trace;
use tokio_native_tls::{TlsConnector as TokioNativetlsConnector, TlsStream};
pub use tokio_native_tls::native_tls::TlsConnector;
use crate::connect::{Address, Connection};
/// Native-tls connector factory and service
pub struct NativetlsConnector {
connector: TokioNativetlsConnector,
}
impl NativetlsConnector {
pub fn new(connector: TlsConnector) -> Self {
Self {
connector: TokioNativetlsConnector::from(connector),
}
}
}
impl NativetlsConnector {
pub fn service(connector: TlsConnector) -> Self {
Self::new(connector)
}
}
impl Clone for NativetlsConnector {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<T: Address, U> ServiceFactory<Connection<T, U>> for NativetlsConnector
where
U: ActixStream + 'static,
{
type Response = Connection<T, TlsStream<U>>;
type Error = io::Error;
type Config = ();
type Service = Self;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
let connector = self.clone();
Box::pin(async { Ok(connector) })
}
}
// NativetlsConnector is both it's ServiceFactory and Service impl type.
// As the factory and service share the same type and state.
impl<T, U> Service<Connection<T, U>> for NativetlsConnector
where
T: Address,
U: ActixStream + 'static,
{
type Response = Connection<T, TlsStream<U>>;
type Error = io::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
actix_service::always_ready!();
fn call(&self, stream: Connection<T, U>) -> Self::Future {
let (io, stream) = stream.replace_io(());
let connector = self.connector.clone();
Box::pin(async move {
trace!("SSL Handshake start for: {:?}", stream.host());
connector
.connect(stream.host(), io)
.await
.map(|res| {
trace!("SSL Handshake success: {:?}", stream.host());
stream.replace_io(res).1
})
.map_err(|e| {
trace!("SSL Handshake error: {:?}", e);
io::Error::new(io::ErrorKind::Other, format!("{}", e))
})
})
}
}

View File

@ -1,130 +0,0 @@
use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready};
use log::trace;
pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
pub use tokio_openssl::SslStream;
use crate::connect::{Address, Connection};
/// OpenSSL connector factory
pub struct OpensslConnector {
connector: SslConnector,
}
impl OpensslConnector {
pub fn new(connector: SslConnector) -> Self {
OpensslConnector { connector }
}
pub fn service(connector: SslConnector) -> OpensslConnectorService {
OpensslConnectorService { connector }
}
}
impl Clone for OpensslConnector {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<T, U> ServiceFactory<Connection<T, U>> for OpensslConnector
where
T: Address,
U: ActixStream + 'static,
{
type Response = Connection<T, SslStream<U>>;
type Error = io::Error;
type Config = ();
type Service = OpensslConnectorService;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
let connector = self.connector.clone();
Box::pin(async { Ok(OpensslConnectorService { connector }) })
}
}
pub struct OpensslConnectorService {
connector: SslConnector,
}
impl Clone for OpensslConnectorService {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<T, U> Service<Connection<T, U>> for OpensslConnectorService
where
T: Address,
U: ActixStream,
{
type Response = Connection<T, SslStream<U>>;
type Error = io::Error;
type Future = ConnectAsyncExt<T, U>;
actix_service::always_ready!();
fn call(&self, stream: Connection<T, U>) -> Self::Future {
trace!("SSL Handshake start for: {:?}", stream.host());
let (io, stream) = stream.replace_io(());
let host = stream.host();
let config = self
.connector
.configure()
.expect("SSL connect configuration was invalid.");
let ssl = config
.into_ssl(host)
.expect("SSL connect configuration was invalid.");
ConnectAsyncExt {
io: Some(SslStream::new(ssl, io).unwrap()),
stream: Some(stream),
}
}
}
pub struct ConnectAsyncExt<T, U> {
io: Option<SslStream<U>>,
stream: Option<Connection<T, ()>>,
}
impl<T: Address, U> Future for ConnectAsyncExt<T, U>
where
T: Address,
U: ActixStream,
{
type Output = Result<Connection<T, SslStream<U>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match ready!(Pin::new(this.io.as_mut().unwrap()).poll_connect(cx)) {
Ok(_) => {
let stream = this.stream.take().unwrap();
trace!("SSL Handshake success: {:?}", stream.host());
Poll::Ready(Ok(stream.replace_io(this.io.take().unwrap()).1))
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
}
}
}

View File

@ -1,146 +0,0 @@
use std::{
convert::TryFrom,
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
pub use webpki_roots::TLS_SERVER_ROOTS;
use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready};
use log::trace;
use tokio_rustls::rustls::{client::ServerName, OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{Connect, TlsConnector};
use crate::connect::{Address, Connection};
/// Returns standard root certificates from `webpki-roots` crate as a rustls certificate store.
pub fn webpki_roots_cert_store() -> RootCertStore {
let mut root_certs = RootCertStore::empty();
for cert in TLS_SERVER_ROOTS.0 {
let cert = OwnedTrustAnchor::from_subject_spki_name_constraints(
cert.subject,
cert.spki,
cert.name_constraints,
);
let certs = vec![cert].into_iter();
root_certs.add_server_trust_anchors(certs);
}
root_certs
}
/// Rustls connector factory
pub struct RustlsConnector {
connector: Arc<ClientConfig>,
}
impl RustlsConnector {
pub fn new(connector: Arc<ClientConfig>) -> Self {
RustlsConnector { connector }
}
}
impl RustlsConnector {
pub fn service(connector: Arc<ClientConfig>) -> RustlsConnectorService {
RustlsConnectorService { connector }
}
}
impl Clone for RustlsConnector {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<T, U> ServiceFactory<Connection<T, U>> for RustlsConnector
where
T: Address,
U: ActixStream + 'static,
{
type Response = Connection<T, TlsStream<U>>;
type Error = io::Error;
type Config = ();
type Service = RustlsConnectorService;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
let connector = self.connector.clone();
Box::pin(async { Ok(RustlsConnectorService { connector }) })
}
}
pub struct RustlsConnectorService {
connector: Arc<ClientConfig>,
}
impl Clone for RustlsConnectorService {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<T, U> Service<Connection<T, U>> for RustlsConnectorService
where
T: Address,
U: ActixStream,
{
type Response = Connection<T, TlsStream<U>>;
type Error = io::Error;
type Future = RustlsConnectorServiceFuture<T, U>;
actix_service::always_ready!();
fn call(&self, connection: Connection<T, U>) -> Self::Future {
trace!("SSL Handshake start for: {:?}", connection.host());
let (stream, connection) = connection.replace_io(());
match ServerName::try_from(connection.host()) {
Ok(host) => RustlsConnectorServiceFuture::Future {
connect: TlsConnector::from(self.connector.clone()).connect(host, stream),
connection: Some(connection),
},
Err(_) => RustlsConnectorServiceFuture::InvalidDns,
}
}
}
pub enum RustlsConnectorServiceFuture<T, U> {
/// See issue <https://github.com/briansmith/webpki/issues/54>
InvalidDns,
Future {
connect: Connect<U>,
connection: Option<Connection<T, ()>>,
},
}
impl<T, U> Future for RustlsConnectorServiceFuture<T, U>
where
T: Address,
U: ActixStream,
{
type Output = Result<Connection<T, TlsStream<U>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::InvalidDns => Poll::Ready(Err(
io::Error::new(io::ErrorKind::Other, "rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54")
)),
Self::Future { connect, connection } => {
let stream = ready!(Pin::new(connect).poll(cx))?;
let connection = connection.take().unwrap();
trace!("SSL Handshake success: {:?}", connection.host());
Poll::Ready(Ok(connection.replace_io(stream).1))
}
}
}
}

View File

@ -1,8 +1,8 @@
use http::Uri; use http::Uri;
use super::Address; use super::Host;
impl Address for Uri { impl Host for Uri {
fn hostname(&self) -> &str { fn hostname(&self) -> &str {
self.host().unwrap_or("") self.host().unwrap_or("")
} }
@ -35,9 +35,18 @@ fn scheme_to_port(scheme: Option<&str>) -> Option<u16> {
Some("mqtts") => Some(8883), Some("mqtts") => Some(8883),
// File Transfer Protocol (FTP) // File Transfer Protocol (FTP)
Some("ftp") => Some(1883), Some("ftp") => Some(21),
Some("ftps") => Some(990), Some("ftps") => Some(990),
// Redis
Some("redis") => Some(6379),
// MySQL
Some("mysql") => Some(3306),
// PostgreSQL
Some("postgres") => Some(5432),
_ => None, _ => None,
} }
} }

View File

@ -1,14 +1,19 @@
//! TLS acceptor and connector services for Actix ecosystem //! TLS acceptor and connector services for the Actix ecosystem.
#![deny(rust_2018_idioms, nonstandard_style)] #![deny(rust_2018_idioms, nonstandard_style)]
#![warn(missing_docs)]
#![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_logo_url = "https://actix.rs/img/logo.png")]
#![doc(html_favicon_url = "https://actix.rs/favicon.ico")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
#[allow(unused_extern_crates)] #[allow(unused_extern_crates)]
extern crate tls_openssl as openssl; extern crate tls_openssl as openssl;
#[cfg(feature = "accept")] #[cfg(feature = "accept")]
#[cfg_attr(docsrs, doc(cfg(feature = "accept")))]
pub mod accept; pub mod accept;
#[cfg(feature = "connect")] #[cfg(feature = "connect")]
#[cfg_attr(docsrs, doc(cfg(feature = "connect")))]
pub mod connect; pub mod connect;

View File

@ -7,13 +7,15 @@
feature = "openssl" feature = "openssl"
))] ))]
extern crate tls_openssl as openssl;
use std::io::{BufReader, Write}; use std::io::{BufReader, Write};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_server::TestServer; use actix_server::TestServer;
use actix_service::ServiceFactoryExt as _; use actix_service::ServiceFactoryExt as _;
use actix_tls::accept::rustls::{Acceptor, TlsStream}; use actix_tls::accept::rustls::{Acceptor, TlsStream};
use actix_tls::connect::tls::openssl::SslConnector; use actix_tls::connect::openssl::reexports::SslConnector;
use actix_utils::future::ok; use actix_utils::future::ok;
use rustls_pemfile::{certs, pkcs8_private_keys}; use rustls_pemfile::{certs, pkcs8_private_keys};
use tls_openssl::ssl::SslVerifyMode; use tls_openssl::ssl::SslVerifyMode;
@ -53,13 +55,13 @@ fn rustls_server_config(cert: String, key: String) -> rustls::ServerConfig {
} }
fn openssl_connector(cert: String, key: String) -> SslConnector { fn openssl_connector(cert: String, key: String) -> SslConnector {
use actix_tls::connect::tls::openssl::{SslConnector as OpensslConnector, SslMethod}; use actix_tls::connect::openssl::reexports::SslMethod;
use tls_openssl::{pkey::PKey, x509::X509}; use openssl::{pkey::PKey, x509::X509};
let cert = X509::from_pem(cert.as_bytes()).unwrap(); let cert = X509::from_pem(cert.as_bytes()).unwrap();
let key = PKey::private_key_from_pem(key.as_bytes()).unwrap(); let key = PKey::private_key_from_pem(key.as_bytes()).unwrap();
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap();
ssl.set_verify(SslVerifyMode::NONE); ssl.set_verify(SslVerifyMode::NONE);
ssl.set_certificate(&cert).unwrap(); ssl.set_certificate(&cert).unwrap();
ssl.set_private_key(&key).unwrap(); ssl.set_private_key(&key).unwrap();

63
actix-tls/tests/test_connect.rs Executable file → Normal file
View File

@ -12,7 +12,7 @@ use actix_service::{fn_service, Service, ServiceFactory};
use bytes::Bytes; use bytes::Bytes;
use futures_util::sink::SinkExt; use futures_util::sink::SinkExt;
use actix_tls::connect::{self as actix_connect, Connect}; use actix_tls::connect::{ConnectError, ConnectInfo, Connection, Connector, Host};
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
#[actix_rt::test] #[actix_rt::test]
@ -25,9 +25,9 @@ async fn test_string() {
}) })
}); });
let conn = actix_connect::default_connector(); let connector = Connector::default().service();
let addr = format!("localhost:{}", srv.port()); let addr = format!("localhost:{}", srv.port());
let con = conn.call(addr.into()).await.unwrap(); let con = connector.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());
} }
@ -42,7 +42,7 @@ async fn test_rustls_string() {
}) })
}); });
let conn = actix_connect::default_connector(); let conn = Connector::default().service();
let addr = format!("localhost:{}", srv.port()); let addr = format!("localhost:{}", srv.port());
let con = conn.call(addr.into()).await.unwrap(); let con = conn.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());
@ -58,23 +58,29 @@ async fn test_static_str() {
}) })
}); });
let conn = actix_connect::default_connector(); let info = ConnectInfo::with_addr("10", srv.addr());
let connector = Connector::default().service();
let conn = connector.call(info).await.unwrap();
assert_eq!(conn.peer_addr().unwrap(), srv.addr());
let con = conn let info = ConnectInfo::new(srv.host().to_owned());
.call(Connect::with_addr("10", srv.addr())) let connector = Connector::default().service();
.await let conn = connector.call(info).await;
.unwrap(); assert!(conn.is_err());
assert_eq!(con.peer_addr().unwrap(), srv.addr());
let connect = Connect::new(srv.host().to_owned());
let conn = actix_connect::default_connector();
let con = conn.call(connect).await;
assert!(con.is_err());
} }
#[actix_rt::test] #[actix_rt::test]
async fn test_new_service() { async fn service_factory() {
pub fn default_connector_factory<T: Host + 'static>() -> impl ServiceFactory<
ConnectInfo<T>,
Config = (),
Response = Connection<T, TcpStream>,
Error = ConnectError,
InitError = (),
> {
Connector::default()
}
let srv = TestServer::with(|| { let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async { fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec); let mut framed = Framed::new(io, BytesCodec);
@ -83,14 +89,11 @@ async fn test_new_service() {
}) })
}); });
let factory = actix_connect::default_connector_factory(); let info = ConnectInfo::with_addr("10", srv.addr());
let factory = default_connector_factory();
let conn = factory.new_service(()).await.unwrap(); let connector = factory.new_service(()).await.unwrap();
let con = conn let con = connector.call(info).await;
.call(Connect::with_addr("10", srv.addr())) assert_eq!(con.unwrap().peer_addr().unwrap(), srv.addr());
.await
.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
} }
#[cfg(all(feature = "openssl", feature = "uri"))] #[cfg(all(feature = "openssl", feature = "uri"))]
@ -106,9 +109,9 @@ async fn test_openssl_uri() {
}) })
}); });
let conn = actix_connect::default_connector(); let connector = Connector::default().service();
let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap(); let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap();
let con = conn.call(addr.into()).await.unwrap(); let con = connector.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());
} }
@ -125,7 +128,7 @@ async fn test_rustls_uri() {
}) })
}); });
let conn = actix_connect::default_connector(); let conn = Connector::default().service();
let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap(); let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap();
let con = conn.call(addr.into()).await.unwrap(); let con = conn.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());
@ -141,11 +144,11 @@ async fn test_local_addr() {
}) })
}); });
let conn = actix_connect::default_connector(); let conn = Connector::default().service();
let local = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3)); let local = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
let (con, _) = conn let (con, _) = conn
.call(Connect::with_addr("10", srv.addr()).set_local_addr(local)) .call(ConnectInfo::with_addr("10", srv.addr()).set_local_addr(local))
.await .await
.unwrap() .unwrap()
.into_parts(); .into_parts();

View File

@ -10,7 +10,9 @@ use actix_server::TestServer;
use actix_service::{fn_service, Service, ServiceFactory}; use actix_service::{fn_service, Service, ServiceFactory};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use actix_tls::connect::{new_connector_factory, Connect, Resolve, Resolver}; use actix_tls::connect::{
ConnectError, ConnectInfo, Connection, Connector, Host, Resolve, Resolver,
};
#[actix_rt::test] #[actix_rt::test]
async fn custom_resolver() { async fn custom_resolver() {
@ -36,6 +38,18 @@ async fn custom_resolver() {
#[actix_rt::test] #[actix_rt::test]
async fn custom_resolver_connect() { async fn custom_resolver_connect() {
pub fn connector_factory<T: Host + 'static>(
resolver: Resolver,
) -> impl ServiceFactory<
ConnectInfo<T>,
Config = (),
Response = Connection<T, TcpStream>,
Error = ConnectError,
InitError = (),
> {
Connector::new(resolver)
}
use trust_dns_resolver::TokioAsyncResolver; use trust_dns_resolver::TokioAsyncResolver;
let srv = let srv =
@ -68,12 +82,11 @@ async fn custom_resolver_connect() {
trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(),
}; };
let resolver = Resolver::new_custom(resolver); let factory = connector_factory(Resolver::custom(resolver));
let factory = new_connector_factory(resolver);
let conn = factory.new_service(()).await.unwrap(); let conn = factory.new_service(()).await.unwrap();
let con = conn let con = conn
.call(Connect::with_addr("example.com", srv.addr())) .call(ConnectInfo::with_addr("example.com", srv.addr()))
.await .await
.unwrap(); .unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());