From 7fed50bcaefde84c6e9424112e220fe789f99a57 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 17 Nov 2018 20:21:28 -0800 Subject: [PATCH] refactor response body management --- src/body.rs | 118 +++------------- src/client/pipeline.rs | 4 + src/client/request.rs | 6 - src/error.rs | 9 +- src/h1/client.rs | 8 +- src/h1/codec.rs | 26 ++-- src/h1/dispatcher.rs | 80 ++++++----- src/h1/encoder.rs | 122 +++------------- src/h1/service.rs | 54 ++++--- src/lib.rs | 2 +- src/response.rs | 309 ++++++++++++++++------------------------- src/service.rs | 144 +++++++++---------- 12 files changed, 334 insertions(+), 548 deletions(-) diff --git a/src/body.rs b/src/body.rs index 6e6239c3..3b4e0113 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,5 +1,5 @@ +use std::mem; use std::sync::Arc; -use std::{fmt, mem}; use bytes::{Bytes, BytesMut}; use futures::{Async, Poll, Stream}; @@ -19,7 +19,8 @@ pub enum BodyLength { Zero, Sized(usize), Sized64(u64), - Unsized, + Chunked, + Stream, } /// Type that provides this trait can be streamed to a peer. @@ -39,17 +40,6 @@ impl MessageBody for () { } } -/// Represents various types of http message body. -pub enum Body { - /// Empty response. `Content-Length` header is set to `0` - Empty, - /// Specific response body. - Binary(Binary), - /// Unspecified streaming response. Developer is responsible for setting - /// right `Content-Length` or `Transfer-Encoding` headers. - Streaming(BodyStream), -} - /// Represents various types of binary body. /// `Content-Length` header is set to length of the body. #[derive(Debug, PartialEq)] @@ -65,84 +55,6 @@ pub enum Binary { SharedVec(Arc>), } -impl Body { - /// Does this body streaming. - #[inline] - pub fn is_streaming(&self) -> bool { - match *self { - Body::Streaming(_) => true, - _ => false, - } - } - - /// Is this binary body. - #[inline] - pub fn is_binary(&self) -> bool { - match *self { - Body::Binary(_) => true, - _ => false, - } - } - - /// Is this binary empy. - #[inline] - pub fn is_empty(&self) -> bool { - match *self { - Body::Empty => true, - _ => false, - } - } - - /// Create body from slice (copy) - pub fn from_slice(s: &[u8]) -> Body { - Body::Binary(Binary::Bytes(Bytes::from(s))) - } - - /// Is this binary body. - #[inline] - pub(crate) fn into_binary(self) -> Option { - match self { - Body::Binary(b) => Some(b), - _ => None, - } - } -} - -impl PartialEq for Body { - fn eq(&self, other: &Body) -> bool { - match *self { - Body::Empty => match *other { - Body::Empty => true, - _ => false, - }, - Body::Binary(ref b) => match *other { - Body::Binary(ref b2) => b == b2, - _ => false, - }, - Body::Streaming(_) => false, - } - } -} - -impl fmt::Debug for Body { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Body::Empty => write!(f, "Body::Empty"), - Body::Binary(ref b) => write!(f, "Body::Binary({:?})", b), - Body::Streaming(_) => write!(f, "Body::Streaming(_)"), - } - } -} - -impl From for Body -where - T: Into, -{ - fn from(b: T) -> Body { - Body::Binary(b.into()) - } -} - impl Binary { #[inline] /// Returns `true` if body is empty @@ -286,6 +198,22 @@ impl MessageBody for Bytes { } } +impl MessageBody for BytesMut { + fn length(&self) -> BodyLength { + BodyLength::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some( + mem::replace(self, BytesMut::new()).freeze(), + ))) + } + } +} + impl MessageBody for &'static str { fn length(&self) -> BodyLength { BodyLength::Sized(self.len()) @@ -370,7 +298,7 @@ where S: Stream, { fn length(&self) -> BodyLength { - BodyLength::Unsized + BodyLength::Chunked } fn poll_next(&mut self) -> Poll, Error> { @@ -382,12 +310,6 @@ where mod tests { use super::*; - #[test] - fn test_body_is_streaming() { - assert_eq!(Body::Empty.is_streaming(), false); - assert_eq!(Body::Binary(Binary::from("")).is_streaming(), false); - } - #[test] fn test_is_empty() { assert_eq!(Binary::from("").is_empty(), true); diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index 93b349e9..56c22bd2 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -100,6 +100,10 @@ where { match self.body.as_mut().unwrap().poll_next()? { Async::Ready(item) => { + // check if body is done + if item.is_none() { + let _ = self.body.take(); + } self.flushed = false; self.framed .as_mut() diff --git a/src/client/request.rs b/src/client/request.rs index d3d1544c..f0b76ed0 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -51,12 +51,6 @@ pub struct ClientRequest { body: B, } -impl RequestHead { - pub fn clear(&mut self) { - self.headers.clear() - } -} - impl ClientRequest<()> { /// Create client request builder pub fn build() -> ClientRequestBuilder { diff --git a/src/error.rs b/src/error.rs index 956ec4eb..1f70396c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -100,11 +100,10 @@ impl Error { } /// Converts error to a response instance and set error message as response body - pub fn response_with_message(self) -> Response { + pub fn response_with_message(self) -> Response { let message = format!("{}", self); - let mut resp: Response = self.into(); - resp.set_body(message); - resp + let resp: Response = self.into(); + resp.set_body(message) } } @@ -637,7 +636,7 @@ where InternalErrorType::Status(st) => Response::new(st), InternalErrorType::Response(ref resp) => { if let Some(resp) = resp.lock().unwrap().take() { - Response::from_parts(resp) + Response::<()>::from_parts(resp) } else { Response::new(StatusCode::INTERNAL_SERVER_ERROR) } diff --git a/src/h1/client.rs b/src/h1/client.rs index 8b367acc..ace50466 100644 --- a/src/h1/client.rs +++ b/src/h1/client.rs @@ -7,7 +7,7 @@ use tokio_codec::{Decoder, Encoder}; use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType}; use super::encoder::RequestEncoder; use super::{Message, MessageType}; -use body::{Binary, Body, BodyLength}; +use body::{Binary, BodyLength}; use client::ClientResponse; use config::ServiceConfig; use error::{ParseError, PayloadError}; @@ -164,14 +164,16 @@ impl ClientCodecInner { write!(buffer.writer(), "{}", len)?; buffer.extend_from_slice(b"\r\n"); } - BodyLength::Unsized => { + BodyLength::Chunked => { buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") } BodyLength::Zero => { len_is_set = false; buffer.extend_from_slice(b"\r\n") } - BodyLength::None => buffer.extend_from_slice(b"\r\n"), + BodyLength::None | BodyLength::Stream => { + buffer.extend_from_slice(b"\r\n") + } } let mut has_date = false; diff --git a/src/h1/codec.rs b/src/h1/codec.rs index bb8afa71..6bc20b18 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -8,12 +8,13 @@ use tokio_codec::{Decoder, Encoder}; use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType}; use super::encoder::ResponseEncoder; use super::{Message, MessageType}; -use body::{Binary, Body, BodyLength}; +use body::{Binary, BodyLength}; use config::ServiceConfig; use error::ParseError; use helpers; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::{Method, Version}; +use message::ResponseHead; use request::Request; use response::Response; @@ -98,9 +99,9 @@ impl Codec { } /// prepare transfer encoding - pub fn prepare_te(&mut self, res: &mut Response) { + pub fn prepare_te(&mut self, head: &mut ResponseHead, length: &mut BodyLength) { self.te - .update(res, self.flags.contains(Flags::HEAD), self.version); + .update(head, self.flags.contains(Flags::HEAD), self.version, length); } fn encode_response( @@ -135,17 +136,8 @@ impl Codec { // render message { let reason = msg.reason().as_bytes(); - if let Body::Binary(ref bytes) = msg.body() { - buffer.reserve( - 256 + msg.headers().len() * AVERAGE_HEADER_SIZE - + bytes.len() - + reason.len(), - ); - } else { - buffer.reserve( - 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len(), - ); - } + buffer + .reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len()); // status line helpers::write_status_line(self.version, msg.status().as_u16(), buffer); @@ -154,7 +146,7 @@ impl Codec { // content length let mut len_is_set = true; match self.te.length { - BodyLength::Unsized => { + BodyLength::Chunked => { buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") } BodyLength::Zero => { @@ -167,7 +159,9 @@ impl Codec { write!(buffer.writer(), "{}", len)?; buffer.extend_from_slice(b"\r\n"); } - BodyLength::None => buffer.extend_from_slice(b"\r\n"), + BodyLength::None | BodyLength::Stream => { + buffer.extend_from_slice(b"\r\n") + } } // write headers diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 4a0bce72..508962b4 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -13,7 +13,7 @@ use tokio_timer::Delay; use error::{ParseError, PayloadError}; use payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; -use body::{Body, BodyStream}; +use body::{BodyLength, MessageBody}; use config::ServiceConfig; use error::DispatchError; use request::Request; @@ -37,14 +37,14 @@ bitflags! { } /// Dispatcher for HTTP/1.1 protocol -pub struct Dispatcher +pub struct Dispatcher where S::Error: Debug, { - inner: Option>, + inner: Option>, } -struct InnerDispatcher +struct InnerDispatcher where S::Error: Debug, { @@ -54,7 +54,7 @@ where error: Option>, config: ServiceConfig, - state: State, + state: State, payload: Option, messages: VecDeque, unhandled: Option, @@ -68,13 +68,13 @@ enum DispatcherMessage { Error(Response), } -enum State { +enum State { None, ServiceCall(S::Future), - SendPayload(BodyStream), + SendPayload(B), } -impl State { +impl State { fn is_empty(&self) -> bool { if let State::None = self { true @@ -84,11 +84,12 @@ impl State { } } -impl Dispatcher +impl Dispatcher where T: AsyncRead + AsyncWrite, - S: Service, + S: Service>, S::Error: Debug, + B: MessageBody, { /// Create http/1 dispatcher. pub fn new(stream: T, config: ServiceConfig, service: S) -> Self { @@ -137,11 +138,12 @@ where } } -impl InnerDispatcher +impl InnerDispatcher where T: AsyncRead + AsyncWrite, - S: Service, + S: Service>, S::Error: Debug, + B: MessageBody, { fn can_read(&self) -> bool { if self.flags.contains(Flags::DISCONNECTED) { @@ -186,11 +188,11 @@ where } } - fn send_response( + fn send_response( &mut self, message: Response, - body: Body, - ) -> Result, DispatchError> { + body: B1, + ) -> Result, DispatchError> { self.framed .force_send(Message::Item(message)) .map_err(|err| { @@ -203,15 +205,9 @@ where self.flags .set(Flags::KEEPALIVE, self.framed.get_codec().keepalive()); self.flags.remove(Flags::FLUSHED); - match body { - Body::Empty => Ok(State::None), - Body::Streaming(stream) => Ok(State::SendPayload(stream)), - Body::Binary(mut bin) => { - self.flags.remove(Flags::FLUSHED); - self.framed.force_send(Message::Chunk(Some(bin.take())))?; - self.framed.force_send(Message::Chunk(None))?; - Ok(State::None) - } + match body.length() { + BodyLength::None | BodyLength::Zero => Ok(State::None), + _ => Ok(State::SendPayload(body)), } } @@ -224,15 +220,18 @@ where Some(self.handle_request(req)?) } Some(DispatcherMessage::Error(res)) => { - Some(self.send_response(res, Body::Empty)?) + self.send_response(res, ())?; + None } None => None, }, State::ServiceCall(mut fut) => { match fut.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { - self.framed.get_codec_mut().prepare_te(&mut res); - let body = res.replace_body(Body::Empty); + let (mut res, body) = res.replace_body(()); + self.framed + .get_codec_mut() + .prepare_te(res.head_mut(), &mut body.length()); Some(self.send_response(res, body)?) } Async::NotReady => { @@ -244,7 +243,10 @@ where State::SendPayload(mut stream) => { loop { if !self.framed.is_write_buf_full() { - match stream.poll().map_err(|_| DispatchError::Unknown)? { + match stream + .poll_next() + .map_err(|_| DispatchError::Unknown)? + { Async::Ready(Some(item)) => { self.flags.remove(Flags::FLUSHED); self.framed @@ -290,12 +292,14 @@ where fn handle_request( &mut self, req: Request, - ) -> Result, DispatchError> { + ) -> Result, DispatchError> { let mut task = self.service.call(req); match task.poll().map_err(DispatchError::Service)? { - Async::Ready(mut res) => { - self.framed.get_codec_mut().prepare_te(&mut res); - let body = res.replace_body(Body::Empty); + Async::Ready(res) => { + let (mut res, body) = res.replace_body(()); + self.framed + .get_codec_mut() + .prepare_te(res.head_mut(), &mut body.length()); self.send_response(res, body) } Async::NotReady => Ok(State::ServiceCall(task)), @@ -436,10 +440,9 @@ where // timeout on first request (slow request) return 408 trace!("Slow request timeout"); self.flags.insert(Flags::STARTED | Flags::DISCONNECTED); - self.state = self.send_response( - Response::RequestTimeout().finish(), - Body::Empty, - )?; + let _ = self + .send_response(Response::RequestTimeout().finish(), ()); + self.state = State::None; } } else if let Some(deadline) = self.config.keep_alive_expire() { self.ka_timer.as_mut().map(|timer| { @@ -462,11 +465,12 @@ where } } -impl Future for Dispatcher +impl Future for Dispatcher where T: AsyncRead + AsyncWrite, - S: Service, + S: Service>, S::Error: Debug, + B: MessageBody, { type Item = H1ServiceResult; type Error = DispatchError; diff --git a/src/h1/encoder.rs b/src/h1/encoder.rs index 421d0b96..aee17c1f 100644 --- a/src/h1/encoder.rs +++ b/src/h1/encoder.rs @@ -8,10 +8,10 @@ use bytes::{Bytes, BytesMut}; use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH}; use http::{StatusCode, Version}; -use body::{Binary, Body, BodyLength}; +use body::{Binary, BodyLength}; use header::ContentEncoding; use http::Method; -use message::RequestHead; +use message::{RequestHead, ResponseHead}; use request::Request; use response::Response; @@ -43,116 +43,36 @@ impl ResponseEncoder { self.te.encode_eof(buf) } - pub fn update(&mut self, resp: &mut Response, head: bool, version: Version) { + pub fn update( + &mut self, + resp: &mut ResponseHead, + head: bool, + version: Version, + length: &mut BodyLength, + ) { self.head = head; - let mut len = 0; - - let has_body = match resp.body() { - Body::Empty => false, - Body::Binary(ref bin) => { - len = bin.len(); - true - } - _ => true, - }; - - let has_body = match resp.body() { - Body::Empty => false, - _ => true, - }; - - let transfer = match resp.body() { - Body::Empty => { - self.length = match resp.status() { + let transfer = match length { + BodyLength::Zero => { + match resp.status { StatusCode::NO_CONTENT | StatusCode::CONTINUE | StatusCode::SWITCHING_PROTOCOLS - | StatusCode::PROCESSING => BodyLength::None, - _ => BodyLength::Zero, - }; + | StatusCode::PROCESSING => *length = BodyLength::None, + _ => (), + } TransferEncoding::empty() } - Body::Binary(_) => { - self.length = BodyLength::Sized(len); - TransferEncoding::length(len as u64) - } - Body::Streaming(_) => { - if resp.upgrade() { - self.length = BodyLength::None; - TransferEncoding::eof() - } else { - self.streaming_encoding(version, resp) - } - } + BodyLength::Sized(len) => TransferEncoding::length(*len as u64), + BodyLength::Sized64(len) => TransferEncoding::length(*len), + BodyLength::Chunked => TransferEncoding::chunked(), + BodyLength::Stream => TransferEncoding::eof(), + BodyLength::None => TransferEncoding::length(0), }; // check for head response - if self.head { - resp.set_body(Body::Empty); - } else { + if !self.head { self.te = transfer; } } - - fn streaming_encoding( - &mut self, - version: Version, - resp: &mut Response, - ) -> TransferEncoding { - match resp.chunked() { - Some(true) => { - // Enable transfer encoding - if version == Version::HTTP_2 { - self.length = BodyLength::None; - TransferEncoding::eof() - } else { - self.length = BodyLength::Unsized; - TransferEncoding::chunked() - } - } - Some(false) => TransferEncoding::eof(), - None => { - // if Content-Length is specified, then use it as length hint - let (len, chunked) = - if let Some(len) = resp.headers().get(CONTENT_LENGTH) { - // Content-Length - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - (Some(len), false) - } else { - error!("illegal Content-Length: {:?}", len); - (None, false) - } - } else { - error!("illegal Content-Length: {:?}", len); - (None, false) - } - } else { - (None, true) - }; - - if !chunked { - if let Some(len) = len { - self.length = BodyLength::Sized64(len); - TransferEncoding::length(len) - } else { - TransferEncoding::eof() - } - } else { - // Enable transfer encoding - match version { - Version::HTTP_11 => { - self.length = BodyLength::Unsized; - TransferEncoding::chunked() - } - _ => { - self.length = BodyLength::None; - TransferEncoding::eof() - } - } - } - } - } - } } #[derive(Debug)] diff --git a/src/h1/service.rs b/src/h1/service.rs index 7e5e8c5f..0f0452ee 100644 --- a/src/h1/service.rs +++ b/src/h1/service.rs @@ -8,6 +8,7 @@ use futures::future::{ok, FutureResult}; use futures::{Async, Future, Poll, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; +use body::MessageBody; use config::{KeepAlive, ServiceConfig}; use error::{DispatchError, ParseError}; use request::Request; @@ -18,17 +19,18 @@ use super::dispatcher::Dispatcher; use super::{H1ServiceResult, Message}; /// `NewService` implementation for HTTP1 transport -pub struct H1Service { +pub struct H1Service { srv: S, cfg: ServiceConfig, - _t: PhantomData, + _t: PhantomData<(T, B)>, } -impl H1Service +impl H1Service where - S: NewService + Clone, + S: NewService> + Clone, S::Service: Clone, S::Error: Debug, + B: MessageBody, { /// Create new `HttpService` instance. pub fn new>(service: F) -> Self { @@ -47,19 +49,20 @@ where } } -impl NewService for H1Service +impl NewService for H1Service where T: AsyncRead + AsyncWrite, - S: NewService + Clone, + S: NewService> + Clone, S::Service: Clone, S::Error: Debug, + B: MessageBody, { type Request = T; type Response = H1ServiceResult; type Error = DispatchError; type InitError = S::InitError; - type Service = H1ServiceHandler; - type Future = H1ServiceResponse; + type Service = H1ServiceHandler; + type Future = H1ServiceResponse; fn new_service(&self) -> Self::Future { H1ServiceResponse { @@ -180,7 +183,11 @@ where } /// Finish service configuration and create `H1Service` instance. - pub fn finish>(self, service: F) -> H1Service { + pub fn finish(self, service: F) -> H1Service + where + B: MessageBody, + F: IntoNewService, + { let cfg = ServiceConfig::new( self.keep_alive, self.client_timeout, @@ -195,20 +202,21 @@ where } #[doc(hidden)] -pub struct H1ServiceResponse { +pub struct H1ServiceResponse { fut: S::Future, cfg: Option, - _t: PhantomData, + _t: PhantomData<(T, B)>, } -impl Future for H1ServiceResponse +impl Future for H1ServiceResponse where T: AsyncRead + AsyncWrite, - S: NewService, + S: NewService>, S::Service: Clone, S::Error: Debug, + B: MessageBody, { - type Item = H1ServiceHandler; + type Item = H1ServiceHandler; type Error = S::InitError; fn poll(&mut self) -> Poll { @@ -221,18 +229,19 @@ where } /// `Service` implementation for HTTP1 transport -pub struct H1ServiceHandler { +pub struct H1ServiceHandler { srv: S, cfg: ServiceConfig, - _t: PhantomData, + _t: PhantomData<(T, B)>, } -impl H1ServiceHandler +impl H1ServiceHandler where - S: Service + Clone, + S: Service> + Clone, S::Error: Debug, + B: MessageBody, { - fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler { + fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler { H1ServiceHandler { srv, cfg, @@ -241,16 +250,17 @@ where } } -impl Service for H1ServiceHandler +impl Service for H1ServiceHandler where T: AsyncRead + AsyncWrite, - S: Service + Clone, + S: Service> + Clone, S::Error: Debug, + B: MessageBody, { type Request = T; type Response = H1ServiceResult; type Error = DispatchError; - type Future = Dispatcher; + type Future = Dispatcher; fn poll_ready(&mut self) -> Poll<(), Self::Error> { self.srv.poll_ready().map_err(DispatchError::Service) diff --git a/src/lib.rs b/src/lib.rs index f64876e3..4b43cae4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,7 +129,7 @@ pub mod h1; pub(crate) mod helpers; pub mod test; pub mod ws; -pub use body::{Binary, Body}; +pub use body::{Binary, MessageBody}; pub use error::{Error, ResponseError, Result}; pub use extensions::Extensions; pub use httpmessage::HttpMessage; diff --git a/src/response.rs b/src/response.rs index f96facba..b3e59982 100644 --- a/src/response.rs +++ b/src/response.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::collections::VecDeque; use std::io::Write; -use std::{fmt, mem, str}; +use std::{fmt, str}; use bytes::{BufMut, Bytes, BytesMut}; use cookie::{Cookie, CookieJar}; @@ -12,7 +12,7 @@ use http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode, Version}; use serde::Serialize; use serde_json; -use body::Body; +use body::{MessageBody, MessageBodyStream}; use error::Error; use header::{ContentEncoding, Header, IntoHeaderValue}; use message::{Head, MessageFlags, ResponseHead}; @@ -32,19 +32,9 @@ pub enum ConnectionType { } /// An HTTP Response -pub struct Response(Box); - -impl Response { - #[inline] - fn get_ref(&self) -> &InnerResponse { - self.0.as_ref() - } - - #[inline] - fn get_mut(&mut self) -> &mut InnerResponse { - self.0.as_mut() - } +pub struct Response(Box, B); +impl Response<()> { /// Create http response builder with specific status. #[inline] pub fn build(status: StatusCode) -> ResponseBuilder { @@ -60,13 +50,7 @@ impl Response { /// Constructs a response #[inline] pub fn new(status: StatusCode) -> Response { - ResponsePool::with_body(status, Body::Empty) - } - - /// Constructs a response with body - #[inline] - pub fn with_body>(status: StatusCode, body: B) -> Response { - ResponsePool::with_body(status, body.into()) + ResponsePool::with_body(status, ()) } /// Constructs an error response @@ -98,6 +82,29 @@ impl Response { cookies: jar, } } +} + +impl Response { + #[inline] + fn get_ref(&self) -> &InnerResponse { + self.0.as_ref() + } + + #[inline] + fn get_mut(&mut self) -> &mut InnerResponse { + self.0.as_mut() + } + + #[inline] + pub(crate) fn head_mut(&mut self) -> &mut ResponseHead { + &mut self.0.as_mut().head + } + + /// Constructs a response with body + #[inline] + pub fn with_body(status: StatusCode, body: B) -> Response { + ResponsePool::with_body(status, body.into()) + } /// The source `error` for this response #[inline] @@ -105,6 +112,39 @@ impl Response { self.get_ref().error.as_ref() } + /// Get the response status code + #[inline] + pub fn status(&self) -> StatusCode { + self.get_ref().head.status + } + + /// Set the `StatusCode` for this response + #[inline] + pub fn status_mut(&mut self) -> &mut StatusCode { + &mut self.get_mut().head.status + } + + /// Get custom reason for the response + #[inline] + pub fn reason(&self) -> &str { + if let Some(reason) = self.get_ref().head.reason { + reason + } else { + self.get_ref() + .head + .status + .canonical_reason() + .unwrap_or("") + } + } + + /// Set the custom reason for the response + #[inline] + pub fn set_reason(&mut self, reason: &'static str) -> &mut Self { + self.get_mut().head.reason = Some(reason); + self + } + /// Get the headers from the response #[inline] pub fn headers(&self) -> &HeaderMap { @@ -167,39 +207,6 @@ impl Response { count } - /// Get the response status code - #[inline] - pub fn status(&self) -> StatusCode { - self.get_ref().head.status - } - - /// Set the `StatusCode` for this response - #[inline] - pub fn status_mut(&mut self) -> &mut StatusCode { - &mut self.get_mut().head.status - } - - /// Get custom reason for the response - #[inline] - pub fn reason(&self) -> &str { - if let Some(reason) = self.get_ref().head.reason { - reason - } else { - self.get_ref() - .head - .status - .canonical_reason() - .unwrap_or("") - } - } - - /// Set the custom reason for the response - #[inline] - pub fn set_reason(&mut self, reason: &'static str) -> &mut Self { - self.get_mut().head.reason = Some(reason); - self - } - /// Set connection type pub fn set_connection_type(&mut self, conn: ConnectionType) -> &mut Self { self.get_mut().connection_type = Some(conn); @@ -224,38 +231,20 @@ impl Response { } } - /// is chunked encoding enabled - #[inline] - pub fn chunked(&self) -> Option { - self.get_ref().chunked - } - - /// Content encoding - #[inline] - pub fn content_encoding(&self) -> Option { - self.get_ref().encoding - } - - /// Set content encoding - pub fn set_content_encoding(&mut self, enc: ContentEncoding) -> &mut Self { - self.get_mut().encoding = Some(enc); - self - } - /// Get body os this response #[inline] - pub fn body(&self) -> &Body { - &self.get_ref().body + pub fn body(&self) -> &B { + &self.1 } /// Set a body - pub fn set_body>(&mut self, body: B) { - self.get_mut().body = body.into(); + pub fn set_body(self, body: B2) -> Response { + Response(self.0, body) } /// Set a body and return previous body value - pub fn replace_body>(&mut self, body: B) -> Body { - mem::replace(&mut self.get_mut().body, body.into()) + pub fn replace_body(self, body: B2) -> (Response, B) { + (Response(self.0, body), self.1) } /// Size of response in bytes, excluding HTTP headers @@ -268,16 +257,6 @@ impl Response { self.get_mut().response_size = size; } - /// Set write buffer capacity - pub fn write_buffer_capacity(&self) -> usize { - self.get_ref().write_capacity - } - - /// Set write buffer capacity - pub fn set_write_buffer_capacity(&mut self, cap: usize) { - self.get_mut().write_capacity = cap; - } - pub(crate) fn release(self) { ResponsePool::release(self.0); } @@ -287,7 +266,7 @@ impl Response { } pub(crate) fn from_parts(parts: ResponseParts) -> Response { - Response(Box::new(InnerResponse::from_parts(parts))) + Response(Box::new(InnerResponse::from_parts(parts)), ()) } } @@ -454,24 +433,6 @@ impl ResponseBuilder { self.connection_type(ConnectionType::Close) } - /// Enables automatic chunked transfer encoding - #[inline] - pub fn chunked(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.chunked = Some(true); - } - self - } - - /// Force disable chunked encoding - #[inline] - pub fn no_chunking(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.chunked = Some(false); - } - self - } - /// Set response content type #[inline] pub fn content_type(&mut self, value: V) -> &mut Self @@ -580,63 +541,73 @@ impl ResponseBuilder { self } - /// Set write buffer capacity - /// - /// This parameter makes sense only for streaming response - /// or actor. If write buffer reaches specified capacity, stream or actor - /// get paused. - /// - /// Default write buffer capacity is 64kb - pub fn write_buffer_capacity(&mut self, cap: usize) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.write_capacity = cap; - } - self - } + // /// Set write buffer capacity + // /// + // /// This parameter makes sense only for streaming response + // /// or actor. If write buffer reaches specified capacity, stream or actor + // /// get paused. + // /// + // /// Default write buffer capacity is 64kb + // pub fn write_buffer_capacity(&mut self, cap: usize) -> &mut Self { + // if let Some(parts) = parts(&mut self.response, &self.err) { + // parts.write_capacity = cap; + // } + // self + // } /// Set a body and generate `Response`. /// /// `ResponseBuilder` can not be used after this call. - pub fn body>(&mut self, body: B) -> Response { - if let Some(e) = self.err.take() { - return Error::from(e).into(); - } + pub fn body(&mut self, body: B) -> Response { + let mut error = if let Some(e) = self.err.take() { + Some(Error::from(e)) + } else { + None + }; + let mut response = self.response.take().expect("cannot reuse response builder"); if let Some(ref jar) = self.cookies { for cookie in jar.delta() { match HeaderValue::from_str(&cookie.to_string()) { - Ok(val) => response.head.headers.append(header::SET_COOKIE, val), - Err(e) => return Error::from(e).into(), + Ok(val) => { + let _ = response.head.headers.append(header::SET_COOKIE, val); + } + Err(e) => if error.is_none() { + error = Some(Error::from(e)); + }, }; } } - response.body = body.into(); - Response(response) + if let Some(error) = error { + response.error = Some(error); + } + + Response(response, body) } #[inline] /// Set a streaming body and generate `Response`. /// /// `ResponseBuilder` can not be used after this call. - pub fn streaming(&mut self, stream: S) -> Response + pub fn streaming(&mut self, stream: S) -> Response where S: Stream + 'static, E: Into, { - self.body(Body::Streaming(Box::new(stream.map_err(|e| e.into())))) + self.body(MessageBodyStream::new(stream.map_err(|e| e.into()))) } /// Set a json body and generate `Response` /// /// `ResponseBuilder` can not be used after this call. - pub fn json(&mut self, value: T) -> Response { + pub fn json(&mut self, value: T) -> Response { self.json2(&value) } /// Set a json body and generate `Response` /// /// `ResponseBuilder` can not be used after this call. - pub fn json2(&mut self, value: &T) -> Response { + pub fn json2(&mut self, value: &T) -> Response { match serde_json::to_string(value) { Ok(body) => { let contains = if let Some(parts) = parts(&mut self.response, &self.err) @@ -651,7 +622,10 @@ impl ResponseBuilder { self.body(body) } - Err(e) => Error::from(e).into(), + Err(e) => { + let mut res: Response = Error::from(e).into(); + res.replace_body(String::new()).0 + } } } @@ -659,8 +633,8 @@ impl ResponseBuilder { /// Set an empty body and generate `Response` /// /// `ResponseBuilder` can not be used after this call. - pub fn finish(&mut self) -> Response { - self.body(Body::Empty) + pub fn finish(&mut self) -> Response<()> { + self.body(()) } /// This method construct new `ResponseBuilder` @@ -701,7 +675,7 @@ impl From for Response { } } -impl From<&'static str> for Response { +impl From<&'static str> for Response<&'static str> { fn from(val: &'static str) -> Self { Response::Ok() .content_type("text/plain; charset=utf-8") @@ -709,7 +683,7 @@ impl From<&'static str> for Response { } } -impl From<&'static [u8]> for Response { +impl From<&'static [u8]> for Response<&'static [u8]> { fn from(val: &'static [u8]) -> Self { Response::Ok() .content_type("application/octet-stream") @@ -717,7 +691,7 @@ impl From<&'static [u8]> for Response { } } -impl From for Response { +impl From for Response { fn from(val: String) -> Self { Response::Ok() .content_type("text/plain; charset=utf-8") @@ -725,15 +699,7 @@ impl From for Response { } } -impl<'a> From<&'a String> for Response { - fn from(val: &'a String) -> Self { - Response::build(StatusCode::OK) - .content_type("text/plain; charset=utf-8") - .body(val) - } -} - -impl From for Response { +impl From for Response { fn from(val: Bytes) -> Self { Response::Ok() .content_type("application/octet-stream") @@ -741,7 +707,7 @@ impl From for Response { } } -impl From for Response { +impl From for Response { fn from(val: BytesMut) -> Self { Response::Ok() .content_type("application/octet-stream") @@ -751,8 +717,6 @@ impl From for Response { struct InnerResponse { head: ResponseHead, - body: Body, - chunked: Option, encoding: Option, connection_type: Option, write_capacity: usize, @@ -763,7 +727,6 @@ struct InnerResponse { pub(crate) struct ResponseParts { head: ResponseHead, - body: Option, encoding: Option, connection_type: Option, error: Option, @@ -771,11 +734,7 @@ pub(crate) struct ResponseParts { impl InnerResponse { #[inline] - fn new( - status: StatusCode, - body: Body, - pool: &'static ResponsePool, - ) -> InnerResponse { + fn new(status: StatusCode, pool: &'static ResponsePool) -> InnerResponse { InnerResponse { head: ResponseHead { status, @@ -784,9 +743,7 @@ impl InnerResponse { reason: None, flags: MessageFlags::empty(), }, - body, pool, - chunked: None, encoding: None, connection_type: None, response_size: 0, @@ -796,18 +753,8 @@ impl InnerResponse { } /// This is for failure, we can not have Send + Sync on Streaming and Actor response - fn into_parts(mut self) -> ResponseParts { - let body = match mem::replace(&mut self.body, Body::Empty) { - Body::Empty => None, - Body::Binary(mut bin) => Some(bin.take()), - Body::Streaming(_) => { - error!("Streaming or Actor body is not support by error response"); - None - } - }; - + fn into_parts(self) -> ResponseParts { ResponseParts { - body, head: self.head, encoding: self.encoding, connection_type: self.connection_type, @@ -816,16 +763,8 @@ impl InnerResponse { } fn from_parts(parts: ResponseParts) -> InnerResponse { - let body = if let Some(ref body) = parts.body { - Body::Binary(body.clone().into()) - } else { - Body::Empty - }; - InnerResponse { - body, head: parts.head, - chunked: None, encoding: parts.encoding, connection_type: parts.connection_type, response_size: 0, @@ -864,7 +803,7 @@ impl ResponsePool { cookies: None, } } else { - let msg = Box::new(InnerResponse::new(status, Body::Empty, pool)); + let msg = Box::new(InnerResponse::new(status, pool)); ResponseBuilder { response: Some(msg), err: None, @@ -874,17 +813,16 @@ impl ResponsePool { } #[inline] - pub fn get_response( + pub fn get_response( pool: &'static ResponsePool, status: StatusCode, - body: Body, - ) -> Response { + body: B, + ) -> Response { if let Some(mut msg) = pool.0.borrow_mut().pop_front() { msg.head.status = status; - msg.body = body; - Response(msg) + Response(msg, body) } else { - Response(Box::new(InnerResponse::new(status, body, pool))) + Response(Box::new(InnerResponse::new(status, pool)), body) } } @@ -894,7 +832,7 @@ impl ResponsePool { } #[inline] - fn with_body(status: StatusCode, body: Body) -> Response { + fn with_body(status: StatusCode, body: B) -> Response { POOL.with(|pool| ResponsePool::get_response(pool, status, body)) } @@ -903,7 +841,6 @@ impl ResponsePool { let mut p = inner.pool.0.borrow_mut(); if p.len() < 128 { inner.head.clear(); - inner.chunked = None; inner.encoding = None; inner.connection_type = None; inner.response_size = 0; diff --git a/src/service.rs b/src/service.rs index 3aa5d1e4..ac92a0f7 100644 --- a/src/service.rs +++ b/src/service.rs @@ -6,7 +6,7 @@ use futures::future::{ok, Either, FutureResult}; use futures::{Async, AsyncSink, Future, Poll, Sink}; use tokio_io::{AsyncRead, AsyncWrite}; -use body::Body; +use body::MessageBody; use error::{Error, ResponseError}; use h1::{Codec, Message}; use response::Response; @@ -58,11 +58,11 @@ where match req { Ok(r) => Either::A(ok(r)), Err((e, framed)) => { - let mut resp = e.error_response(); - resp.set_body(format!("{}", e)); + let mut res = e.error_response().set_body(format!("{}", e)); + let (res, _body) = res.replace_body(()); Either::B(SendErrorFut { framed: Some(framed), - res: Some(resp.into()), + res: Some(res.into()), err: Some(e), _t: PhantomData, }) @@ -109,30 +109,30 @@ where } } -pub struct SendResponse(PhantomData<(T,)>); +pub struct SendResponse(PhantomData<(T, B)>); -impl Default for SendResponse -where - T: AsyncRead + AsyncWrite, -{ +impl Default for SendResponse { fn default() -> Self { SendResponse(PhantomData) } } -impl SendResponse +impl SendResponse where T: AsyncRead + AsyncWrite, + B: MessageBody, { pub fn send( mut framed: Framed, - mut res: Response, + res: Response, ) -> impl Future, Error = Error> { - // init codec - framed.get_codec_mut().prepare_te(&mut res); - // extract body from response - let body = res.replace_body(Body::Empty); + let (mut res, body) = res.replace_body(()); + + // init codec + framed + .get_codec_mut() + .prepare_te(&mut res.head_mut(), &mut body.length()); // write response SendResponseFut { @@ -143,15 +143,16 @@ where } } -impl NewService for SendResponse +impl NewService for SendResponse where T: AsyncRead + AsyncWrite, + B: MessageBody, { - type Request = (Response, Framed); + type Request = (Response, Framed); type Response = Framed; type Error = Error; type InitError = (); - type Service = SendResponse; + type Service = SendResponse; type Future = FutureResult; fn new_service(&self) -> Self::Future { @@ -159,22 +160,25 @@ where } } -impl Service for SendResponse +impl Service for SendResponse where T: AsyncRead + AsyncWrite, + B: MessageBody, { - type Request = (Response, Framed); + type Request = (Response, Framed); type Response = Framed; type Error = Error; - type Future = SendResponseFut; + type Future = SendResponseFut; fn poll_ready(&mut self) -> Poll<(), Self::Error> { Ok(Async::Ready(())) } - fn call(&mut self, (mut res, mut framed): Self::Request) -> Self::Future { - framed.get_codec_mut().prepare_te(&mut res); - let body = res.replace_body(Body::Empty); + fn call(&mut self, (res, mut framed): Self::Request) -> Self::Future { + let (mut res, body) = res.replace_body(()); + framed + .get_codec_mut() + .prepare_te(res.head_mut(), &mut body.length()); SendResponseFut { res: Some(Message::Item(res)), body: Some(body), @@ -183,73 +187,69 @@ where } } -pub struct SendResponseFut { +pub struct SendResponseFut { res: Option>, - body: Option, + body: Option, framed: Option>, } -impl Future for SendResponseFut +impl Future for SendResponseFut where T: AsyncRead + AsyncWrite, + B: MessageBody, { type Item = Framed; type Error = Error; fn poll(&mut self) -> Poll { - // send response - if self.res.is_some() { + loop { + let mut body_ready = self.body.is_some(); let framed = self.framed.as_mut().unwrap(); - if !framed.is_write_buf_full() { - if let Some(res) = self.res.take() { - framed.force_send(res)?; - } - } - } - // send body - if self.res.is_none() && self.body.is_some() { - let framed = self.framed.as_mut().unwrap(); - if !framed.is_write_buf_full() { - let body = self.body.take().unwrap(); - match body { - Body::Empty => (), - Body::Streaming(mut stream) => loop { - match stream.poll()? { - Async::Ready(item) => { - let done = item.is_none(); - framed.force_send(Message::Chunk(item.into()))?; - if !done { - if !framed.is_write_buf_full() { - continue; - } else { - self.body = Some(Body::Streaming(stream)); - break; - } - } - } - Async::NotReady => { - self.body = Some(Body::Streaming(stream)); - break; + // send body + if self.res.is_none() && self.body.is_some() { + while body_ready && self.body.is_some() && !framed.is_write_buf_full() { + match self.body.as_mut().unwrap().poll_next()? { + Async::Ready(item) => { + // body is done + if item.is_none() { + let _ = self.body.take(); } + framed.force_send(Message::Chunk(item))?; } - }, - Body::Binary(mut bin) => { - framed.force_send(Message::Chunk(Some(bin.take())))?; - framed.force_send(Message::Chunk(None))?; + Async::NotReady => body_ready = false, } } } - } - // flush - match self.framed.as_mut().unwrap().poll_complete()? { - Async::Ready(_) => if self.res.is_some() || self.body.is_some() { - return self.poll(); - }, - Async::NotReady => return Ok(Async::NotReady), - } + // flush write buffer + if !framed.is_write_buf_empty() { + match framed.poll_complete()? { + Async::Ready(_) => if body_ready { + continue; + } else { + return Ok(Async::NotReady); + }, + Async::NotReady => return Ok(Async::NotReady), + } + } - Ok(Async::Ready(self.framed.take().unwrap())) + // send response + if let Some(res) = self.res.take() { + framed.force_send(res)?; + continue; + } + + if self.body.is_some() { + if body_ready { + continue; + } else { + return Ok(Async::NotReady); + } + } else { + break; + } + } + return Ok(Async::Ready(self.framed.take().unwrap())); } }