diff --git a/src/error.rs b/src/error.rs index f3b1507d..13a54608 100644 --- a/src/error.rs +++ b/src/error.rs @@ -344,6 +344,31 @@ impl ErrorResponse for WsHandshakeError { } } +/// A set of errors that can occur during parsing urlencoded payloads +#[derive(Fail, Debug, PartialEq)] +pub enum UrlencodedError { + /// Can not decode chunked transfer encoding + #[fail(display="Can not decode chunked transfer encoding")] + Chunked, + /// Payload size is bigger than 256k + #[fail(display="Payload size is bigger than 256k")] + Overflow, + /// Payload size is now known + #[fail(display="Payload size is now known")] + UnknownLength, + /// Content type error + #[fail(display="Content type error")] + ContentType, +} + +/// Return `BadRequest` for `UrlencodedError` +impl ErrorResponse for UrlencodedError { + + fn error_response(&self) -> HttpResponse { + HttpResponse::new(StatusCode::BAD_REQUEST, Body::Empty) + } +} + #[cfg(test)] mod tests { use std::error::Error as StdError; diff --git a/src/httprequest.rs b/src/httprequest.rs index 25abc16c..f2c1d601 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -12,7 +12,8 @@ use {Cookie, HttpRange}; use recognizer::Params; use payload::Payload; use multipart::Multipart; -use error::{ParseError, PayloadError, MultipartError, CookieParseError, HttpRangeError}; +use error::{ParseError, PayloadError, + MultipartError, CookieParseError, HttpRangeError, UrlencodedError}; struct HttpMessage { version: Version, @@ -306,36 +307,41 @@ impl HttpRequest { /// * content type is not `application/x-www-form-urlencoded` /// * transfer encoding is `chunked`. /// * content-length is greater than 256k - pub fn urlencoded(&self, payload: Payload) -> Result { - if let Ok(chunked) = self.chunked() { - if chunked { - return Err(payload) - } + pub fn urlencoded(&mut self) -> Result { + if let Ok(true) = self.chunked() { + return Err(UrlencodedError::Chunked) } if let Some(len) = self.headers().get(header::CONTENT_LENGTH) { if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::() { if len > 262_144 { - return Err(payload) + return Err(UrlencodedError::Overflow) } } else { - return Err(payload) + return Err(UrlencodedError::UnknownLength) } } else { - return Err(payload) + return Err(UrlencodedError::UnknownLength) } } - if let Some(content_type) = self.0.headers.get(header::CONTENT_TYPE) { + // check content type + let t = if let Some(content_type) = self.0.headers.get(header::CONTENT_TYPE) { if let Ok(content_type) = content_type.to_str() { - if content_type.to_lowercase() == "application/x-www-form-urlencoded" { - return Ok(UrlEncoded{pl: payload, body: BytesMut::new()}) - } + content_type.to_lowercase() == "application/x-www-form-urlencoded" + } else { + false } - } + } else { + false + }; - Err(payload) + if t { + Ok(UrlEncoded{pl: self.take_payload(), body: BytesMut::new()}) + } else { + Err(UrlencodedError::ContentType) + } } } @@ -409,43 +415,39 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert(header::TRANSFER_ENCODING, header::HeaderValue::from_static("chunked")); - let req = HttpRequest::new( + let mut req = HttpRequest::new( Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new(), Payload::empty()); - let (_, payload) = Payload::new(false); - assert!(req.urlencoded(payload).is_err()); + assert_eq!(req.urlencoded().err().unwrap(), UrlencodedError::Chunked); let mut headers = HeaderMap::new(); headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/x-www-form-urlencoded")); headers.insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("xxxx")); - let req = HttpRequest::new( + let mut req = HttpRequest::new( Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new(), Payload::empty()); - let (_, payload) = Payload::new(false); - assert!(req.urlencoded(payload).is_err()); + assert_eq!(req.urlencoded().err().unwrap(), UrlencodedError::UnknownLength); let mut headers = HeaderMap::new(); headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/x-www-form-urlencoded")); headers.insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("1000000")); - let req = HttpRequest::new( + let mut req = HttpRequest::new( Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new(), Payload::empty()); - let (_, payload) = Payload::new(false); - assert!(req.urlencoded(payload).is_err()); + assert_eq!(req.urlencoded().err().unwrap(), UrlencodedError::Overflow); let mut headers = HeaderMap::new(); headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("text/plain")); headers.insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("10")); - let req = HttpRequest::new( + let mut req = HttpRequest::new( Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new(), Payload::empty()); - let (_, payload) = Payload::new(false); - assert!(req.urlencoded(payload).is_err()); + assert_eq!(req.urlencoded().err().unwrap(), UrlencodedError::ContentType); } }