diff --git a/src/body.rs b/src/body.rs index 3449448e..cc4e77af 100644 --- a/src/body.rs +++ b/src/body.rs @@ -37,6 +37,37 @@ impl MessageBody for () { } } +pub enum ResponseBody { + Body(B), + Other(Body), +} + +impl ResponseBody { + pub fn as_ref(&self) -> Option<&B> { + if let ResponseBody::Body(ref b) = self { + Some(b) + } else { + None + } + } +} + +impl MessageBody for ResponseBody { + fn length(&self) -> BodyLength { + match self { + ResponseBody::Body(ref body) => body.length(), + ResponseBody::Other(ref body) => body.length(), + } + } + + fn poll_next(&mut self) -> Poll, Error> { + match self { + ResponseBody::Body(ref mut body) => body.poll_next(), + ResponseBody::Other(ref mut body) => body.poll_next(), + } + } +} + /// Represents various types of http message body. pub enum Body { /// Empty response. `Content-Length` header is not set. @@ -332,6 +363,15 @@ mod tests { } } + impl ResponseBody { + pub(crate) fn get_ref(&self) -> &[u8] { + match *self { + ResponseBody::Body(ref b) => b.get_ref(), + ResponseBody::Other(ref b) => b.get_ref(), + } + } + } + #[test] fn test_static_str() { assert_eq!(Body::from("").length(), BodyLength::Sized(0)); diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 550d27c2..48c8e710 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -13,7 +13,7 @@ use tokio_timer::Delay; use error::{ParseError, PayloadError}; use payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; -use body::{BodyLength, MessageBody}; +use body::{Body, BodyLength, MessageBody, ResponseBody}; use config::ServiceConfig; use error::DispatchError; use request::Request; @@ -70,7 +70,7 @@ enum DispatcherMessage { enum State { None, ServiceCall(S::Future), - SendPayload(B), + SendPayload(ResponseBody), } impl State { @@ -186,11 +186,11 @@ where } } - fn send_response( + fn send_response( &mut self, message: Response<()>, - body: B1, - ) -> Result, DispatchError> { + body: ResponseBody, + ) -> Result, DispatchError> { self.framed .force_send(Message::Item((message, body.length()))) .map_err(|err| { @@ -217,7 +217,7 @@ where Some(self.handle_request(req)?) } Some(DispatcherMessage::Error(res)) => { - self.send_response(res, ())?; + self.send_response(res, ResponseBody::Other(Body::Empty))?; None } None => None, @@ -431,7 +431,7 @@ where trace!("Slow request timeout"); let _ = self.send_response( Response::RequestTimeout().finish().drop_body(), - (), + ResponseBody::Other(Body::Empty), ); } else { trace!("Keep-alive connection timeout"); diff --git a/src/response.rs b/src/response.rs index bc730718..ae68189d 100644 --- a/src/response.rs +++ b/src/response.rs @@ -12,7 +12,7 @@ use http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode, Version}; use serde::Serialize; use serde_json; -use body::{Body, BodyStream, MessageBody}; +use body::{Body, BodyStream, MessageBody, ResponseBody}; use error::Error; use header::{Header, IntoHeaderValue}; use message::{ConnectionType, Head, ResponseHead}; @@ -21,7 +21,7 @@ use message::{ConnectionType, Head, ResponseHead}; pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536; /// An HTTP Response -pub struct Response(Box, B); +pub struct Response(Box, ResponseBody); impl Response { /// Create http response builder with specific status. @@ -71,6 +71,15 @@ impl Response { cookies: jar, } } + + /// Convert response to response with body + pub fn into_body(self) -> Response { + let b = match self.1 { + ResponseBody::Body(b) => b, + ResponseBody::Other(b) => b, + }; + Response(self.0, ResponseBody::Other(b)) + } } impl Response { @@ -195,23 +204,26 @@ impl Response { /// Get body os this response #[inline] - pub fn body(&self) -> &B { + pub(crate) fn body(&self) -> &ResponseBody { &self.1 } /// Set a body - pub fn set_body(self, body: B2) -> Response { - Response(self.0, body) + pub(crate) fn set_body(self, body: B2) -> Response { + Response(self.0, ResponseBody::Body(body)) } /// Drop request's body - pub fn drop_body(self) -> Response<()> { - Response(self.0, ()) + pub(crate) fn drop_body(self) -> Response<()> { + Response(self.0, ResponseBody::Body(())) } /// Set a body and return previous body value - pub fn replace_body(self, body: B2) -> (Response, B) { - (Response(self.0, body), self.1) + pub(crate) fn replace_body( + self, + body: B2, + ) -> (Response, ResponseBody) { + (Response(self.0, ResponseBody::Body(body)), self.1) } /// Size of response in bytes, excluding HTTP headers @@ -233,7 +245,10 @@ impl Response { } pub(crate) fn from_parts(parts: ResponseParts) -> Response { - Response(Box::new(InnerResponse::from_parts(parts)), Body::Empty) + Response( + Box::new(InnerResponse::from_parts(parts)), + ResponseBody::Body(Body::Empty), + ) } } @@ -250,7 +265,7 @@ impl fmt::Debug for Response { for (key, val) in self.get_ref().head.headers.iter() { let _ = writeln!(f, " {:?}: {:?}", key, val); } - let _ = writeln!(f, " body: {:?}", self.body().length()); + let _ = writeln!(f, " body: {:?}", self.1.length()); res } } @@ -559,11 +574,9 @@ impl ResponseBuilder { /// /// `ResponseBuilder` can not be used after this call. pub fn message_body(&mut self, body: B) -> Response { - let mut error = if let Some(e) = self.err.take() { - Some(Error::from(e)) - } else { - None - }; + if let Some(e) = self.err.take() { + return Response::from(Error::from(e)).into_body(); + } let mut response = self.response.take().expect("cannot reuse response builder"); if let Some(ref jar) = self.cookies { @@ -572,17 +585,12 @@ impl ResponseBuilder { Ok(val) => { let _ = response.head.headers.append(header::SET_COOKIE, val); } - Err(e) => if error.is_none() { - error = Some(Error::from(e)); - }, + Err(e) => return Response::from(Error::from(e)).into_body(), }; } } - if let Some(error) = error { - response.error = Some(error); - } - Response(response, body) + Response(response, ResponseBody::Body(body)) } #[inline] @@ -812,9 +820,12 @@ impl ResponsePool { ) -> Response { if let Some(mut msg) = pool.0.borrow_mut().pop_front() { msg.head.status = status; - Response(msg, body) + Response(msg, ResponseBody::Body(body)) } else { - Response(Box::new(InnerResponse::new(status, pool)), body) + Response( + Box::new(InnerResponse::new(status, pool)), + ResponseBody::Body(body), + ) } } @@ -971,10 +982,7 @@ mod tests { let resp = Response::build(StatusCode::OK).json(vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("application/json")); - assert_eq!( - *resp.body(), - Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")) - ); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); } #[test] @@ -984,10 +992,7 @@ mod tests { .json(vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("text/json")); - assert_eq!( - *resp.body(), - Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")) - ); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); } #[test] @@ -995,10 +1000,7 @@ mod tests { let resp = Response::build(StatusCode::OK).json2(&vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("application/json")); - assert_eq!( - *resp.body(), - Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")) - ); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); } #[test] @@ -1008,10 +1010,7 @@ mod tests { .json2(&vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("text/json")); - assert_eq!( - *resp.body(), - Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")) - ); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); } #[test] diff --git a/src/service.rs b/src/service.rs index f934305c..6a31b6bb 100644 --- a/src/service.rs +++ b/src/service.rs @@ -6,7 +6,7 @@ use futures::future::{ok, Either, FutureResult}; use futures::{Async, Future, Poll, Sink}; use tokio_io::{AsyncRead, AsyncWrite}; -use body::{BodyLength, MessageBody}; +use body::{BodyLength, MessageBody, ResponseBody}; use error::{Error, ResponseError}; use h1::{Codec, Message}; use response::Response; @@ -174,7 +174,7 @@ where pub struct SendResponseFut { res: Option, BodyLength)>>, - body: Option, + body: Option>, framed: Option>, } diff --git a/tests/test_server.rs b/tests/test_server.rs index a153e584..300b38a8 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -445,3 +445,25 @@ fn test_body_chunked_implicit() { let bytes = srv.block_on(response.body()).unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } + +#[test] +fn test_response_http_error_handling() { + let mut srv = test::TestServer::with_factory(|| { + h1::H1Service::new(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + }).map(|_| ()) + }); + + let req = srv.get().finish().unwrap(); + let response = srv.send_request(req).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +}