From dc19a9f862ec7de82f1351d2d5c29e8fd501c799 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 29 Oct 2018 20:29:47 -0700 Subject: [PATCH] refactor Resolver service --- examples/basic.rs | 3 +- examples/ssl.rs | 6 +-- src/codec/framed.rs | 3 +- src/codec/framed2.rs | 3 +- src/codec/framed_read.rs | 3 +- src/codec/framed_write.rs | 3 +- src/connector.rs | 67 +++++++++++++++----------------- src/lib.rs | 5 ++- src/resolver.rs | 82 +++++++++------------------------------ src/server/accept.rs | 19 ++++++--- src/server/server.rs | 10 ++++- src/server/worker.rs | 13 ++++--- src/service/mod.rs | 8 +++- src/timeout.rs | 16 ++++---- 14 files changed, 107 insertions(+), 134 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index f31b6d76..8162fcc6 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -80,8 +80,7 @@ fn main() { future::ok(()) }) }, - ) - .unwrap() + ).unwrap() .start(); sys.run(); diff --git a/examples/ssl.rs b/examples/ssl.rs index 25bc6aad..1d43f24d 100644 --- a/examples/ssl.rs +++ b/examples/ssl.rs @@ -24,7 +24,8 @@ struct ServiceState { } fn service( - st: &mut ServiceState, _: T, + st: &mut ServiceState, + _: T, ) -> impl Future { let num = st.num.fetch_add(1, Ordering::Relaxed); println!("got ssl connection {:?}", num); @@ -60,8 +61,7 @@ fn main() { println!("got ssl connection {:?}", num); future::ok(()) }) - }) - .unwrap() + }).unwrap() .start(); sys.run(); diff --git a/src/codec/framed.rs b/src/codec/framed.rs index c73238bc..388f10c9 100644 --- a/src/codec/framed.rs +++ b/src/codec/framed.rs @@ -176,7 +176,8 @@ where type SinkError = U::Error; fn start_send( - &mut self, item: Self::SinkItem, + &mut self, + item: Self::SinkItem, ) -> StartSend { self.inner.get_mut().start_send(item) } diff --git a/src/codec/framed2.rs b/src/codec/framed2.rs index 08211335..2bc3992c 100644 --- a/src/codec/framed2.rs +++ b/src/codec/framed2.rs @@ -191,7 +191,8 @@ where type SinkError = E::Error; fn start_send( - &mut self, item: Self::SinkItem, + &mut self, + item: Self::SinkItem, ) -> StartSend { self.inner.get_mut().start_send(item) } diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 065e2920..3813d85e 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -98,7 +98,8 @@ where type SinkError = T::SinkError; fn start_send( - &mut self, item: Self::SinkItem, + &mut self, + item: Self::SinkItem, ) -> StartSend { self.inner.inner.0.start_send(item) } diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index a39f6d03..1348b09c 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -197,8 +197,7 @@ where io::ErrorKind::WriteZero, "failed to \ write frame to transport", - ) - .into()); + ).into()); } // TODO: Add a way to `bytes` to do this w/o returning the drained diff --git a/src/connector.rs b/src/connector.rs index 465dd6ea..901b97ba 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -11,7 +11,7 @@ use tokio_tcp::{ConnectFuture, TcpStream}; use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; use trust_dns_resolver::system_conf::read_system_conf; -use super::resolver::{HostAware, Resolver, ResolverError, ResolverFuture}; +use super::resolver::{HostAware, ResolveError, Resolver, ResolverFuture}; use super::service::{NewService, Service}; // #[derive(Fail, Debug)] @@ -19,9 +19,9 @@ use super::service::{NewService, Service}; pub enum ConnectorError { /// Failed to resolve the hostname // #[fail(display = "Failed resolving hostname: {}", _0)] - Resolver(ResolverError), + Resolver(ResolveError), - /// Not dns records + /// No dns records // #[fail(display = "Invalid input: {}", _0)] NoRecords, @@ -29,17 +29,21 @@ pub enum ConnectorError { // #[fail(display = "Timeout out while establishing connection")] Timeout, + /// Invalid input + InvalidInput, + /// Connection io error // #[fail(display = "{}", _0)] IoError(io::Error), } -impl From for ConnectorError { - fn from(err: ResolverError) -> Self { +impl From for ConnectorError { + fn from(err: ResolveError) -> Self { ConnectorError::Resolver(err) } } +/// Connect request #[derive(Eq, PartialEq, Debug, Hash)] pub struct Connect { pub host: String, @@ -48,6 +52,7 @@ pub struct Connect { } impl Connect { + /// Create new `Connect` instance. pub fn new>(host: T, port: u16) -> Connect { Connect { port, @@ -56,20 +61,19 @@ impl Connect { } } - /// split the string by ':' and convert the second part to u16 - pub fn parse>(host: T) -> Option { + /// Create `Connect` instance by spliting the string by ':' and convert the second part to u16 + pub fn with>(host: T) -> Result { let mut parts_iter = host.as_ref().splitn(2, ':'); - if let Some(host) = parts_iter.next() { - let port_str = parts_iter.next().unwrap_or(""); - if let Ok(port) = port_str.parse::() { - return Some(Connect { - port, - host: host.to_owned(), - timeout: Duration::from_secs(1), - }); - } - } - None + let host = parts_iter.next().ok_or(ConnectorError::InvalidInput)?; + let port_str = parts_iter.next().unwrap_or(""); + let port = port_str + .parse::() + .map_err(|_| ConnectorError::InvalidInput)?; + Ok(Connect { + port, + host: host.to_owned(), + timeout: Duration::from_secs(1), + }) } /// Set connect timeout @@ -93,6 +97,7 @@ impl fmt::Display for Connect { } } +/// Tcp connector pub struct Connector { resolver: Resolver, } @@ -110,12 +115,14 @@ impl Default for Connector { } impl Connector { + /// Create new connector with resolver configuration pub fn new(cfg: ResolverConfig, opts: ResolverOpts) -> Self { Connector { resolver: Resolver::new(cfg, opts), } } + /// Create new connector with custom resolver pub fn with_resolver( resolver: Resolver, ) -> impl Service @@ -123,17 +130,10 @@ impl Connector { Connector { resolver } } - pub fn new_service() -> impl NewService< - Request = Connect, - Response = (Connect, TcpStream), - Error = ConnectorError, - InitError = E, - > + Clone { - || -> FutureResult { ok(Connector::default()) } - } - + /// Create new default connector service pub fn new_service_with_config( - cfg: ResolverConfig, opts: ResolverOpts, + cfg: ResolverConfig, + opts: ResolverOpts, ) -> impl NewService< Request = Connect, Response = (Connect, TcpStream), @@ -185,19 +185,14 @@ impl Future for ConnectorFuture { return fut.poll(); } match self.fut.poll().map_err(ConnectorError::from)? { - Async::Ready((req, _, mut addrs)) => { + Async::Ready((req, mut addrs)) => { if addrs.is_empty() { Err(ConnectorError::NoRecords) } else { for addr in &mut addrs { match addr { - SocketAddr::V4(ref mut addr) if addr.port() == 0 => { - addr.set_port(req.port) - } - SocketAddr::V6(ref mut addr) if addr.port() == 0 => { - addr.set_port(req.port) - } - _ => (), + SocketAddr::V4(ref mut addr) => addr.set_port(req.port), + SocketAddr::V6(ref mut addr) => addr.set_port(req.port), } } self.fut2 = Some(TcpConnector::new(req, addrs)); diff --git a/src/lib.rs b/src/lib.rs index 4240d062..0fd6e917 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,10 @@ #![cfg_attr( feature = "cargo-clippy", - allow(declare_interior_mutable_const, borrow_interior_mutable_const) + allow( + declare_interior_mutable_const, + borrow_interior_mutable_const + ) )] #[macro_use] diff --git a/src/resolver.rs b/src/resolver.rs index 4606b1e5..a4b8141b 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -6,19 +6,13 @@ use futures::{Async, Future, Poll}; use tokio_current_thread::spawn; use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; -use trust_dns_resolver::error::ResolveError; +pub use trust_dns_resolver::error::ResolveError; use trust_dns_resolver::lookup_ip::LookupIpFuture; use trust_dns_resolver::system_conf::read_system_conf; use trust_dns_resolver::{AsyncResolver, Background}; use super::service::Service; -#[derive(Debug)] -pub enum ResolverError { - Resolve(ResolveError), - InvalidInput, -} - pub trait HostAware { fn host(&self) -> &str; } @@ -75,8 +69,8 @@ impl Clone for Resolver { impl Service for Resolver { type Request = T; - type Response = (T, String, VecDeque); - type Error = ResolverError; + type Response = (T, VecDeque); + type Error = ResolveError; type Future = ResolverFuture; fn poll_ready(&mut self) -> Poll<(), Self::Error> { @@ -84,7 +78,7 @@ impl Service for Resolver { } fn call(&mut self, req: Self::Request) -> Self::Future { - ResolverFuture::new(req, 0, &self.resolver) + ResolverFuture::new(req, &self.resolver) } } @@ -92,78 +86,38 @@ impl Service for Resolver { /// Resolver future pub struct ResolverFuture { req: Option, - port: u16, lookup: Option>, addrs: Option>, - error: Option, - host: Option, } impl ResolverFuture { - pub fn new(addr: T, port: u16, resolver: &AsyncResolver) -> Self { + pub fn new(addr: T, resolver: &AsyncResolver) -> Self { // we need to do dns resolution - match ResolverFuture::::parse(addr.host(), port) { - Ok((host, port)) => { - let lookup = Some(resolver.lookup_ip(host.as_str())); - ResolverFuture { - port, - lookup, - req: Some(addr), - host: Some(host.to_owned()), - addrs: None, - error: None, - } - } - Err(err) => ResolverFuture { - port, - req: None, - host: None, - lookup: None, - addrs: None, - error: Some(err), - }, + let lookup = Some(resolver.lookup_ip(addr.host())); + ResolverFuture { + lookup, + req: Some(addr), + addrs: None, } } - - fn parse(addr: &str, port: u16) -> Result<(String, u16), ResolverError> { - // split the string by ':' and convert the second part to u16 - let mut parts_iter = addr.splitn(2, ':'); - let host = parts_iter.next().ok_or(ResolverError::InvalidInput)?; - let port_str = parts_iter.next().unwrap_or(""); - let port: u16 = port_str.parse().unwrap_or(port); - - Ok((host.to_owned(), port)) - } } impl Future for ResolverFuture { - type Item = (T, String, VecDeque); - type Error = ResolverError; + type Item = (T, VecDeque); + type Error = ResolveError; fn poll(&mut self) -> Poll { - if let Some(err) = self.error.take() { - Err(err) - } else if let Some(addrs) = self.addrs.take() { - Ok(Async::Ready(( - self.req.take().unwrap(), - self.host.take().unwrap(), - addrs, - ))) + if let Some(addrs) = self.addrs.take() { + Ok(Async::Ready((self.req.take().unwrap(), addrs))) } else { match self.lookup.as_mut().unwrap().poll() { Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::Ready(ips)) => { - let addrs: VecDeque<_> = ips - .iter() - .map(|ip| SocketAddr::new(ip, self.port)) - .collect(); - Ok(Async::Ready(( - self.req.take().unwrap(), - self.host.take().unwrap(), - addrs, - ))) + let addrs: VecDeque<_> = + ips.iter().map(|ip| SocketAddr::new(ip, 0)).collect(); + Ok(Async::Ready((self.req.take().unwrap(), addrs))) } - Err(err) => Err(ResolverError::Resolve(err)), + Err(err) => Err(err), } } } diff --git a/src/server/accept.rs b/src/server/accept.rs index 9135b1f6..bc16d9d6 100644 --- a/src/server/accept.rs +++ b/src/server/accept.rs @@ -87,7 +87,9 @@ impl AcceptLoop { } pub(crate) fn start( - &mut self, socks: Vec<(Token, net::TcpListener)>, workers: Vec, + &mut self, + socks: Vec<(Token, net::TcpListener)>, + workers: Vec, ) -> mpsc::UnboundedReceiver { let (tx, rx) = self.srv.take().expect("Can not re-use AcceptInfo"); @@ -135,9 +137,12 @@ fn connection_error(e: &io::Error) -> bool { impl Accept { #![cfg_attr(feature = "cargo-clippy", allow(too_many_arguments))] pub(crate) fn start( - rx: sync_mpsc::Receiver, cmd_reg: mio::Registration, - notify_reg: mio::Registration, socks: Vec<(Token, net::TcpListener)>, - srv: mpsc::UnboundedSender, workers: Vec, + rx: sync_mpsc::Receiver, + cmd_reg: mio::Registration, + notify_reg: mio::Registration, + socks: Vec<(Token, net::TcpListener)>, + srv: mpsc::UnboundedSender, + workers: Vec, ) { let sys = System::current(); @@ -173,8 +178,10 @@ impl Accept { } fn new( - rx: sync_mpsc::Receiver, socks: Vec<(Token, net::TcpListener)>, - workers: Vec, srv: mpsc::UnboundedSender, + rx: sync_mpsc::Receiver, + socks: Vec<(Token, net::TcpListener)>, + workers: Vec, + srv: mpsc::UnboundedSender, ) -> Accept { // Create a poll instance let poll = match mio::Poll::new() { diff --git a/src/server/server.rs b/src/server/server.rs index 484a2abf..09e795d0 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -137,7 +137,10 @@ impl Server { /// Add new service to server pub fn listen>( - mut self, name: N, lst: net::TcpListener, factory: F, + mut self, + name: N, + lst: net::TcpListener, + factory: F, ) -> Self where F: StreamServiceFactory, @@ -151,7 +154,10 @@ impl Server { /// Add new service to server pub fn listen2>( - mut self, name: N, lst: net::TcpListener, factory: F, + mut self, + name: N, + lst: net::TcpListener, + factory: F, ) -> Self where F: ServiceFactory, diff --git a/src/server/worker.rs b/src/server/worker.rs index ae8ee333..68e9cbf7 100644 --- a/src/server/worker.rs +++ b/src/server/worker.rs @@ -61,7 +61,9 @@ pub(crate) struct WorkerClient { impl WorkerClient { pub fn new( - idx: usize, tx: UnboundedSender, avail: WorkerAvailability, + idx: usize, + tx: UnboundedSender, + avail: WorkerAvailability, ) -> Self { WorkerClient { idx, tx, avail } } @@ -128,8 +130,10 @@ pub(crate) struct Worker { impl Worker { pub(crate) fn start( - rx: UnboundedReceiver, factories: Vec>, - availability: WorkerAvailability, shutdown_timeout: time::Duration, + rx: UnboundedReceiver, + factories: Vec>, + availability: WorkerAvailability, + shutdown_timeout: time::Duration, ) { availability.set(false); let mut wrk = MAX_CONNS_COUNTER.with(|conns| Worker { @@ -151,8 +155,7 @@ impl Worker { .map_err(|e| { error!("Can not start worker: {:?}", e); Arbiter::current().do_send(StopArbiter(0)); - }) - .and_then(move |services| { + }).and_then(move |services| { wrk.services.extend(services); wrk }), diff --git a/src/service/mod.rs b/src/service/mod.rs index 0b486f39..f8de5627 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -27,7 +27,9 @@ pub trait ServiceExt: Service { /// Apply function to specified service and use it as a next service in /// chain. fn apply( - self, service: I, f: F, + self, + service: I, + f: F, ) -> AndThen> where Self: Sized, @@ -120,7 +122,9 @@ pub trait ServiceExt: Service { pub trait NewServiceExt: NewService { fn apply( - self, service: I, f: F, + self, + service: I, + f: F, ) -> AndThenNewService> where Self: Sized, diff --git a/src/timeout.rs b/src/timeout.rs index 1262e6f0..98398c24 100644 --- a/src/timeout.rs +++ b/src/timeout.rs @@ -1,4 +1,4 @@ -//! Tower middleware that applies a timeout to requests. +//! Service that applies a timeout to requests. //! //! If the response does not complete within the specified timeout, the response //! will be aborted. @@ -34,13 +34,6 @@ impl fmt::Debug for TimeoutError { } } -/// `Timeout` response future -#[derive(Debug)] -pub struct TimeoutFut { - fut: T::Future, - timeout: Duration, -} - impl Timeout where T: NewService + Clone, @@ -69,6 +62,13 @@ where } } +/// `Timeout` response future +#[derive(Debug)] +pub struct TimeoutFut { + fut: T::Future, + timeout: Duration, +} + impl Future for TimeoutFut where T: NewService,