diff --git a/Cargo.toml b/Cargo.toml index fcc169b76..e586f34ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,16 +34,26 @@ session = ["cookie/secure"] cell = ["actix-net/cell"] +# tls +tls = ["native-tls", "actix-net/tls"] + +# openssl +ssl = ["openssl", "actix-net/ssl"] + +# rustls +rust-tls = ["rustls", "actix-net/rust-tls"] + [dependencies] actix = "0.7.5" -actix-net = "0.2.0" -#actix-net = { git="https://github.com/actix/actix-net.git" } +#actix-net = "0.2.0" +actix-net = { git="https://github.com/actix/actix-net.git" } base64 = "0.9" bitflags = "1.0" http = "0.1.8" httparse = "1.3" failure = "0.1.2" +indexmap = "1.0" log = "0.4" mime = "0.3" rand = "0.5" @@ -61,6 +71,7 @@ url = { version="1.7", features=["query_encoding"] } # io net2 = "0.2" +slab = "0.4" bytes = "0.4" byteorder = "1.2" futures = "0.1" @@ -70,6 +81,17 @@ tokio-io = "0.1" tokio-tcp = "0.1" tokio-timer = "0.2" tokio-current-thread = "0.1" +trust-dns-proto = "0.5.0" +trust-dns-resolver = "0.10.0" + +# native-tls +native-tls = { version="0.2", optional = true } + +# openssl +openssl = { version="0.10", optional = true } + +#rustls +rustls = { version = "^0.14", optional = true } [dev-dependencies] actix-web = "0.7" diff --git a/src/client/connect.rs b/src/client/connect.rs new file mode 100644 index 000000000..40c3e8ec9 --- /dev/null +++ b/src/client/connect.rs @@ -0,0 +1,80 @@ +use actix_net::connector::RequestPort; +use actix_net::resolver::RequestHost; +use http::uri::Uri; +use http::{Error as HttpError, HttpTryFrom}; + +use super::error::{ConnectorError, InvalidUrlKind}; +use super::pool::Key; + +#[derive(Debug)] +/// `Connect` type represents a message that can be sent to +/// `Connector` with a connection request. +pub struct Connect { + pub(crate) uri: Uri, +} + +impl Connect { + /// Construct `Uri` instance and create `Connect` message. + pub fn new(uri: U) -> Result + where + Uri: HttpTryFrom, + { + Ok(Connect { + uri: Uri::try_from(uri).map_err(|e| e.into())?, + }) + } + + /// Create `Connect` message for specified `Uri` + pub fn with(uri: Uri) -> Connect { + Connect { uri } + } + + pub(crate) fn is_secure(&self) -> bool { + if let Some(scheme) = self.uri.scheme_part() { + scheme.as_str() == "https" + } else { + false + } + } + + pub(crate) fn key(&self) -> Key { + self.uri.authority_part().unwrap().clone().into() + } + + pub(crate) fn validate(&self) -> Result<(), ConnectorError> { + if self.uri.host().is_none() { + Err(ConnectorError::InvalidUrl(InvalidUrlKind::MissingHost)) + } else if self.uri.scheme_part().is_none() { + Err(ConnectorError::InvalidUrl(InvalidUrlKind::MissingScheme)) + } else if let Some(scheme) = self.uri.scheme_part() { + match scheme.as_str() { + "http" | "ws" | "https" | "wss" => Ok(()), + _ => Err(ConnectorError::InvalidUrl(InvalidUrlKind::UnknownScheme)), + } + } else { + Ok(()) + } + } +} + +impl RequestHost for Connect { + fn host(&self) -> &str { + &self.uri.host().unwrap() + } +} + +impl RequestPort for Connect { + fn port(&self) -> u16 { + if let Some(port) = self.uri.port() { + port + } else if let Some(scheme) = self.uri.scheme_part() { + match scheme.as_str() { + "http" | "ws" => 80, + "https" | "wss" => 443, + _ => 80, + } + } else { + 80 + } + } +} diff --git a/src/client/connection.rs b/src/client/connection.rs new file mode 100644 index 000000000..294e100c8 --- /dev/null +++ b/src/client/connection.rs @@ -0,0 +1,79 @@ +use std::{fmt, io, time}; + +use futures::Poll; +use tokio_io::{AsyncRead, AsyncWrite}; + +use super::pool::Acquired; + +/// HTTP client connection +pub struct Connection { + io: T, + created: time::Instant, + pool: Option>, +} + +impl fmt::Debug for Connection +where + T: AsyncRead + AsyncWrite + fmt::Debug + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Connection {:?}", self.io) + } +} + +impl Connection { + pub(crate) fn new(io: T, created: time::Instant, pool: Acquired) -> Self { + Connection { + io, + created, + pool: Some(pool), + } + } + + /// Raw IO stream + pub fn get_mut(&mut self) -> &mut T { + &mut self.io + } + + /// Close connection + pub fn close(mut self) { + if let Some(mut pool) = self.pool.take() { + pool.close(self) + } + } + + /// Release this connection to the connection pool + pub fn release(mut self) { + if let Some(mut pool) = self.pool.take() { + pool.release(self) + } + } + + pub(crate) fn into_inner(self) -> (T, time::Instant) { + (self.io, self.created) + } +} + +impl io::Read for Connection { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.io.read(buf) + } +} + +impl AsyncRead for Connection {} + +impl io::Write for Connection { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.io.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.io.flush() + } +} + +impl AsyncWrite for Connection { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.io.shutdown() + } +} diff --git a/src/client/connector.rs b/src/client/connector.rs new file mode 100644 index 000000000..97da074d1 --- /dev/null +++ b/src/client/connector.rs @@ -0,0 +1,500 @@ +use std::time::Duration; +use std::{fmt, io}; + +use actix_net::connector::TcpConnector; +use actix_net::resolver::Resolver; +use actix_net::service::{Service, ServiceExt}; +use actix_net::timeout::{TimeoutError, TimeoutService}; +use futures::future::Either; +use futures::Poll; +use tokio_io::{AsyncRead, AsyncWrite}; +use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; + +use super::connect::Connect; +use super::connection::Connection; +use super::error::ConnectorError; +use super::pool::ConnectionPool; + +#[cfg(feature = "ssl")] +use actix_net::ssl::OpensslConnector; +#[cfg(feature = "ssl")] +use openssl::ssl::{SslConnector, SslMethod}; + +#[cfg(not(feature = "ssl"))] +type SslConnector = (); + +/// Http client connector builde instance. +/// `Connector` type uses builder-like pattern for connector service construction. +pub struct Connector { + resolver: Resolver, + timeout: Duration, + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Duration, + limit: usize, + #[allow(dead_code)] + connector: SslConnector, +} + +impl Default for Connector { + fn default() -> Connector { + let connector = { + #[cfg(feature = "ssl")] + { + SslConnector::builder(SslMethod::tls()).unwrap().build() + } + #[cfg(not(feature = "ssl"))] + { + () + } + }; + + Connector { + connector, + resolver: Resolver::default(), + timeout: Duration::from_secs(1), + conn_lifetime: Duration::from_secs(75), + conn_keep_alive: Duration::from_secs(15), + disconnect_timeout: Duration::from_millis(3000), + limit: 100, + } + } +} + +impl Connector { + /// Use custom resolver configuration. + pub fn resolver_config(mut self, cfg: ResolverConfig, opts: ResolverOpts) -> Self { + self.resolver = Resolver::new(cfg, opts); + self + } + + /// Connection timeout, i.e. max time to connect to remote host including dns name resolution. + /// Set to 1 second by default. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + #[cfg(feature = "ssl")] + /// Use custom `SslConnector` instance. + pub fn ssl(mut self, connector: SslConnector) -> Self { + self.connector = connector; + self + } + + /// Set total number of simultaneous connections per type of scheme. + /// + /// If limit is 0, the connector has no limit. + /// The default limit size is 100. + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + /// Set keep-alive period for opened connection. + /// + /// Keep-alive period is the period between connection usage. If + /// the delay between repeated usages of the same connection + /// exceeds this period, the connection is closed. + /// Default keep-alive period is 15 seconds. + pub fn conn_keep_alive(mut self, dur: Duration) -> Self { + self.conn_keep_alive = dur; + self + } + + /// Set max lifetime period for connection. + /// + /// Connection lifetime is max lifetime of any opened connection + /// until it is closed regardless of keep-alive period. + /// Default lifetime period is 75 seconds. + pub fn conn_lifetime(mut self, dur: Duration) -> Self { + self.conn_lifetime = dur; + self + } + + /// Set server connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the socket get dropped. This timeout affects only secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 3000 milliseconds. + pub fn disconnect_timeout(mut self, dur: Duration) -> Self { + self.disconnect_timeout = dur; + self + } + + /// Finish configuration process and create connector service. + pub fn service( + self, + ) -> impl Service< + Request = Connect, + Response = impl AsyncRead + AsyncWrite + fmt::Debug, + Error = ConnectorError, + > + Clone { + #[cfg(not(feature = "ssl"))] + { + let connector = TimeoutService::new( + self.timeout, + self.resolver + .map_err(ConnectorError::from) + .and_then(TcpConnector::default().from_err()), + ).map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectorError::Timeout, + }); + + connect_impl::InnerConnector { + tcp_pool: ConnectionPool::new( + connector, + self.conn_lifetime, + self.conn_keep_alive, + None, + self.limit, + ), + } + } + #[cfg(feature = "ssl")] + { + let ssl_service = TimeoutService::new( + self.timeout, + self.resolver + .clone() + .map_err(ConnectorError::from) + .and_then(TcpConnector::default().from_err()) + .and_then( + OpensslConnector::service(self.connector) + .map_err(ConnectorError::SslError), + ), + ).map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectorError::Timeout, + }); + + let tcp_service = TimeoutService::new( + self.timeout, + self.resolver + .map_err(ConnectorError::from) + .and_then(TcpConnector::default().from_err()), + ).map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectorError::Timeout, + }); + + connect_impl::InnerConnector { + tcp_pool: ConnectionPool::new( + tcp_service, + self.conn_lifetime, + self.conn_keep_alive, + None, + self.limit, + ), + ssl_pool: ConnectionPool::new( + ssl_service, + self.conn_lifetime, + self.conn_keep_alive, + Some(self.disconnect_timeout), + self.limit, + ), + } + } + } +} + +#[cfg(not(feature = "ssl"))] +mod connect_impl { + use super::*; + use futures::future::{err, FutureResult}; + + pub(crate) struct InnerConnector + where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, + { + pub(crate) tcp_pool: ConnectionPool, + } + + impl Clone for InnerConnector + where + Io: AsyncRead + AsyncWrite + 'static, + T: Service + + Clone, + { + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + } + } + } + + impl Service for InnerConnector + where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, + { + type Request = Connect; + type Response = Connection; + type Error = ConnectorError; + type Future = Either< + as Service>::Future, + FutureResult, ConnectorError>, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.tcp_pool.poll_ready() + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + if req.is_secure() { + Either::B(err(ConnectorError::SslIsNotSupported)) + } else if let Err(e) = req.validate() { + Either::B(err(e)) + } else { + Either::A(self.tcp_pool.call(req)) + } + } + } +} + +#[cfg(feature = "ssl")] +mod connect_impl { + use std::marker::PhantomData; + + use futures::future::{err, FutureResult}; + use futures::{Async, Future, Poll}; + + use super::*; + + pub(crate) struct InnerConnector + where + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + T1: Service< + Request = Connect, + Response = (Connect, Io1), + Error = ConnectorError, + >, + T2: Service< + Request = Connect, + Response = (Connect, Io2), + Error = ConnectorError, + >, + { + pub(crate) tcp_pool: ConnectionPool, + pub(crate) ssl_pool: ConnectionPool, + } + + impl Clone for InnerConnector + where + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + T1: Service< + Request = Connect, + Response = (Connect, Io1), + Error = ConnectorError, + > + Clone, + T2: Service< + Request = Connect, + Response = (Connect, Io2), + Error = ConnectorError, + > + Clone, + { + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + ssl_pool: self.ssl_pool.clone(), + } + } + } + + impl Service for InnerConnector + where + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + T1: Service< + Request = Connect, + Response = (Connect, Io1), + Error = ConnectorError, + >, + T2: Service< + Request = Connect, + Response = (Connect, Io2), + Error = ConnectorError, + >, + { + type Request = Connect; + type Response = IoEither, Connection>; + type Error = ConnectorError; + type Future = Either< + FutureResult, + Either< + InnerConnectorResponseA, + InnerConnectorResponseB, + >, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.tcp_pool.poll_ready() + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + if let Err(e) = req.validate() { + Either::A(err(e)) + } else if req.is_secure() { + Either::B(Either::A(InnerConnectorResponseA { + fut: self.tcp_pool.call(req), + _t: PhantomData, + })) + } else { + Either::B(Either::B(InnerConnectorResponseB { + fut: self.ssl_pool.call(req), + _t: PhantomData, + })) + } + } + } + + pub(crate) struct InnerConnectorResponseA + where + Io1: AsyncRead + AsyncWrite + 'static, + T: Service, + { + fut: as Service>::Future, + _t: PhantomData, + } + + impl Future for InnerConnectorResponseA + where + T: Service, + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + { + type Item = IoEither, Connection>; + type Error = ConnectorError; + + fn poll(&mut self) -> Poll { + match self.fut.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(res) => Ok(Async::Ready(IoEither::A(res))), + } + } + } + + pub(crate) struct InnerConnectorResponseB + where + Io2: AsyncRead + AsyncWrite + 'static, + T: Service, + { + fut: as Service>::Future, + _t: PhantomData, + } + + impl Future for InnerConnectorResponseB + where + T: Service, + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + { + type Item = IoEither, Connection>; + type Error = ConnectorError; + + fn poll(&mut self) -> Poll { + match self.fut.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(res) => Ok(Async::Ready(IoEither::B(res))), + } + } + } +} + +pub(crate) enum IoEither { + A(Io1), + B(Io2), +} + +impl io::Read for IoEither +where + Io1: io::Read, + Io2: io::Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + IoEither::A(ref mut io) => io.read(buf), + IoEither::B(ref mut io) => io.read(buf), + } + } +} + +impl AsyncRead for IoEither +where + Io1: AsyncRead, + Io2: AsyncRead, +{ + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match self { + IoEither::A(ref io) => io.prepare_uninitialized_buffer(buf), + IoEither::B(ref io) => io.prepare_uninitialized_buffer(buf), + } + } +} + +impl AsyncWrite for IoEither +where + Io1: AsyncWrite, + Io2: AsyncWrite, +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self { + IoEither::A(ref mut io) => io.shutdown(), + IoEither::B(ref mut io) => io.shutdown(), + } + } + + fn poll_write(&mut self, buf: &[u8]) -> Poll { + match self { + IoEither::A(ref mut io) => io.poll_write(buf), + IoEither::B(ref mut io) => io.poll_write(buf), + } + } + + fn poll_flush(&mut self) -> Poll<(), io::Error> { + match self { + IoEither::A(ref mut io) => io.poll_flush(), + IoEither::B(ref mut io) => io.poll_flush(), + } + } +} + +impl io::Write for IoEither +where + Io1: io::Write, + Io2: io::Write, +{ + fn flush(&mut self) -> io::Result<()> { + match self { + IoEither::A(ref mut io) => io.flush(), + IoEither::B(ref mut io) => io.flush(), + } + } + + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + IoEither::A(ref mut io) => io.write(buf), + IoEither::B(ref mut io) => io.write(buf), + } + } +} + +impl fmt::Debug for IoEither +where + Io1: fmt::Debug, + Io2: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + IoEither::A(ref io) => io.fmt(fmt), + IoEither::B(ref io) => io.fmt(fmt), + } + } +} diff --git a/src/client/error.rs b/src/client/error.rs new file mode 100644 index 000000000..ba6407230 --- /dev/null +++ b/src/client/error.rs @@ -0,0 +1,77 @@ +use std::io; + +use trust_dns_resolver::error::ResolveError; + +#[cfg(feature = "ssl")] +use openssl::ssl::Error as SslError; + +#[cfg(all( + feature = "tls", + not(any(feature = "ssl", feature = "rust-tls")) +))] +use native_tls::Error as SslError; + +#[cfg(all( + feature = "rust-tls", + not(any(feature = "tls", feature = "ssl")) +))] +use std::io::Error as SslError; + +/// A set of errors that can occur while connecting to an HTTP host +#[derive(Fail, Debug)] +pub enum ConnectorError { + /// Invalid URL + #[fail(display = "Invalid URL")] + InvalidUrl(InvalidUrlKind), + + /// SSL feature is not enabled + #[fail(display = "SSL is not supported")] + SslIsNotSupported, + + /// SSL error + #[cfg(any(feature = "tls", feature = "ssl", feature = "rust-tls"))] + #[fail(display = "{}", _0)] + SslError(#[cause] SslError), + + /// Failed to resolve the hostname + #[fail(display = "Failed resolving hostname: {}", _0)] + Resolver(ResolveError), + + /// No dns records + #[fail(display = "No dns records found for the input")] + NoRecords, + + /// Connecting took too long + #[fail(display = "Timeout out while establishing connection")] + Timeout, + + /// Connector has been disconnected + #[fail(display = "Internal error: connector has been disconnected")] + Disconnected, + + /// Connection io error + #[fail(display = "{}", _0)] + IoError(io::Error), +} + +#[derive(Fail, Debug)] +pub enum InvalidUrlKind { + #[fail(display = "Missing url scheme")] + MissingScheme, + #[fail(display = "Unknown url scheme")] + UnknownScheme, + #[fail(display = "Missing host name")] + MissingHost, +} + +impl From for ConnectorError { + fn from(err: io::Error) -> ConnectorError { + ConnectorError::IoError(err) + } +} + +impl From for ConnectorError { + fn from(err: ResolveError) -> ConnectorError { + ConnectorError::Resolver(err) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index a5582fea8..714e6c694 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,6 +1,14 @@ //! Http client api +mod connect; +mod connection; +mod connector; +mod error; +mod pool; mod request; mod response; +pub use self::connect::Connect; +pub use self::connector::Connector; +pub use self::error::{ConnectorError, InvalidUrlKind}; pub use self::request::{ClientRequest, ClientRequestBuilder}; pub use self::response::ClientResponse; diff --git a/src/client/pool.rs b/src/client/pool.rs new file mode 100644 index 000000000..6ff8c96ce --- /dev/null +++ b/src/client/pool.rs @@ -0,0 +1,579 @@ +use std::cell::RefCell; +use std::collections::{HashMap, VecDeque}; +use std::io; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use actix_net::service::Service; +use futures::future::{ok, Either, FutureResult}; +use futures::sync::oneshot; +use futures::task::AtomicTask; +use futures::{Async, Future, Poll}; +use http::uri::Authority; +use indexmap::IndexSet; +use slab::Slab; +use tokio_current_thread::spawn; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_timer::{sleep, Delay}; + +use super::connect::Connect; +use super::connection::Connection; +use super::error::ConnectorError; + +#[derive(Hash, Eq, PartialEq, Clone, Debug)] +pub(crate) struct Key { + authority: Authority, +} + +impl From for Key { + fn from(authority: Authority) -> Key { + Key { authority } + } +} + +#[derive(Debug)] +struct AvailableConnection { + io: T, + used: Instant, + created: Instant, +} + +/// Connections pool +pub(crate) struct ConnectionPool( + T, + Rc>>, +); + +impl ConnectionPool +where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, +{ + pub(crate) fn new( + connector: T, + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Option, + limit: usize, + ) -> Self { + ConnectionPool( + connector, + Rc::new(RefCell::new(Inner { + conn_lifetime, + conn_keep_alive, + disconnect_timeout, + limit, + acquired: 0, + waiters: Slab::new(), + waiters_queue: IndexSet::new(), + available: HashMap::new(), + task: AtomicTask::new(), + })), + ) + } +} + +impl Clone for ConnectionPool +where + T: Clone, + Io: AsyncRead + AsyncWrite + 'static, +{ + fn clone(&self) -> Self { + ConnectionPool(self.0.clone(), self.1.clone()) + } +} + +impl Service for ConnectionPool +where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, +{ + type Request = Connect; + type Response = Connection; + type Error = ConnectorError; + type Future = Either< + FutureResult, ConnectorError>, + Either, OpenConnection>, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.0.poll_ready() + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let key = req.key(); + + // acquire connection + match self.1.as_ref().borrow_mut().acquire(&key) { + Acquire::Acquired(io, created) => { + // use existing connection + Either::A(ok(Connection::new( + io, + created, + Acquired(key, Some(self.1.clone())), + ))) + } + Acquire::NotAvailable => { + // connection is not available, wait + let (rx, token) = self.1.as_ref().borrow_mut().wait_for(req); + Either::B(Either::A(WaitForConnection { + rx, + key, + token, + inner: Some(self.1.clone()), + })) + } + Acquire::Available => { + // open new connection + Either::B(Either::B(OpenConnection::new( + key, + self.1.clone(), + self.0.call(req), + ))) + } + } + } +} + +#[doc(hidden)] +pub struct WaitForConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + key: Key, + token: usize, + rx: oneshot::Receiver, ConnectorError>>, + inner: Option>>>, +} + +impl Drop for WaitForConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fn drop(&mut self) { + if let Some(i) = self.inner.take() { + let mut inner = i.as_ref().borrow_mut(); + inner.release_waiter(&self.key, self.token); + inner.check_availibility(); + } + } +} + +impl Future for WaitForConnection +where + Io: AsyncRead + AsyncWrite, +{ + type Item = Connection; + type Error = ConnectorError; + + fn poll(&mut self) -> Poll { + match self.rx.poll() { + Ok(Async::Ready(item)) => match item { + Err(err) => Err(err), + Ok(conn) => { + let _ = self.inner.take(); + Ok(Async::Ready(conn)) + } + }, + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => { + let _ = self.inner.take(); + Err(ConnectorError::Disconnected) + } + } + } +} + +#[doc(hidden)] +pub struct OpenConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fut: F, + key: Key, + inner: Option>>>, +} + +impl OpenConnection +where + F: Future, + Io: AsyncRead + AsyncWrite + 'static, +{ + fn new(key: Key, inner: Rc>>, fut: F) -> Self { + OpenConnection { + key, + fut, + inner: Some(inner), + } + } +} + +impl Drop for OpenConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + let mut inner = inner.as_ref().borrow_mut(); + inner.release(); + inner.check_availibility(); + } + } +} + +impl Future for OpenConnection +where + F: Future, + Io: AsyncRead + AsyncWrite, +{ + type Item = Connection; + type Error = ConnectorError; + + fn poll(&mut self) -> Poll { + match self.fut.poll() { + Err(err) => Err(err.into()), + Ok(Async::Ready((_, io))) => { + let _ = self.inner.take(); + Ok(Async::Ready(Connection::new( + io, + Instant::now(), + Acquired(self.key.clone(), self.inner.clone()), + ))) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + } + } +} + +struct OpenWaitingConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fut: F, + key: Key, + rx: Option, ConnectorError>>>, + inner: Option>>>, +} + +impl OpenWaitingConnection +where + F: Future + 'static, + Io: AsyncRead + AsyncWrite + 'static, +{ + fn spawn( + key: Key, + rx: oneshot::Sender, ConnectorError>>, + inner: Rc>>, + fut: F, + ) { + spawn(OpenWaitingConnection { + key, + fut, + rx: Some(rx), + inner: Some(inner), + }) + } +} + +impl Drop for OpenWaitingConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + let mut inner = inner.as_ref().borrow_mut(); + inner.release(); + inner.check_availibility(); + } + } +} + +impl Future for OpenWaitingConnection +where + F: Future, + Io: AsyncRead + AsyncWrite, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll { + match self.fut.poll() { + Err(err) => { + let _ = self.inner.take(); + if let Some(rx) = self.rx.take() { + let _ = rx.send(Err(err)); + } + Err(()) + } + Ok(Async::Ready((_, io))) => { + let _ = self.inner.take(); + if let Some(rx) = self.rx.take() { + let _ = rx.send(Ok(Connection::new( + io, + Instant::now(), + Acquired(self.key.clone(), self.inner.clone()), + ))); + } + Ok(Async::Ready(())) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + } + } +} + +enum Acquire { + Acquired(T, Instant), + Available, + NotAvailable, +} + +pub(crate) struct Inner +where + Io: AsyncRead + AsyncWrite + 'static, +{ + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Option, + limit: usize, + acquired: usize, + available: HashMap>>, + waiters: Slab<( + Connect, + oneshot::Sender, ConnectorError>>, + )>, + waiters_queue: IndexSet<(Key, usize)>, + task: AtomicTask, +} + +impl Inner +where + Io: AsyncRead + AsyncWrite + 'static, +{ + /// connection is not available, wait + fn wait_for( + &mut self, + connect: Connect, + ) -> ( + oneshot::Receiver, ConnectorError>>, + usize, + ) { + let (tx, rx) = oneshot::channel(); + + let key = connect.key(); + let entry = self.waiters.vacant_entry(); + let token = entry.key(); + entry.insert((connect, tx)); + assert!(!self.waiters_queue.insert((key, token))); + (rx, token) + } + + fn release_waiter(&mut self, key: &Key, token: usize) { + self.waiters.remove(token); + self.waiters_queue.remove(&(key.clone(), token)); + } + + fn acquire(&mut self, key: &Key) -> Acquire { + // check limits + if self.limit > 0 && self.acquired >= self.limit { + return Acquire::NotAvailable; + } + + self.reserve(); + + // check if open connection is available + // cleanup stale connections at the same time + if let Some(ref mut connections) = self.available.get_mut(key) { + let now = Instant::now(); + while let Some(conn) = connections.pop_back() { + // check if it still usable + if (now - conn.used) > self.conn_keep_alive + || (now - conn.created) > self.conn_lifetime + { + if let Some(timeout) = self.disconnect_timeout { + spawn(CloseConnection::new(conn.io, timeout)) + } + } else { + let mut io = conn.io; + let mut buf = [0; 2]; + match io.read(&mut buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Ok(n) if n > 0 => { + if let Some(timeout) = self.disconnect_timeout { + spawn(CloseConnection::new(io, timeout)) + } + continue; + } + Ok(_) | Err(_) => continue, + } + return Acquire::Acquired(io, conn.created); + } + } + } + Acquire::Available + } + + fn reserve(&mut self) { + self.acquired += 1; + } + + fn release(&mut self) { + self.acquired -= 1; + } + + fn release_conn(&mut self, key: &Key, io: Io, created: Instant) { + self.acquired -= 1; + self.available + .entry(key.clone()) + .or_insert_with(VecDeque::new) + .push_back(AvailableConnection { + io, + created, + used: Instant::now(), + }); + } + + fn release_close(&mut self, io: Io) { + self.acquired -= 1; + if let Some(timeout) = self.disconnect_timeout { + spawn(CloseConnection::new(io, timeout)) + } + } + + fn check_availibility(&self) { + if !self.waiters_queue.is_empty() && self.acquired < self.limit { + self.task.notify() + } + } +} + +struct ConnectorPoolSupport +where + Io: AsyncRead + AsyncWrite + 'static, +{ + connector: T, + inner: Rc>>, +} + +impl Future for ConnectorPoolSupport +where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, + T::Future: 'static, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll { + let mut inner = self.inner.as_ref().borrow_mut(); + inner.task.register(); + + // check waiters + loop { + let (key, token) = { + if let Some((key, token)) = inner.waiters_queue.get_index(0) { + (key.clone(), *token) + } else { + break; + } + }; + match inner.acquire(&key) { + Acquire::NotAvailable => break, + Acquire::Acquired(io, created) => { + let (_, tx) = inner.waiters.remove(token); + if let Err(conn) = tx.send(Ok(Connection::new( + io, + created, + Acquired(key.clone(), Some(self.inner.clone())), + ))) { + let (io, created) = conn.unwrap().into_inner(); + inner.release_conn(&key, io, created); + } + } + Acquire::Available => { + let (connect, tx) = inner.waiters.remove(token); + OpenWaitingConnection::spawn( + key.clone(), + tx, + self.inner.clone(), + self.connector.call(connect), + ); + } + } + let _ = inner.waiters_queue.swap_remove_index(0); + } + + Ok(Async::NotReady) + } +} + +struct CloseConnection { + io: T, + timeout: Delay, +} + +impl CloseConnection +where + T: AsyncWrite, +{ + fn new(io: T, timeout: Duration) -> Self { + CloseConnection { + io, + timeout: sleep(timeout), + } + } +} + +impl Future for CloseConnection +where + T: AsyncWrite, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + match self.timeout.poll() { + Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), + Ok(Async::NotReady) => match self.io.shutdown() { + Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), + Ok(Async::NotReady) => Ok(Async::NotReady), + }, + } + } +} + +pub(crate) struct Acquired( + Key, + Option>>>, +); + +impl Acquired +where + T: AsyncRead + AsyncWrite + 'static, +{ + pub(crate) fn close(&mut self, conn: Connection) { + if let Some(inner) = self.1.take() { + let (io, _) = conn.into_inner(); + inner.as_ref().borrow_mut().release_close(io); + } + } + pub(crate) fn release(&mut self, conn: Connection) { + if let Some(inner) = self.1.take() { + let (io, created) = conn.into_inner(); + inner + .as_ref() + .borrow_mut() + .release_conn(&self.0, io, created); + } + } +} + +impl Drop for Acquired +where + T: AsyncRead + AsyncWrite + 'static, +{ + fn drop(&mut self) { + if let Some(inner) = self.1.take() { + inner.as_ref().borrow_mut().release(); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 572c23ae1..32369d167 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,6 +83,7 @@ extern crate cookie; extern crate encoding; extern crate http as modhttp; extern crate httparse; +extern crate indexmap; extern crate mime; extern crate net2; extern crate percent_encoding; @@ -90,18 +91,24 @@ extern crate rand; extern crate serde; extern crate serde_json; extern crate serde_urlencoded; +extern crate slab; extern crate tokio; extern crate tokio_codec; extern crate tokio_current_thread; extern crate tokio_io; extern crate tokio_tcp; extern crate tokio_timer; +extern crate trust_dns_proto; +extern crate trust_dns_resolver; extern crate url as urlcrate; #[cfg(test)] #[macro_use] extern crate serde_derive; +#[cfg(feature = "ssl")] +extern crate openssl; + mod body; pub mod client; mod config;