diff --git a/src/body.rs b/src/body.rs index c78ea8172..e001273c4 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,13 +1,39 @@ -use bytes::{Bytes, BytesMut}; -use futures::Stream; use std::sync::Arc; use std::{fmt, mem}; +use bytes::{Bytes, BytesMut}; +use futures::{Async, Poll, Stream}; + use error::Error; /// Type represent streaming body pub type BodyStream = Box>; +/// Different type of bory +pub enum BodyType { + None, + Zero, + Sized(usize), + Unsized, +} + +/// Type that provides this trait can be streamed to a peer. +pub trait MessageBody { + fn tp(&self) -> BodyType; + + fn poll_next(&mut self) -> Poll, Error>; +} + +impl MessageBody for () { + fn tp(&self) -> BodyType { + BodyType::Zero + } + + fn poll_next(&mut self) -> Poll, Error> { + Ok(Async::Ready(None)) + } +} + /// Represents various types of http message body. pub enum Body { /// Empty response. `Content-Length` header is set to `0` @@ -241,6 +267,112 @@ impl AsRef<[u8]> for Binary { } } +impl MessageBody for Bytes { + fn tp(&self) -> BodyType { + BodyType::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(mem::replace(self, Bytes::new())))) + } + } +} + +impl MessageBody for &'static str { + fn tp(&self) -> BodyType { + BodyType::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from_static( + mem::replace(self, "").as_ref(), + )))) + } + } +} + +impl MessageBody for &'static [u8] { + fn tp(&self) -> BodyType { + BodyType::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from_static(mem::replace( + self, b"", + ))))) + } + } +} + +impl MessageBody for Vec { + fn tp(&self) -> BodyType { + BodyType::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from(mem::replace( + self, + Vec::new(), + ))))) + } + } +} + +impl MessageBody for String { + fn tp(&self) -> BodyType { + BodyType::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from( + mem::replace(self, String::new()).into_bytes(), + )))) + } + } +} + +#[doc(hidden)] +pub struct MessageBodyStream { + stream: S, +} + +impl MessageBodyStream +where + S: Stream, +{ + pub fn new(stream: S) -> Self { + MessageBodyStream { stream } + } +} + +impl MessageBody for MessageBodyStream +where + S: Stream, +{ + fn tp(&self) -> BodyType { + BodyType::Unsized + } + + fn poll_next(&mut self) -> Poll, Error> { + self.stream.poll() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/client/connect.rs b/src/client/connect.rs index 40c3e8ec9..a445228e3 100644 --- a/src/client/connect.rs +++ b/src/client/connect.rs @@ -14,8 +14,13 @@ pub struct Connect { } impl Connect { + /// Create `Connect` message for specified `Uri` + pub fn new(uri: Uri) -> Connect { + Connect { uri } + } + /// Construct `Uri` instance and create `Connect` message. - pub fn new(uri: U) -> Result + pub fn try_from(uri: U) -> Result where Uri: HttpTryFrom, { @@ -24,11 +29,6 @@ impl Connect { }) } - /// Create `Connect` message for specified `Uri` - pub fn with(uri: Uri) -> Connect { - Connect { uri } - } - pub(crate) fn is_secure(&self) -> bool { if let Some(scheme) = self.uri.scheme_part() { scheme.as_str() == "https" diff --git a/src/client/connection.rs b/src/client/connection.rs index 294e100c8..eec64267a 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -6,7 +6,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use super::pool::Acquired; /// HTTP client connection -pub struct Connection { +pub struct Connection { io: T, created: time::Instant, pool: Option>, @@ -14,7 +14,7 @@ pub struct Connection { impl fmt::Debug for Connection where - T: AsyncRead + AsyncWrite + fmt::Debug + 'static, + T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Connection {:?}", self.io) diff --git a/src/client/connector.rs b/src/client/connector.rs index 97da074d1..1eae135f2 100644 --- a/src/client/connector.rs +++ b/src/client/connector.rs @@ -130,7 +130,7 @@ impl Connector { self, ) -> impl Service< Request = Connect, - Response = impl AsyncRead + AsyncWrite + fmt::Debug, + Response = Connection, Error = ConnectorError, > + Clone { #[cfg(not(feature = "ssl"))] diff --git a/src/client/error.rs b/src/client/error.rs index ba6407230..2c4753642 100644 --- a/src/client/error.rs +++ b/src/client/error.rs @@ -17,6 +17,8 @@ use native_tls::Error as SslError; ))] use std::io::Error as SslError; +use error::{Error, ParseError}; + /// A set of errors that can occur while connecting to an HTTP host #[derive(Fail, Debug)] pub enum ConnectorError { @@ -75,3 +77,41 @@ impl From for ConnectorError { ConnectorError::Resolver(err) } } + +/// A set of errors that can occur during request sending and response reading +#[derive(Debug)] +pub enum SendRequestError { + /// Failed to connect to host + // #[fail(display = "Failed to connect to host: {}", _0)] + Connector(ConnectorError), + /// Error sending request + Send(io::Error), + /// Error parsing response + Response(ParseError), + /// Error sending request body + Body(Error), +} + +impl From for SendRequestError { + fn from(err: io::Error) -> SendRequestError { + SendRequestError::Send(err) + } +} + +impl From for SendRequestError { + fn from(err: ConnectorError) -> SendRequestError { + SendRequestError::Connector(err) + } +} + +impl From for SendRequestError { + fn from(err: ParseError) -> SendRequestError { + SendRequestError::Response(err) + } +} + +impl From for SendRequestError { + fn from(err: Error) -> SendRequestError { + SendRequestError::Body(err) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 714e6c694..da0cbc670 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,12 +3,14 @@ mod connect; mod connection; mod connector; mod error; +mod pipeline; mod pool; mod request; mod response; pub use self::connect::Connect; +pub use self::connection::Connection; pub use self::connector::Connector; -pub use self::error::{ConnectorError, InvalidUrlKind}; -pub use self::request::{ClientRequest, ClientRequestBuilder}; +pub use self::error::{ConnectorError, InvalidUrlKind, SendRequestError}; +pub use self::request::{ClientRequest, ClientRequestBuilder, RequestHead}; pub use self::response::ClientResponse; diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs new file mode 100644 index 000000000..fff900131 --- /dev/null +++ b/src/client/pipeline.rs @@ -0,0 +1,174 @@ +use std::collections::VecDeque; + +use actix_net::codec::Framed; +use actix_net::service::Service; +use bytes::Bytes; +use futures::future::{err, ok, Either}; +use futures::{Async, Future, Poll, Sink, Stream}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use super::error::{ConnectorError, SendRequestError}; +use super::request::RequestHead; +use super::response::ClientResponse; +use super::{Connect, Connection}; +use body::{BodyStream, BodyType, MessageBody}; +use error::Error; +use h1; + +pub fn send_request( + head: RequestHead, + body: B, + connector: &mut T, +) -> impl Future +where + T: Service, Error = ConnectorError>, + B: MessageBody, + Io: AsyncRead + AsyncWrite + 'static, +{ + let tp = body.tp(); + + connector + .call(Connect::new(head.uri.clone())) + .from_err() + .map(|io| Framed::new(io, h1::ClientCodec::default())) + .and_then(|framed| framed.send((head, tp).into()).from_err()) + .and_then(move |framed| match body.tp() { + BodyType::None | BodyType::Zero => Either::A(ok(framed)), + _ => Either::B(SendBody::new(body, framed)), + }).and_then(|framed| { + framed + .into_future() + .map_err(|(e, _)| SendRequestError::from(e)) + .and_then(|(item, framed)| { + if let Some(item) = item { + let mut res = item.into_item().unwrap(); + match framed.get_codec().message_type() { + h1::MessageType::None => release_connection(framed), + _ => res.payload = Some(Payload::stream(framed)), + } + ok(res) + } else { + err(ConnectorError::Disconnected.into()) + } + }) + }) +} + +struct SendBody { + body: Option, + framed: Option, h1::ClientCodec>>, + write_buf: VecDeque>, + flushed: bool, +} + +impl SendBody +where + Io: AsyncRead + AsyncWrite + 'static, + B: MessageBody, +{ + fn new(body: B, framed: Framed, h1::ClientCodec>) -> Self { + SendBody { + body: Some(body), + framed: Some(framed), + write_buf: VecDeque::new(), + flushed: true, + } + } +} + +impl Future for SendBody +where + Io: AsyncRead + AsyncWrite + 'static, + B: MessageBody, +{ + type Item = Framed, h1::ClientCodec>; + type Error = SendRequestError; + + fn poll(&mut self) -> Poll { + let mut body_ready = true; + loop { + while body_ready + && self.body.is_some() + && !self.framed.as_ref().unwrap().is_write_buf_full() + { + match self.body.as_mut().unwrap().poll_next()? { + Async::Ready(None) => { + self.flushed = false; + self.framed + .as_mut() + .unwrap() + .start_send(h1::Message::Chunk(None))?; + break; + } + Async::Ready(Some(chunk)) => { + self.flushed = false; + self.framed + .as_mut() + .unwrap() + .start_send(h1::Message::Chunk(Some(chunk)))?; + } + Async::NotReady => body_ready = false, + } + } + + if !self.flushed { + match self.framed.as_mut().unwrap().poll_complete()? { + Async::Ready(_) => { + self.flushed = true; + continue; + } + Async::NotReady => return Ok(Async::NotReady), + } + } + + if self.body.is_none() { + return Ok(Async::Ready(self.framed.take().unwrap())); + } + return Ok(Async::NotReady); + } + } +} + +struct Payload { + framed: Option, h1::ClientCodec>>, +} + +impl Payload { + fn stream(framed: Framed, h1::ClientCodec>) -> BodyStream { + Box::new(Payload { + framed: Some(framed), + }) + } +} + +impl Stream for Payload { + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + match self.framed.as_mut().unwrap().poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(Some(chunk)) => match chunk { + h1::Message::Chunk(Some(chunk)) => Ok(Async::Ready(Some(chunk))), + h1::Message::Chunk(None) => { + release_connection(self.framed.take().unwrap()); + Ok(Async::Ready(None)) + } + h1::Message::Item(_) => unreachable!(), + }, + Async::Ready(None) => Ok(Async::Ready(None)), + } + } +} + +fn release_connection(framed: Framed, h1::ClientCodec>) +where + Io: AsyncRead + AsyncWrite + 'static, +{ + let parts = framed.into_parts(); + if parts.read_buf.is_empty() && parts.write_buf.is_empty() { + parts.io.release() + } else { + parts.io.close() + } +} diff --git a/src/client/pool.rs b/src/client/pool.rs index 6ff8c96ce..25296a6dd 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -327,10 +327,7 @@ enum Acquire { NotAvailable, } -pub(crate) struct Inner -where - Io: AsyncRead + AsyncWrite + 'static, -{ +pub(crate) struct Inner { conn_lifetime: Duration, conn_keep_alive: Duration, disconnect_timeout: Option, @@ -345,6 +342,33 @@ where task: AtomicTask, } +impl Inner { + fn reserve(&mut self) { + self.acquired += 1; + } + + fn release(&mut self) { + self.acquired -= 1; + } + + fn release_waiter(&mut self, key: &Key, token: usize) { + self.waiters.remove(token); + self.waiters_queue.remove(&(key.clone(), token)); + } + + fn release_conn(&mut self, key: &Key, io: Io, created: Instant) { + self.acquired -= 1; + self.available + .entry(key.clone()) + .or_insert_with(VecDeque::new) + .push_back(AvailableConnection { + io, + created, + used: Instant::now(), + }); + } +} + impl Inner where Io: AsyncRead + AsyncWrite + 'static, @@ -367,11 +391,6 @@ where (rx, token) } - fn release_waiter(&mut self, key: &Key, token: usize) { - self.waiters.remove(token); - self.waiters_queue.remove(&(key.clone(), token)); - } - fn acquire(&mut self, key: &Key) -> Acquire { // check limits if self.limit > 0 && self.acquired >= self.limit { @@ -412,26 +431,6 @@ where Acquire::Available } - fn reserve(&mut self) { - self.acquired += 1; - } - - fn release(&mut self) { - self.acquired -= 1; - } - - fn release_conn(&mut self, key: &Key, io: Io, created: Instant) { - self.acquired -= 1; - self.available - .entry(key.clone()) - .or_insert_with(VecDeque::new) - .push_back(AvailableConnection { - io, - created, - used: Instant::now(), - }); - } - fn release_close(&mut self, io: Io) { self.acquired -= 1; if let Some(timeout) = self.disconnect_timeout { @@ -541,10 +540,7 @@ where } } -pub(crate) struct Acquired( - Key, - Option>>>, -); +pub(crate) struct Acquired(Key, Option>>>); impl Acquired where @@ -567,10 +563,7 @@ where } } -impl Drop for Acquired -where - T: AsyncRead + AsyncWrite + 'static, -{ +impl Drop for Acquired { fn drop(&mut self) { if let Some(inner) = self.1.take() { inner.as_ref().borrow_mut().release(); diff --git a/src/client/request.rs b/src/client/request.rs index 8c7949336..603374135 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -2,17 +2,25 @@ use std::fmt; use std::fmt::Write as FmtWrite; use std::io::Write; -use bytes::{BufMut, BytesMut}; +use actix_net::service::Service; +use bytes::{BufMut, Bytes, BytesMut}; use cookie::{Cookie, CookieJar}; +use futures::{Future, Stream}; use percent_encoding::{percent_encode, USERINFO_ENCODE_SET}; +use tokio_io::{AsyncRead, AsyncWrite}; use urlcrate::Url; +use body::{MessageBody, MessageBodyStream}; +use error::Error; use header::{self, Header, IntoHeaderValue}; use http::{ uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method, Uri, Version, }; +use super::response::ClientResponse; +use super::{pipeline, Connect, Connection, ConnectorError, SendRequestError}; + /// An HTTP Client Request /// /// ```rust @@ -38,29 +46,40 @@ use http::{ /// ); /// } /// ``` -pub struct ClientRequest { - uri: Uri, - method: Method, - version: Version, - headers: HeaderMap, - chunked: bool, - upgrade: bool, +pub struct ClientRequest { + head: RequestHead, + body: B, } -impl Default for ClientRequest { - fn default() -> ClientRequest { - ClientRequest { +pub struct RequestHead { + pub uri: Uri, + pub method: Method, + pub version: Version, + pub headers: HeaderMap, +} + +impl Default for RequestHead { + fn default() -> RequestHead { + RequestHead { uri: Uri::default(), method: Method::default(), version: Version::HTTP_11, headers: HeaderMap::with_capacity(16), - chunked: false, - upgrade: false, } } } -impl ClientRequest { +impl ClientRequest<()> { + /// Create client request builder + pub fn build() -> ClientRequestBuilder { + ClientRequestBuilder { + head: Some(RequestHead::default()), + err: None, + cookies: None, + default_headers: true, + } + } + /// Create request builder for `GET` request pub fn get>(uri: U) -> ClientRequestBuilder { let mut builder = ClientRequest::build(); @@ -97,87 +116,90 @@ impl ClientRequest { } } -impl ClientRequest { - /// Create client request builder - pub fn build() -> ClientRequestBuilder { - ClientRequestBuilder { - request: Some(ClientRequest::default()), - err: None, - cookies: None, - default_headers: true, - } - } - +impl ClientRequest +where + B: MessageBody, +{ /// Get the request URI #[inline] pub fn uri(&self) -> &Uri { - &self.uri + &self.head.uri } /// Set client request URI #[inline] pub fn set_uri(&mut self, uri: Uri) { - self.uri = uri + self.head.uri = uri } /// Get the request method #[inline] pub fn method(&self) -> &Method { - &self.method + &self.head.method } /// Set HTTP `Method` for the request #[inline] pub fn set_method(&mut self, method: Method) { - self.method = method + self.head.method = method } /// Get HTTP version for the request #[inline] pub fn version(&self) -> Version { - self.version + self.head.version } /// Set http `Version` for the request #[inline] pub fn set_version(&mut self, version: Version) { - self.version = version + self.head.version = version } /// Get the headers from the request #[inline] pub fn headers(&self) -> &HeaderMap { - &self.headers + &self.head.headers } /// Get a mutable reference to the headers #[inline] pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.headers + &mut self.head.headers } - /// is chunked encoding enabled - #[inline] - pub fn chunked(&self) -> bool { - self.chunked + /// Deconstruct ClientRequest to a RequestHead and body tuple + pub fn into_parts(self) -> (RequestHead, B) { + (self.head, self.body) } - /// is upgrade request - #[inline] - pub fn upgrade(&self) -> bool { - self.upgrade + // Send request + /// + /// This method returns a future that resolves to a ClientResponse + pub fn send( + self, + connector: &mut T, + ) -> impl Future + where + T: Service, Error = ConnectorError>, + Io: AsyncRead + AsyncWrite + 'static, + { + pipeline::send_request(self.head, self.body, connector) } } -impl fmt::Debug for ClientRequest { +impl fmt::Debug for ClientRequest +where + B: MessageBody, +{ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!( f, "\nClientRequest {:?} {}:{}", - self.version, self.method, self.uri + self.head.version, self.head.method, self.head.uri )?; writeln!(f, " headers:")?; - for (key, val) in self.headers.iter() { + for (key, val) in self.head.headers.iter() { writeln!(f, " {:?}: {:?}", key, val)?; } Ok(()) @@ -189,7 +211,7 @@ impl fmt::Debug for ClientRequest { /// This type can be used to construct an instance of `ClientRequest` through a /// builder-like pattern. pub struct ClientRequestBuilder { - request: Option, + head: Option, err: Option, cookies: Option, default_headers: bool, @@ -208,7 +230,7 @@ impl ClientRequestBuilder { fn _uri(&mut self, url: &str) -> &mut Self { match Uri::try_from(url) { Ok(uri) => { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { parts.uri = uri; } } @@ -220,7 +242,7 @@ impl ClientRequestBuilder { /// Set HTTP method of this request. #[inline] pub fn method(&mut self, method: Method) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { parts.method = method; } self @@ -229,7 +251,7 @@ impl ClientRequestBuilder { /// Set HTTP method of this request. #[inline] pub fn get_method(&mut self) -> &Method { - let parts = self.request.as_ref().expect("cannot reuse request builder"); + let parts = self.head.as_ref().expect("cannot reuse request builder"); &parts.method } @@ -238,7 +260,7 @@ impl ClientRequestBuilder { /// By default requests's HTTP version depends on network stream #[inline] pub fn version(&mut self, version: Version) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { parts.version = version; } self @@ -263,7 +285,7 @@ impl ClientRequestBuilder { /// ``` #[doc(hidden)] pub fn set(&mut self, hdr: H) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { match hdr.try_into() { Ok(value) => { parts.headers.insert(H::name(), value); @@ -299,7 +321,7 @@ impl ClientRequestBuilder { HeaderName: HttpTryFrom, V: IntoHeaderValue, { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { match HeaderName::try_from(key) { Ok(key) => match value.try_into() { Ok(value) => { @@ -319,7 +341,7 @@ impl ClientRequestBuilder { HeaderName: HttpTryFrom, V: IntoHeaderValue, { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { match HeaderName::try_from(key) { Ok(key) => match value.try_into() { Ok(value) => { @@ -339,7 +361,7 @@ impl ClientRequestBuilder { HeaderName: HttpTryFrom, V: IntoHeaderValue, { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { match HeaderName::try_from(key) { Ok(key) => if !parts.headers.contains_key(&key) { match value.try_into() { @@ -357,11 +379,12 @@ impl ClientRequestBuilder { /// Enable connection upgrade #[inline] - pub fn upgrade(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.upgrade = true; - } - self + pub fn upgrade(&mut self, value: V) -> &mut Self + where + V: IntoHeaderValue, + { + self.set_header(header::UPGRADE, value) + .set_header(header::CONNECTION, "upgrade") } /// Set request's content type @@ -370,7 +393,7 @@ impl ClientRequestBuilder { where HeaderValue: HttpTryFrom, { - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { match HeaderValue::try_from(value) { Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); @@ -454,14 +477,17 @@ impl ClientRequestBuilder { /// Set a body and generate `ClientRequest`. /// /// `ClientRequestBuilder` can not be used after this call. - pub fn finish(&mut self) -> Result { + pub fn body( + &mut self, + body: B, + ) -> Result, HttpError> { if let Some(e) = self.err.take() { return Err(e); } if self.default_headers { // enable br only for https - let https = if let Some(parts) = parts(&mut self.request, &self.err) { + let https = if let Some(parts) = parts(&mut self.head, &self.err) { parts .uri .scheme_part() @@ -478,7 +504,7 @@ impl ClientRequestBuilder { } // set request host header - if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(parts) = parts(&mut self.head, &self.err) { if let Some(host) = parts.uri.host() { if !parts.headers.contains_key(header::HOST) { let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); @@ -505,7 +531,7 @@ impl ClientRequestBuilder { ); } - let mut request = self.request.take().expect("cannot reuse request builder"); + let mut head = self.head.take().expect("cannot reuse request builder"); // set cookies if let Some(ref mut jar) = self.cookies { @@ -515,18 +541,38 @@ impl ClientRequestBuilder { let value = percent_encode(c.value().as_bytes(), USERINFO_ENCODE_SET); let _ = write!(&mut cookie, "; {}={}", name, value); } - request.headers.insert( + head.headers.insert( header::COOKIE, HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), ); } - Ok(request) + Ok(ClientRequest { head, body }) + } + + /// Set an streaming body and generate `ClientRequest`. + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn stream( + &mut self, + stream: S, + ) -> Result, HttpError> + where + S: Stream, + { + self.body(MessageBodyStream::new(stream)) + } + + /// Set an empty body and generate `ClientRequest`. + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn finish(&mut self) -> Result, HttpError> { + self.body(()) } /// This method construct new `ClientRequestBuilder` pub fn take(&mut self) -> ClientRequestBuilder { ClientRequestBuilder { - request: self.request.take(), + head: self.head.take(), err: self.err.take(), cookies: self.cookies.take(), default_headers: self.default_headers, @@ -536,9 +582,9 @@ impl ClientRequestBuilder { #[inline] fn parts<'a>( - parts: &'a mut Option, + parts: &'a mut Option, err: &Option, -) -> Option<&'a mut ClientRequest> { +) -> Option<&'a mut RequestHead> { if err.is_some() { return None; } @@ -547,7 +593,7 @@ fn parts<'a>( impl fmt::Debug for ClientRequestBuilder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(ref parts) = self.request { + if let Some(ref parts) = self.head { writeln!( f, "\nClientRequestBuilder {:?} {}:{}", diff --git a/src/client/response.rs b/src/client/response.rs index 627a1c78e..0d5a87a0d 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -2,35 +2,38 @@ use std::cell::{Cell, Ref, RefCell, RefMut}; use std::fmt; use std::rc::Rc; +use bytes::Bytes; +use futures::{Async, Poll, Stream}; use http::{HeaderMap, Method, StatusCode, Version}; +use body::BodyStream; +use error::Error; use extensions::Extensions; -use httpmessage::HttpMessage; -use payload::Payload; use request::{Message, MessageFlags, MessagePool}; use uri::Url; /// Client Response pub struct ClientResponse { pub(crate) inner: Rc, + pub(crate) payload: Option, } -impl HttpMessage for ClientResponse { - type Stream = Payload; +// impl HttpMessage for ClientResponse { +// type Stream = Payload; - fn headers(&self) -> &HeaderMap { - &self.inner.headers - } +// fn headers(&self) -> &HeaderMap { +// &self.inner.headers +// } - #[inline] - fn payload(&self) -> Payload { - if let Some(payload) = self.inner.payload.borrow_mut().take() { - payload - } else { - Payload::empty() - } - } -} +// #[inline] +// fn payload(&self) -> Payload { +// if let Some(payload) = self.inner.payload.borrow_mut().take() { +// payload +// } else { +// Payload::empty() +// } +// } +// } impl ClientResponse { /// Create new Request instance @@ -52,6 +55,7 @@ impl ClientResponse { payload: RefCell::new(None), extensions: RefCell::new(Extensions::new()), }), + payload: None, } } @@ -108,6 +112,19 @@ impl ClientResponse { } } +impl Stream for ClientResponse { + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + if let Some(ref mut payload) = self.payload { + payload.poll() + } else { + Ok(Async::Ready(None)) + } + } +} + impl Drop for ClientResponse { fn drop(&mut self) { if Rc::strong_count(&self.inner) == 1 { diff --git a/src/h1/client.rs b/src/h1/client.rs index b55af185f..9ace98e0e 100644 --- a/src/h1/client.rs +++ b/src/h1/client.rs @@ -7,12 +7,14 @@ use tokio_codec::{Decoder, Encoder}; use super::decoder::{PayloadDecoder, PayloadItem, PayloadType, ResponseDecoder}; use super::encoder::{RequestEncoder, ResponseLength}; use super::{Message, MessageType}; -use body::{Binary, Body}; -use client::{ClientRequest, ClientResponse}; +use body::{Binary, Body, BodyType}; +use client::{ClientResponse, RequestHead}; use config::ServiceConfig; use error::ParseError; use helpers; -use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use http::header::{ + HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, +}; use http::{Method, Version}; use request::MessagePool; @@ -22,7 +24,7 @@ bitflags! { const UPGRADE = 0b0000_0010; const KEEPALIVE = 0b0000_0100; const KEEPALIVE_ENABLED = 0b0000_1000; - const UNHANDLED = 0b0001_0000; + const STREAM = 0b0001_0000; } } @@ -86,8 +88,8 @@ impl ClientCodec { /// Check last request's message type pub fn message_type(&self) -> MessageType { - if self.flags.contains(Flags::UNHANDLED) { - MessageType::Unhandled + if self.flags.contains(Flags::STREAM) { + MessageType::Stream } else if self.payload.is_none() { MessageType::None } else { @@ -96,38 +98,31 @@ impl ClientCodec { } /// prepare transfer encoding - pub fn prepare_te(&mut self, res: &mut ClientRequest) { + pub fn prepare_te(&mut self, head: &mut RequestHead, btype: BodyType) { self.te - .update(res, self.flags.contains(Flags::HEAD), self.version); + .update(head, self.flags.contains(Flags::HEAD), self.version); } fn encode_response( &mut self, - msg: ClientRequest, + msg: RequestHead, + btype: BodyType, buffer: &mut BytesMut, ) -> io::Result<()> { - // Connection upgrade - if msg.upgrade() { - self.flags.insert(Flags::UPGRADE); - } - // render message { // status line writeln!( Writer(buffer), "{} {} {:?}\r", - msg.method(), - msg.uri() - .path_and_query() - .map(|u| u.as_str()) - .unwrap_or("/"), - msg.version() + msg.method, + msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), + msg.version ).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; // write headers - buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE); - for (key, value) in msg.headers() { + buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE); + for (key, value) in &msg.headers { let v = value.as_ref(); let k = key.as_str().as_bytes(); buffer.reserve(k.len() + v.len() + 4); @@ -135,10 +130,15 @@ impl ClientCodec { buffer.put_slice(b": "); buffer.put_slice(v); buffer.put_slice(b"\r\n"); + + // Connection upgrade + if key == UPGRADE { + self.flags.insert(Flags::UPGRADE); + } } // set date header - if !msg.headers().contains_key(DATE) { + if !msg.headers.contains_key(DATE) { self.config.set_date(buffer); } else { buffer.extend_from_slice(b"\r\n"); @@ -160,8 +160,6 @@ impl Decoder for ClientCodec { Some(PayloadItem::Eof) => Some(Message::Chunk(None)), None => None, }) - } else if self.flags.contains(Flags::UNHANDLED) { - Ok(None) } else if let Some((req, payload)) = self.decoder.decode(src)? { self.flags .set(Flags::HEAD, req.inner.method == Method::HEAD); @@ -172,9 +170,9 @@ impl Decoder for ClientCodec { match payload { PayloadType::None => self.payload = None, PayloadType::Payload(pl) => self.payload = Some(pl), - PayloadType::Unhandled => { - self.payload = None; - self.flags.insert(Flags::UNHANDLED); + PayloadType::Stream(pl) => { + self.payload = Some(pl); + self.flags.insert(Flags::STREAM); } }; Ok(Some(Message::Item(req))) @@ -185,7 +183,7 @@ impl Decoder for ClientCodec { } impl Encoder for ClientCodec { - type Item = Message; + type Item = Message<(RequestHead, BodyType)>; type Error = io::Error; fn encode( @@ -194,8 +192,8 @@ impl Encoder for ClientCodec { dst: &mut BytesMut, ) -> Result<(), Self::Error> { match item { - Message::Item(res) => { - self.encode_response(res, dst)?; + Message::Item((msg, btype)) => { + self.encode_response(msg, btype, dst)?; } Message::Chunk(Some(bytes)) => { self.te.encode(bytes.as_ref(), dst)?; diff --git a/src/h1/codec.rs b/src/h1/codec.rs index dd1e27e08..44a1b81fc 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -23,7 +23,7 @@ bitflags! { const UPGRADE = 0b0000_0010; const KEEPALIVE = 0b0000_0100; const KEEPALIVE_ENABLED = 0b0000_1000; - const UNHANDLED = 0b0001_0000; + const STREAM = 0b0001_0000; } } @@ -93,8 +93,8 @@ impl Codec { /// Check last request's message type pub fn message_type(&self) -> MessageType { - if self.flags.contains(Flags::UNHANDLED) { - MessageType::Unhandled + if self.flags.contains(Flags::STREAM) { + MessageType::Stream } else if self.payload.is_none() { MessageType::None } else { @@ -259,8 +259,6 @@ impl Decoder for Codec { } None => None, }) - } else if self.flags.contains(Flags::UNHANDLED) { - Ok(None) } else if let Some((req, payload)) = self.decoder.decode(src)? { self.flags .set(Flags::HEAD, req.inner.method == Method::HEAD); @@ -271,9 +269,9 @@ impl Decoder for Codec { match payload { PayloadType::None => self.payload = None, PayloadType::Payload(pl) => self.payload = Some(pl), - PayloadType::Unhandled => { - self.payload = None; - self.flags.insert(Flags::UNHANDLED); + PayloadType::Stream(pl) => { + self.payload = Some(pl); + self.flags.insert(Flags::STREAM); } } Ok(Some(Message::Item(req))) diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index 472e29936..f2a3ee3f7 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -25,7 +25,7 @@ pub struct ResponseDecoder(&'static MessagePool); pub enum PayloadType { None, Payload(PayloadDecoder), - Unhandled, + Stream(PayloadDecoder), } impl RequestDecoder { @@ -174,7 +174,7 @@ impl Decoder for RequestDecoder { PayloadType::Payload(PayloadDecoder::length(len)) } else if has_upgrade || msg.inner.method == Method::CONNECT { // upgrade(websocket) or connect - PayloadType::Unhandled + PayloadType::Stream(PayloadDecoder::eof()) } else if src.len() >= MAX_BUFFER_SIZE { error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(ParseError::TooLarge); @@ -321,7 +321,7 @@ impl Decoder for ResponseDecoder { || msg.inner.method == Method::CONNECT { // switching protocol or connect - PayloadType::Unhandled + PayloadType::Stream(PayloadDecoder::eof()) } else if src.len() >= MAX_BUFFER_SIZE { error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(ParseError::TooLarge); @@ -667,7 +667,7 @@ mod tests { fn is_unhandled(&self) -> bool { match self { - PayloadType::Unhandled => true, + PayloadType::Stream(_) => true, _ => false, } } diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index f4dfdb219..b23af9648 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -344,7 +344,7 @@ where *req.inner.payload.borrow_mut() = Some(pl); self.payload = Some(ps); } - MessageType::Unhandled => { + MessageType::Stream => { self.unhandled = Some(req); return Ok(updated); } diff --git a/src/h1/encoder.rs b/src/h1/encoder.rs index fc0a3a1b2..caad113d9 100644 --- a/src/h1/encoder.rs +++ b/src/h1/encoder.rs @@ -9,7 +9,7 @@ use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH}; use http::{StatusCode, Version}; use body::{Binary, Body}; -use client::ClientRequest; +use client::RequestHead; use header::ContentEncoding; use http::Method; use request::Request; @@ -196,7 +196,7 @@ impl RequestEncoder { self.te.encode_eof(buf) } - pub fn update(&mut self, resp: &mut ClientRequest, head: bool, version: Version) { + pub fn update(&mut self, resp: &mut RequestHead, head: bool, version: Version) { self.head = head; } } diff --git a/src/h1/mod.rs b/src/h1/mod.rs index e73771778..e7e0759b9 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -33,6 +33,15 @@ pub enum Message { Chunk(Option), } +impl Message { + pub fn into_item(self) -> Option { + match self { + Message::Item(item) => Some(item), + _ => None, + } + } +} + impl From for Message { fn from(item: T) -> Self { Message::Item(item) @@ -44,7 +53,7 @@ impl From for Message { pub enum MessageType { None, Payload, - Unhandled, + Stream, } #[cfg(test)] diff --git a/src/request.rs b/src/request.rs index 07632bf05..593080fa6 100644 --- a/src/request.rs +++ b/src/request.rs @@ -240,7 +240,10 @@ impl MessagePool { if let Some(r) = Rc::get_mut(&mut msg) { r.reset(); } - return ClientResponse { inner: msg }; + return ClientResponse { + inner: msg, + payload: None, + }; } ClientResponse::with_pool(pool) } diff --git a/src/ws/client/service.rs b/src/ws/client/service.rs index 485ce5620..80cf98684 100644 --- a/src/ws/client/service.rs +++ b/src/ws/client/service.rs @@ -13,6 +13,7 @@ use rand; use sha1::Sha1; use tokio_io::{AsyncRead, AsyncWrite}; +use body::BodyType; use client::ClientResponse; use h1; use ws::Codec; @@ -89,9 +90,7 @@ where req.request.set_header(header::ORIGIN, origin); } - req.request.upgrade(); - req.request.set_header(header::UPGRADE, "websocket"); - req.request.set_header(header::CONNECTION, "upgrade"); + req.request.upgrade("websocket"); req.request.set_header(header::SEC_WEBSOCKET_VERSION, "13"); if let Some(protocols) = req.protocols.take() { @@ -142,7 +141,7 @@ where // h1 protocol let framed = Framed::new(io, h1::ClientCodec::default()); framed - .send(request.into()) + .send((request.into_parts().0, BodyType::None).into()) .map_err(ClientError::from) .and_then(|framed| { framed diff --git a/tests/test_client.rs b/tests/test_client.rs new file mode 100644 index 000000000..6741a2c61 --- /dev/null +++ b/tests/test_client.rs @@ -0,0 +1,147 @@ +extern crate actix; +extern crate actix_http; +extern crate actix_net; +extern crate bytes; +extern crate futures; + +use std::{thread, time}; + +use actix::System; +use actix_net::server::Server; +use actix_net::service::NewServiceExt; +use futures::future::{self, lazy, ok}; + +use actix_http::{client, h1, test, Request, Response}; + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[test] +fn test_h1_v2() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::build() + .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .map(|_| ()) + }).unwrap() + .run(); + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let mut connector = sys + .block_on(lazy(|| Ok::<_, ()>(client::Connector::default().service()))) + .unwrap(); + + let req = client::ClientRequest::get(format!("http://{}/", addr)) + .finish() + .unwrap(); + + let response = sys.block_on(req.send(&mut connector)).unwrap(); + assert!(response.status().is_success()); + + let request = client::ClientRequest::get(format!("http://{}/", addr)) + .header("x-test", "111") + .finish() + .unwrap(); + let repr = format!("{:?}", request); + assert!(repr.contains("ClientRequest")); + assert!(repr.contains("x-test")); + + let response = sys.block_on(request.send(&mut connector)).unwrap(); + assert!(response.status().is_success()); + + // read response + // let bytes = srv.execute(response.body()).unwrap(); + // assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + let request = client::ClientRequest::post(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(request.send(&mut connector)).unwrap(); + assert!(response.status().is_success()); + + // read response + // let bytes = srv.execute(response.body()).unwrap(); + // assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_connection_close() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map(|_| ()) + }).unwrap() + .run(); + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let mut connector = sys + .block_on(lazy(|| Ok::<_, ()>(client::Connector::default().service()))) + .unwrap(); + + let request = client::ClientRequest::get(format!("http://{}/", addr)) + .header("Connection", "close") + .finish() + .unwrap(); + let response = sys.block_on(request.send(&mut connector)).unwrap(); + assert!(response.status().is_success()); +} + +#[test] +fn test_with_query_parameter() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::build() + .finish(|req: Request| { + if req.uri().query().unwrap().contains("qp=") { + ok::<_, ()>(Response::Ok().finish()) + } else { + ok::<_, ()>(Response::BadRequest().finish()) + } + }).map(|_| ()) + }).unwrap() + .run(); + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let mut connector = sys + .block_on(lazy(|| Ok::<_, ()>(client::Connector::default().service()))) + .unwrap(); + + let request = client::ClientRequest::get(format!("http://{}/?qp=5", addr)) + .finish() + .unwrap(); + + let response = sys.block_on(request.send(&mut connector)).unwrap(); + assert!(response.status().is_success()); +}