diff --git a/src/client/connection.rs b/src/client/connection.rs index eec64267a..363a4ece9 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -5,14 +5,23 @@ use tokio_io::{AsyncRead, AsyncWrite}; use super::pool::Acquired; +pub trait Connection: AsyncRead + AsyncWrite + 'static { + /// Close connection + fn close(&mut self); + + /// Release connection to the connection pool + fn release(&mut self); +} + +#[doc(hidden)] /// HTTP client connection -pub struct Connection { - io: T, +pub struct IoConnection { + io: Option, created: time::Instant, pool: Option>, } -impl fmt::Debug for Connection +impl fmt::Debug for IoConnection where T: fmt::Debug, { @@ -21,59 +30,73 @@ where } } -impl Connection { +impl IoConnection { pub(crate) fn new(io: T, created: time::Instant, pool: Acquired) -> Self { - Connection { - io, + IoConnection { created, + io: Some(io), pool: Some(pool), } } /// Raw IO stream pub fn get_mut(&mut self) -> &mut T { - &mut self.io + self.io.as_mut().unwrap() } + pub(crate) fn into_inner(self) -> (T, time::Instant) { + (self.io.unwrap(), self.created) + } +} + +impl Connection for IoConnection { /// Close connection - pub fn close(mut self) { + fn close(&mut self) { if let Some(mut pool) = self.pool.take() { - pool.close(self) + if let Some(io) = self.io.take() { + pool.close(IoConnection { + io: Some(io), + created: self.created, + pool: None, + }) + } } } /// Release this connection to the connection pool - pub fn release(mut self) { + fn release(&mut self) { if let Some(mut pool) = self.pool.take() { - pool.release(self) + if let Some(io) = self.io.take() { + pool.release(IoConnection { + io: Some(io), + created: self.created, + pool: None, + }) + } } } - - pub(crate) fn into_inner(self) -> (T, time::Instant) { - (self.io, self.created) - } } -impl io::Read for Connection { +impl io::Read for IoConnection { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.io.read(buf) + self.io.as_mut().unwrap().read(buf) } } -impl AsyncRead for Connection {} +impl AsyncRead for IoConnection {} -impl io::Write for Connection { +impl io::Write for IoConnection { fn write(&mut self, buf: &[u8]) -> io::Result { - self.io.write(buf) + self.io.as_mut().unwrap().write(buf) } fn flush(&mut self) -> io::Result<()> { - self.io.flush() + self.io.as_mut().unwrap().flush() } } -impl AsyncWrite for Connection { +impl AsyncWrite for IoConnection { fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.shutdown() + self.io.as_mut().unwrap().shutdown() } } diff --git a/src/client/connector.rs b/src/client/connector.rs index 1eae135f2..818085214 100644 --- a/src/client/connector.rs +++ b/src/client/connector.rs @@ -11,7 +11,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; use super::connect::Connect; -use super::connection::Connection; +use super::connection::{Connection, IoConnection}; use super::error::ConnectorError; use super::pool::ConnectionPool; @@ -130,7 +130,7 @@ impl Connector { self, ) -> impl Service< Request = Connect, - Response = Connection, + Response = impl Connection, Error = ConnectorError, > + Clone { #[cfg(not(feature = "ssl"))] @@ -234,11 +234,11 @@ mod connect_impl { T: Service, { type Request = Connect; - type Response = Connection; + type Response = IoConnection; type Error = ConnectorError; type Future = Either< as Service>::Future, - FutureResult, ConnectorError>, + FutureResult, ConnectorError>, >; fn poll_ready(&mut self) -> Poll<(), Self::Error> { @@ -324,7 +324,7 @@ mod connect_impl { >, { type Request = Connect; - type Response = IoEither, Connection>; + type Response = IoEither, IoConnection>; type Error = ConnectorError; type Future = Either< FutureResult, @@ -342,13 +342,13 @@ mod connect_impl { if let Err(e) = req.validate() { Either::A(err(e)) } else if req.is_secure() { - Either::B(Either::A(InnerConnectorResponseA { - fut: self.tcp_pool.call(req), + Either::B(Either::B(InnerConnectorResponseB { + fut: self.ssl_pool.call(req), _t: PhantomData, })) } else { - Either::B(Either::B(InnerConnectorResponseB { - fut: self.ssl_pool.call(req), + Either::B(Either::A(InnerConnectorResponseA { + fut: self.tcp_pool.call(req), _t: PhantomData, })) } @@ -370,7 +370,7 @@ mod connect_impl { Io1: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + 'static, { - type Item = IoEither, Connection>; + type Item = IoEither, IoConnection>; type Error = ConnectorError; fn poll(&mut self) -> Poll { @@ -396,7 +396,7 @@ mod connect_impl { Io1: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + 'static, { - type Item = IoEither, Connection>; + type Item = IoEither, IoConnection>; type Error = ConnectorError; fn poll(&mut self) -> Poll { @@ -413,10 +413,30 @@ pub(crate) enum IoEither { B(Io2), } +impl Connection for IoEither +where + Io1: Connection, + Io2: Connection, +{ + fn close(&mut self) { + match self { + IoEither::A(ref mut io) => io.close(), + IoEither::B(ref mut io) => io.close(), + } + } + + fn release(&mut self) { + match self { + IoEither::A(ref mut io) => io.release(), + IoEither::B(ref mut io) => io.release(), + } + } +} + impl io::Read for IoEither where - Io1: io::Read, - Io2: io::Read, + Io1: Connection, + Io2: Connection, { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { @@ -428,8 +448,8 @@ where impl AsyncRead for IoEither where - Io1: AsyncRead, - Io2: AsyncRead, + Io1: Connection, + Io2: Connection, { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { match self { @@ -441,8 +461,8 @@ where impl AsyncWrite for IoEither where - Io1: AsyncWrite, - Io2: AsyncWrite, + Io1: Connection, + Io2: Connection, { fn shutdown(&mut self) -> Poll<(), io::Error> { match self { @@ -468,8 +488,8 @@ where impl io::Write for IoEither where - Io1: io::Write, - Io2: io::Write, + Io1: Connection, + Io2: Connection, { fn flush(&mut self) -> io::Result<()> { match self { diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index 24ec8366d..dc6a644dc 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -15,27 +15,32 @@ use body::{BodyType, MessageBody, PayloadStream}; use error::PayloadError; use h1; -pub fn send_request( +pub(crate) fn send_request( head: RequestHead, body: B, connector: &mut T, ) -> impl Future where - T: Service, Error = ConnectorError>, + T: Service, B: MessageBody, - Io: AsyncRead + AsyncWrite + 'static, + I: Connection, { let tp = body.tp(); connector + // connect to the host .call(Connect::new(head.uri.clone())) .from_err() + // create Framed and send reqest .map(|io| Framed::new(io, h1::ClientCodec::default())) .and_then(|framed| framed.send((head, tp).into()).from_err()) + // send request body .and_then(move |framed| match body.tp() { BodyType::None | BodyType::Zero => Either::A(ok(framed)), _ => Either::B(SendBody::new(body, framed)), - }).and_then(|framed| { + }) + // read response and init read body + .and_then(|framed| { framed .into_future() .map_err(|(e, _)| SendRequestError::from(e)) @@ -55,19 +60,20 @@ where }) } -struct SendBody { +/// Future responsible for sending request body to the peer +struct SendBody { body: Option, - framed: Option, h1::ClientCodec>>, + framed: Option>, write_buf: VecDeque>, flushed: bool, } -impl SendBody +impl SendBody where - Io: AsyncRead + AsyncWrite + 'static, + I: AsyncRead + AsyncWrite + 'static, B: MessageBody, { - fn new(body: B, framed: Framed, h1::ClientCodec>) -> Self { + fn new(body: B, framed: Framed) -> Self { SendBody { body: Some(body), framed: Some(framed), @@ -77,12 +83,12 @@ where } } -impl Future for SendBody +impl Future for SendBody where - Io: AsyncRead + AsyncWrite + 'static, + I: Connection, B: MessageBody, { - type Item = Framed, h1::ClientCodec>; + type Item = Framed; type Error = SendRequestError; fn poll(&mut self) -> Poll { @@ -135,7 +141,7 @@ impl Stream for EmptyPayload { } pub(crate) struct Payload { - framed: Option, h1::ClientPayloadCodec>>, + framed: Option>, } impl Payload<()> { @@ -144,15 +150,15 @@ impl Payload<()> { } } -impl Payload { - fn stream(framed: Framed, h1::ClientCodec>) -> PayloadStream { +impl Payload { + fn stream(framed: Framed) -> PayloadStream { Box::new(Payload { framed: Some(framed.map_codec(|codec| codec.into_payload_codec())), }) } } -impl Stream for Payload { +impl Stream for Payload { type Item = Bytes; type Error = PayloadError; @@ -170,11 +176,11 @@ impl Stream for Payload { } } -fn release_connection(framed: Framed, U>) +fn release_connection(framed: Framed) where - T: AsyncRead + AsyncWrite + 'static, + T: Connection, { - let parts = framed.into_parts(); + let mut parts = framed.into_parts(); if parts.read_buf.is_empty() && parts.write_buf.is_empty() { parts.io.release() } else { diff --git a/src/client/pool.rs b/src/client/pool.rs index 25296a6dd..44008f346 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -17,7 +17,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::{sleep, Delay}; use super::connect::Connect; -use super::connection::Connection; +use super::connection::IoConnection; use super::error::ConnectorError; #[derive(Hash, Eq, PartialEq, Clone, Debug)] @@ -89,10 +89,10 @@ where T: Service, { type Request = Connect; - type Response = Connection; + type Response = IoConnection; type Error = ConnectorError; type Future = Either< - FutureResult, ConnectorError>, + FutureResult, ConnectorError>, Either, OpenConnection>, >; @@ -107,7 +107,7 @@ where match self.1.as_ref().borrow_mut().acquire(&key) { Acquire::Acquired(io, created) => { // use existing connection - Either::A(ok(Connection::new( + Either::A(ok(IoConnection::new( io, created, Acquired(key, Some(self.1.clone())), @@ -142,7 +142,7 @@ where { key: Key, token: usize, - rx: oneshot::Receiver, ConnectorError>>, + rx: oneshot::Receiver, ConnectorError>>, inner: Option>>>, } @@ -163,7 +163,7 @@ impl Future for WaitForConnection where Io: AsyncRead + AsyncWrite, { - type Item = Connection; + type Item = IoConnection; type Error = ConnectorError; fn poll(&mut self) -> Poll { @@ -226,7 +226,7 @@ where F: Future, Io: AsyncRead + AsyncWrite, { - type Item = Connection; + type Item = IoConnection; type Error = ConnectorError; fn poll(&mut self) -> Poll { @@ -234,7 +234,7 @@ where Err(err) => Err(err.into()), Ok(Async::Ready((_, io))) => { let _ = self.inner.take(); - Ok(Async::Ready(Connection::new( + Ok(Async::Ready(IoConnection::new( io, Instant::now(), Acquired(self.key.clone(), self.inner.clone()), @@ -251,7 +251,7 @@ where { fut: F, key: Key, - rx: Option, ConnectorError>>>, + rx: Option, ConnectorError>>>, inner: Option>>>, } @@ -262,7 +262,7 @@ where { fn spawn( key: Key, - rx: oneshot::Sender, ConnectorError>>, + rx: oneshot::Sender, ConnectorError>>, inner: Rc>>, fut: F, ) { @@ -308,7 +308,7 @@ where Ok(Async::Ready((_, io))) => { let _ = self.inner.take(); if let Some(rx) = self.rx.take() { - let _ = rx.send(Ok(Connection::new( + let _ = rx.send(Ok(IoConnection::new( io, Instant::now(), Acquired(self.key.clone(), self.inner.clone()), @@ -336,7 +336,7 @@ pub(crate) struct Inner { available: HashMap>>, waiters: Slab<( Connect, - oneshot::Sender, ConnectorError>>, + oneshot::Sender, ConnectorError>>, )>, waiters_queue: IndexSet<(Key, usize)>, task: AtomicTask, @@ -378,7 +378,7 @@ where &mut self, connect: Connect, ) -> ( - oneshot::Receiver, ConnectorError>>, + oneshot::Receiver, ConnectorError>>, usize, ) { let (tx, rx) = oneshot::channel(); @@ -479,7 +479,7 @@ where Acquire::NotAvailable => break, Acquire::Acquired(io, created) => { let (_, tx) = inner.waiters.remove(token); - if let Err(conn) = tx.send(Ok(Connection::new( + if let Err(conn) = tx.send(Ok(IoConnection::new( io, created, Acquired(key.clone(), Some(self.inner.clone())), @@ -546,13 +546,13 @@ impl Acquired where T: AsyncRead + AsyncWrite + 'static, { - pub(crate) fn close(&mut self, conn: Connection) { + pub(crate) fn close(&mut self, conn: IoConnection) { if let Some(inner) = self.1.take() { let (io, _) = conn.into_inner(); inner.as_ref().borrow_mut().release_close(io); } } - pub(crate) fn release(&mut self, conn: Connection) { + pub(crate) fn release(&mut self, conn: IoConnection) { if let Some(inner) = self.1.take() { let (io, created) = conn.into_inner(); inner diff --git a/src/client/request.rs b/src/client/request.rs index 603374135..602abed59 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -7,7 +7,6 @@ use bytes::{BufMut, Bytes, BytesMut}; use cookie::{Cookie, CookieJar}; use futures::{Future, Stream}; use percent_encoding::{percent_encode, USERINFO_ENCODE_SET}; -use tokio_io::{AsyncRead, AsyncWrite}; use urlcrate::Url; use body::{MessageBody, MessageBodyStream}; @@ -176,13 +175,13 @@ where // Send request /// /// This method returns a future that resolves to a ClientResponse - pub fn send( + pub fn send( self, connector: &mut T, ) -> impl Future where - T: Service, Error = ConnectorError>, - Io: AsyncRead + AsyncWrite + 'static, + T: Service, + I: Connection, { pipeline::send_request(self.head, self.body, connector) }