From 874e5f2e505d08c6b5df74b12dc0809e5b5d9a94 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Fri, 22 Jan 2021 17:33:50 -0800 Subject: [PATCH] change default name resolver and allow custom resolvers (#248) --- actix-rt/src/arbiter.rs | 55 +---- actix-rt/src/lib.rs | 2 +- actix-rt/tests/integration_tests.rs | 37 ---- actix-server/Cargo.toml | 4 +- actix-tls/CHANGES.md | 10 + actix-tls/Cargo.toml | 15 +- actix-tls/src/accept/nativetls.rs | 13 +- actix-tls/src/accept/openssl.rs | 24 +- actix-tls/src/accept/rustls.rs | 26 ++- actix-tls/src/connect/connect.rs | 119 +++++----- actix-tls/src/connect/connector.rs | 99 +++------ actix-tls/src/connect/error.rs | 7 +- actix-tls/src/connect/mod.rs | 68 ++---- actix-tls/src/connect/resolve.rs | 316 +++++++++++++++------------ actix-tls/src/connect/service.rs | 183 ++++++---------- actix-tls/src/connect/ssl/openssl.rs | 93 ++++---- actix-tls/src/connect/ssl/rustls.rs | 25 +-- actix-tls/tests/test_connect.rs | 64 +++++- actix-utils/Cargo.toml | 2 +- 19 files changed, 513 insertions(+), 649 deletions(-) mode change 100644 => 100755 actix-server/Cargo.toml mode change 100644 => 100755 actix-tls/Cargo.toml mode change 100644 => 100755 actix-tls/src/connect/connect.rs mode change 100644 => 100755 actix-tls/src/connect/connector.rs mode change 100644 => 100755 actix-tls/src/connect/resolve.rs mode change 100644 => 100755 actix-tls/src/connect/service.rs mode change 100644 => 100755 actix-tls/src/connect/ssl/openssl.rs mode change 100644 => 100755 actix-tls/src/connect/ssl/rustls.rs mode change 100644 => 100755 actix-tls/tests/test_connect.rs diff --git a/actix-rt/src/arbiter.rs b/actix-rt/src/arbiter.rs index 7aae7cd2..95b40b25 100644 --- a/actix-rt/src/arbiter.rs +++ b/actix-rt/src/arbiter.rs @@ -9,9 +9,6 @@ use std::{fmt, thread}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot::{channel, error::RecvError as Canceled, Sender}; -// use futures_util::stream::FuturesUnordered; -// use tokio::task::JoinHandle; -// use tokio::stream::StreamExt; use tokio::task::LocalSet; use crate::runtime::Runtime; @@ -19,12 +16,6 @@ use crate::system::System; thread_local!( static ADDR: RefCell> = RefCell::new(None); - // TODO: Commented out code are for Arbiter::local_join function. - // It can be safely removed if this function is not used in actix-*. - // - // /// stores join handle for spawned async tasks. - // static HANDLE: RefCell>> = - // RefCell::new(FuturesUnordered::new()); static STORAGE: RefCell>> = RefCell::new(HashMap::new()); ); @@ -154,11 +145,6 @@ impl Arbiter { where F: Future + 'static, { - // HANDLE.with(|handle| { - // let handle = handle.borrow(); - // handle.push(tokio::task::spawn_local(future)); - // }); - // let _ = tokio::task::spawn_local(CleanupPending); let _ = tokio::task::spawn_local(future); } @@ -277,32 +263,12 @@ impl Arbiter { /// Returns a future that will be completed once all currently spawned futures /// have completed. - #[deprecated(since = "1.2.0", note = "Arbiter::local_join function is removed.")] + #[deprecated(since = "2.0.0", note = "Arbiter::local_join function is removed.")] pub async fn local_join() { - // let handle = HANDLE.with(|fut| std::mem::take(&mut *fut.borrow_mut())); - // async move { - // handle.collect::>().await; - // } unimplemented!("Arbiter::local_join function is removed.") } } -// /// Future used for cleaning-up already finished `JoinHandle`s -// /// from the `PENDING` list so the vector doesn't grow indefinitely -// struct CleanupPending; -// -// impl Future for CleanupPending { -// type Output = (); -// -// fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -// HANDLE.with(move |handle| { -// recycle_join_handle(&mut *handle.borrow_mut(), cx); -// }); -// -// Poll::Ready(()) -// } -// } - struct ArbiterController { rx: UnboundedReceiver, } @@ -330,11 +296,6 @@ impl Future for ArbiterController { Poll::Ready(Some(item)) => match item { ArbiterCommand::Stop => return Poll::Ready(()), ArbiterCommand::Execute(fut) => { - // HANDLE.with(|handle| { - // let mut handle = handle.borrow_mut(); - // handle.push(tokio::task::spawn_local(fut)); - // recycle_join_handle(&mut *handle, cx); - // }); tokio::task::spawn_local(fut); } ArbiterCommand::ExecuteFn(f) => { @@ -347,20 +308,6 @@ impl Future for ArbiterController { } } -// fn recycle_join_handle(handle: &mut FuturesUnordered>, cx: &mut Context<'_>) { -// let _ = Pin::new(&mut *handle).poll_next(cx); -// -// // Try to recycle more join handles and free up memory. -// // -// // this is a guess. The yield limit for FuturesUnordered is 32. -// // So poll an extra 3 times would make the total poll below 128. -// if handle.len() > 64 { -// (0..3).for_each(|_| { -// let _ = Pin::new(&mut *handle).poll_next(cx); -// }) -// } -// } - #[derive(Debug)] pub(crate) enum SystemCommand { Exit(i32), diff --git a/actix-rt/src/lib.rs b/actix-rt/src/lib.rs index 30fa2d78..6e9d0464 100644 --- a/actix-rt/src/lib.rs +++ b/actix-rt/src/lib.rs @@ -66,7 +66,7 @@ pub mod time { pub use tokio::time::{timeout, Timeout}; } -/// Blocking task management. +/// Task management. pub mod task { pub use tokio::task::{spawn_blocking, yield_now, JoinHandle}; } diff --git a/actix-rt/tests/integration_tests.rs b/actix-rt/tests/integration_tests.rs index 12ceb4ef..f338602d 100644 --- a/actix-rt/tests/integration_tests.rs +++ b/actix-rt/tests/integration_tests.rs @@ -62,43 +62,6 @@ fn join_another_arbiter() { ); } -// #[test] -// fn join_current_arbiter() { -// let time = Duration::from_secs(2); -// -// let instant = Instant::now(); -// actix_rt::System::new("test_join_current_arbiter").block_on(async move { -// actix_rt::spawn(async move { -// tokio::time::delay_for(time).await; -// actix_rt::Arbiter::current().stop(); -// }); -// actix_rt::Arbiter::local_join().await; -// }); -// assert!( -// instant.elapsed() >= time, -// "Join on current arbiter should wait for all spawned futures" -// ); -// -// let large_timer = Duration::from_secs(20); -// let instant = Instant::now(); -// actix_rt::System::new("test_join_current_arbiter").block_on(async move { -// actix_rt::spawn(async move { -// tokio::time::delay_for(time).await; -// actix_rt::Arbiter::current().stop(); -// }); -// let f = actix_rt::Arbiter::local_join(); -// actix_rt::spawn(async move { -// tokio::time::delay_for(large_timer).await; -// actix_rt::Arbiter::current().stop(); -// }); -// f.await; -// }); -// assert!( -// instant.elapsed() < large_timer, -// "local_join should await only for the already spawned futures" -// ); -// } - #[test] fn non_static_block_on() { let string = String::from("test_str"); diff --git a/actix-server/Cargo.toml b/actix-server/Cargo.toml old mode 100644 new mode 100755 index ead85de0..25095039 --- a/actix-server/Cargo.toml +++ b/actix-server/Cargo.toml @@ -24,11 +24,11 @@ default = [] [dependencies] actix-codec = "0.4.0-beta.1" -actix-rt = "2.0.0-beta.2" +actix-rt = { version = "2.0.0-beta.2", default-features = false } actix-service = "2.0.0-beta.3" actix-utils = "3.0.0-beta.1" -futures-core = { version = "0.3.7", default-features = false } +futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } log = "0.4" mio = { version = "0.7.6", features = ["os-poll", "net"] } num_cpus = "1.13" diff --git a/actix-tls/CHANGES.md b/actix-tls/CHANGES.md index 1a7ef7a7..d73a7830 100644 --- a/actix-tls/CHANGES.md +++ b/actix-tls/CHANGES.md @@ -1,6 +1,16 @@ # Changes ## Unreleased - 2021-xx-xx +* Remove `trust-dns-proto` and `trust-dns-resolver` [#248] +* Use `tokio::net::lookup_host` as simple and basic default resolver [#248] +* Add `Resolve` trait for custom dns resolver. [#248] +* Add `Resolver::new_custom` function to construct custom resolvers. [#248] +* Export `webpki_roots::TLS_SERVER_ROOTS` in `actix_tls::connect` mod and remove + the export from `actix_tls::accept` [#248] +* Remove `ConnectTakeAddrsIter`. `Connect::take_addrs` would return + `ConnectAddrsIter<'static>` as owned iterator. [#248] + +[#248]: https://github.com/actix/actix-net/pull/248 ## 3.0.0-beta.2 - 2021-xx-xx diff --git a/actix-tls/Cargo.toml b/actix-tls/Cargo.toml old mode 100644 new mode 100755 index 8f23d60c..0886b27e --- a/actix-tls/Cargo.toml +++ b/actix-tls/Cargo.toml @@ -29,7 +29,7 @@ default = ["accept", "connect", "uri"] accept = [] # enable connector services -connect = ["trust-dns-proto/tokio-runtime", "trust-dns-resolver/tokio-runtime", "trust-dns-resolver/system-config"] +connect = [] # use openssl impls openssl = ["tls-openssl", "tokio-openssl"] @@ -45,20 +45,15 @@ uri = ["http"] [dependencies] actix-codec = "0.4.0-beta.1" -actix-rt = "2.0.0-beta.2" +actix-rt = { version = "2.0.0-beta.2", default-features = false } actix-service = "2.0.0-beta.3" actix-utils = "3.0.0-beta.1" derive_more = "0.99.5" -either = "1.6" -futures-util = { version = "0.3.7", default-features = false } -http = { version = "0.2.2", optional = true } +futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } +http = { version = "0.2.3", optional = true } log = "0.4" -# resolver -trust-dns-proto = { version = "0.20.0", default-features = false, optional = true } -trust-dns-resolver = { version = "0.20.0", default-features = false, optional = true } - # openssl tls-openssl = { package = "openssl", version = "0.10", optional = true } tokio-openssl = { version = "0.6", optional = true } @@ -76,8 +71,10 @@ tls-native-tls = { package = "native-tls", version = "0.2", optional = true } tokio-native-tls = { version = "0.3", optional = true } [dev-dependencies] +actix-rt = "2.0.0-beta.2" actix-server = "2.0.0-beta.2" bytes = "1" env_logger = "0.8" futures-util = { version = "0.3.7", default-features = false, features = ["sink"] } log = "0.4" +trust-dns-resolver = "0.20.0" \ No newline at end of file diff --git a/actix-tls/src/accept/nativetls.rs b/actix-tls/src/accept/nativetls.rs index 5d80ce8b..9b2aeefb 100644 --- a/actix-tls/src/accept/nativetls.rs +++ b/actix-tls/src/accept/nativetls.rs @@ -3,7 +3,7 @@ use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::Counter; -use futures_util::future::{ready, LocalBoxFuture, Ready}; +use futures_core::future::LocalBoxFuture; pub use native_tls::Error; pub use tokio_native_tls::{TlsAcceptor, TlsStream}; @@ -44,15 +44,16 @@ where type Service = NativeTlsAcceptorService; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - MAX_CONN_COUNTER.with(|conns| { - ready(Ok(NativeTlsAcceptorService { + let res = MAX_CONN_COUNTER.with(|conns| { + Ok(NativeTlsAcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), - })) - }) + }) + }); + Box::pin(async { res }) } } diff --git a/actix-tls/src/accept/openssl.rs b/actix-tls/src/accept/openssl.rs index efda5c38..5f2d2fc2 100644 --- a/actix-tls/src/accept/openssl.rs +++ b/actix-tls/src/accept/openssl.rs @@ -1,14 +1,13 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::{Counter, CounterGuard}; -use futures_util::{ - future::{ready, Ready}, - ready, -}; +use futures_core::{future::LocalBoxFuture, ready}; pub use openssl::ssl::{ AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder, @@ -50,15 +49,16 @@ where type Config = (); type Service = AcceptorService; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - MAX_CONN_COUNTER.with(|conns| { - ready(Ok(AcceptorService { + let res = MAX_CONN_COUNTER.with(|conns| { + Ok(AcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), - })) - }) + }) + }); + Box::pin(async { res }) } } diff --git a/actix-tls/src/accept/rustls.rs b/actix-tls/src/accept/rustls.rs index a6686f44..e7efaa3f 100644 --- a/actix-tls/src/accept/rustls.rs +++ b/actix-tls/src/accept/rustls.rs @@ -1,18 +1,19 @@ -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; +use std::{ + future::Future, + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::{Counter, CounterGuard}; -use futures_util::future::{ready, Ready}; +use futures_core::future::LocalBoxFuture; use tokio_rustls::{Accept, TlsAcceptor}; pub use rustls::{ServerConfig, Session}; pub use tokio_rustls::server::TlsStream; -pub use webpki_roots::TLS_SERVER_ROOTS; use super::MAX_CONN_COUNTER; @@ -52,15 +53,16 @@ where type Service = AcceptorService; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - MAX_CONN_COUNTER.with(|conns| { - ready(Ok(AcceptorService { + let res = MAX_CONN_COUNTER.with(|conns| { + Ok(AcceptorService { acceptor: self.config.clone().into(), conns: conns.clone(), - })) - }) + }) + }); + Box::pin(async { res }) } } diff --git a/actix-tls/src/connect/connect.rs b/actix-tls/src/connect/connect.rs old mode 100644 new mode 100755 index 37fa9c6e..3928870d --- a/actix-tls/src/connect/connect.rs +++ b/actix-tls/src/connect/connect.rs @@ -1,9 +1,9 @@ -use std::collections::{vec_deque, VecDeque}; -use std::fmt; -use std::iter::{FromIterator, FusedIterator}; -use std::net::SocketAddr; - -use either::Either; +use std::{ + collections::{vec_deque, VecDeque}, + fmt, + iter::{FromIterator, FusedIterator}, + net::SocketAddr, +}; /// Connect request pub trait Address: Unpin + 'static { @@ -39,7 +39,25 @@ impl Address for &'static str { pub struct Connect { pub(crate) req: T, pub(crate) port: u16, - pub(crate) addr: Option>>, + pub(crate) addr: ConnectAddrs, +} + +#[derive(Eq, PartialEq, Debug, Hash)] +pub(crate) enum ConnectAddrs { + One(Option), + Multi(VecDeque), +} + +impl ConnectAddrs { + pub(crate) fn is_none(&self) -> bool { + matches!(*self, Self::One(None)) + } +} + +impl Default for ConnectAddrs { + fn default() -> Self { + Self::One(None) + } } impl Connect { @@ -49,7 +67,7 @@ impl Connect { Connect { req, port: port.unwrap_or(0), - addr: None, + addr: ConnectAddrs::One(None), } } @@ -59,7 +77,7 @@ impl Connect { Connect { req, port: 0, - addr: Some(Either::Left(addr)), + addr: ConnectAddrs::One(Some(addr)), } } @@ -73,9 +91,7 @@ impl Connect { /// Use address. pub fn set_addr(mut self, addr: Option) -> Self { - if let Some(addr) = addr { - self.addr = Some(Either::Left(addr)); - } + self.addr = ConnectAddrs::One(addr); self } @@ -86,9 +102,9 @@ impl Connect { { let mut addrs = VecDeque::from_iter(addrs); self.addr = if addrs.len() < 2 { - addrs.pop_front().map(Either::Left) + ConnectAddrs::One(addrs.pop_front()) } else { - Some(Either::Right(addrs)) + ConnectAddrs::Multi(addrs) }; self } @@ -105,24 +121,18 @@ impl Connect { /// Pre-resolved addresses of the request. pub fn addrs(&self) -> ConnectAddrsIter<'_> { - let inner = match self.addr { - None => Either::Left(None), - Some(Either::Left(addr)) => Either::Left(Some(addr)), - Some(Either::Right(ref addrs)) => Either::Right(addrs.iter()), - }; - - ConnectAddrsIter { inner } + match self.addr { + ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), + ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), + } } /// Takes pre-resolved addresses of the request. - pub fn take_addrs(&mut self) -> ConnectTakeAddrsIter { - let inner = match self.addr.take() { - None => Either::Left(None), - Some(Either::Left(addr)) => Either::Left(Some(addr)), - Some(Either::Right(addrs)) => Either::Right(addrs.into_iter()), - }; - - ConnectTakeAddrsIter { inner } + pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> { + match std::mem::take(&mut self.addr) { + ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), + ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), + } } } @@ -140,25 +150,29 @@ impl fmt::Display for Connect { /// Iterator over addresses in a [`Connect`] request. #[derive(Clone)] -pub struct ConnectAddrsIter<'a> { - inner: Either, vec_deque::Iter<'a, SocketAddr>>, +pub enum ConnectAddrsIter<'a> { + One(Option), + Multi(vec_deque::Iter<'a, SocketAddr>), + MultiOwned(vec_deque::IntoIter), } impl Iterator for ConnectAddrsIter<'_> { type Item = SocketAddr; fn next(&mut self) -> Option { - match self.inner { - Either::Left(ref mut opt) => opt.take(), - Either::Right(ref mut iter) => iter.next().copied(), + match *self { + Self::One(ref mut addr) => addr.take(), + Self::Multi(ref mut iter) => iter.next().copied(), + Self::MultiOwned(ref mut iter) => iter.next(), } } fn size_hint(&self) -> (usize, Option) { - match self.inner { - Either::Left(Some(_)) => (1, Some(1)), - Either::Left(None) => (0, Some(0)), - Either::Right(ref iter) => iter.size_hint(), + match *self { + Self::One(None) => (0, Some(0)), + Self::One(Some(_)) => (1, Some(1)), + Self::Multi(ref iter) => iter.size_hint(), + Self::MultiOwned(ref iter) => iter.size_hint(), } } } @@ -173,35 +187,6 @@ impl ExactSizeIterator for ConnectAddrsIter<'_> {} impl FusedIterator for ConnectAddrsIter<'_> {} -/// Owned iterator over addresses in a [`Connect`] request. -#[derive(Debug)] -pub struct ConnectTakeAddrsIter { - inner: Either, vec_deque::IntoIter>, -} - -impl Iterator for ConnectTakeAddrsIter { - type Item = SocketAddr; - - fn next(&mut self) -> Option { - match self.inner { - Either::Left(ref mut opt) => opt.take(), - Either::Right(ref mut iter) => iter.next(), - } - } - - fn size_hint(&self) -> (usize, Option) { - match self.inner { - Either::Left(Some(_)) => (1, Some(1)), - Either::Left(None) => (0, Some(0)), - Either::Right(ref iter) => iter.size_hint(), - } - } -} - -impl ExactSizeIterator for ConnectTakeAddrsIter {} - -impl FusedIterator for ConnectTakeAddrsIter {} - fn parse(host: &str) -> (&str, Option) { let mut parts_iter = host.splitn(2, ':'); if let Some(host) = parts_iter.next() { diff --git a/actix-tls/src/connect/connector.rs b/actix-tls/src/connect/connector.rs old mode 100644 new mode 100755 index a0a6b8b5..fe4ceec6 --- a/actix-tls/src/connect/connector.rs +++ b/actix-tls/src/connect/connector.rs @@ -1,76 +1,50 @@ -use std::collections::VecDeque; -use std::future::Future; -use std::io; -use std::marker::PhantomData; -use std::net::SocketAddr; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + collections::VecDeque, + future::Future, + io, + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ready, Ready}; +use futures_core::future::LocalBoxFuture; use log::{error, trace}; -use super::connect::{Address, Connect, Connection}; +use super::connect::{Address, Connect, ConnectAddrs, Connection}; use super::error::ConnectError; /// TCP connector service factory -#[derive(Debug)] -pub struct TcpConnectorFactory(PhantomData); - -impl TcpConnectorFactory { - pub fn new() -> Self { - TcpConnectorFactory(PhantomData) - } +#[derive(Copy, Clone, Debug)] +pub struct TcpConnectorFactory; +impl TcpConnectorFactory { /// Create TCP connector service - pub fn service(&self) -> TcpConnector { - TcpConnector(PhantomData) + pub fn service(&self) -> TcpConnector { + TcpConnector } } -impl Default for TcpConnectorFactory { - fn default() -> Self { - TcpConnectorFactory(PhantomData) - } -} - -impl Clone for TcpConnectorFactory { - fn clone(&self) -> Self { - TcpConnectorFactory(PhantomData) - } -} - -impl ServiceFactory> for TcpConnectorFactory { +impl ServiceFactory> for TcpConnectorFactory { type Response = Connection; type Error = ConnectError; type Config = (); - type Service = TcpConnector; + type Service = TcpConnector; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - ready(Ok(self.service())) + let service = self.service(); + Box::pin(async move { Ok(service) }) } } /// TCP connector service -#[derive(Default, Debug)] -pub struct TcpConnector(PhantomData); +#[derive(Copy, Clone, Debug)] +pub struct TcpConnector; -impl TcpConnector { - pub fn new() -> Self { - TcpConnector(PhantomData) - } -} - -impl Clone for TcpConnector { - fn clone(&self) -> Self { - TcpConnector(PhantomData) - } -} - -impl Service> for TcpConnector { +impl Service> for TcpConnector { type Response = Connection; type Error = ConnectError; type Future = TcpConnectorResponse; @@ -81,17 +55,10 @@ impl Service> for TcpConnector { let port = req.port(); let Connect { req, addr, .. } = req; - if let Some(addr) = addr { - TcpConnectorResponse::new(req, port, addr) - } else { - error!("TCP connector: got unresolved address"); - TcpConnectorResponse::Error(Some(ConnectError::Unresolved)) - } + TcpConnectorResponse::new(req, port, addr) } } -type LocalBoxFuture<'a, T> = Pin + 'a>>; - #[doc(hidden)] /// TCP stream connector response future pub enum TcpConnectorResponse { @@ -105,11 +72,7 @@ pub enum TcpConnectorResponse { } impl TcpConnectorResponse { - pub fn new( - req: T, - port: u16, - addr: either::Either>, - ) -> TcpConnectorResponse { + pub(crate) fn new(req: T, port: u16, addr: ConnectAddrs) -> TcpConnectorResponse { trace!( "TCP connector - connecting to {:?} port:{}", req.host(), @@ -117,13 +80,19 @@ impl TcpConnectorResponse { ); match addr { - either::Either::Left(addr) => TcpConnectorResponse::Response { + ConnectAddrs::One(None) => { + error!("TCP connector: got unresolved address"); + TcpConnectorResponse::Error(Some(ConnectError::Unresolved)) + } + ConnectAddrs::One(Some(addr)) => TcpConnectorResponse::Response { req: Some(req), port, addrs: None, stream: Some(Box::pin(TcpStream::connect(addr))), }, - either::Either::Right(addrs) => TcpConnectorResponse::Response { + // when resolver returns multiple socket addr for request they would be popped from + // front end of queue and returns with the first successful tcp connection. + ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response { req: Some(req), port, addrs: Some(addrs), @@ -165,7 +134,7 @@ impl Future for TcpConnectorResponse { port, ); if addrs.is_none() || addrs.as_ref().unwrap().is_empty() { - return Poll::Ready(Err(err.into())); + return Poll::Ready(Err(ConnectError::Io(err))); } } } diff --git a/actix-tls/src/connect/error.rs b/actix-tls/src/connect/error.rs index 84b363dc..5d8cb9db 100644 --- a/actix-tls/src/connect/error.rs +++ b/actix-tls/src/connect/error.rs @@ -1,13 +1,12 @@ use std::io; -use derive_more::{Display, From}; -use trust_dns_resolver::error::ResolveError; +use derive_more::Display; -#[derive(Debug, From, Display)] +#[derive(Debug, Display)] pub enum ConnectError { /// Failed to resolve the hostname #[display(fmt = "Failed resolving hostname: {}", _0)] - Resolver(ResolveError), + Resolver(Box), /// No dns records #[display(fmt = "No dns records found for the input")] diff --git a/actix-tls/src/connect/mod.rs b/actix-tls/src/connect/mod.rs index 75312c59..c864886d 100644 --- a/actix-tls/src/connect/mod.rs +++ b/actix-tls/src/connect/mod.rs @@ -4,6 +4,17 @@ //! //! * `openssl` - enables TLS support via `openssl` crate //! * `rustls` - enables TLS support via `rustls` crate +//! +//! ## Workflow of connector service: +//! - resolve [`Address`](self::connect::Address) with given [`Resolver`](self::resolve::Resolver) +//! and collect [`SocketAddrs`](std::net::SocketAddr). +//! - establish Tcp connection and return [`TcpStream`](tokio::net::TcpStream). +//! +//! ## Workflow of tls connector services: +//! - Establish [`TcpStream`](tokio::net::TcpStream) with connector service. +//! - Wrap around the stream and do connect handshake with remote address. +//! - Return certain stream type impl [`AsyncRead`](tokio::io::AsyncRead) and +//! [`AsyncWrite`](tokio::io::AsyncWrite) mod connect; mod connector; @@ -14,67 +25,26 @@ pub mod ssl; #[cfg(feature = "uri")] mod uri; -use actix_rt::{net::TcpStream, Arbiter}; +use actix_rt::net::TcpStream; use actix_service::{pipeline, pipeline_factory, Service, ServiceFactory}; -use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; -use trust_dns_resolver::system_conf::read_system_conf; -use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; - -pub mod resolver { - pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; - pub use trust_dns_resolver::system_conf::read_system_conf; - pub use trust_dns_resolver::{error::ResolveError, AsyncResolver}; -} pub use self::connect::{Address, Connect, Connection}; pub use self::connector::{TcpConnector, TcpConnectorFactory}; pub use self::error::ConnectError; -pub use self::resolve::{Resolver, ResolverFactory}; +pub use self::resolve::{Resolve, Resolver, ResolverFactory}; pub use self::service::{ConnectService, ConnectServiceFactory, TcpConnectService}; -pub async fn start_resolver( - cfg: ResolverConfig, - opts: ResolverOpts, -) -> Result { - Ok(AsyncResolver::tokio(cfg, opts)?) -} - -struct DefaultResolver(AsyncResolver); - -pub(crate) async fn get_default_resolver() -> Result { - if Arbiter::contains_item::() { - Ok(Arbiter::get_item(|item: &DefaultResolver| item.0.clone())) - } else { - let (cfg, opts) = match read_system_conf() { - Ok((cfg, opts)) => (cfg, opts), - Err(e) => { - log::error!("TRust-DNS can not load system config: {}", e); - (ResolverConfig::default(), ResolverOpts::default()) - } - }; - - let resolver = AsyncResolver::tokio(cfg, opts)?; - - Arbiter::set_item(DefaultResolver(resolver.clone())); - Ok(resolver) - } -} - -pub async fn start_default_resolver() -> Result { - get_default_resolver().await -} - /// Create TCP connector service. pub fn new_connector( - resolver: AsyncResolver, + resolver: Resolver, ) -> impl Service, Response = Connection, Error = ConnectError> + Clone { - pipeline(Resolver::new(resolver)).and_then(TcpConnector::new()) + pipeline(resolver).and_then(TcpConnector) } /// Create TCP connector service factory. pub fn new_connector_factory( - resolver: AsyncResolver, + resolver: Resolver, ) -> impl ServiceFactory< Connect, Config = (), @@ -82,14 +52,14 @@ pub fn new_connector_factory( Error = ConnectError, InitError = (), > + Clone { - pipeline_factory(ResolverFactory::new(resolver)).and_then(TcpConnectorFactory::new()) + pipeline_factory(ResolverFactory::new(resolver)).and_then(TcpConnectorFactory) } /// Create connector service with default parameters. pub fn default_connector( ) -> impl Service, Response = Connection, Error = ConnectError> + Clone { - pipeline(Resolver::default()).and_then(TcpConnector::new()) + new_connector(Resolver::Default) } /// Create connector service factory with default parameters. @@ -100,5 +70,5 @@ pub fn default_connector_factory() -> impl ServiceFactory< Error = ConnectError, InitError = (), > + Clone { - pipeline_factory(ResolverFactory::default()).and_then(TcpConnectorFactory::new()) + new_connector_factory(Resolver::Default) } diff --git a/actix-tls/src/connect/resolve.rs b/actix-tls/src/connect/resolve.rs old mode 100644 new mode 100755 index 61535faa..211da387 --- a/actix-tls/src/connect/resolve.rs +++ b/actix-tls/src/connect/resolve.rs @@ -1,184 +1,225 @@ -use std::future::Future; -use std::marker::PhantomData; -use std::net::SocketAddr; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, + vec::IntoIter, +}; +use actix_rt::task::{spawn_blocking, JoinHandle}; use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ok, Either, Ready}; +use futures_core::{future::LocalBoxFuture, ready}; use log::trace; -use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; -use trust_dns_resolver::{error::ResolveError, lookup_ip::LookupIp}; use super::connect::{Address, Connect}; use super::error::ConnectError; -use super::get_default_resolver; /// DNS Resolver Service factory -pub struct ResolverFactory { - resolver: Option, - _t: PhantomData, +#[derive(Clone)] +pub struct ResolverFactory { + resolver: Resolver, } -impl ResolverFactory { - /// Create new resolver instance with custom configuration and options. - pub fn new(resolver: AsyncResolver) -> Self { - ResolverFactory { - resolver: Some(resolver), - _t: PhantomData, - } +impl ResolverFactory { + pub fn new(resolver: Resolver) -> Self { + Self { resolver } } - pub fn service(&self) -> Resolver { - Resolver { - resolver: self.resolver.clone(), - _t: PhantomData, - } + pub fn service(&self) -> Resolver { + self.resolver.clone() } } -impl Default for ResolverFactory { - fn default() -> Self { - ResolverFactory { - resolver: None, - _t: PhantomData, - } - } -} - -impl Clone for ResolverFactory { - fn clone(&self) -> Self { - ResolverFactory { - resolver: self.resolver.clone(), - _t: PhantomData, - } - } -} - -impl ServiceFactory> for ResolverFactory { +impl ServiceFactory> for ResolverFactory { type Response = Connect; type Error = ConnectError; type Config = (); - type Service = Resolver; + type Service = Resolver; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - ok(self.service()) + let service = self.resolver.clone(); + Box::pin(async { Ok(service) }) } } /// DNS Resolver Service -pub struct Resolver { - resolver: Option, - _t: PhantomData, +#[derive(Clone)] +pub enum Resolver { + Default, + Custom(Rc), } -impl Resolver { - /// Create new resolver instance with custom configuration and options. - pub fn new(resolver: AsyncResolver) -> Self { - Resolver { - resolver: Some(resolver), - _t: PhantomData, - } +/// trait for custom lookup with self defined resolver. +/// +/// # Example: +/// ```rust +/// use std::net::SocketAddr; +/// +/// use actix_tls::connect::{Resolve, Resolver}; +/// use futures_util::future::LocalBoxFuture; +/// +/// // use trust_dns_resolver as custom resolver. +/// use trust_dns_resolver::TokioAsyncResolver; +/// +/// struct MyResolver { +/// trust_dns: TokioAsyncResolver, +/// }; +/// +/// // impl Resolve trait and convert given host address str and port to SocketAddr. +/// impl Resolve for MyResolver { +/// fn lookup<'a>( +/// &'a self, +/// host: &'a str, +/// port: u16, +/// ) -> LocalBoxFuture<'a, Result, Box>> { +/// Box::pin(async move { +/// let res = self +/// .trust_dns +/// .lookup_ip(host) +/// .await? +/// .iter() +/// .map(|ip| SocketAddr::new(ip, port)) +/// .collect(); +/// Ok(res) +/// }) +/// } +/// } +/// +/// let resolver = MyResolver { +/// trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), +/// }; +/// +/// // construct custom resolver +/// let resolver = Resolver::new_custom(resolver); +/// +/// // pass custom resolver to connector builder. +/// // connector would then be usable as a service or awc's connector. +/// let connector = actix_tls::connect::new_connector::<&str>(resolver.clone()); +/// +/// // resolver can be passed to connector factory where returned service factory +/// // can be used to construct new connector services. +/// let factory = actix_tls::connect::new_connector_factory::<&str>(resolver); +///``` +pub trait Resolve { + fn lookup<'a>( + &'a self, + host: &'a str, + port: u16, + ) -> LocalBoxFuture<'a, Result, Box>>; +} + +impl Resolver { + /// Constructor for custom Resolve trait object and use it as resolver. + pub fn new_custom(resolver: impl Resolve + 'static) -> Self { + Self::Custom(Rc::new(resolver)) + } + + // look up with default resolver variant. + fn look_up(req: &Connect) -> JoinHandle>> { + let host = req.host(); + // TODO: Connect should always return host with port if possible. + let host = if req + .host() + .splitn(2, ':') + .last() + .and_then(|p| p.parse::().ok()) + .map(|p| p == req.port()) + .unwrap_or(false) + { + host.to_string() + } else { + format!("{}:{}", host, req.port()) + }; + + spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host)) } } -impl Default for Resolver { - fn default() -> Self { - Resolver { - resolver: None, - _t: PhantomData, - } - } -} - -impl Clone for Resolver { - fn clone(&self) -> Self { - Resolver { - resolver: self.resolver.clone(), - _t: PhantomData, - } - } -} - -impl Service> for Resolver { +impl Service> for Resolver { type Response = Connect; type Error = ConnectError; - #[allow(clippy::type_complexity)] - type Future = Either< - Pin>>>, - Ready, Self::Error>>, - >; + type Future = ResolverFuture; actix_service::always_ready!(); - fn call(&mut self, mut req: Connect) -> Self::Future { - if req.addr.is_some() { - Either::Right(ok(req)) + fn call(&mut self, req: Connect) -> Self::Future { + if !req.addr.is_none() { + ResolverFuture::Connected(Some(req)) } else if let Ok(ip) = req.host().parse() { - req.addr = Some(either::Either::Left(SocketAddr::new(ip, req.port()))); - Either::Right(ok(req)) + let addr = SocketAddr::new(ip, req.port()); + let req = req.set_addr(Some(addr)); + ResolverFuture::Connected(Some(req)) } else { - let resolver = self.resolver.as_ref().map(AsyncResolver::clone); - Either::Left(Box::pin(async move { - trace!("DNS resolver: resolving host {:?}", req.host()); - let resolver = if let Some(resolver) = resolver { - resolver - } else { - get_default_resolver() - .await - .expect("Failed to get default resolver") - }; - ResolverFuture::new(req, &resolver).await - })) + trace!("DNS resolver: resolving host {:?}", req.host()); + + match self { + Self::Default => { + let fut = Self::look_up(&req); + ResolverFuture::LookUp(fut, Some(req)) + } + + Self::Custom(resolver) => { + let resolver = Rc::clone(&resolver); + ResolverFuture::LookupCustom(Box::pin(async move { + let addrs = resolver + .lookup(req.host(), req.port()) + .await + .map_err(ConnectError::Resolver)?; + + let req = req.set_addrs(addrs); + + if req.addr.is_none() { + Err(ConnectError::NoRecords) + } else { + Ok(req) + } + })) + } + } } } } -type LookupIpFuture = Pin>>>; - -#[doc(hidden)] -/// Resolver future -pub struct ResolverFuture { - req: Option>, - lookup: LookupIpFuture, -} - -impl ResolverFuture { - pub fn new(req: Connect, resolver: &AsyncResolver) -> Self { - let host = if let Some(host) = req.host().splitn(2, ':').next() { - host - } else { - req.host() - }; - - // Clone data to be moved to the lookup future - let host_clone = host.to_owned(); - let resolver_clone = resolver.clone(); - - ResolverFuture { - lookup: Box::pin(async move { - let resolver = resolver_clone; - resolver.lookup_ip(host_clone).await - }), - req: Some(req), - } - } +pub enum ResolverFuture { + Connected(Option>), + LookUp( + JoinHandle>>, + Option>, + ), + LookupCustom(LocalBoxFuture<'static, Result, ConnectError>>), } impl Future for ResolverFuture { type Output = Result, ConnectError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + match self.get_mut() { + Self::Connected(conn) => Poll::Ready(Ok(conn + .take() + .expect("ResolverFuture polled after finished"))), + Self::LookUp(fut, req) => { + let res = match ready!(Pin::new(fut).poll(cx)) { + Ok(Ok(res)) => Ok(res), + Ok(Err(e)) => Err(ConnectError::Resolver(Box::new(e))), + Err(e) => Err(ConnectError::Io(e.into())), + }; - match Pin::new(&mut this.lookup).poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(ips)) => { - let req = this.req.take().unwrap(); - let port = req.port(); - let req = req.set_addrs(ips.iter().map(|ip| SocketAddr::new(ip, port))); + let req = req.take().unwrap(); + + let addrs = res.map_err(|e| { + trace!( + "DNS resolver: failed to resolve host {:?} err: {:?}", + req.host(), + e + ); + e + })?; + + let req = req.set_addrs(addrs); trace!( "DNS resolver: host {:?} resolved to {:?}", @@ -192,14 +233,7 @@ impl Future for ResolverFuture { Poll::Ready(Ok(req)) } } - Poll::Ready(Err(e)) => { - trace!( - "DNS resolver: failed to resolve host {:?} err: {}", - this.req.as_ref().unwrap().host(), - e - ); - Poll::Ready(Err(e.into())) - } + Self::LookupCustom(fut) => fut.as_mut().poll(cx), } } } diff --git a/actix-tls/src/connect/service.rs b/actix-tls/src/connect/service.rs old mode 100644 new mode 100755 index 59fe20cc..6fa8a453 --- a/actix-tls/src/connect/service.rs +++ b/actix-tls/src/connect/service.rs @@ -1,42 +1,34 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; -use either::Either; -use futures_util::future::{ok, Ready}; -use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; +use futures_core::{future::LocalBoxFuture, ready}; use super::connect::{Address, Connect, Connection}; use super::connector::{TcpConnector, TcpConnectorFactory}; use super::error::ConnectError; use super::resolve::{Resolver, ResolverFactory}; -pub struct ConnectServiceFactory { - tcp: TcpConnectorFactory, - resolver: ResolverFactory, +pub struct ConnectServiceFactory { + tcp: TcpConnectorFactory, + resolver: ResolverFactory, } -impl ConnectServiceFactory { +impl ConnectServiceFactory { /// Construct new ConnectService factory - pub fn new() -> Self { + pub fn new(resolver: Resolver) -> Self { ConnectServiceFactory { - tcp: TcpConnectorFactory::default(), - resolver: ResolverFactory::default(), - } - } - - /// Construct new connect service with custom dns resolver - pub fn with_resolver(resolver: AsyncResolver) -> Self { - ConnectServiceFactory { - tcp: TcpConnectorFactory::default(), + tcp: TcpConnectorFactory, resolver: ResolverFactory::new(resolver), } } /// Construct new service - pub fn service(&self) -> ConnectService { + pub fn service(&self) -> ConnectService { ConnectService { tcp: self.tcp.service(), resolver: self.resolver.service(), @@ -44,7 +36,7 @@ impl ConnectServiceFactory { } /// Construct new tcp stream service - pub fn tcp_service(&self) -> TcpConnectService { + pub fn tcp_service(&self) -> TcpConnectService { TcpConnectService { tcp: self.tcp.service(), resolver: self.resolver.service(), @@ -52,44 +44,36 @@ impl ConnectServiceFactory { } } -impl Default for ConnectServiceFactory { - fn default() -> Self { - ConnectServiceFactory { - tcp: TcpConnectorFactory::default(), - resolver: ResolverFactory::default(), - } - } -} - -impl Clone for ConnectServiceFactory { +impl Clone for ConnectServiceFactory { fn clone(&self) -> Self { ConnectServiceFactory { - tcp: self.tcp.clone(), + tcp: self.tcp, resolver: self.resolver.clone(), } } } -impl ServiceFactory> for ConnectServiceFactory { +impl ServiceFactory> for ConnectServiceFactory { type Response = Connection; type Error = ConnectError; type Config = (); - type Service = ConnectService; + type Service = ConnectService; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - ok(self.service()) + let service = self.service(); + Box::pin(async move { Ok(service) }) } } #[derive(Clone)] -pub struct ConnectService { - tcp: TcpConnector, - resolver: Resolver, +pub struct ConnectService { + tcp: TcpConnector, + resolver: Resolver, } -impl Service> for ConnectService { +impl Service> for ConnectService { type Response = Connection; type Error = ConnectError; type Future = ConnectServiceResponse; @@ -98,65 +82,67 @@ impl Service> for ConnectService { fn call(&mut self, req: Connect) -> Self::Future { ConnectServiceResponse { - state: ConnectState::Resolve(self.resolver.call(req)), - tcp: self.tcp.clone(), + fut: ConnectFuture::Resolve(self.resolver.call(req)), + tcp: self.tcp, } } } -enum ConnectState { - Resolve( as Service>>::Future), - Connect( as Service>>::Future), +// helper enum to generic over futures of resolve and connect phase. +pub(crate) enum ConnectFuture { + Resolve(>>::Future), + Connect(>>::Future), } -impl ConnectState { - #[allow(clippy::type_complexity)] - fn poll( +// helper enum to contain the future output of ConnectFuture +pub(crate) enum ConnectOutput { + Resolved(Connect), + Connected(Connection), +} + +impl ConnectFuture { + fn poll_connect( &mut self, cx: &mut Context<'_>, - ) -> Either, ConnectError>>, Connect> { + ) -> Poll, ConnectError>> { match self { - ConnectState::Resolve(ref mut fut) => match Pin::new(fut).poll(cx) { - Poll::Pending => Either::Left(Poll::Pending), - Poll::Ready(Ok(res)) => Either::Right(res), - Poll::Ready(Err(err)) => Either::Left(Poll::Ready(Err(err))), - }, - ConnectState::Connect(ref mut fut) => Either::Left(Pin::new(fut).poll(cx)), + ConnectFuture::Resolve(ref mut fut) => { + Pin::new(fut).poll(cx).map_ok(ConnectOutput::Resolved) + } + ConnectFuture::Connect(ref mut fut) => { + Pin::new(fut).poll(cx).map_ok(ConnectOutput::Connected) + } } } } pub struct ConnectServiceResponse { - state: ConnectState, - tcp: TcpConnector, + fut: ConnectFuture, + tcp: TcpConnector, } impl Future for ConnectServiceResponse { type Output = Result, ConnectError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let res = match self.state.poll(cx) { - Either::Right(res) => { - self.state = ConnectState::Connect(self.tcp.call(res)); - self.state.poll(cx) + loop { + match ready!(self.fut.poll_connect(cx))? { + ConnectOutput::Resolved(res) => { + self.fut = ConnectFuture::Connect(self.tcp.call(res)); + } + ConnectOutput::Connected(res) => return Poll::Ready(Ok(res)), } - Either::Left(res) => return res, - }; - - match res { - Either::Left(res) => res, - Either::Right(_) => panic!(), } } } #[derive(Clone)] -pub struct TcpConnectService { - tcp: TcpConnector, - resolver: Resolver, +pub struct TcpConnectService { + tcp: TcpConnector, + resolver: Resolver, } -impl Service> for TcpConnectService { +impl Service> for TcpConnectService { type Response = TcpStream; type Error = ConnectError; type Future = TcpConnectServiceResponse; @@ -165,61 +151,28 @@ impl Service> for TcpConnectService { fn call(&mut self, req: Connect) -> Self::Future { TcpConnectServiceResponse { - state: TcpConnectState::Resolve(self.resolver.call(req)), - tcp: self.tcp.clone(), + fut: ConnectFuture::Resolve(self.resolver.call(req)), + tcp: self.tcp, } } } -enum TcpConnectState { - Resolve( as Service>>::Future), - Connect( as Service>>::Future), -} - -impl TcpConnectState { - fn poll( - &mut self, - cx: &mut Context<'_>, - ) -> Either>, Connect> { - match self { - TcpConnectState::Resolve(ref mut fut) => match Pin::new(fut).poll(cx) { - Poll::Pending => (), - Poll::Ready(Ok(res)) => return Either::Right(res), - Poll::Ready(Err(err)) => return Either::Left(Poll::Ready(Err(err))), - }, - TcpConnectState::Connect(ref mut fut) => { - if let Poll::Ready(res) = Pin::new(fut).poll(cx) { - return match res { - Ok(conn) => Either::Left(Poll::Ready(Ok(conn.into_parts().0))), - Err(err) => Either::Left(Poll::Ready(Err(err))), - }; - } - } - } - Either::Left(Poll::Pending) - } -} - pub struct TcpConnectServiceResponse { - state: TcpConnectState, - tcp: TcpConnector, + fut: ConnectFuture, + tcp: TcpConnector, } impl Future for TcpConnectServiceResponse { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let res = match self.state.poll(cx) { - Either::Right(res) => { - self.state = TcpConnectState::Connect(self.tcp.call(res)); - self.state.poll(cx) + loop { + match ready!(self.fut.poll_connect(cx))? { + ConnectOutput::Resolved(res) => { + self.fut = ConnectFuture::Connect(self.tcp.call(res)); + } + ConnectOutput::Connected(conn) => return Poll::Ready(Ok(conn.into_parts().0)), } - Either::Left(res) => return res, - }; - - match res { - Either::Left(res) => res, - Either::Right(_) => panic!(), } } } diff --git a/actix-tls/src/connect/ssl/openssl.rs b/actix-tls/src/connect/ssl/openssl.rs old mode 100644 new mode 100755 index 5193ce37..d5d71edc --- a/actix-tls/src/connect/ssl/openssl.rs +++ b/actix-tls/src/connect/ssl/openssl.rs @@ -1,23 +1,23 @@ -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{fmt, io}; +use std::{ + fmt, + future::Future, + io, + pin::Pin, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; -use futures_util::{ - future::{ready, Either, Ready}, - ready, -}; +use futures_core::{future::LocalBoxFuture, ready}; use log::trace; + pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; pub use tokio_openssl::SslStream; -use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; +use crate::connect::resolve::Resolve; use crate::connect::{ - Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection, + Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection, Resolver, }; /// OpenSSL connector factory @@ -29,9 +29,7 @@ impl OpensslConnector { pub fn new(connector: SslConnector) -> Self { OpensslConnector { connector } } -} -impl OpensslConnector { pub fn service(connector: SslConnector) -> OpensslConnectorService { OpensslConnectorService { connector } } @@ -55,12 +53,11 @@ where type Config = (); type Service = OpensslConnectorService; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - ready(Ok(OpensslConnectorService { - connector: self.connector.clone(), - })) + let connector = self.connector.clone(); + Box::pin(async { Ok(OpensslConnectorService { connector }) }) } } @@ -83,29 +80,27 @@ where { type Response = Connection>; type Error = io::Error; - #[allow(clippy::type_complexity)] - type Future = Either, Ready>>; + type Future = ConnectAsyncExt; actix_service::always_ready!(); fn call(&mut self, stream: Connection) -> Self::Future { trace!("SSL Handshake start for: {:?}", stream.host()); let (io, stream) = stream.replace(()); - let host = stream.host().to_string(); + let host = stream.host(); - match self.connector.configure() { - Err(e) => Either::Right(ready(Err(io::Error::new(io::ErrorKind::Other, e)))), - Ok(config) => { - let ssl = config - .into_ssl(&host) - .expect("SSL connect configuration was invalid."); + let config = self + .connector + .configure() + .expect("SSL connect configuration was invalid."); - Either::Left(ConnectAsyncExt { - io: Some(SslStream::new(ssl, io).unwrap()), - stream: Some(stream), - _t: PhantomData, - }) - } + let ssl = config + .into_ssl(host) + .expect("SSL connect configuration was invalid."); + + ConnectAsyncExt { + io: Some(SslStream::new(ssl, io).unwrap()), + stream: Some(stream), } } } @@ -113,7 +108,6 @@ where pub struct ConnectAsyncExt { io: Option>, stream: Option>, - _t: PhantomData, } impl Future for ConnectAsyncExt @@ -139,30 +133,30 @@ where } } -pub struct OpensslConnectServiceFactory { - tcp: ConnectServiceFactory, +pub struct OpensslConnectServiceFactory { + tcp: ConnectServiceFactory, openssl: OpensslConnector, } -impl OpensslConnectServiceFactory { +impl OpensslConnectServiceFactory { /// Construct new OpensslConnectService factory pub fn new(connector: SslConnector) -> Self { OpensslConnectServiceFactory { - tcp: ConnectServiceFactory::default(), + tcp: ConnectServiceFactory::new(Resolver::Default), openssl: OpensslConnector::new(connector), } } /// Construct new connect service with custom DNS resolver - pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self { + pub fn with_resolver(connector: SslConnector, resolver: impl Resolve + 'static) -> Self { OpensslConnectServiceFactory { - tcp: ConnectServiceFactory::with_resolver(resolver), + tcp: ConnectServiceFactory::new(Resolver::new_custom(resolver)), openssl: OpensslConnector::new(connector), } } /// Construct OpenSSL connect service - pub fn service(&self) -> OpensslConnectService { + pub fn service(&self) -> OpensslConnectService { OpensslConnectService { tcp: self.tcp.service(), openssl: OpensslConnectorService { @@ -172,7 +166,7 @@ impl OpensslConnectServiceFactory { } } -impl Clone for OpensslConnectServiceFactory { +impl Clone for OpensslConnectServiceFactory { fn clone(&self) -> Self { OpensslConnectServiceFactory { tcp: self.tcp.clone(), @@ -181,26 +175,27 @@ impl Clone for OpensslConnectServiceFactory { } } -impl ServiceFactory> for OpensslConnectServiceFactory { +impl ServiceFactory> for OpensslConnectServiceFactory { type Response = SslStream; type Error = ConnectError; type Config = (); - type Service = OpensslConnectService; + type Service = OpensslConnectService; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - ready(Ok(self.service())) + let service = self.service(); + Box::pin(async { Ok(service) }) } } #[derive(Clone)] -pub struct OpensslConnectService { - tcp: ConnectService, +pub struct OpensslConnectService { + tcp: ConnectService, openssl: OpensslConnectorService, } -impl Service> for OpensslConnectService { +impl Service> for OpensslConnectService { type Response = SslStream; type Error = ConnectError; type Future = OpensslConnectServiceResponse; @@ -217,7 +212,7 @@ impl Service> for OpensslConnectService { } pub struct OpensslConnectServiceResponse { - fut1: Option< as Service>>::Future>, + fut1: Option<>>::Future>, fut2: Option<>>::Future>, openssl: OpensslConnectorService, } diff --git a/actix-tls/src/connect/ssl/rustls.rs b/actix-tls/src/connect/ssl/rustls.rs old mode 100644 new mode 100755 index 390ba413..e12ab7ec --- a/actix-tls/src/connect/ssl/rustls.rs +++ b/actix-tls/src/connect/ssl/rustls.rs @@ -1,18 +1,18 @@ -use std::fmt; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; +use std::{ + fmt, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; pub use rustls::Session; pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; +pub use webpki_roots::TLS_SERVER_ROOTS; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::{Service, ServiceFactory}; -use futures_util::{ - future::{ready, Ready}, - ready, -}; +use futures_core::{future::LocalBoxFuture, ready}; use log::trace; use tokio_rustls::{Connect, TlsConnector}; use webpki::DNSNameRef; @@ -53,12 +53,11 @@ where type Config = (); type Service = RustlsConnectorService; type InitError = (); - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - ready(Ok(RustlsConnectorService { - connector: self.connector.clone(), - })) + let connector = self.connector.clone(); + Box::pin(async { Ok(RustlsConnectorService { connector }) }) } } diff --git a/actix-tls/tests/test_connect.rs b/actix-tls/tests/test_connect.rs old mode 100644 new mode 100755 index aa773c7f..32d2ac0f --- a/actix-tls/tests/test_connect.rs +++ b/actix-tls/tests/test_connect.rs @@ -1,17 +1,15 @@ use std::io; +use std::net::SocketAddr; use actix_codec::{BytesCodec, Framed}; use actix_rt::net::TcpStream; use actix_server::TestServer; use actix_service::{fn_service, Service, ServiceFactory}; use bytes::Bytes; +use futures_core::future::LocalBoxFuture; use futures_util::sink::SinkExt; -use actix_tls::connect::{ - self as actix_connect, - resolver::{ResolverConfig, ResolverOpts}, - Connect, -}; +use actix_tls::connect::{self as actix_connect, Connect, Resolve, Resolver}; #[cfg(all(feature = "connect", feature = "openssl"))] #[actix_rt::test] @@ -57,14 +55,13 @@ async fn test_static_str() { }) }); - let resolver = actix_connect::start_default_resolver().await.unwrap(); - let mut conn = actix_connect::new_connector(resolver.clone()); + let mut conn = actix_connect::default_connector(); let con = conn.call(Connect::with("10", srv.addr())).await.unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); let connect = Connect::new(srv.host().to_owned()); - let mut conn = actix_connect::new_connector(resolver); + let mut conn = actix_connect::default_connector(); let con = conn.call(connect).await; assert!(con.is_err()); } @@ -79,10 +76,53 @@ async fn test_new_service() { }) }); - let resolver = - actix_connect::start_resolver(ResolverConfig::default(), ResolverOpts::default()) - .await - .unwrap(); + let factory = actix_connect::default_connector_factory(); + + let mut conn = factory.new_service(()).await.unwrap(); + let con = conn.call(Connect::with("10", srv.addr())).await.unwrap(); + assert_eq!(con.peer_addr().unwrap(), srv.addr()); +} + +#[actix_rt::test] +async fn test_custom_resolver() { + use trust_dns_resolver::TokioAsyncResolver; + + let srv = TestServer::with(|| { + fn_service(|io: TcpStream| async { + let mut framed = Framed::new(io, BytesCodec); + framed.send(Bytes::from_static(b"test")).await?; + Ok::<_, io::Error>(()) + }) + }); + + struct MyResolver { + trust_dns: TokioAsyncResolver, + } + + impl Resolve for MyResolver { + fn lookup<'a>( + &'a self, + host: &'a str, + port: u16, + ) -> LocalBoxFuture<'a, Result, Box>> { + Box::pin(async move { + let res = self + .trust_dns + .lookup_ip(host) + .await? + .iter() + .map(|ip| SocketAddr::new(ip, port)) + .collect(); + Ok(res) + }) + } + } + + let resolver = MyResolver { + trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), + }; + + let resolver = Resolver::new_custom(resolver); let factory = actix_connect::new_connector_factory(resolver); diff --git a/actix-utils/Cargo.toml b/actix-utils/Cargo.toml index f038414c..c82cf79e 100644 --- a/actix-utils/Cargo.toml +++ b/actix-utils/Cargo.toml @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] actix-codec = "0.4.0-beta.1" -actix-rt = "2.0.0-beta.2" +actix-rt = { version = "2.0.0-beta.2", default-features = false } actix-service = "2.0.0-beta.3" futures-core = { version = "0.3.7", default-features = false }