From eaefe21b9841664c3242563b20a4fbe3ef25aba1 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Tue, 26 Jan 2021 08:05:19 +0000 Subject: [PATCH] add tests for custom resolver --- actix-service/src/and_then.rs | 2 +- actix-service/src/apply_cfg.rs | 2 +- actix-service/src/then.rs | 2 +- actix-service/src/transform.rs | 2 +- actix-tls/CHANGES.md | 9 +- actix-tls/Cargo.toml | 10 +- actix-tls/src/connect/connect.rs | 184 +++++++++++++++++------------ actix-tls/src/connect/connector.rs | 50 ++++---- actix-tls/src/connect/mod.rs | 24 ++-- actix-tls/src/connect/resolve.rs | 36 +++--- actix-tls/src/connect/uri.rs | 50 ++++---- actix-tls/tests/test_connect.rs | 62 ++-------- actix-tls/tests/test_resolvers.rs | 78 ++++++++++++ 13 files changed, 294 insertions(+), 217 deletions(-) create mode 100644 actix-tls/tests/test_resolvers.rs diff --git a/actix-service/src/and_then.rs b/actix-service/src/and_then.rs index 54a132da..e3b293ea 100644 --- a/actix-service/src/and_then.rs +++ b/actix-service/src/and_then.rs @@ -1,3 +1,4 @@ +use alloc::rc::Rc; use core::{ future::Future, marker::PhantomData, @@ -5,7 +6,6 @@ use core::{ task::{Context, Poll}, }; -use alloc::rc::Rc; use futures_core::ready; use pin_project_lite::pin_project; diff --git a/actix-service/src/apply_cfg.rs b/actix-service/src/apply_cfg.rs index 276efc0f..25fc5fc2 100644 --- a/actix-service/src/apply_cfg.rs +++ b/actix-service/src/apply_cfg.rs @@ -1,3 +1,4 @@ +use alloc::rc::Rc; use core::{ future::Future, marker::PhantomData, @@ -5,7 +6,6 @@ use core::{ task::{Context, Poll}, }; -use alloc::rc::Rc; use futures_core::ready; use pin_project_lite::pin_project; diff --git a/actix-service/src/then.rs b/actix-service/src/then.rs index 9cd311ad..c9428824 100644 --- a/actix-service/src/then.rs +++ b/actix-service/src/then.rs @@ -1,3 +1,4 @@ +use alloc::rc::Rc; use core::{ future::Future, marker::PhantomData, @@ -5,7 +6,6 @@ use core::{ task::{Context, Poll}, }; -use alloc::rc::Rc; use futures_core::ready; use pin_project_lite::pin_project; diff --git a/actix-service/src/transform.rs b/actix-service/src/transform.rs index d5cbcd88..7f477e54 100644 --- a/actix-service/src/transform.rs +++ b/actix-service/src/transform.rs @@ -1,3 +1,4 @@ +use alloc::{rc::Rc, sync::Arc}; use core::{ future::Future, marker::PhantomData, @@ -5,7 +6,6 @@ use core::{ task::{Context, Poll}, }; -use alloc::{rc::Rc, sync::Arc}; use futures_core::ready; use pin_project_lite::pin_project; diff --git a/actix-tls/CHANGES.md b/actix-tls/CHANGES.md index d73a7830..11a1a410 100644 --- a/actix-tls/CHANGES.md +++ b/actix-tls/CHANGES.md @@ -1,14 +1,15 @@ # Changes ## Unreleased - 2021-xx-xx -* Remove `trust-dns-proto` and `trust-dns-resolver` [#248] -* Use `tokio::net::lookup_host` as simple and basic default resolver [#248] +* Remove `trust-dns-proto` and `trust-dns-resolver`. [#248] +* Use `std::net::ToSocketAddrs` as simple and basic default resolver. [#248] * Add `Resolve` trait for custom dns resolver. [#248] * Add `Resolver::new_custom` function to construct custom resolvers. [#248] * Export `webpki_roots::TLS_SERVER_ROOTS` in `actix_tls::connect` mod and remove the export from `actix_tls::accept` [#248] -* Remove `ConnectTakeAddrsIter`. `Connect::take_addrs` would return - `ConnectAddrsIter<'static>` as owned iterator. [#248] +* Remove `ConnectTakeAddrsIter`. `Connect::take_addrs` now returns `ConnectAddrsIter<'static>` + as owned iterator. [#248] +* Rename `Address::{host => hostname}` to more accurately describe which URL segment is returned. [#248]: https://github.com/actix/actix-net/pull/248 diff --git a/actix-tls/Cargo.toml b/actix-tls/Cargo.toml index 0886b27e..c31cded0 100755 --- a/actix-tls/Cargo.toml +++ b/actix-tls/Cargo.toml @@ -18,10 +18,6 @@ features = ["openssl", "rustls", "native-tls", "accept", "connect", "uri"] name = "actix_tls" path = "src/lib.rs" -[[example]] -name = "basic" -required-features = ["accept", "rustls"] - [features] default = ["accept", "connect", "uri"] @@ -77,4 +73,8 @@ bytes = "1" env_logger = "0.8" futures-util = { version = "0.3.7", default-features = false, features = ["sink"] } log = "0.4" -trust-dns-resolver = "0.20.0" \ No newline at end of file +trust-dns-resolver = "0.20.0" + +[[example]] +name = "basic" +required-features = ["accept", "rustls"] diff --git a/actix-tls/src/connect/connect.rs b/actix-tls/src/connect/connect.rs index 3928870d..ab7e03dc 100755 --- a/actix-tls/src/connect/connect.rs +++ b/actix-tls/src/connect/connect.rs @@ -1,135 +1,147 @@ use std::{ collections::{vec_deque, VecDeque}, fmt, - iter::{FromIterator, FusedIterator}, + iter::{self, FromIterator as _}, + mem, net::SocketAddr, }; -/// Connect request +/// Parse a host into parts (hostname and port). pub trait Address: Unpin + 'static { - /// Host name of the request - fn host(&self) -> &str; + /// Get hostname part. + fn hostname(&self) -> &str; - /// Port of the request - fn port(&self) -> Option; + /// Get optional port part. + fn port(&self) -> Option { + None + } } impl Address for String { - fn host(&self) -> &str { + fn hostname(&self) -> &str { &self } - - fn port(&self) -> Option { - None - } } impl Address for &'static str { - fn host(&self) -> &str { + fn hostname(&self) -> &str { self } - - fn port(&self) -> Option { - None - } } -/// Connect request -#[derive(Eq, PartialEq, Debug, Hash)] +#[derive(Debug, Eq, PartialEq, Hash)] +pub(crate) enum ConnectAddrs { + None, + One(SocketAddr), + Multi(VecDeque), +} + +impl ConnectAddrs { + pub(crate) fn is_none(&self) -> bool { + matches!(self, Self::None) + } + + pub(crate) fn is_some(&self) -> bool { + !self.is_none() + } +} + +impl Default for ConnectAddrs { + fn default() -> Self { + Self::None + } +} + +impl From> for ConnectAddrs { + fn from(addr: Option) -> Self { + match addr { + Some(addr) => ConnectAddrs::One(addr), + None => ConnectAddrs::None, + } + } +} + +/// Connection info. +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Connect { pub(crate) req: T, pub(crate) port: u16, pub(crate) addr: ConnectAddrs, } -#[derive(Eq, PartialEq, Debug, Hash)] -pub(crate) enum ConnectAddrs { - One(Option), - Multi(VecDeque), -} - -impl ConnectAddrs { - pub(crate) fn is_none(&self) -> bool { - matches!(*self, Self::One(None)) - } -} - -impl Default for ConnectAddrs { - fn default() -> Self { - Self::One(None) - } -} - impl Connect { /// Create `Connect` instance by splitting the string by ':' and convert the second part to u16 pub fn new(req: T) -> Connect { - let (_, port) = parse(req.host()); + let (_, port) = parse_host(req.hostname()); + Connect { req, port: port.unwrap_or(0), - addr: ConnectAddrs::One(None), + addr: ConnectAddrs::None, } } /// Create new `Connect` instance from host and address. Connector skips name resolution stage /// for such connect messages. - pub fn with(req: T, addr: SocketAddr) -> Connect { + pub fn with_addr(req: T, addr: SocketAddr) -> Connect { Connect { req, port: 0, - addr: ConnectAddrs::One(Some(addr)), + addr: ConnectAddrs::One(addr), } } /// Use port if address does not provide one. /// - /// By default it set to 0 + /// Default value is 0. pub fn set_port(mut self, port: u16) -> Self { self.port = port; self } - /// Use address. + /// Set address. pub fn set_addr(mut self, addr: Option) -> Self { - self.addr = ConnectAddrs::One(addr); + self.addr = ConnectAddrs::from(addr); self } - /// Use addresses. + /// Set list of addresses. pub fn set_addrs(mut self, addrs: I) -> Self where I: IntoIterator, { let mut addrs = VecDeque::from_iter(addrs); self.addr = if addrs.len() < 2 { - ConnectAddrs::One(addrs.pop_front()) + ConnectAddrs::from(addrs.pop_front()) } else { ConnectAddrs::Multi(addrs) }; self } - /// Host name - pub fn host(&self) -> &str { - self.req.host() + /// Get hostname. + pub fn hostname(&self) -> &str { + self.req.hostname() } - /// Port of the request + /// Get request port. pub fn port(&self) -> u16 { self.req.port().unwrap_or(self.port) } - /// Pre-resolved addresses of the request. + /// Get resolved request addresses. pub fn addrs(&self) -> ConnectAddrsIter<'_> { match self.addr { + ConnectAddrs::None => ConnectAddrsIter::None, ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), } } - /// Takes pre-resolved addresses of the request. + /// Take resolved request addresses. pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> { - match std::mem::take(&mut self.addr) { + match mem::take(&mut self.addr) { + ConnectAddrs::None => ConnectAddrsIter::None, ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), } @@ -144,14 +156,15 @@ impl From for Connect { impl fmt::Display for Connect { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.host(), self.port()) + write!(f, "{}:{}", self.hostname(), self.port()) } } /// Iterator over addresses in a [`Connect`] request. #[derive(Clone)] pub enum ConnectAddrsIter<'a> { - One(Option), + None, + One(SocketAddr), Multi(vec_deque::Iter<'a, SocketAddr>), MultiOwned(vec_deque::IntoIter), } @@ -161,7 +174,8 @@ impl Iterator for ConnectAddrsIter<'_> { fn next(&mut self) -> Option { match *self { - Self::One(ref mut addr) => addr.take(), + Self::None => None, + Self::One(addr) => Some(addr), Self::Multi(ref mut iter) => iter.next().copied(), Self::MultiOwned(ref mut iter) => iter.next(), } @@ -169,8 +183,8 @@ impl Iterator for ConnectAddrsIter<'_> { fn size_hint(&self) -> (usize, Option) { match *self { - Self::One(None) => (0, Some(0)), - Self::One(Some(_)) => (1, Some(1)), + Self::None => (0, Some(0)), + Self::One(_) => (1, Some(1)), Self::Multi(ref iter) => iter.size_hint(), Self::MultiOwned(ref iter) => iter.size_hint(), } @@ -183,23 +197,9 @@ impl fmt::Debug for ConnectAddrsIter<'_> { } } -impl ExactSizeIterator for ConnectAddrsIter<'_> {} +impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {} -impl FusedIterator for ConnectAddrsIter<'_> {} - -fn parse(host: &str) -> (&str, Option) { - let mut parts_iter = host.splitn(2, ':'); - if let Some(host) = parts_iter.next() { - let port_str = parts_iter.next().unwrap_or(""); - if let Ok(port) = port_str.parse::() { - (host, Some(port)) - } else { - (host, None) - } - } else { - (host, None) - } -} +impl iter::FusedIterator for ConnectAddrsIter<'_> {} pub struct Connection { io: U, @@ -224,25 +224,25 @@ impl Connection { } /// Replace inclosed object, return new Stream and old object - pub fn replace(self, io: Y) -> (U, Connection) { + pub fn replace_io(self, io: Y) -> (U, Connection) { (self.io, Connection { io, req: self.req }) } /// Returns a shared reference to the underlying stream. - pub fn get_ref(&self) -> &U { + pub fn io_ref(&self) -> &U { &self.io } /// Returns a mutable reference to the underlying stream. - pub fn get_mut(&mut self) -> &mut U { + pub fn io_mut(&mut self) -> &mut U { &mut self.io } } impl Connection { - /// Get request + /// Get hostname. pub fn host(&self) -> &str { - &self.req.host() + self.req.hostname() } } @@ -265,3 +265,31 @@ impl fmt::Debug for Connection { write!(f, "Stream {{{:?}}}", self.io) } } + +fn parse_host(host: &str) -> (&str, Option) { + let mut parts_iter = host.splitn(2, ':'); + + match parts_iter.next() { + Some(hostname) => { + let port_str = parts_iter.next().unwrap_or(""); + let port = port_str.parse::().ok(); + (hostname, port) + } + + None => (host, None), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_host_parser() { + assert_eq!(parse_host("example.com"), ("example.com", None)); + assert_eq!(parse_host("example.com:8080"), ("example.com", Some(8080))); + assert_eq!(parse_host("example:8080"), ("example", Some(8080))); + assert_eq!(parse_host("example.com:false"), ("example.com", None)); + assert_eq!(parse_host("example.com:false:false"), ("example.com", None)); + } +} diff --git a/actix-tls/src/connect/connector.rs b/actix-tls/src/connect/connector.rs index dc4fbc72..5284eff4 100755 --- a/actix-tls/src/connect/connector.rs +++ b/actix-tls/src/connect/connector.rs @@ -9,14 +9,14 @@ use std::{ use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; -use futures_core::future::LocalBoxFuture; +use futures_core::{future::LocalBoxFuture, ready}; use log::{error, trace}; use super::connect::{Address, Connect, ConnectAddrs, Connection}; use super::error::ConnectError; /// TCP connector service factory -#[derive(Copy, Clone, Debug)] +#[derive(Debug, Copy, Clone)] pub struct TcpConnectorFactory; impl TcpConnectorFactory { @@ -41,7 +41,7 @@ impl ServiceFactory> for TcpConnectorFactory { } /// TCP connector service -#[derive(Copy, Clone, Debug)] +#[derive(Debug, Copy, Clone)] pub struct TcpConnector; impl Service> for TcpConnector { @@ -59,7 +59,6 @@ impl Service> for TcpConnector { } } -#[doc(hidden)] /// TCP stream connector response future pub enum TcpConnectorResponse { Response { @@ -73,23 +72,27 @@ pub enum TcpConnectorResponse { impl TcpConnectorResponse { pub(crate) fn new(req: T, port: u16, addr: ConnectAddrs) -> TcpConnectorResponse { + if addr.is_none() { + error!("TCP connector: unresolved connection address"); + return TcpConnectorResponse::Error(Some(ConnectError::Unresolved)); + } + trace!( - "TCP connector - connecting to {:?} port:{}", - req.host(), + "TCP connector: connecting to {} on port {}", + req.hostname(), port ); match addr { - ConnectAddrs::One(None) => { - error!("TCP connector: got unresolved address"); - TcpConnectorResponse::Error(Some(ConnectError::Unresolved)) - } - ConnectAddrs::One(Some(addr)) => TcpConnectorResponse::Response { + ConnectAddrs::None => unreachable!("none variant already checked"), + + ConnectAddrs::One(addr) => TcpConnectorResponse::Response { req: Some(req), port, addrs: None, stream: Some(Box::pin(TcpStream::connect(addr))), }, + // when resolver returns multiple socket addr for request they would be popped from // front end of queue and returns with the first successful tcp connection. ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response { @@ -106,10 +109,9 @@ impl Future for TcpConnectorResponse { type Output = Result, ConnectError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - match this { - TcpConnectorResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())), - // connect + match self.get_mut() { + TcpConnectorResponse::Error(err) => Poll::Ready(Err(err.take().unwrap())), + TcpConnectorResponse::Response { req, port, @@ -117,22 +119,24 @@ impl Future for TcpConnectorResponse { stream, } => loop { if let Some(new) = stream.as_mut() { - match new.as_mut().poll(cx) { - Poll::Ready(Ok(sock)) => { + match ready!(new.as_mut().poll(cx)) { + Ok(sock) => { let req = req.take().unwrap(); trace!( - "TCP connector - successfully connected to connecting to {:?} - {:?}", - req.host(), sock.peer_addr() + "TCP connector: successfully connected to {:?} - {:?}", + req.hostname(), + sock.peer_addr() ); return Poll::Ready(Ok(Connection::new(sock, req))); } - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => { + + Err(err) => { trace!( - "TCP connector - failed to connect to connecting to {:?} port: {}", - req.as_ref().unwrap().host(), + "TCP connector: failed to connect to {:?} port: {}", + req.as_ref().unwrap().hostname(), port, ); + if addrs.is_none() || addrs.as_ref().unwrap().is_empty() { return Poll::Ready(Err(ConnectError::Io(err))); } diff --git a/actix-tls/src/connect/mod.rs b/actix-tls/src/connect/mod.rs index c864886d..4010e3cb 100644 --- a/actix-tls/src/connect/mod.rs +++ b/actix-tls/src/connect/mod.rs @@ -1,21 +1,21 @@ -//! TCP connector service for Actix ecosystem. +//! TCP connector services for Actix ecosystem. //! -//! ## Package feature +//! # Stages of the TCP connector service: +//! - Resolve [`Address`] with given [`Resolver`] and collect list of socket addresses. +//! - Establish TCP connection and return [`TcpStream`]. //! +//! # Stages of TLS connector services: +//! - Establish [`TcpStream`] with connector service. +//! - Wrap the stream and perform connect handshake with remote peer. +//! - Return certain stream type that impls `AsyncRead` and `AsyncWrite`. +//! +//! # Package feature //! * `openssl` - enables TLS support via `openssl` crate //! * `rustls` - enables TLS support via `rustls` crate //! -//! ## Workflow of connector service: -//! - resolve [`Address`](self::connect::Address) with given [`Resolver`](self::resolve::Resolver) -//! and collect [`SocketAddrs`](std::net::SocketAddr). -//! - establish Tcp connection and return [`TcpStream`](tokio::net::TcpStream). -//! -//! ## Workflow of tls connector services: -//! - Establish [`TcpStream`](tokio::net::TcpStream) with connector service. -//! - Wrap around the stream and do connect handshake with remote address. -//! - Return certain stream type impl [`AsyncRead`](tokio::io::AsyncRead) and -//! [`AsyncWrite`](tokio::io::AsyncWrite) +//! [`TcpStream`]: actix_rt::net::TcpStream +#[allow(clippy::module_inception)] mod connect; mod connector; mod error; diff --git a/actix-tls/src/connect/resolve.rs b/actix-tls/src/connect/resolve.rs index a672b971..32e442bf 100755 --- a/actix-tls/src/connect/resolve.rs +++ b/actix-tls/src/connect/resolve.rs @@ -16,7 +16,7 @@ use log::trace; use super::connect::{Address, Connect}; use super::error::ConnectError; -/// DNS Resolver Service factory +/// DNS Resolver Service Factory #[derive(Clone)] pub struct ResolverFactory { resolver: Resolver, @@ -53,16 +53,16 @@ pub enum Resolver { Custom(Rc), } -/// trait for custom lookup with self defined resolver. +/// An interface for custom async DNS resolvers. /// -/// # Example: +/// # Usage /// ```rust /// use std::net::SocketAddr; /// /// use actix_tls::connect::{Resolve, Resolver}; /// use futures_util::future::LocalBoxFuture; /// -/// // use trust_dns_resolver as custom resolver. +/// // use trust-dns async tokio resolver /// use trust_dns_resolver::TokioAsyncResolver; /// /// struct MyResolver { @@ -103,7 +103,7 @@ pub enum Resolver { /// // resolver can be passed to connector factory where returned service factory /// // can be used to construct new connector services. /// let factory = actix_tls::connect::new_connector_factory::<&str>(resolver); -///``` +/// ``` pub trait Resolve { fn lookup<'a>( &'a self, @@ -120,10 +120,10 @@ impl Resolver { // look up with default resolver variant. fn look_up(req: &Connect) -> JoinHandle>> { - let host = req.host(); + let host = req.hostname(); // TODO: Connect should always return host with port if possible. let host = if req - .host() + .hostname() .splitn(2, ':') .last() .and_then(|p| p.parse::().ok()) @@ -135,6 +135,7 @@ impl Resolver { format!("{}:{}", host, req.port()) }; + // run blocking DNS lookup in thread pool spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host)) } } @@ -147,14 +148,14 @@ impl Service> for Resolver { actix_service::always_ready!(); fn call(&self, req: Connect) -> Self::Future { - if !req.addr.is_none() { + if req.addr.is_some() { ResolverFuture::Connected(Some(req)) - } else if let Ok(ip) = req.host().parse() { + } else if let Ok(ip) = req.hostname().parse() { let addr = SocketAddr::new(ip, req.port()); let req = req.set_addr(Some(addr)); ResolverFuture::Connected(Some(req)) } else { - trace!("DNS resolver: resolving host {:?}", req.host()); + trace!("DNS resolver: resolving host {:?}", req.hostname()); match self { Self::Default => { @@ -166,7 +167,7 @@ impl Service> for Resolver { let resolver = Rc::clone(&resolver); ResolverFuture::LookupCustom(Box::pin(async move { let addrs = resolver - .lookup(req.host(), req.port()) + .lookup(req.hostname(), req.port()) .await .map_err(ConnectError::Resolver)?; @@ -201,6 +202,7 @@ impl Future for ResolverFuture { Self::Connected(conn) => Poll::Ready(Ok(conn .take() .expect("ResolverFuture polled after finished"))), + Self::LookUp(fut, req) => { let res = match ready!(Pin::new(fut).poll(cx)) { Ok(Ok(res)) => Ok(res), @@ -210,20 +212,21 @@ impl Future for ResolverFuture { let req = req.take().unwrap(); - let addrs = res.map_err(|e| { + let addrs = res.map_err(|err| { trace!( "DNS resolver: failed to resolve host {:?} err: {:?}", - req.host(), - e + req.hostname(), + err ); - e + + err })?; let req = req.set_addrs(addrs); trace!( "DNS resolver: host {:?} resolved to {:?}", - req.host(), + req.hostname(), req.addrs() ); @@ -233,6 +236,7 @@ impl Future for ResolverFuture { Poll::Ready(Ok(req)) } } + Self::LookupCustom(fut) => fut.as_mut().poll(cx), } } diff --git a/actix-tls/src/connect/uri.rs b/actix-tls/src/connect/uri.rs index b208a8b3..2d54b618 100644 --- a/actix-tls/src/connect/uri.rs +++ b/actix-tls/src/connect/uri.rs @@ -3,35 +3,41 @@ use http::Uri; use super::Address; impl Address for Uri { - fn host(&self) -> &str { + fn hostname(&self) -> &str { self.host().unwrap_or("") } fn port(&self) -> Option { - if let Some(port) = self.port_u16() { - Some(port) - } else { - port(self.scheme_str()) + match self.port_u16() { + Some(port) => Some(port), + None => scheme_to_port(self.scheme_str()), } } } -// TODO: load data from file -fn port(scheme: Option<&str>) -> Option { - if let Some(scheme) = scheme { - match scheme { - "http" => Some(80), - "https" => Some(443), - "ws" => Some(80), - "wss" => Some(443), - "amqp" => Some(5672), - "amqps" => Some(5671), - "sb" => Some(5671), - "mqtt" => Some(1883), - "mqtts" => Some(8883), - _ => None, - } - } else { - None +// Get port from well-known URL schemes. +fn scheme_to_port(scheme: Option<&str>) -> Option { + match scheme { + // HTTP + Some("http") => Some(80), + Some("https") => Some(443), + + // WebSockets + Some("ws") => Some(80), + Some("wss") => Some(443), + + // Advanced Message Queuing Protocol (AMQP) + Some("amqp") => Some(5672), + Some("amqps") => Some(5671), + + // Message Queuing Telemetry Transport (MQTT) + Some("mqtt") => Some(1883), + Some("mqtts") => Some(8883), + + // File Transfer Protocol (FTP) + Some("ftp") => Some(1883), + Some("ftps") => Some(990), + + _ => None, } } diff --git a/actix-tls/tests/test_connect.rs b/actix-tls/tests/test_connect.rs index 392c76c6..7ee7afda 100755 --- a/actix-tls/tests/test_connect.rs +++ b/actix-tls/tests/test_connect.rs @@ -1,15 +1,13 @@ use std::io; -use std::net::SocketAddr; use actix_codec::{BytesCodec, Framed}; use actix_rt::net::TcpStream; use actix_server::TestServer; use actix_service::{fn_service, Service, ServiceFactory}; use bytes::Bytes; -use futures_core::future::LocalBoxFuture; use futures_util::sink::SinkExt; -use actix_tls::connect::{self as actix_connect, Connect, Resolve, Resolver}; +use actix_tls::connect::{self as actix_connect, Connect}; #[cfg(all(feature = "connect", feature = "openssl"))] #[actix_rt::test] @@ -57,7 +55,10 @@ async fn test_static_str() { let conn = actix_connect::default_connector(); - let con = conn.call(Connect::with("10", srv.addr())).await.unwrap(); + let con = conn + .call(Connect::with_addr("10", srv.addr())) + .await + .unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); let connect = Connect::new(srv.host().to_owned()); @@ -80,55 +81,10 @@ async fn test_new_service() { let factory = actix_connect::default_connector_factory(); let conn = factory.new_service(()).await.unwrap(); - let con = conn.call(Connect::with("10", srv.addr())).await.unwrap(); - assert_eq!(con.peer_addr().unwrap(), srv.addr()); -} - -#[actix_rt::test] -async fn test_custom_resolver() { - use trust_dns_resolver::TokioAsyncResolver; - - let srv = TestServer::with(|| { - fn_service(|io: TcpStream| async { - let mut framed = Framed::new(io, BytesCodec); - framed.send(Bytes::from_static(b"test")).await?; - Ok::<_, io::Error>(()) - }) - }); - - struct MyResolver { - trust_dns: TokioAsyncResolver, - } - - impl Resolve for MyResolver { - fn lookup<'a>( - &'a self, - host: &'a str, - port: u16, - ) -> LocalBoxFuture<'a, Result, Box>> { - Box::pin(async move { - let res = self - .trust_dns - .lookup_ip(host) - .await? - .iter() - .map(|ip| SocketAddr::new(ip, port)) - .collect(); - Ok(res) - }) - } - } - - let resolver = MyResolver { - trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), - }; - - let resolver = Resolver::new_custom(resolver); - - let factory = actix_connect::new_connector_factory(resolver); - - let conn = factory.new_service(()).await.unwrap(); - let con = conn.call(Connect::with("10", srv.addr())).await.unwrap(); + let con = conn + .call(Connect::with_addr("10", srv.addr())) + .await + .unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); } diff --git a/actix-tls/tests/test_resolvers.rs b/actix-tls/tests/test_resolvers.rs new file mode 100644 index 00000000..0f49c486 --- /dev/null +++ b/actix-tls/tests/test_resolvers.rs @@ -0,0 +1,78 @@ +use std::{ + io, + net::{Ipv4Addr, SocketAddr}, +}; + +use actix_rt::net::TcpStream; +use actix_server::TestServer; +use actix_service::{fn_service, Service, ServiceFactory}; +use futures_core::future::LocalBoxFuture; + +use actix_tls::connect::{new_connector_factory, Connect, Resolve, Resolver}; + +#[actix_rt::test] +async fn custom_resolver() { + /// Always resolves to localhost with the given port. + struct LocalOnlyResolver; + + impl Resolve for LocalOnlyResolver { + fn lookup<'a>( + &'a self, + _host: &'a str, + port: u16, + ) -> LocalBoxFuture<'a, Result, Box>> { + Box::pin(async move { + let local = format!("127.0.0.1:{}", port).parse().unwrap(); + Ok(vec![local]) + }) + } + } + + let addr = LocalOnlyResolver.lookup("example.com", 8080).await.unwrap()[0]; + assert_eq!(addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080)) +} + +#[actix_rt::test] +async fn custom_resolver_connect() { + use trust_dns_resolver::TokioAsyncResolver; + + let srv = + TestServer::with(|| fn_service(|_io: TcpStream| async { Ok::<_, io::Error>(()) })); + + struct MyResolver { + trust_dns: TokioAsyncResolver, + } + + impl Resolve for MyResolver { + fn lookup<'a>( + &'a self, + host: &'a str, + port: u16, + ) -> LocalBoxFuture<'a, Result, Box>> { + Box::pin(async move { + let res = self + .trust_dns + .lookup_ip(host) + .await? + .iter() + .map(|ip| SocketAddr::new(ip, port)) + .collect(); + Ok(res) + }) + } + } + + let resolver = MyResolver { + trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), + }; + + let resolver = Resolver::new_custom(resolver); + let factory = new_connector_factory(resolver); + + let conn = factory.new_service(()).await.unwrap(); + let con = conn + .call(Connect::with_addr("example.com", srv.addr())) + .await + .unwrap(); + assert_eq!(con.peer_addr().unwrap(), srv.addr()); +}