diff --git a/src/extractor.rs b/src/extractor.rs index 6f0e5b334..3ada8d5d5 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,6 +1,7 @@ use std::str; use std::ops::{Deref, DerefMut}; +use mime::Mime; use bytes::Bytes; use serde_urlencoded; use serde::de::{self, DeserializeOwned}; @@ -301,12 +302,20 @@ impl Default for FormConfig { /// ``` impl FromRequest for Bytes { - type Config = (); - type Result = Box>; + type Config = PayloadConfig; + type Result = Either, + Box>>; #[inline] - fn from_request(req: &HttpRequest, _: &Self::Config) -> Self::Result { - Box::new(MessageBody::new(req.clone()).from_err()) + fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { + // check content-type + if let Err(e) = cfg.check_mimetype(req) { + return Either::A(result(Err(e))); + } + + Either::B(Box::new(MessageBody::new(req.clone()) + .limit(cfg.limit) + .from_err())) } } @@ -328,12 +337,18 @@ impl FromRequest for Bytes /// ``` impl FromRequest for String { - type Config = (); + type Config = PayloadConfig; type Result = Either, Box>>; #[inline] - fn from_request(req: &HttpRequest, _: &Self::Config) -> Self::Result { + fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { + // check content-type + if let Err(e) = cfg.check_mimetype(req) { + return Either::A(result(Err(e))); + } + + // check charset let encoding = match req.encoding() { Err(_) => return Either::A( result(Err(ErrorBadRequest("Unknown request charset")))), @@ -342,6 +357,7 @@ impl FromRequest for String Either::B(Box::new( MessageBody::new(req.clone()) + .limit(cfg.limit) .from_err() .and_then(move |body| { let enc: *const Encoding = encoding as *const Encoding; @@ -357,9 +373,57 @@ impl FromRequest for String } } +/// Payload configuration for request's payload. +pub struct PayloadConfig { + limit: usize, + mimetype: Option, +} + +impl PayloadConfig { + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(&mut self, limit: usize) -> &mut Self { + self.limit = limit; + self + } + + /// Set required mime-type of the request. By default mime type is not enforced. + pub fn mimetype(&mut self, mt: Mime) -> &mut Self { + self.mimetype = Some(mt); + self + } + + fn check_mimetype(&self, req: &HttpRequest) -> Result<(), Error> { + // check content-type + if let Some(ref mt) = self.mimetype { + match req.mime_type() { + Ok(Some(ref req_mt)) => { + if mt != req_mt { + return Err(ErrorBadRequest("Unexpected Content-Type")); + } + }, + Ok(None) => { + return Err(ErrorBadRequest("Content-Type is expected")); + }, + Err(err) => { + return Err(err.into()); + }, + } + } + Ok(()) + } +} + +impl Default for PayloadConfig { + fn default() -> Self { + PayloadConfig{limit: 262_144, mimetype: None} + } +} + #[cfg(test)] mod tests { use super::*; + use mime; use bytes::Bytes; use futures::{Async, Future}; use http::header; @@ -375,10 +439,11 @@ mod tests { #[test] fn test_bytes() { + let cfg = PayloadConfig::default(); let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); - match Bytes::from_request(&req, &()).poll().unwrap() { + match Bytes::from_request(&req, &cfg).poll().unwrap() { Async::Ready(s) => { assert_eq!(s, Bytes::from_static(b"hello=world")); }, @@ -388,10 +453,11 @@ mod tests { #[test] fn test_string() { + let cfg = PayloadConfig::default(); let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); - match String::from_request(&req, &()).poll().unwrap() { + match String::from_request(&req, &cfg).poll().unwrap() { Async::Ready(s) => { assert_eq!(s, "hello=world"); }, @@ -417,6 +483,21 @@ mod tests { } } + #[test] + fn test_payload_config() { + let req = HttpRequest::default(); + let mut cfg = PayloadConfig::default(); + cfg.mimetype(mime::APPLICATION_JSON); + assert!(cfg.check_mimetype(&req).is_err()); + + let req = TestRequest::with_header( + header::CONTENT_TYPE, "application/x-www-form-urlencoded").finish(); + assert!(cfg.check_mimetype(&req).is_err()); + + let req = TestRequest::with_header(header::CONTENT_TYPE, "application/json").finish(); + assert!(cfg.check_mimetype(&req).is_ok()); + } + #[derive(Deserialize)] struct MyStruct { key: String, diff --git a/src/lib.rs b/src/lib.rs index bf536e2d0..3e21b3973 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -180,7 +180,7 @@ pub mod dev { pub use json::{JsonBody, JsonConfig}; pub use info::ConnectionInfo; pub use handler::{Handler, Reply}; - pub use extractor::{FormConfig}; + pub use extractor::{FormConfig, PayloadConfig}; pub use route::Route; pub use router::{Router, Resource, ResourceType}; pub use resource::ResourceHandler;