diff --git a/src/lib.rs b/src/lib.rs index dc0493a81..18cf93f43 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ pub mod dev { pub use crate::service::{ HttpServiceFactory, ServiceFromRequest, ServiceRequest, ServiceResponse, }; + pub use crate::types::json::JsonBody; pub use actix_http::body::{Body, BodyLength, MessageBody, ResponseBody}; pub use actix_http::dev::ResponseBuilder as HttpResponseBuilder; @@ -95,8 +96,7 @@ pub mod web { pub use crate::data::{Data, RouteData}; pub use crate::request::HttpRequest; - pub use crate::types::{Form, Json, Path, Payload, Query}; - pub use crate::types::{FormConfig, JsonConfig, PayloadConfig}; + pub use crate::types::*; /// Create resource for a specific path. /// diff --git a/src/types/json.rs b/src/types/json.rs index 92b7f20f6..74ee5eb21 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -3,16 +3,15 @@ use std::rc::Rc; use std::{fmt, ops}; -use bytes::Bytes; -use futures::{Future, Stream}; +use bytes::{Bytes, BytesMut}; +use futures::{Future, Poll, Stream}; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json; -use actix_http::dev::JsonBody; -use actix_http::error::{Error, JsonPayloadError}; -use actix_http::http::StatusCode; -use actix_http::Response; +use actix_http::error::{Error, JsonPayloadError, PayloadError}; +use actix_http::http::{header::CONTENT_LENGTH, StatusCode}; +use actix_http::{HttpMessage, Payload, Response}; use crate::extract::FromRequest; use crate::request::HttpRequest; @@ -257,3 +256,187 @@ impl Default for JsonConfig { } } } + +/// Request's payload json parser, it resolves to a deserialized `T` value. +/// This future could be used with `ServiceRequest` and `ServiceFromRequest`. +/// +/// Returns error: +/// +/// * content type is not `application/json` +/// * content length is greater than 256k +pub struct JsonBody { + limit: usize, + length: Option, + stream: Payload, + err: Option, + fut: Option>>, +} + +impl JsonBody +where + T: HttpMessage, + T::Stream: Stream + 'static, + U: DeserializeOwned + 'static, +{ + /// Create `JsonBody` for request. + pub fn new(req: &mut T) -> Self { + // check content-type + let json = if let Ok(Some(mime)) = req.mime_type() { + mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) + } else { + false + }; + if !json { + return JsonBody { + limit: 262_144, + length: None, + stream: Payload::None, + fut: None, + err: Some(JsonPayloadError::ContentType), + }; + } + + let mut len = None; + if let Some(l) = req.headers().get(CONTENT_LENGTH) { + if let Ok(s) = l.to_str() { + if let Ok(l) = s.parse::() { + len = Some(l) + } + } + } + + JsonBody { + limit: 262_144, + length: len, + stream: req.take_payload(), + fut: None, + err: None, + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } +} + +impl Future for JsonBody +where + T: HttpMessage, + T::Stream: Stream + 'static, + U: DeserializeOwned + 'static, +{ + type Item = U; + type Error = JsonPayloadError; + + fn poll(&mut self) -> Poll { + if let Some(ref mut fut) = self.fut { + return fut.poll(); + } + + if let Some(err) = self.err.take() { + return Err(err); + } + + let limit = self.limit; + if let Some(len) = self.length.take() { + if len > limit { + return Err(JsonPayloadError::Overflow); + } + } + + let fut = std::mem::replace(&mut self.stream, Payload::None) + .from_err() + .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { + if (body.len() + chunk.len()) > limit { + Err(JsonPayloadError::Overflow) + } else { + body.extend_from_slice(&chunk); + Ok(body) + } + }) + .and_then(|body| Ok(serde_json::from_slice::(&body)?)); + self.fut = Some(Box::new(fut)); + self.poll() + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use serde_derive::{Deserialize, Serialize}; + + use super::*; + use crate::http::header; + use crate::test::{block_on, TestRequest}; + + fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { + match err { + JsonPayloadError::Overflow => match other { + JsonPayloadError::Overflow => true, + _ => false, + }, + JsonPayloadError::ContentType => match other { + JsonPayloadError::ContentType => true, + _ => false, + }, + _ => false, + } + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct MyObject { + name: String, + } + + #[test] + fn test_json_body() { + let mut req = TestRequest::default().to_request(); + let json = block_on(req.json::()); + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/text"), + ) + .to_request(); + let json = block_on(req.json::()); + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("10000"), + ) + .to_request(); + + let json = block_on(req.json::().limit(100)); + assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); + + let mut req = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .to_request(); + + let json = block_on(req.json::()); + assert_eq!( + json.ok().unwrap(), + MyObject { + name: "test".to_owned() + } + ); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b5f8de603..2fc3ca938 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,7 +1,7 @@ //! Helper types mod form; -mod json; +pub(crate) mod json; mod path; mod payload; mod query;