1
0
mirror of https://github.com/fafhrd91/actix-net synced 2024-11-30 16:34:36 +01:00

add tests for custom resolver

This commit is contained in:
Rob Ede 2021-01-26 08:05:19 +00:00
parent 636cef8868
commit eaefe21b98
No known key found for this signature in database
GPG Key ID: C2A3B36E841A91E6
13 changed files with 294 additions and 217 deletions

View File

@ -1,3 +1,4 @@
use alloc::rc::Rc;
use core::{ use core::{
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
@ -5,7 +6,6 @@ use core::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use alloc::rc::Rc;
use futures_core::ready; use futures_core::ready;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;

View File

@ -1,3 +1,4 @@
use alloc::rc::Rc;
use core::{ use core::{
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
@ -5,7 +6,6 @@ use core::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use alloc::rc::Rc;
use futures_core::ready; use futures_core::ready;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;

View File

@ -1,3 +1,4 @@
use alloc::rc::Rc;
use core::{ use core::{
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
@ -5,7 +6,6 @@ use core::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use alloc::rc::Rc;
use futures_core::ready; use futures_core::ready;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;

View File

@ -1,3 +1,4 @@
use alloc::{rc::Rc, sync::Arc};
use core::{ use core::{
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
@ -5,7 +6,6 @@ use core::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use alloc::{rc::Rc, sync::Arc};
use futures_core::ready; use futures_core::ready;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;

View File

@ -1,14 +1,15 @@
# Changes # Changes
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
* Remove `trust-dns-proto` and `trust-dns-resolver` [#248] * Remove `trust-dns-proto` and `trust-dns-resolver`. [#248]
* Use `tokio::net::lookup_host` as simple and basic default resolver [#248] * Use `std::net::ToSocketAddrs` as simple and basic default resolver. [#248]
* Add `Resolve` trait for custom dns resolver. [#248] * Add `Resolve` trait for custom dns resolver. [#248]
* Add `Resolver::new_custom` function to construct custom resolvers. [#248] * Add `Resolver::new_custom` function to construct custom resolvers. [#248]
* Export `webpki_roots::TLS_SERVER_ROOTS` in `actix_tls::connect` mod and remove * Export `webpki_roots::TLS_SERVER_ROOTS` in `actix_tls::connect` mod and remove
the export from `actix_tls::accept` [#248] the export from `actix_tls::accept` [#248]
* Remove `ConnectTakeAddrsIter`. `Connect::take_addrs` would return * Remove `ConnectTakeAddrsIter`. `Connect::take_addrs` now returns `ConnectAddrsIter<'static>`
`ConnectAddrsIter<'static>` as owned iterator. [#248] as owned iterator. [#248]
* Rename `Address::{host => hostname}` to more accurately describe which URL segment is returned.
[#248]: https://github.com/actix/actix-net/pull/248 [#248]: https://github.com/actix/actix-net/pull/248

View File

@ -18,10 +18,6 @@ features = ["openssl", "rustls", "native-tls", "accept", "connect", "uri"]
name = "actix_tls" name = "actix_tls"
path = "src/lib.rs" path = "src/lib.rs"
[[example]]
name = "basic"
required-features = ["accept", "rustls"]
[features] [features]
default = ["accept", "connect", "uri"] default = ["accept", "connect", "uri"]
@ -77,4 +73,8 @@ bytes = "1"
env_logger = "0.8" env_logger = "0.8"
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] } futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
log = "0.4" log = "0.4"
trust-dns-resolver = "0.20.0" trust-dns-resolver = "0.20.0"
[[example]]
name = "basic"
required-features = ["accept", "rustls"]

View File

@ -1,135 +1,147 @@
use std::{ use std::{
collections::{vec_deque, VecDeque}, collections::{vec_deque, VecDeque},
fmt, fmt,
iter::{FromIterator, FusedIterator}, iter::{self, FromIterator as _},
mem,
net::SocketAddr, net::SocketAddr,
}; };
/// Connect request /// Parse a host into parts (hostname and port).
pub trait Address: Unpin + 'static { pub trait Address: Unpin + 'static {
/// Host name of the request /// Get hostname part.
fn host(&self) -> &str; fn hostname(&self) -> &str;
/// Port of the request /// Get optional port part.
fn port(&self) -> Option<u16>; fn port(&self) -> Option<u16> {
None
}
} }
impl Address for String { impl Address for String {
fn host(&self) -> &str { fn hostname(&self) -> &str {
&self &self
} }
fn port(&self) -> Option<u16> {
None
}
} }
impl Address for &'static str { impl Address for &'static str {
fn host(&self) -> &str { fn hostname(&self) -> &str {
self self
} }
fn port(&self) -> Option<u16> {
None
}
} }
/// Connect request #[derive(Debug, Eq, PartialEq, Hash)]
#[derive(Eq, PartialEq, Debug, Hash)] pub(crate) enum ConnectAddrs {
None,
One(SocketAddr),
Multi(VecDeque<SocketAddr>),
}
impl ConnectAddrs {
pub(crate) fn is_none(&self) -> bool {
matches!(self, Self::None)
}
pub(crate) fn is_some(&self) -> bool {
!self.is_none()
}
}
impl Default for ConnectAddrs {
fn default() -> Self {
Self::None
}
}
impl From<Option<SocketAddr>> for ConnectAddrs {
fn from(addr: Option<SocketAddr>) -> Self {
match addr {
Some(addr) => ConnectAddrs::One(addr),
None => ConnectAddrs::None,
}
}
}
/// Connection info.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Connect<T> { pub struct Connect<T> {
pub(crate) req: T, pub(crate) req: T,
pub(crate) port: u16, pub(crate) port: u16,
pub(crate) addr: ConnectAddrs, pub(crate) addr: ConnectAddrs,
} }
#[derive(Eq, PartialEq, Debug, Hash)]
pub(crate) enum ConnectAddrs {
One(Option<SocketAddr>),
Multi(VecDeque<SocketAddr>),
}
impl ConnectAddrs {
pub(crate) fn is_none(&self) -> bool {
matches!(*self, Self::One(None))
}
}
impl Default for ConnectAddrs {
fn default() -> Self {
Self::One(None)
}
}
impl<T: Address> Connect<T> { impl<T: Address> Connect<T> {
/// Create `Connect` instance by splitting the string by ':' and convert the second part to u16 /// Create `Connect` instance by splitting the string by ':' and convert the second part to u16
pub fn new(req: T) -> Connect<T> { pub fn new(req: T) -> Connect<T> {
let (_, port) = parse(req.host()); let (_, port) = parse_host(req.hostname());
Connect { Connect {
req, req,
port: port.unwrap_or(0), port: port.unwrap_or(0),
addr: ConnectAddrs::One(None), addr: ConnectAddrs::None,
} }
} }
/// Create new `Connect` instance from host and address. Connector skips name resolution stage /// Create new `Connect` instance from host and address. Connector skips name resolution stage
/// for such connect messages. /// for such connect messages.
pub fn with(req: T, addr: SocketAddr) -> Connect<T> { pub fn with_addr(req: T, addr: SocketAddr) -> Connect<T> {
Connect { Connect {
req, req,
port: 0, port: 0,
addr: ConnectAddrs::One(Some(addr)), addr: ConnectAddrs::One(addr),
} }
} }
/// Use port if address does not provide one. /// Use port if address does not provide one.
/// ///
/// By default it set to 0 /// Default value is 0.
pub fn set_port(mut self, port: u16) -> Self { pub fn set_port(mut self, port: u16) -> Self {
self.port = port; self.port = port;
self self
} }
/// Use address. /// Set address.
pub fn set_addr(mut self, addr: Option<SocketAddr>) -> Self { pub fn set_addr(mut self, addr: Option<SocketAddr>) -> Self {
self.addr = ConnectAddrs::One(addr); self.addr = ConnectAddrs::from(addr);
self self
} }
/// Use addresses. /// Set list of addresses.
pub fn set_addrs<I>(mut self, addrs: I) -> Self pub fn set_addrs<I>(mut self, addrs: I) -> Self
where where
I: IntoIterator<Item = SocketAddr>, I: IntoIterator<Item = SocketAddr>,
{ {
let mut addrs = VecDeque::from_iter(addrs); let mut addrs = VecDeque::from_iter(addrs);
self.addr = if addrs.len() < 2 { self.addr = if addrs.len() < 2 {
ConnectAddrs::One(addrs.pop_front()) ConnectAddrs::from(addrs.pop_front())
} else { } else {
ConnectAddrs::Multi(addrs) ConnectAddrs::Multi(addrs)
}; };
self self
} }
/// Host name /// Get hostname.
pub fn host(&self) -> &str { pub fn hostname(&self) -> &str {
self.req.host() self.req.hostname()
} }
/// Port of the request /// Get request port.
pub fn port(&self) -> u16 { pub fn port(&self) -> u16 {
self.req.port().unwrap_or(self.port) self.req.port().unwrap_or(self.port)
} }
/// Pre-resolved addresses of the request. /// Get resolved request addresses.
pub fn addrs(&self) -> ConnectAddrsIter<'_> { pub fn addrs(&self) -> ConnectAddrsIter<'_> {
match self.addr { match self.addr {
ConnectAddrs::None => ConnectAddrsIter::None,
ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr),
ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()),
} }
} }
/// Takes pre-resolved addresses of the request. /// Take resolved request addresses.
pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> { pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> {
match std::mem::take(&mut self.addr) { match mem::take(&mut self.addr) {
ConnectAddrs::None => ConnectAddrsIter::None,
ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr),
ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()),
} }
@ -144,14 +156,15 @@ impl<T: Address> From<T> for Connect<T> {
impl<T: Address> fmt::Display for Connect<T> { impl<T: Address> fmt::Display for Connect<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.host(), self.port()) write!(f, "{}:{}", self.hostname(), self.port())
} }
} }
/// Iterator over addresses in a [`Connect`] request. /// Iterator over addresses in a [`Connect`] request.
#[derive(Clone)] #[derive(Clone)]
pub enum ConnectAddrsIter<'a> { pub enum ConnectAddrsIter<'a> {
One(Option<SocketAddr>), None,
One(SocketAddr),
Multi(vec_deque::Iter<'a, SocketAddr>), Multi(vec_deque::Iter<'a, SocketAddr>),
MultiOwned(vec_deque::IntoIter<SocketAddr>), MultiOwned(vec_deque::IntoIter<SocketAddr>),
} }
@ -161,7 +174,8 @@ impl Iterator for ConnectAddrsIter<'_> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
match *self { match *self {
Self::One(ref mut addr) => addr.take(), Self::None => None,
Self::One(addr) => Some(addr),
Self::Multi(ref mut iter) => iter.next().copied(), Self::Multi(ref mut iter) => iter.next().copied(),
Self::MultiOwned(ref mut iter) => iter.next(), Self::MultiOwned(ref mut iter) => iter.next(),
} }
@ -169,8 +183,8 @@ impl Iterator for ConnectAddrsIter<'_> {
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
match *self { match *self {
Self::One(None) => (0, Some(0)), Self::None => (0, Some(0)),
Self::One(Some(_)) => (1, Some(1)), Self::One(_) => (1, Some(1)),
Self::Multi(ref iter) => iter.size_hint(), Self::Multi(ref iter) => iter.size_hint(),
Self::MultiOwned(ref iter) => iter.size_hint(), Self::MultiOwned(ref iter) => iter.size_hint(),
} }
@ -183,23 +197,9 @@ impl fmt::Debug for ConnectAddrsIter<'_> {
} }
} }
impl ExactSizeIterator for ConnectAddrsIter<'_> {} impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {}
impl FusedIterator for ConnectAddrsIter<'_> {} impl iter::FusedIterator for ConnectAddrsIter<'_> {}
fn parse(host: &str) -> (&str, Option<u16>) {
let mut parts_iter = host.splitn(2, ':');
if let Some(host) = parts_iter.next() {
let port_str = parts_iter.next().unwrap_or("");
if let Ok(port) = port_str.parse::<u16>() {
(host, Some(port))
} else {
(host, None)
}
} else {
(host, None)
}
}
pub struct Connection<T, U> { pub struct Connection<T, U> {
io: U, io: U,
@ -224,25 +224,25 @@ impl<T, U> Connection<T, U> {
} }
/// Replace inclosed object, return new Stream and old object /// Replace inclosed object, return new Stream and old object
pub fn replace<Y>(self, io: Y) -> (U, Connection<T, Y>) { pub fn replace_io<Y>(self, io: Y) -> (U, Connection<T, Y>) {
(self.io, Connection { io, req: self.req }) (self.io, Connection { io, req: self.req })
} }
/// Returns a shared reference to the underlying stream. /// Returns a shared reference to the underlying stream.
pub fn get_ref(&self) -> &U { pub fn io_ref(&self) -> &U {
&self.io &self.io
} }
/// Returns a mutable reference to the underlying stream. /// Returns a mutable reference to the underlying stream.
pub fn get_mut(&mut self) -> &mut U { pub fn io_mut(&mut self) -> &mut U {
&mut self.io &mut self.io
} }
} }
impl<T: Address, U> Connection<T, U> { impl<T: Address, U> Connection<T, U> {
/// Get request /// Get hostname.
pub fn host(&self) -> &str { pub fn host(&self) -> &str {
&self.req.host() self.req.hostname()
} }
} }
@ -265,3 +265,31 @@ impl<T, U: fmt::Debug> fmt::Debug for Connection<T, U> {
write!(f, "Stream {{{:?}}}", self.io) write!(f, "Stream {{{:?}}}", self.io)
} }
} }
fn parse_host(host: &str) -> (&str, Option<u16>) {
let mut parts_iter = host.splitn(2, ':');
match parts_iter.next() {
Some(hostname) => {
let port_str = parts_iter.next().unwrap_or("");
let port = port_str.parse::<u16>().ok();
(hostname, port)
}
None => (host, None),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_host_parser() {
assert_eq!(parse_host("example.com"), ("example.com", None));
assert_eq!(parse_host("example.com:8080"), ("example.com", Some(8080)));
assert_eq!(parse_host("example:8080"), ("example", Some(8080)));
assert_eq!(parse_host("example.com:false"), ("example.com", None));
assert_eq!(parse_host("example.com:false:false"), ("example.com", None));
}
}

View File

@ -9,14 +9,14 @@ use std::{
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_core::future::LocalBoxFuture; use futures_core::{future::LocalBoxFuture, ready};
use log::{error, trace}; use log::{error, trace};
use super::connect::{Address, Connect, ConnectAddrs, Connection}; use super::connect::{Address, Connect, ConnectAddrs, Connection};
use super::error::ConnectError; use super::error::ConnectError;
/// TCP connector service factory /// TCP connector service factory
#[derive(Copy, Clone, Debug)] #[derive(Debug, Copy, Clone)]
pub struct TcpConnectorFactory; pub struct TcpConnectorFactory;
impl TcpConnectorFactory { impl TcpConnectorFactory {
@ -41,7 +41,7 @@ impl<T: Address> ServiceFactory<Connect<T>> for TcpConnectorFactory {
} }
/// TCP connector service /// TCP connector service
#[derive(Copy, Clone, Debug)] #[derive(Debug, Copy, Clone)]
pub struct TcpConnector; pub struct TcpConnector;
impl<T: Address> Service<Connect<T>> for TcpConnector { impl<T: Address> Service<Connect<T>> for TcpConnector {
@ -59,7 +59,6 @@ impl<T: Address> Service<Connect<T>> for TcpConnector {
} }
} }
#[doc(hidden)]
/// TCP stream connector response future /// TCP stream connector response future
pub enum TcpConnectorResponse<T> { pub enum TcpConnectorResponse<T> {
Response { Response {
@ -73,23 +72,27 @@ pub enum TcpConnectorResponse<T> {
impl<T: Address> TcpConnectorResponse<T> { impl<T: Address> TcpConnectorResponse<T> {
pub(crate) fn new(req: T, port: u16, addr: ConnectAddrs) -> TcpConnectorResponse<T> { pub(crate) fn new(req: T, port: u16, addr: ConnectAddrs) -> TcpConnectorResponse<T> {
if addr.is_none() {
error!("TCP connector: unresolved connection address");
return TcpConnectorResponse::Error(Some(ConnectError::Unresolved));
}
trace!( trace!(
"TCP connector - connecting to {:?} port:{}", "TCP connector: connecting to {} on port {}",
req.host(), req.hostname(),
port port
); );
match addr { match addr {
ConnectAddrs::One(None) => { ConnectAddrs::None => unreachable!("none variant already checked"),
error!("TCP connector: got unresolved address");
TcpConnectorResponse::Error(Some(ConnectError::Unresolved)) ConnectAddrs::One(addr) => TcpConnectorResponse::Response {
}
ConnectAddrs::One(Some(addr)) => TcpConnectorResponse::Response {
req: Some(req), req: Some(req),
port, port,
addrs: None, addrs: None,
stream: Some(Box::pin(TcpStream::connect(addr))), stream: Some(Box::pin(TcpStream::connect(addr))),
}, },
// when resolver returns multiple socket addr for request they would be popped from // when resolver returns multiple socket addr for request they would be popped from
// front end of queue and returns with the first successful tcp connection. // front end of queue and returns with the first successful tcp connection.
ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response { ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response {
@ -106,10 +109,9 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
type Output = Result<Connection<T, TcpStream>, ConnectError>; type Output = Result<Connection<T, TcpStream>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); match self.get_mut() {
match this { TcpConnectorResponse::Error(err) => Poll::Ready(Err(err.take().unwrap())),
TcpConnectorResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
// connect
TcpConnectorResponse::Response { TcpConnectorResponse::Response {
req, req,
port, port,
@ -117,22 +119,24 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
stream, stream,
} => loop { } => loop {
if let Some(new) = stream.as_mut() { if let Some(new) = stream.as_mut() {
match new.as_mut().poll(cx) { match ready!(new.as_mut().poll(cx)) {
Poll::Ready(Ok(sock)) => { Ok(sock) => {
let req = req.take().unwrap(); let req = req.take().unwrap();
trace!( trace!(
"TCP connector - successfully connected to connecting to {:?} - {:?}", "TCP connector: successfully connected to {:?} - {:?}",
req.host(), sock.peer_addr() req.hostname(),
sock.peer_addr()
); );
return Poll::Ready(Ok(Connection::new(sock, req))); return Poll::Ready(Ok(Connection::new(sock, req)));
} }
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => { Err(err) => {
trace!( trace!(
"TCP connector - failed to connect to connecting to {:?} port: {}", "TCP connector: failed to connect to {:?} port: {}",
req.as_ref().unwrap().host(), req.as_ref().unwrap().hostname(),
port, port,
); );
if addrs.is_none() || addrs.as_ref().unwrap().is_empty() { if addrs.is_none() || addrs.as_ref().unwrap().is_empty() {
return Poll::Ready(Err(ConnectError::Io(err))); return Poll::Ready(Err(ConnectError::Io(err)));
} }

View File

@ -1,21 +1,21 @@
//! TCP connector service for Actix ecosystem. //! TCP connector services for Actix ecosystem.
//! //!
//! ## Package feature //! # Stages of the TCP connector service:
//! - Resolve [`Address`] with given [`Resolver`] and collect list of socket addresses.
//! - Establish TCP connection and return [`TcpStream`].
//! //!
//! # Stages of TLS connector services:
//! - Establish [`TcpStream`] with connector service.
//! - Wrap the stream and perform connect handshake with remote peer.
//! - Return certain stream type that impls `AsyncRead` and `AsyncWrite`.
//!
//! # Package feature
//! * `openssl` - enables TLS support via `openssl` crate //! * `openssl` - enables TLS support via `openssl` crate
//! * `rustls` - enables TLS support via `rustls` crate //! * `rustls` - enables TLS support via `rustls` crate
//! //!
//! ## Workflow of connector service: //! [`TcpStream`]: actix_rt::net::TcpStream
//! - resolve [`Address`](self::connect::Address) with given [`Resolver`](self::resolve::Resolver)
//! and collect [`SocketAddrs`](std::net::SocketAddr).
//! - establish Tcp connection and return [`TcpStream`](tokio::net::TcpStream).
//!
//! ## Workflow of tls connector services:
//! - Establish [`TcpStream`](tokio::net::TcpStream) with connector service.
//! - Wrap around the stream and do connect handshake with remote address.
//! - Return certain stream type impl [`AsyncRead`](tokio::io::AsyncRead) and
//! [`AsyncWrite`](tokio::io::AsyncWrite)
#[allow(clippy::module_inception)]
mod connect; mod connect;
mod connector; mod connector;
mod error; mod error;

View File

@ -16,7 +16,7 @@ use log::trace;
use super::connect::{Address, Connect}; use super::connect::{Address, Connect};
use super::error::ConnectError; use super::error::ConnectError;
/// DNS Resolver Service factory /// DNS Resolver Service Factory
#[derive(Clone)] #[derive(Clone)]
pub struct ResolverFactory { pub struct ResolverFactory {
resolver: Resolver, resolver: Resolver,
@ -53,16 +53,16 @@ pub enum Resolver {
Custom(Rc<dyn Resolve>), Custom(Rc<dyn Resolve>),
} }
/// trait for custom lookup with self defined resolver. /// An interface for custom async DNS resolvers.
/// ///
/// # Example: /// # Usage
/// ```rust /// ```rust
/// use std::net::SocketAddr; /// use std::net::SocketAddr;
/// ///
/// use actix_tls::connect::{Resolve, Resolver}; /// use actix_tls::connect::{Resolve, Resolver};
/// use futures_util::future::LocalBoxFuture; /// use futures_util::future::LocalBoxFuture;
/// ///
/// // use trust_dns_resolver as custom resolver. /// // use trust-dns async tokio resolver
/// use trust_dns_resolver::TokioAsyncResolver; /// use trust_dns_resolver::TokioAsyncResolver;
/// ///
/// struct MyResolver { /// struct MyResolver {
@ -103,7 +103,7 @@ pub enum Resolver {
/// // resolver can be passed to connector factory where returned service factory /// // resolver can be passed to connector factory where returned service factory
/// // can be used to construct new connector services. /// // can be used to construct new connector services.
/// let factory = actix_tls::connect::new_connector_factory::<&str>(resolver); /// let factory = actix_tls::connect::new_connector_factory::<&str>(resolver);
///``` /// ```
pub trait Resolve { pub trait Resolve {
fn lookup<'a>( fn lookup<'a>(
&'a self, &'a self,
@ -120,10 +120,10 @@ impl Resolver {
// look up with default resolver variant. // look up with default resolver variant.
fn look_up<T: Address>(req: &Connect<T>) -> JoinHandle<io::Result<IntoIter<SocketAddr>>> { fn look_up<T: Address>(req: &Connect<T>) -> JoinHandle<io::Result<IntoIter<SocketAddr>>> {
let host = req.host(); let host = req.hostname();
// TODO: Connect should always return host with port if possible. // TODO: Connect should always return host with port if possible.
let host = if req let host = if req
.host() .hostname()
.splitn(2, ':') .splitn(2, ':')
.last() .last()
.and_then(|p| p.parse::<u16>().ok()) .and_then(|p| p.parse::<u16>().ok())
@ -135,6 +135,7 @@ impl Resolver {
format!("{}:{}", host, req.port()) format!("{}:{}", host, req.port())
}; };
// run blocking DNS lookup in thread pool
spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host)) spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host))
} }
} }
@ -147,14 +148,14 @@ impl<T: Address> Service<Connect<T>> for Resolver {
actix_service::always_ready!(); actix_service::always_ready!();
fn call(&self, req: Connect<T>) -> Self::Future { fn call(&self, req: Connect<T>) -> Self::Future {
if !req.addr.is_none() { if req.addr.is_some() {
ResolverFuture::Connected(Some(req)) ResolverFuture::Connected(Some(req))
} else if let Ok(ip) = req.host().parse() { } else if let Ok(ip) = req.hostname().parse() {
let addr = SocketAddr::new(ip, req.port()); let addr = SocketAddr::new(ip, req.port());
let req = req.set_addr(Some(addr)); let req = req.set_addr(Some(addr));
ResolverFuture::Connected(Some(req)) ResolverFuture::Connected(Some(req))
} else { } else {
trace!("DNS resolver: resolving host {:?}", req.host()); trace!("DNS resolver: resolving host {:?}", req.hostname());
match self { match self {
Self::Default => { Self::Default => {
@ -166,7 +167,7 @@ impl<T: Address> Service<Connect<T>> for Resolver {
let resolver = Rc::clone(&resolver); let resolver = Rc::clone(&resolver);
ResolverFuture::LookupCustom(Box::pin(async move { ResolverFuture::LookupCustom(Box::pin(async move {
let addrs = resolver let addrs = resolver
.lookup(req.host(), req.port()) .lookup(req.hostname(), req.port())
.await .await
.map_err(ConnectError::Resolver)?; .map_err(ConnectError::Resolver)?;
@ -201,6 +202,7 @@ impl<T: Address> Future for ResolverFuture<T> {
Self::Connected(conn) => Poll::Ready(Ok(conn Self::Connected(conn) => Poll::Ready(Ok(conn
.take() .take()
.expect("ResolverFuture polled after finished"))), .expect("ResolverFuture polled after finished"))),
Self::LookUp(fut, req) => { Self::LookUp(fut, req) => {
let res = match ready!(Pin::new(fut).poll(cx)) { let res = match ready!(Pin::new(fut).poll(cx)) {
Ok(Ok(res)) => Ok(res), Ok(Ok(res)) => Ok(res),
@ -210,20 +212,21 @@ impl<T: Address> Future for ResolverFuture<T> {
let req = req.take().unwrap(); let req = req.take().unwrap();
let addrs = res.map_err(|e| { let addrs = res.map_err(|err| {
trace!( trace!(
"DNS resolver: failed to resolve host {:?} err: {:?}", "DNS resolver: failed to resolve host {:?} err: {:?}",
req.host(), req.hostname(),
e err
); );
e
err
})?; })?;
let req = req.set_addrs(addrs); let req = req.set_addrs(addrs);
trace!( trace!(
"DNS resolver: host {:?} resolved to {:?}", "DNS resolver: host {:?} resolved to {:?}",
req.host(), req.hostname(),
req.addrs() req.addrs()
); );
@ -233,6 +236,7 @@ impl<T: Address> Future for ResolverFuture<T> {
Poll::Ready(Ok(req)) Poll::Ready(Ok(req))
} }
} }
Self::LookupCustom(fut) => fut.as_mut().poll(cx), Self::LookupCustom(fut) => fut.as_mut().poll(cx),
} }
} }

View File

@ -3,35 +3,41 @@ use http::Uri;
use super::Address; use super::Address;
impl Address for Uri { impl Address for Uri {
fn host(&self) -> &str { fn hostname(&self) -> &str {
self.host().unwrap_or("") self.host().unwrap_or("")
} }
fn port(&self) -> Option<u16> { fn port(&self) -> Option<u16> {
if let Some(port) = self.port_u16() { match self.port_u16() {
Some(port) Some(port) => Some(port),
} else { None => scheme_to_port(self.scheme_str()),
port(self.scheme_str())
} }
} }
} }
// TODO: load data from file // Get port from well-known URL schemes.
fn port(scheme: Option<&str>) -> Option<u16> { fn scheme_to_port(scheme: Option<&str>) -> Option<u16> {
if let Some(scheme) = scheme { match scheme {
match scheme { // HTTP
"http" => Some(80), Some("http") => Some(80),
"https" => Some(443), Some("https") => Some(443),
"ws" => Some(80),
"wss" => Some(443), // WebSockets
"amqp" => Some(5672), Some("ws") => Some(80),
"amqps" => Some(5671), Some("wss") => Some(443),
"sb" => Some(5671),
"mqtt" => Some(1883), // Advanced Message Queuing Protocol (AMQP)
"mqtts" => Some(8883), Some("amqp") => Some(5672),
_ => None, Some("amqps") => Some(5671),
}
} else { // Message Queuing Telemetry Transport (MQTT)
None Some("mqtt") => Some(1883),
Some("mqtts") => Some(8883),
// File Transfer Protocol (FTP)
Some("ftp") => Some(1883),
Some("ftps") => Some(990),
_ => None,
} }
} }

View File

@ -1,15 +1,13 @@
use std::io; use std::io;
use std::net::SocketAddr;
use actix_codec::{BytesCodec, Framed}; use actix_codec::{BytesCodec, Framed};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_server::TestServer; use actix_server::TestServer;
use actix_service::{fn_service, Service, ServiceFactory}; use actix_service::{fn_service, Service, ServiceFactory};
use bytes::Bytes; use bytes::Bytes;
use futures_core::future::LocalBoxFuture;
use futures_util::sink::SinkExt; use futures_util::sink::SinkExt;
use actix_tls::connect::{self as actix_connect, Connect, Resolve, Resolver}; use actix_tls::connect::{self as actix_connect, Connect};
#[cfg(all(feature = "connect", feature = "openssl"))] #[cfg(all(feature = "connect", feature = "openssl"))]
#[actix_rt::test] #[actix_rt::test]
@ -57,7 +55,10 @@ async fn test_static_str() {
let conn = actix_connect::default_connector(); let conn = actix_connect::default_connector();
let con = conn.call(Connect::with("10", srv.addr())).await.unwrap(); let con = conn
.call(Connect::with_addr("10", srv.addr()))
.await
.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());
let connect = Connect::new(srv.host().to_owned()); let connect = Connect::new(srv.host().to_owned());
@ -80,55 +81,10 @@ async fn test_new_service() {
let factory = actix_connect::default_connector_factory(); let factory = actix_connect::default_connector_factory();
let conn = factory.new_service(()).await.unwrap(); let conn = factory.new_service(()).await.unwrap();
let con = conn.call(Connect::with("10", srv.addr())).await.unwrap(); let con = conn
assert_eq!(con.peer_addr().unwrap(), srv.addr()); .call(Connect::with_addr("10", srv.addr()))
} .await
.unwrap();
#[actix_rt::test]
async fn test_custom_resolver() {
use trust_dns_resolver::TokioAsyncResolver;
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
struct MyResolver {
trust_dns: TokioAsyncResolver,
}
impl Resolve for MyResolver {
fn lookup<'a>(
&'a self,
host: &'a str,
port: u16,
) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>> {
Box::pin(async move {
let res = self
.trust_dns
.lookup_ip(host)
.await?
.iter()
.map(|ip| SocketAddr::new(ip, port))
.collect();
Ok(res)
})
}
}
let resolver = MyResolver {
trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(),
};
let resolver = Resolver::new_custom(resolver);
let factory = actix_connect::new_connector_factory(resolver);
let conn = factory.new_service(()).await.unwrap();
let con = conn.call(Connect::with("10", srv.addr())).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());
} }

View File

@ -0,0 +1,78 @@
use std::{
io,
net::{Ipv4Addr, SocketAddr},
};
use actix_rt::net::TcpStream;
use actix_server::TestServer;
use actix_service::{fn_service, Service, ServiceFactory};
use futures_core::future::LocalBoxFuture;
use actix_tls::connect::{new_connector_factory, Connect, Resolve, Resolver};
#[actix_rt::test]
async fn custom_resolver() {
/// Always resolves to localhost with the given port.
struct LocalOnlyResolver;
impl Resolve for LocalOnlyResolver {
fn lookup<'a>(
&'a self,
_host: &'a str,
port: u16,
) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>> {
Box::pin(async move {
let local = format!("127.0.0.1:{}", port).parse().unwrap();
Ok(vec![local])
})
}
}
let addr = LocalOnlyResolver.lookup("example.com", 8080).await.unwrap()[0];
assert_eq!(addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080))
}
#[actix_rt::test]
async fn custom_resolver_connect() {
use trust_dns_resolver::TokioAsyncResolver;
let srv =
TestServer::with(|| fn_service(|_io: TcpStream| async { Ok::<_, io::Error>(()) }));
struct MyResolver {
trust_dns: TokioAsyncResolver,
}
impl Resolve for MyResolver {
fn lookup<'a>(
&'a self,
host: &'a str,
port: u16,
) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>> {
Box::pin(async move {
let res = self
.trust_dns
.lookup_ip(host)
.await?
.iter()
.map(|ip| SocketAddr::new(ip, port))
.collect();
Ok(res)
})
}
}
let resolver = MyResolver {
trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(),
};
let resolver = Resolver::new_custom(resolver);
let factory = new_connector_factory(resolver);
let conn = factory.new_service(()).await.unwrap();
let con = conn
.call(Connect::with_addr("example.com", srv.addr()))
.await
.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr());
}