1
0
mirror of https://github.com/fafhrd91/actix-net synced 2024-11-27 20:12:58 +01:00

fix actix-tls tests (#241)

This commit is contained in:
fakeshadow 2020-12-29 19:36:17 +08:00 committed by GitHub
parent 0934078947
commit 03eb96d6d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 291 additions and 207 deletions

View File

@ -27,3 +27,7 @@ actix-tracing = { path = "actix-tracing" }
actix-utils = { path = "actix-utils" } actix-utils = { path = "actix-utils" }
actix-router = { path = "router" } actix-router = { path = "router" }
bytestring = { path = "string" } bytestring = { path = "string" }
# FIXME: remove override
trust-dns-proto = { git = "https://github.com/bluejekyll/trust-dns.git", branch = "main" }
trust-dns-resolver = { git = "https://github.com/bluejekyll/trust-dns.git", branch = "main" }

View File

@ -19,7 +19,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
derive_more = "0.99.2" derive_more = "0.99.2"
futures-channel = "0.3.1" futures-channel = "0.3.7"
parking_lot = "0.11" parking_lot = "0.11"
lazy_static = "1.3" lazy_static = "1.3"
log = "0.4" log = "0.4"

View File

@ -12,7 +12,7 @@ license = "MIT OR Apache-2.0"
edition = "2018" edition = "2018"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["openssl", "rustls", "native-tls", "accept", "connect", "http"] features = ["openssl", "rustls", "native-tls", "accept", "connect", "uri"]
[lib] [lib]
name = "actix_tls" name = "actix_tls"
@ -23,13 +23,13 @@ name = "basic"
required-features = ["accept", "rustls"] required-features = ["accept", "rustls"]
[features] [features]
default = ["accept", "connect", "http"] default = ["accept", "connect", "uri"]
# enable acceptor services # enable acceptor services
accept = [] accept = []
# enable connector services # enable connector services
connect = [] connect = ["trust-dns-proto/tokio-runtime", "trust-dns-resolver/tokio-runtime", "trust-dns-resolver/system-config"]
# use openssl impls # use openssl impls
openssl = ["tls-openssl", "tokio-openssl"] openssl = ["tls-openssl", "tokio-openssl"]
@ -40,6 +40,9 @@ rustls = ["tls-rustls", "webpki", "webpki-roots", "tokio-rustls"]
# use native-tls impls # use native-tls impls
native-tls = ["tls-native-tls", "tokio-native-tls"] native-tls = ["tls-native-tls", "tokio-native-tls"]
# support http::Uri as connect address
uri = ["http"]
[dependencies] [dependencies]
actix-codec = "0.4.0-beta.1" actix-codec = "0.4.0-beta.1"
actix-rt = "2.0.0-beta.1" actix-rt = "2.0.0-beta.1"
@ -49,10 +52,12 @@ actix-utils = "3.0.0-beta.1"
derive_more = "0.99.5" derive_more = "0.99.5"
either = "1.6" either = "1.6"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
http = { version = "0.2.0", optional = true } http = { version = "0.2.2", optional = true }
log = "0.4" log = "0.4"
trust-dns-proto = { version = "0.19", default-features = false, features = ["tokio-runtime"] }
trust-dns-resolver = { version = "0.19", default-features = false, features = ["tokio-runtime", "system-config"] } # resolver
trust-dns-proto = { version = "0.20.0-alpha.3", default-features = false, optional = true }
trust-dns-resolver = { version = "0.20.0-alpha.3", default-features = false, optional = true }
# openssl # openssl
tls-openssl = { package = "openssl", version = "0.10", optional = true } tls-openssl = { package = "openssl", version = "0.10", optional = true }
@ -74,5 +79,6 @@ tokio-native-tls = { version = "0.3", optional = true }
actix-server = "2.0.0-beta.1" actix-server = "2.0.0-beta.1"
actix-testing = "2.0.0-beta.1" actix-testing = "2.0.0-beta.1"
bytes = "1" bytes = "1"
log = "0.4"
env_logger = "0.8" env_logger = "0.8"
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
log = "0.4"

View File

@ -1,10 +1,9 @@
use std::marker::PhantomData;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::counter::Counter; use actix_utils::counter::Counter;
use futures_util::future::{self, FutureExt, LocalBoxFuture, TryFutureExt}; use futures_util::future::{ready, LocalBoxFuture, Ready};
pub use native_tls::Error; pub use native_tls::Error;
pub use tokio_native_tls::{TlsAcceptor, TlsStream}; pub use tokio_native_tls::{TlsAcceptor, TlsStream};
@ -14,75 +13,64 @@ use super::MAX_CONN_COUNTER;
/// Accept TLS connections via `native-tls` package. /// Accept TLS connections via `native-tls` package.
/// ///
/// `native-tls` feature enables this `Acceptor` type. /// `native-tls` feature enables this `Acceptor` type.
pub struct Acceptor<T> { pub struct Acceptor {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
io: PhantomData<T>,
} }
impl<T> Acceptor<T> impl Acceptor {
where
T: AsyncRead + AsyncWrite + Unpin,
{
/// Create `native-tls` based `Acceptor` service factory. /// Create `native-tls` based `Acceptor` service factory.
#[inline] #[inline]
pub fn new(acceptor: TlsAcceptor) -> Self { pub fn new(acceptor: TlsAcceptor) -> Self {
Acceptor { Acceptor { acceptor }
acceptor,
io: PhantomData,
}
} }
} }
impl<T> Clone for Acceptor<T> { impl Clone for Acceptor {
#[inline] #[inline]
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
io: PhantomData,
} }
} }
} }
impl<T> ServiceFactory<T> for Acceptor<T> impl<T> ServiceFactory<T> for Acceptor
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Response = TlsStream<T>; type Response = TlsStream<T>;
type Error = Error; type Error = Error;
type Service = NativeTlsAcceptorService<T>;
type Config = (); type Config = ();
type Service = NativeTlsAcceptorService;
type InitError = (); type InitError = ();
type Future = future::Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
MAX_CONN_COUNTER.with(|conns| { MAX_CONN_COUNTER.with(|conns| {
future::ok(NativeTlsAcceptorService { ready(Ok(NativeTlsAcceptorService {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
conns: conns.clone(), conns: conns.clone(),
io: PhantomData, }))
})
}) })
} }
} }
pub struct NativeTlsAcceptorService<T> { pub struct NativeTlsAcceptorService {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
io: PhantomData<T>,
conns: Counter, conns: Counter,
} }
impl<T> Clone for NativeTlsAcceptorService<T> { impl Clone for NativeTlsAcceptorService {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
io: PhantomData,
conns: self.conns.clone(), conns: self.conns.clone(),
} }
} }
} }
impl<T> Service<T> for NativeTlsAcceptorService<T> impl<T> Service<T> for NativeTlsAcceptorService
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
@ -101,12 +89,10 @@ where
fn call(&mut self, io: T) -> Self::Future { fn call(&mut self, io: T) -> Self::Future {
let guard = self.conns.get(); let guard = self.conns.get();
let this = self.clone(); let this = self.clone();
async move { this.acceptor.accept(io).await } Box::pin(async move {
.map_ok(move |io| { let io = this.acceptor.accept(io).await;
// Required to preserve `CounterGuard` until `Self::Future` is completely resolved. drop(guard);
let _ = guard; io
io })
})
.boxed_local()
} }
} }

View File

@ -1,5 +1,4 @@
use std::future::Future; use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -7,7 +6,7 @@ use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::counter::{Counter, CounterGuard}; use actix_utils::counter::{Counter, CounterGuard};
use futures_util::{ use futures_util::{
future::{ok, Ready}, future::{ready, Ready},
ready, ready,
}; };
@ -21,61 +20,54 @@ use super::MAX_CONN_COUNTER;
/// Accept TLS connections via `openssl` package. /// Accept TLS connections via `openssl` package.
/// ///
/// `openssl` feature enables this `Acceptor` type. /// `openssl` feature enables this `Acceptor` type.
pub struct Acceptor<T: AsyncRead + AsyncWrite> { pub struct Acceptor {
acceptor: SslAcceptor, acceptor: SslAcceptor,
io: PhantomData<T>,
} }
impl<T: AsyncRead + AsyncWrite> Acceptor<T> { impl Acceptor {
/// Create OpenSSL based `Acceptor` service factory. /// Create OpenSSL based `Acceptor` service factory.
#[inline] #[inline]
pub fn new(acceptor: SslAcceptor) -> Self { pub fn new(acceptor: SslAcceptor) -> Self {
Acceptor { Acceptor { acceptor }
acceptor,
io: PhantomData,
}
} }
} }
impl<T: AsyncRead + AsyncWrite> Clone for Acceptor<T> { impl Clone for Acceptor {
#[inline] #[inline]
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
io: PhantomData,
} }
} }
} }
impl<T> ServiceFactory<T> for Acceptor<T> impl<T> ServiceFactory<T> for Acceptor
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Response = SslStream<T>; type Response = SslStream<T>;
type Error = SslError; type Error = SslError;
type Config = (); type Config = ();
type Service = AcceptorService<T>; type Service = AcceptorService;
type InitError = (); type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
MAX_CONN_COUNTER.with(|conns| { MAX_CONN_COUNTER.with(|conns| {
ok(AcceptorService { ready(Ok(AcceptorService {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
conns: conns.clone(), conns: conns.clone(),
io: PhantomData, }))
})
}) })
} }
} }
pub struct AcceptorService<T> { pub struct AcceptorService {
acceptor: SslAcceptor, acceptor: SslAcceptor,
conns: Counter, conns: Counter,
io: PhantomData<T>,
} }
impl<T> Service<T> for AcceptorService<T> impl<T> Service<T> for AcceptorService
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {

View File

@ -1,6 +1,5 @@
use std::future::Future; use std::future::Future;
use std::io; use std::io;
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -8,7 +7,7 @@ use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::counter::{Counter, CounterGuard}; use actix_utils::counter::{Counter, CounterGuard};
use futures_util::future::{ok, Ready}; use futures_util::future::{ready, Ready};
use tokio_rustls::{Accept, TlsAcceptor}; use tokio_rustls::{Accept, TlsAcceptor};
pub use rustls::{ServerConfig, Session}; pub use rustls::{ServerConfig, Session};
@ -20,66 +19,58 @@ use super::MAX_CONN_COUNTER;
/// Accept TLS connections via `rustls` package. /// Accept TLS connections via `rustls` package.
/// ///
/// `rustls` feature enables this `Acceptor` type. /// `rustls` feature enables this `Acceptor` type.
pub struct Acceptor<T> { pub struct Acceptor {
config: Arc<ServerConfig>, config: Arc<ServerConfig>,
io: PhantomData<T>,
} }
impl<T> Acceptor<T> impl Acceptor {
where
T: AsyncRead + AsyncWrite,
{
/// Create Rustls based `Acceptor` service factory. /// Create Rustls based `Acceptor` service factory.
#[inline] #[inline]
pub fn new(config: ServerConfig) -> Self { pub fn new(config: ServerConfig) -> Self {
Acceptor { Acceptor {
config: Arc::new(config), config: Arc::new(config),
io: PhantomData,
} }
} }
} }
impl<T> Clone for Acceptor<T> { impl Clone for Acceptor {
#[inline] #[inline]
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
config: self.config.clone(), config: self.config.clone(),
io: PhantomData,
} }
} }
} }
impl<T> ServiceFactory<T> for Acceptor<T> impl<T> ServiceFactory<T> for Acceptor
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
type Response = TlsStream<T>; type Response = TlsStream<T>;
type Error = io::Error; type Error = io::Error;
type Service = AcceptorService<T>;
type Config = (); type Config = ();
type Service = AcceptorService;
type InitError = (); type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
MAX_CONN_COUNTER.with(|conns| { MAX_CONN_COUNTER.with(|conns| {
ok(AcceptorService { ready(Ok(AcceptorService {
acceptor: self.config.clone().into(), acceptor: self.config.clone().into(),
conns: conns.clone(), conns: conns.clone(),
io: PhantomData, }))
})
}) })
} }
} }
/// Rustls based `Acceptor` service /// Rustls based `Acceptor` service
pub struct AcceptorService<T> { pub struct AcceptorService {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
io: PhantomData<T>,
conns: Counter, conns: Counter,
} }
impl<T> Service<T> for AcceptorService<T> impl<T> Service<T> for AcceptorService
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
@ -119,11 +110,6 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
Pin::new(&mut this.fut).poll(cx)
let res = futures_util::ready!(Pin::new(&mut this.fut).poll(cx));
match res {
Ok(io) => Poll::Ready(Ok(io)),
Err(e) => Poll::Ready(Err(e)),
}
} }
} }

View File

@ -8,7 +8,7 @@ use std::task::{Context, Poll};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{err, ok, BoxFuture, Either, FutureExt, Ready}; use futures_util::future::{ready, Ready};
use log::{error, trace}; use log::{error, trace};
use super::connect::{Address, Connect, Connection}; use super::connect::{Address, Connect, Connection};
@ -50,7 +50,7 @@ impl<T: Address> ServiceFactory<Connect<T>> for TcpConnectorFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(self.service()) ready(Ok(self.service()))
} }
} }
@ -73,8 +73,7 @@ impl<T> Clone for TcpConnector<T> {
impl<T: Address> Service<Connect<T>> for TcpConnector<T> { impl<T: Address> Service<Connect<T>> for TcpConnector<T> {
type Response = Connection<T, TcpStream>; type Response = Connection<T, TcpStream>;
type Error = ConnectError; type Error = ConnectError;
#[allow(clippy::type_complexity)] type Future = TcpConnectorResponse<T>;
type Future = Either<TcpConnectorResponse<T>, Ready<Result<Self::Response, Self::Error>>>;
actix_service::always_ready!(); actix_service::always_ready!();
@ -83,21 +82,26 @@ impl<T: Address> Service<Connect<T>> for TcpConnector<T> {
let Connect { req, addr, .. } = req; let Connect { req, addr, .. } = req;
if let Some(addr) = addr { if let Some(addr) = addr {
Either::Left(TcpConnectorResponse::new(req, port, addr)) TcpConnectorResponse::new(req, port, addr)
} else { } else {
error!("TCP connector: got unresolved address"); error!("TCP connector: got unresolved address");
Either::Right(err(ConnectError::Unresolved)) TcpConnectorResponse::Error(Some(ConnectError::Unresolved))
} }
} }
} }
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[doc(hidden)] #[doc(hidden)]
/// TCP stream connector response future /// TCP stream connector response future
pub struct TcpConnectorResponse<T> { pub enum TcpConnectorResponse<T> {
req: Option<T>, Response {
port: u16, req: Option<T>,
addrs: Option<VecDeque<SocketAddr>>, port: u16,
stream: Option<BoxFuture<'static, Result<TcpStream, io::Error>>>, addrs: Option<VecDeque<SocketAddr>>,
stream: Option<LocalBoxFuture<'static, Result<TcpStream, io::Error>>>,
},
Error(Option<ConnectError>),
} }
impl<T: Address> TcpConnectorResponse<T> { impl<T: Address> TcpConnectorResponse<T> {
@ -113,13 +117,13 @@ impl<T: Address> TcpConnectorResponse<T> {
); );
match addr { match addr {
either::Either::Left(addr) => TcpConnectorResponse { either::Either::Left(addr) => TcpConnectorResponse::Response {
req: Some(req), req: Some(req),
port, port,
addrs: None, addrs: None,
stream: Some(TcpStream::connect(addr).boxed()), stream: Some(Box::pin(TcpStream::connect(addr))),
}, },
either::Either::Right(addrs) => TcpConnectorResponse { either::Either::Right(addrs) => TcpConnectorResponse::Response {
req: Some(req), req: Some(req),
port, port,
addrs: Some(addrs), addrs: Some(addrs),
@ -134,36 +138,43 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
match this {
// connect TcpConnectorResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
loop { // connect
if let Some(new) = this.stream.as_mut() { TcpConnectorResponse::Response {
match new.as_mut().poll(cx) { req,
Poll::Ready(Ok(sock)) => { port,
let req = this.req.take().unwrap(); addrs,
trace!( stream,
"TCP connector - successfully connected to connecting to {:?} - {:?}", } => loop {
req.host(), sock.peer_addr() if let Some(new) = stream.as_mut() {
); match new.as_mut().poll(cx) {
return Poll::Ready(Ok(Connection::new(sock, req))); Poll::Ready(Ok(sock)) => {
} let req = req.take().unwrap();
Poll::Pending => return Poll::Pending, trace!(
Poll::Ready(Err(err)) => { "TCP connector - successfully connected to connecting to {:?} - {:?}",
trace!( req.host(), sock.peer_addr()
"TCP connector - failed to connect to connecting to {:?} port: {}", );
this.req.as_ref().unwrap().host(), return Poll::Ready(Ok(Connection::new(sock, req)));
this.port, }
); Poll::Pending => return Poll::Pending,
if this.addrs.is_none() || this.addrs.as_ref().unwrap().is_empty() { Poll::Ready(Err(err)) => {
return Poll::Ready(Err(err.into())); trace!(
"TCP connector - failed to connect to connecting to {:?} port: {}",
req.as_ref().unwrap().host(),
port,
);
if addrs.is_none() || addrs.as_ref().unwrap().is_empty() {
return Poll::Ready(Err(err.into()));
}
} }
} }
} }
}
// try to connect // try to connect
let addr = this.addrs.as_mut().unwrap().pop_front().unwrap(); let addr = addrs.as_mut().unwrap().pop_front().unwrap();
this.stream = Some(TcpStream::connect(addr).boxed()); *stream = Some(Box::pin(TcpStream::connect(addr)));
},
} }
} }
} }

View File

@ -11,6 +11,7 @@ mod error;
mod resolve; mod resolve;
mod service; mod service;
pub mod ssl; pub mod ssl;
#[cfg(feature = "uri")]
mod uri; mod uri;
use actix_rt::{net::TcpStream, Arbiter}; use actix_rt::{net::TcpStream, Arbiter};
@ -35,7 +36,7 @@ pub async fn start_resolver(
cfg: ResolverConfig, cfg: ResolverConfig,
opts: ResolverOpts, opts: ResolverOpts,
) -> Result<AsyncResolver, ConnectError> { ) -> Result<AsyncResolver, ConnectError> {
Ok(AsyncResolver::tokio(cfg, opts).await?) Ok(AsyncResolver::tokio(cfg, opts)?)
} }
struct DefaultResolver(AsyncResolver); struct DefaultResolver(AsyncResolver);
@ -52,7 +53,7 @@ pub(crate) async fn get_default_resolver() -> Result<AsyncResolver, ConnectError
} }
}; };
let resolver = AsyncResolver::tokio(cfg, opts).await?; let resolver = AsyncResolver::tokio(cfg, opts)?;
Arbiter::set_item(DefaultResolver(resolver.clone())); Arbiter::set_item(DefaultResolver(resolver.clone()));
Ok(resolver) Ok(resolver)

View File

@ -8,7 +8,7 @@ use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::{ use futures_util::{
future::{err, ok, Either, Ready}, future::{ready, Either, Ready},
ready, ready,
}; };
use log::trace; use log::trace;
@ -21,43 +21,31 @@ use crate::connect::{
}; };
/// OpenSSL connector factory /// OpenSSL connector factory
pub struct OpensslConnector<T, U> { pub struct OpensslConnector {
connector: SslConnector, connector: SslConnector,
_t: PhantomData<(T, U)>,
} }
impl<T, U> OpensslConnector<T, U> { impl OpensslConnector {
pub fn new(connector: SslConnector) -> Self { pub fn new(connector: SslConnector) -> Self {
OpensslConnector { OpensslConnector { connector }
connector,
_t: PhantomData,
}
} }
} }
impl<T, U> OpensslConnector<T, U> impl OpensslConnector {
where pub fn service(connector: SslConnector) -> OpensslConnectorService {
T: Address + 'static, OpensslConnectorService { connector }
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
pub fn service(connector: SslConnector) -> OpensslConnectorService<T, U> {
OpensslConnectorService {
connector,
_t: PhantomData,
}
} }
} }
impl<T, U> Clone for OpensslConnector<T, U> { impl Clone for OpensslConnector {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData,
} }
} }
} }
impl<T, U> ServiceFactory<Connection<T, U>> for OpensslConnector<T, U> impl<T, U> ServiceFactory<Connection<T, U>> for OpensslConnector
where where
T: Address + 'static, T: Address + 'static,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
@ -65,33 +53,30 @@ where
type Response = Connection<T, SslStream<U>>; type Response = Connection<T, SslStream<U>>;
type Error = io::Error; type Error = io::Error;
type Config = (); type Config = ();
type Service = OpensslConnectorService<T, U>; type Service = OpensslConnectorService;
type InitError = (); type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(OpensslConnectorService { ready(Ok(OpensslConnectorService {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData, }))
})
} }
} }
pub struct OpensslConnectorService<T, U> { pub struct OpensslConnectorService {
connector: SslConnector, connector: SslConnector,
_t: PhantomData<(T, U)>,
} }
impl<T, U> Clone for OpensslConnectorService<T, U> { impl Clone for OpensslConnectorService {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData,
} }
} }
} }
impl<T, U> Service<Connection<T, U>> for OpensslConnectorService<T, U> impl<T, U> Service<Connection<T, U>> for OpensslConnectorService
where where
T: Address + 'static, T: Address + 'static,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
@ -109,7 +94,7 @@ where
let host = stream.host().to_string(); let host = stream.host().to_string();
match self.connector.configure() { match self.connector.configure() {
Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))), Err(e) => Either::Right(ready(Err(io::Error::new(io::ErrorKind::Other, e)))),
Ok(config) => { Ok(config) => {
let ssl = config let ssl = config
.into_ssl(&host) .into_ssl(&host)
@ -156,7 +141,7 @@ where
pub struct OpensslConnectServiceFactory<T> { pub struct OpensslConnectServiceFactory<T> {
tcp: ConnectServiceFactory<T>, tcp: ConnectServiceFactory<T>,
openssl: OpensslConnector<T, TcpStream>, openssl: OpensslConnector,
} }
impl<T> OpensslConnectServiceFactory<T> { impl<T> OpensslConnectServiceFactory<T> {
@ -182,7 +167,6 @@ impl<T> OpensslConnectServiceFactory<T> {
tcp: self.tcp.service(), tcp: self.tcp.service(),
openssl: OpensslConnectorService { openssl: OpensslConnectorService {
connector: self.openssl.connector.clone(), connector: self.openssl.connector.clone(),
_t: PhantomData,
}, },
} }
} }
@ -206,14 +190,14 @@ impl<T: Address + 'static> ServiceFactory<Connect<T>> for OpensslConnectServiceF
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(self.service()) ready(Ok(self.service()))
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct OpensslConnectService<T> { pub struct OpensslConnectService<T> {
tcp: ConnectService<T>, tcp: ConnectService<T>,
openssl: OpensslConnectorService<T, TcpStream>, openssl: OpensslConnectorService,
} }
impl<T: Address + 'static> Service<Connect<T>> for OpensslConnectService<T> { impl<T: Address + 'static> Service<Connect<T>> for OpensslConnectService<T> {
@ -234,10 +218,8 @@ impl<T: Address + 'static> Service<Connect<T>> for OpensslConnectService<T> {
pub struct OpensslConnectServiceResponse<T: Address + 'static> { pub struct OpensslConnectServiceResponse<T: Address + 'static> {
fut1: Option<<ConnectService<T> as Service<Connect<T>>>::Future>, fut1: Option<<ConnectService<T> as Service<Connect<T>>>::Future>,
fut2: Option< fut2: Option<<OpensslConnectorService as Service<Connection<T, TcpStream>>>::Future>,
<OpensslConnectorService<T, TcpStream> as Service<Connection<T, TcpStream>>>::Future, openssl: OpensslConnectorService,
>,
openssl: OpensslConnectorService<T, TcpStream>,
} }
impl<T: Address> Future for OpensslConnectServiceResponse<T> { impl<T: Address> Future for OpensslConnectServiceResponse<T> {

View File

@ -1,6 +1,5 @@
use std::fmt; use std::fmt;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -10,7 +9,10 @@ pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{ok, Ready}; use futures_util::{
future::{ready, Ready},
ready,
};
use log::trace; use log::trace;
use tokio_rustls::{Connect, TlsConnector}; use tokio_rustls::{Connect, TlsConnector};
use webpki::DNSNameRef; use webpki::DNSNameRef;
@ -18,77 +20,63 @@ use webpki::DNSNameRef;
use crate::connect::{Address, Connection}; use crate::connect::{Address, Connection};
/// Rustls connector factory /// Rustls connector factory
pub struct RustlsConnector<T, U> { pub struct RustlsConnector {
connector: Arc<ClientConfig>, connector: Arc<ClientConfig>,
_t: PhantomData<(T, U)>,
} }
impl<T, U> RustlsConnector<T, U> { impl RustlsConnector {
pub fn new(connector: Arc<ClientConfig>) -> Self { pub fn new(connector: Arc<ClientConfig>) -> Self {
RustlsConnector { RustlsConnector { connector }
connector,
_t: PhantomData,
}
} }
} }
impl<T, U> RustlsConnector<T, U> impl RustlsConnector {
where pub fn service(connector: Arc<ClientConfig>) -> RustlsConnectorService {
T: Address, RustlsConnectorService { connector }
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
{
pub fn service(connector: Arc<ClientConfig>) -> RustlsConnectorService<T, U> {
RustlsConnectorService {
connector,
_t: PhantomData,
}
} }
} }
impl<T, U> Clone for RustlsConnector<T, U> { impl Clone for RustlsConnector {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData,
} }
} }
} }
impl<T: Address, U> ServiceFactory<Connection<T, U>> for RustlsConnector<T, U> impl<T: Address, U> ServiceFactory<Connection<T, U>> for RustlsConnector
where where
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
{ {
type Response = Connection<T, TlsStream<U>>; type Response = Connection<T, TlsStream<U>>;
type Error = std::io::Error; type Error = std::io::Error;
type Config = (); type Config = ();
type Service = RustlsConnectorService<T, U>; type Service = RustlsConnectorService;
type InitError = (); type InitError = ();
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(RustlsConnectorService { ready(Ok(RustlsConnectorService {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData, }))
})
} }
} }
pub struct RustlsConnectorService<T, U> { pub struct RustlsConnectorService {
connector: Arc<ClientConfig>, connector: Arc<ClientConfig>,
_t: PhantomData<(T, U)>,
} }
impl<T, U> Clone for RustlsConnectorService<T, U> { impl Clone for RustlsConnectorService {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData,
} }
} }
} }
impl<T: Address, U> Service<Connection<T, U>> for RustlsConnectorService<T, U> impl<T, U> Service<Connection<T, U>> for RustlsConnectorService
where where
T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
{ {
type Response = Connection<T, TlsStream<U>>; type Response = Connection<T, TlsStream<U>>;
@ -114,20 +102,18 @@ pub struct ConnectAsyncExt<T, U> {
stream: Option<Connection<T, ()>>, stream: Option<Connection<T, ()>>,
} }
impl<T: Address, U> Future for ConnectAsyncExt<T, U> impl<T, U> Future for ConnectAsyncExt<T, U>
where where
T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
{ {
type Output = Result<Connection<T, TlsStream<U>>, std::io::Error>; type Output = Result<Connection<T, TlsStream<U>>, std::io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
Poll::Ready( let stream = ready!(Pin::new(&mut this.fut).poll(cx))?;
futures_util::ready!(Pin::new(&mut this.fut).poll(cx)).map(|stream| { let s = this.stream.take().unwrap();
let s = this.stream.take().unwrap(); trace!("SSL Handshake success: {:?}", s.host());
trace!("SSL Handshake success: {:?}", s.host()); Poll::Ready(Ok(s.replace(stream).1))
s.replace(stream).1
}),
)
} }
} }

View File

@ -0,0 +1,130 @@
use std::io;
use actix_codec::{BytesCodec, Framed};
use actix_rt::net::TcpStream;
use actix_service::{fn_service, Service, ServiceFactory};
use actix_testing::TestServer;
use bytes::Bytes;
use futures_util::sink::SinkExt;
use actix_tls::connect::{
self as actix_connect,
resolver::{ResolverConfig, ResolverOpts},
Connect,
};
#[cfg(all(feature = "connect", feature = "openssl"))]
#[actix_rt::test]
async fn test_string() {
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
let mut conn = actix_connect::default_connector();
let addr = format!("localhost:{}", srv.port());
let con = conn.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
}
#[cfg(feature = "rustls")]
#[actix_rt::test]
async fn test_rustls_string() {
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
let mut conn = actix_connect::default_connector();
let addr = format!("localhost:{}", srv.port());
let con = conn.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
}
#[actix_rt::test]
async fn test_static_str() {
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
let resolver = actix_connect::start_default_resolver().await.unwrap();
let mut conn = actix_connect::new_connector(resolver.clone());
let con = conn.call(Connect::with("10", srv.addr())).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
let connect = Connect::new(srv.host().to_owned());
let mut conn = actix_connect::new_connector(resolver);
let con = conn.call(connect).await;
assert!(con.is_err());
}
#[actix_rt::test]
async fn test_new_service() {
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
let resolver =
actix_connect::start_resolver(ResolverConfig::default(), ResolverOpts::default())
.await
.unwrap();
let factory = actix_connect::new_connector_factory(resolver);
let mut conn = factory.new_service(()).await.unwrap();
let con = conn.call(Connect::with("10", srv.addr())).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
}
#[cfg(all(feature = "openssl", feature = "uri"))]
#[actix_rt::test]
async fn test_openssl_uri() {
use std::convert::TryFrom;
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
let mut conn = actix_connect::default_connector();
let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap();
let con = conn.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
}
#[cfg(all(feature = "rustls", feature = "uri"))]
#[actix_rt::test]
async fn test_rustls_uri() {
use std::convert::TryFrom;
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
let mut conn = actix_connect::default_connector();
let addr = http::Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap();
let con = conn.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
}

View File

@ -15,7 +15,7 @@ name = "bytestring"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
bytes = "0.5.3" bytes = "1"
serde = { version = "1.0", optional = true } serde = { version = "1.0", optional = true }
[dev-dependencies] [dev-dependencies]