diff --git a/awc/src/error.rs b/awc/src/error.rs index bbfd9b97..20654bdf 100644 --- a/awc/src/error.rs +++ b/awc/src/error.rs @@ -54,9 +54,6 @@ impl From for WsClientError { /// A set of errors that can occur during parsing json payloads #[derive(Debug, Display, From)] pub enum JsonPayloadError { - /// Payload size is bigger than allowed. (default: 32kB) - #[display(fmt = "Json payload size is bigger than allowed.")] - Overflow, /// Content type error #[display(fmt = "Content type error")] ContentType, diff --git a/awc/src/response.rs b/awc/src/response.rs index b9173520..a4719a9a 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -1,8 +1,9 @@ use std::cell::{Ref, RefMut}; use std::fmt; +use std::marker::PhantomData; use bytes::{Bytes, BytesMut}; -use futures::{Future, Poll, Stream}; +use futures::{Async, Future, Poll, Stream}; use actix_http::cookie::Cookie; use actix_http::error::{CookieParseError, PayloadError}; @@ -103,7 +104,7 @@ impl ClientResponse { impl ClientResponse where - S: Stream + 'static, + S: Stream, { /// Loads http response's body. pub fn body(&mut self) -> MessageBody { @@ -147,16 +148,14 @@ impl fmt::Debug for ClientResponse { /// Future that resolves to a complete http message body. pub struct MessageBody { - limit: usize, length: Option, - stream: Option>, err: Option, - fut: Option>>, + fut: Option>, } impl MessageBody where - S: Stream + 'static, + S: Stream, { /// Create `MessageBody` for request. pub fn new(res: &mut ClientResponse) -> MessageBody { @@ -174,24 +173,22 @@ where } MessageBody { - limit: 262_144, length: len, - stream: Some(res.take_payload()), - fut: None, err: None, + fut: Some(ReadBody::new(res.take_payload(), 262_144)), } } /// Change max size of payload. By default max size is 256Kb pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; + if let Some(ref mut fut) = self.fut { + fut.limit = limit; + } self } fn err(e: PayloadError) -> Self { MessageBody { - stream: None, - limit: 262_144, fut: None, err: Some(e), length: None, @@ -201,44 +198,23 @@ where impl Future for MessageBody where - S: Stream + 'static, + S: Stream, { type Item = Bytes; type Error = PayloadError; fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut { - return fut.poll(); - } - if let Some(err) = self.err.take() { return Err(err); } if let Some(len) = self.length.take() { - if len > self.limit { + if len > self.fut.as_ref().unwrap().limit { return Err(PayloadError::Overflow); } } - // future - let limit = self.limit; - self.fut = Some(Box::new( - self.stream - .take() - .expect("Can not be used second time") - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(PayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .map(|body| body.freeze()), - )); - self.poll() + self.fut.as_mut().unwrap().poll() } } @@ -249,16 +225,15 @@ where /// * content type is not `application/json` /// * content length is greater than 64k pub struct JsonBody { - limit: usize, length: Option, - stream: Payload, err: Option, - fut: Option>>, + fut: Option>, + _t: PhantomData, } impl JsonBody where - S: Stream + 'static, + S: Stream, U: DeserializeOwned, { /// Create `JsonBody` for request. @@ -271,11 +246,10 @@ where }; if !json { return JsonBody { - limit: 65536, length: None, - stream: Payload::None, fut: None, err: Some(JsonPayloadError::ContentType), + _t: PhantomData, }; } @@ -289,58 +263,84 @@ where } JsonBody { - limit: 65536, length: len, - stream: req.take_payload(), - fut: None, err: None, + fut: Some(ReadBody::new(req.take_payload(), 65536)), + _t: PhantomData, } } /// Change max size of payload. By default max size is 64Kb pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; + if let Some(ref mut fut) = self.fut { + fut.limit = limit; + } self } } impl Future for JsonBody where - T: Stream + 'static, + T: Stream, U: DeserializeOwned + 'static, { type Item = U; type Error = JsonPayloadError; fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut { - return fut.poll(); - } - if let Some(err) = self.err.take() { return Err(err); } - let limit = self.limit; if let Some(len) = self.length.take() { - if len > limit { - return Err(JsonPayloadError::Overflow); + if len > self.fut.as_ref().unwrap().limit { + return Err(JsonPayloadError::Payload(PayloadError::Overflow)); } } - let fut = std::mem::replace(&mut self.stream, Payload::None) - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(JsonPayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) + let body = futures::try_ready!(self.fut.as_mut().unwrap().poll()); + Ok(Async::Ready(serde_json::from_slice::(&body)?)) + } +} + +struct ReadBody { + stream: Payload, + buf: BytesMut, + limit: usize, +} + +impl ReadBody { + fn new(stream: Payload, limit: usize) -> Self { + Self { + stream, + buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)), + limit, + } + } +} + +impl Future for ReadBody +where + S: Stream, +{ + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll { + loop { + return match self.stream.poll()? { + Async::Ready(Some(chunk)) => { + if (self.buf.len() + chunk.len()) > self.limit { + Err(PayloadError::Overflow) + } else { + self.buf.extend_from_slice(&chunk); + continue; + } } - }) - .and_then(|body| Ok(serde_json::from_slice::(&body)?)); - self.fut = Some(Box::new(fut)); - self.poll() + Async::Ready(None) => Ok(Async::Ready(self.buf.take().freeze())), + Async::NotReady => Ok(Async::NotReady), + }; + } } } @@ -391,8 +391,8 @@ mod tests { fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { match err { - JsonPayloadError::Overflow => match other { - JsonPayloadError::Overflow => true, + JsonPayloadError::Payload(PayloadError::Overflow) => match other { + JsonPayloadError::Payload(PayloadError::Overflow) => true, _ => false, }, JsonPayloadError::ContentType => match other { @@ -430,7 +430,10 @@ mod tests { .finish(); let json = block_on(JsonBody::<_, MyObject>::new(&mut req).limit(100)); - assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); + assert!(json_eq( + json.err().unwrap(), + JsonPayloadError::Payload(PayloadError::Overflow) + )); let mut req = TestResponse::default() .header(