From 4309d9b88c90fb2448af127384d9b3cb8b11b932 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 28 Mar 2019 05:04:39 -0700 Subject: [PATCH] port multipart support --- Cargo.toml | 1 + actix-http/src/error.rs | 14 + src/error.rs | 36 ++ src/test.rs | 7 +- src/types/mod.rs | 2 + src/types/multipart.rs | 1137 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 1194 insertions(+), 3 deletions(-) create mode 100644 src/types/multipart.rs diff --git a/Cargo.toml b/Cargo.toml index 363989bf7..63a607e2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,6 +82,7 @@ derive_more = "0.14" encoding = "0.2" futures = "0.1" hashbrown = "0.1.8" +httparse = "1.3" log = "0.4" mime = "0.3" net2 = "0.2.33" diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 23e6728d8..a026fe9db 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -905,6 +905,20 @@ mod tests { assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } + #[test] + fn test_payload_error() { + let err: PayloadError = + io::Error::new(io::ErrorKind::Other, "ParseError").into(); + assert_eq!(format!("{}", err), "ParseError"); + assert_eq!(format!("{}", err.cause().unwrap()), "ParseError"); + + let err = PayloadError::Incomplete; + assert_eq!( + format!("{}", err), + "A payload reached EOF, but is not complete." + ); + } + macro_rules! from { ($from:expr => $error:pat) => { match ParseError::from($from) { diff --git a/src/error.rs b/src/error.rs index fc0f9fdf3..f7610d509 100644 --- a/src/error.rs +++ b/src/error.rs @@ -143,6 +143,36 @@ impl ResponseError for ReadlinesError { } } +/// A set of errors that can occur during parsing multipart streams +#[derive(Debug, Display, From)] +pub enum MultipartError { + /// Content-Type header is not found + #[display(fmt = "No Content-type header found")] + NoContentType, + /// Can not parse Content-Type header + #[display(fmt = "Can not parse Content-Type header")] + ParseContentType, + /// Multipart boundary is not found + #[display(fmt = "Multipart boundary is not found")] + Boundary, + /// Multipart stream is incomplete + #[display(fmt = "Multipart stream is incomplete")] + Incomplete, + /// Error during field parsing + #[display(fmt = "{}", _0)] + Parse(ParseError), + /// Payload error + #[display(fmt = "{}", _0)] + Payload(PayloadError), +} + +/// Return `BadRequest` for `MultipartError` +impl ResponseError for MultipartError { + fn error_response(&self) -> HttpResponse { + HttpResponse::new(StatusCode::BAD_REQUEST) + } +} + #[cfg(test)] mod tests { use super::*; @@ -172,4 +202,10 @@ mod tests { let resp: HttpResponse = ReadlinesError::EncodingError.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } + + #[test] + fn test_multipart_error() { + let resp: HttpResponse = MultipartError::Boundary.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } } diff --git a/src/test.rs b/src/test.rs index 9e1f01f90..4fdb6ea38 100644 --- a/src/test.rs +++ b/src/test.rs @@ -49,11 +49,12 @@ where /// /// Note that this function is intended to be used only for testing purpose. /// This function panics on nested call. -pub fn run_on(f: F) -> Result +pub fn run_on(f: F) -> R where - F: Fn() -> Result, + F: Fn() -> R, { - RT.with(move |rt| rt.borrow_mut().block_on(lazy(f))) + RT.with(move |rt| rt.borrow_mut().block_on(lazy(|| Ok::<_, ()>(f())))) + .unwrap() } pub fn ok_service() -> impl Service< diff --git a/src/types/mod.rs b/src/types/mod.rs index 30ee73091..c9aed94f9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod form; pub(crate) mod json; +mod multipart; mod path; pub(crate) mod payload; mod query; @@ -9,6 +10,7 @@ pub(crate) mod readlines; pub use self::form::{Form, FormConfig}; pub use self::json::{Json, JsonConfig}; +pub use self::multipart::{Multipart, MultipartItem}; pub use self::path::Path; pub use self::payload::{Payload, PayloadConfig}; pub use self::query::Query; diff --git a/src/types/multipart.rs b/src/types/multipart.rs new file mode 100644 index 000000000..d66053ff5 --- /dev/null +++ b/src/types/multipart.rs @@ -0,0 +1,1137 @@ +//! Multipart payload support +use std::cell::{RefCell, UnsafeCell}; +use std::collections::VecDeque; +use std::marker::PhantomData; +use std::rc::Rc; +use std::{cmp, fmt}; + +use bytes::{Bytes, BytesMut}; +use futures::task::{current as current_task, Task}; +use futures::{Async, Poll, Stream}; +use httparse; +use mime; + +use crate::error::{Error, MultipartError, ParseError, PayloadError}; +use crate::extract::FromRequest; +use crate::http::header::{ + self, ContentDisposition, HeaderMap, HeaderName, HeaderValue, +}; +use crate::http::HttpTryFrom; +use crate::service::ServiceFromRequest; +use crate::HttpMessage; + +const MAX_HEADERS: usize = 32; + +/// The server-side implementation of `multipart/form-data` requests. +/// +/// This will parse the incoming stream into `MultipartItem` instances via its +/// Stream implementation. +/// `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` +/// is used for nested multipart streams. +pub struct Multipart { + safety: Safety, + error: Option, + inner: Option>>, +} + +/// Multipart item +pub enum MultipartItem { + /// Multipart field + Field(Field), + /// Nested multipart stream + Nested(Multipart), +} + +/// Get request's payload as multipart stream +/// +/// Content-type: multipart/form-data; +/// +/// ## Server example +/// +/// ```rust +/// # use futures::{Future, Stream}; +/// # use futures::future::{ok, result, Either}; +/// use actix_web::{web, HttpResponse, Error}; +/// +/// fn index(payload: web::Multipart) -> impl Future { +/// payload.from_err() // <- get multipart stream for current request +/// .and_then(|item| match item { // <- iterate over multipart items +/// web::MultipartItem::Field(field) => { +/// // Field in turn is stream of *Bytes* object +/// Either::A(field.from_err() +/// .fold((), |_, chunk| { +/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk)); +/// Ok::<_, Error>(()) +/// })) +/// }, +/// web::MultipartItem::Nested(mp) => { +/// // Or item could be nested Multipart stream +/// Either::B(ok(())) +/// } +/// }) +/// .fold((), |_, _| Ok::<_, Error>(())) +/// .map(|_| HttpResponse::Ok().into()) +/// } +/// # fn main() {} +/// ``` +impl

FromRequest

for Multipart +where + P: Stream + 'static, +{ + type Error = Error; + type Future = Result; + + #[inline] + fn from_request(req: &mut ServiceFromRequest

) -> Self::Future { + let pl = req.take_payload(); + Ok(Multipart::new(req.headers(), pl)) + } +} + +enum InnerMultipartItem { + None, + Field(Rc>), + Multipart(Rc>), +} + +#[derive(PartialEq, Debug)] +enum InnerState { + /// Stream eof + Eof, + /// Skip data until first boundary + FirstBoundary, + /// Reading boundary + Boundary, + /// Reading Headers, + Headers, +} + +struct InnerMultipart { + payload: PayloadRef, + boundary: String, + state: InnerState, + item: InnerMultipartItem, +} + +impl Multipart { + /// Create multipart instance for boundary. + pub fn new(headers: &HeaderMap, stream: S) -> Multipart + where + S: Stream + 'static, + { + match Self::boundary(headers) { + Ok(boundary) => Multipart { + error: None, + safety: Safety::new(), + inner: Some(Rc::new(RefCell::new(InnerMultipart { + boundary, + payload: PayloadRef::new(PayloadBuffer::new(stream)), + state: InnerState::FirstBoundary, + item: InnerMultipartItem::None, + }))), + }, + Err(err) => Multipart { + error: Some(err), + safety: Safety::new(), + inner: None, + }, + } + } + + /// Extract boundary info from headers. + fn boundary(headers: &HeaderMap) -> Result { + if let Some(content_type) = headers.get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + if let Ok(ct) = content_type.parse::() { + if let Some(boundary) = ct.get_param(mime::BOUNDARY) { + Ok(boundary.as_str().to_owned()) + } else { + Err(MultipartError::Boundary) + } + } else { + Err(MultipartError::ParseContentType) + } + } else { + Err(MultipartError::ParseContentType) + } + } else { + Err(MultipartError::NoContentType) + } + } +} + +impl Stream for Multipart { + type Item = MultipartItem; + type Error = MultipartError; + + fn poll(&mut self) -> Poll, Self::Error> { + if let Some(err) = self.error.take() { + Err(err) + } else if self.safety.current() { + self.inner.as_mut().unwrap().borrow_mut().poll(&self.safety) + } else { + Ok(Async::NotReady) + } + } +} + +impl InnerMultipart { + fn read_headers(payload: &mut PayloadBuffer) -> Poll { + match payload.read_until(b"\r\n\r\n")? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(bytes)) => { + let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS]; + match httparse::parse_headers(&bytes, &mut hdrs) { + Ok(httparse::Status::Complete((_, hdrs))) => { + // convert headers + let mut headers = HeaderMap::with_capacity(hdrs.len()); + for h in hdrs { + if let Ok(name) = HeaderName::try_from(h.name) { + if let Ok(value) = HeaderValue::try_from(h.value) { + headers.append(name, value); + } else { + return Err(ParseError::Header.into()); + } + } else { + return Err(ParseError::Header.into()); + } + } + Ok(Async::Ready(headers)) + } + Ok(httparse::Status::Partial) => Err(ParseError::Header.into()), + Err(err) => Err(ParseError::from(err).into()), + } + } + } + } + + fn read_boundary( + payload: &mut PayloadBuffer, + boundary: &str, + ) -> Poll { + // TODO: need to read epilogue + match payload.readline()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(chunk)) => { + if chunk.len() == boundary.len() + 4 + && &chunk[..2] == b"--" + && &chunk[2..boundary.len() + 2] == boundary.as_bytes() + { + Ok(Async::Ready(false)) + } else if chunk.len() == boundary.len() + 6 + && &chunk[..2] == b"--" + && &chunk[2..boundary.len() + 2] == boundary.as_bytes() + && &chunk[boundary.len() + 2..boundary.len() + 4] == b"--" + { + Ok(Async::Ready(true)) + } else { + Err(MultipartError::Boundary) + } + } + } + } + + fn skip_until_boundary( + payload: &mut PayloadBuffer, + boundary: &str, + ) -> Poll { + let mut eof = false; + loop { + match payload.readline()? { + Async::Ready(Some(chunk)) => { + if chunk.is_empty() { + //ValueError("Could not find starting boundary %r" + //% (self._boundary)) + } + if chunk.len() < boundary.len() { + continue; + } + if &chunk[..2] == b"--" + && &chunk[2..chunk.len() - 2] == boundary.as_bytes() + { + break; + } else { + if chunk.len() < boundary.len() + 2 { + continue; + } + let b: &[u8] = boundary.as_ref(); + if &chunk[..boundary.len()] == b + && &chunk[boundary.len()..boundary.len() + 2] == b"--" + { + eof = true; + break; + } + } + } + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(None) => return Err(MultipartError::Incomplete), + } + } + Ok(Async::Ready(eof)) + } + + fn poll(&mut self, safety: &Safety) -> Poll, MultipartError> { + if self.state == InnerState::Eof { + Ok(Async::Ready(None)) + } else { + // release field + loop { + // Nested multipart streams of fields has to be consumed + // before switching to next + if safety.current() { + let stop = match self.item { + InnerMultipartItem::Field(ref mut field) => { + match field.borrow_mut().poll(safety)? { + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(Some(_)) => continue, + Async::Ready(None) => true, + } + } + InnerMultipartItem::Multipart(ref mut multipart) => { + match multipart.borrow_mut().poll(safety)? { + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(Some(_)) => continue, + Async::Ready(None) => true, + } + } + _ => false, + }; + if stop { + self.item = InnerMultipartItem::None; + } + if let InnerMultipartItem::None = self.item { + break; + } + } + } + + let headers = if let Some(payload) = self.payload.get_mut(safety) { + match self.state { + // read until first boundary + InnerState::FirstBoundary => { + match InnerMultipart::skip_until_boundary( + payload, + &self.boundary, + )? { + Async::Ready(eof) => { + if eof { + self.state = InnerState::Eof; + return Ok(Async::Ready(None)); + } else { + self.state = InnerState::Headers; + } + } + Async::NotReady => return Ok(Async::NotReady), + } + } + // read boundary + InnerState::Boundary => { + match InnerMultipart::read_boundary(payload, &self.boundary)? { + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(eof) => { + if eof { + self.state = InnerState::Eof; + return Ok(Async::Ready(None)); + } else { + self.state = InnerState::Headers; + } + } + } + } + _ => (), + } + + // read field headers for next field + if self.state == InnerState::Headers { + if let Async::Ready(headers) = InnerMultipart::read_headers(payload)? + { + self.state = InnerState::Boundary; + headers + } else { + return Ok(Async::NotReady); + } + } else { + unreachable!() + } + } else { + log::debug!("NotReady: field is in flight"); + return Ok(Async::NotReady); + }; + + // content type + let mut mt = mime::APPLICATION_OCTET_STREAM; + if let Some(content_type) = headers.get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + if let Ok(ct) = content_type.parse::() { + mt = ct; + } + } + } + + self.state = InnerState::Boundary; + + // nested multipart stream + if mt.type_() == mime::MULTIPART { + let inner = if let Some(boundary) = mt.get_param(mime::BOUNDARY) { + Rc::new(RefCell::new(InnerMultipart { + payload: self.payload.clone(), + boundary: boundary.as_str().to_owned(), + state: InnerState::FirstBoundary, + item: InnerMultipartItem::None, + })) + } else { + return Err(MultipartError::Boundary); + }; + + self.item = InnerMultipartItem::Multipart(Rc::clone(&inner)); + + Ok(Async::Ready(Some(MultipartItem::Nested(Multipart { + safety: safety.clone(), + error: None, + inner: Some(inner), + })))) + } else { + let field = Rc::new(RefCell::new(InnerField::new( + self.payload.clone(), + self.boundary.clone(), + &headers, + )?)); + self.item = InnerMultipartItem::Field(Rc::clone(&field)); + + Ok(Async::Ready(Some(MultipartItem::Field(Field::new( + safety.clone(), + headers, + mt, + field, + ))))) + } + } + } +} + +impl Drop for InnerMultipart { + fn drop(&mut self) { + // InnerMultipartItem::Field has to be dropped first because of Safety. + self.item = InnerMultipartItem::None; + } +} + +/// A single field in a multipart stream +pub struct Field { + ct: mime::Mime, + headers: HeaderMap, + inner: Rc>, + safety: Safety, +} + +impl Field { + fn new( + safety: Safety, + headers: HeaderMap, + ct: mime::Mime, + inner: Rc>, + ) -> Self { + Field { + ct, + headers, + inner, + safety, + } + } + + /// Get a map of headers + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Get the content type of the field + pub fn content_type(&self) -> &mime::Mime { + &self.ct + } + + /// Get the content disposition of the field, if it exists + pub fn content_disposition(&self) -> Option { + // RFC 7578: 'Each part MUST contain a Content-Disposition header field + // where the disposition type is "form-data".' + if let Some(content_disposition) = self.headers.get(header::CONTENT_DISPOSITION) + { + ContentDisposition::from_raw(content_disposition).ok() + } else { + None + } + } +} + +impl Stream for Field { + type Item = Bytes; + type Error = MultipartError; + + fn poll(&mut self) -> Poll, Self::Error> { + if self.safety.current() { + self.inner.borrow_mut().poll(&self.safety) + } else { + Ok(Async::NotReady) + } + } +} + +impl fmt::Debug for Field { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "\nMultipartField: {}", self.ct)?; + writeln!(f, " boundary: {}", self.inner.borrow().boundary)?; + writeln!(f, " headers:")?; + for (key, val) in self.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +struct InnerField { + payload: Option, + boundary: String, + eof: bool, + length: Option, +} + +impl InnerField { + fn new( + payload: PayloadRef, + boundary: String, + headers: &HeaderMap, + ) -> Result { + let len = if let Some(len) = headers.get(header::CONTENT_LENGTH) { + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + Some(len) + } else { + return Err(PayloadError::Incomplete(None)); + } + } else { + return Err(PayloadError::Incomplete(None)); + } + } else { + None + }; + + Ok(InnerField { + boundary, + payload: Some(payload), + eof: false, + length: len, + }) + } + + /// Reads body part content chunk of the specified size. + /// The body part must has `Content-Length` header with proper value. + fn read_len( + payload: &mut PayloadBuffer, + size: &mut u64, + ) -> Poll, MultipartError> { + if *size == 0 { + Ok(Async::Ready(None)) + } else { + match payload.readany() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(None)) => Err(MultipartError::Incomplete), + Ok(Async::Ready(Some(mut chunk))) => { + let len = cmp::min(chunk.len() as u64, *size); + *size -= len; + let ch = chunk.split_to(len as usize); + if !chunk.is_empty() { + payload.unprocessed(chunk); + } + Ok(Async::Ready(Some(ch))) + } + Err(err) => Err(err.into()), + } + } + } + + /// Reads content chunk of body part with unknown length. + /// The `Content-Length` header for body part is not necessary. + fn read_stream( + payload: &mut PayloadBuffer, + boundary: &str, + ) -> Poll, MultipartError> { + match payload.read_until(b"\r")? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(mut chunk)) => { + if chunk.len() == 1 { + payload.unprocessed(chunk); + match payload.read_exact(boundary.len() + 4)? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(mut chunk)) => { + if &chunk[..2] == b"\r\n" + && &chunk[2..4] == b"--" + && &chunk[4..] == boundary.as_bytes() + { + payload.unprocessed(chunk); + Ok(Async::Ready(None)) + } else { + // \r might be part of data stream + let ch = chunk.split_to(1); + payload.unprocessed(chunk); + Ok(Async::Ready(Some(ch))) + } + } + } + } else { + let to = chunk.len() - 1; + let ch = chunk.split_to(to); + payload.unprocessed(chunk); + Ok(Async::Ready(Some(ch))) + } + } + } + } + + fn poll(&mut self, s: &Safety) -> Poll, MultipartError> { + if self.payload.is_none() { + return Ok(Async::Ready(None)); + } + + let result = if let Some(payload) = self.payload.as_ref().unwrap().get_mut(s) { + let res = if let Some(ref mut len) = self.length { + InnerField::read_len(payload, len)? + } else { + InnerField::read_stream(payload, &self.boundary)? + }; + + match res { + Async::NotReady => Async::NotReady, + Async::Ready(Some(bytes)) => Async::Ready(Some(bytes)), + Async::Ready(None) => { + self.eof = true; + match payload.readline()? { + Async::NotReady => Async::NotReady, + Async::Ready(None) => Async::Ready(None), + Async::Ready(Some(line)) => { + if line.as_ref() != b"\r\n" { + log::warn!("multipart field did not read all the data or it is malformed"); + } + Async::Ready(None) + } + } + } + } + } else { + Async::NotReady + }; + + if Async::Ready(None) == result { + self.payload.take(); + } + Ok(result) + } +} + +struct PayloadRef { + payload: Rc>, +} + +impl PayloadRef { + fn new(payload: PayloadBuffer) -> PayloadRef { + PayloadRef { + payload: Rc::new(payload.into()), + } + } + + fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<&'a mut PayloadBuffer> + where + 'a: 'b, + { + // Unsafe: Invariant is inforced by Safety Safety is used as ref counter, + // only top most ref can have mutable access to payload. + if s.current() { + let payload: &mut PayloadBuffer = unsafe { &mut *self.payload.get() }; + Some(payload) + } else { + None + } + } +} + +impl Clone for PayloadRef { + fn clone(&self) -> PayloadRef { + PayloadRef { + payload: Rc::clone(&self.payload), + } + } +} + +/// Counter. It tracks of number of clones of payloads and give access to +/// payload only to top most task panics if Safety get destroyed and it not top +/// most task. +#[derive(Debug)] +struct Safety { + task: Option, + level: usize, + payload: Rc>, +} + +impl Safety { + fn new() -> Safety { + let payload = Rc::new(PhantomData); + Safety { + task: None, + level: Rc::strong_count(&payload), + payload, + } + } + + fn current(&self) -> bool { + Rc::strong_count(&self.payload) == self.level + } +} + +impl Clone for Safety { + fn clone(&self) -> Safety { + let payload = Rc::clone(&self.payload); + Safety { + task: Some(current_task()), + level: Rc::strong_count(&payload), + payload, + } + } +} + +impl Drop for Safety { + fn drop(&mut self) { + // parent task is dead + if Rc::strong_count(&self.payload) != self.level { + panic!("Safety get dropped but it is not from top-most task"); + } + if let Some(task) = self.task.take() { + task.notify() + } + } +} + +/// Payload buffer +pub struct PayloadBuffer { + len: usize, + items: VecDeque, + stream: Box>, +} + +impl PayloadBuffer { + /// Create new `PayloadBuffer` instance + pub fn new(stream: S) -> Self + where + S: Stream + 'static, + { + PayloadBuffer { + len: 0, + items: VecDeque::new(), + stream: Box::new(stream), + } + } + + #[inline] + fn poll_stream(&mut self) -> Poll { + self.stream.poll().map(|res| match res { + Async::Ready(Some(data)) => { + self.len += data.len(); + self.items.push_back(data); + Async::Ready(true) + } + Async::Ready(None) => Async::Ready(false), + Async::NotReady => Async::NotReady, + }) + } + + /// Read first available chunk of bytes + #[inline] + pub fn readany(&mut self) -> Poll, PayloadError> { + if let Some(data) = self.items.pop_front() { + self.len -= data.len(); + Ok(Async::Ready(Some(data))) + } else { + match self.poll_stream()? { + Async::Ready(true) => self.readany(), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + } + + /// Read exact number of bytes + #[inline] + pub fn read_exact(&mut self, size: usize) -> Poll, PayloadError> { + if size <= self.len { + self.len -= size; + let mut chunk = self.items.pop_front().unwrap(); + if size < chunk.len() { + let buf = chunk.split_to(size); + self.items.push_front(chunk); + Ok(Async::Ready(Some(buf))) + } else if size == chunk.len() { + Ok(Async::Ready(Some(chunk))) + } else { + let mut buf = BytesMut::with_capacity(size); + buf.extend_from_slice(&chunk); + + while buf.len() < size { + let mut chunk = self.items.pop_front().unwrap(); + let rem = cmp::min(size - buf.len(), chunk.len()); + buf.extend_from_slice(&chunk.split_to(rem)); + if !chunk.is_empty() { + self.items.push_front(chunk); + } + } + Ok(Async::Ready(Some(buf.freeze()))) + } + } else { + match self.poll_stream()? { + Async::Ready(true) => self.read_exact(size), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + } + + /// Read until specified ending + pub fn read_until(&mut self, line: &[u8]) -> Poll, PayloadError> { + let mut idx = 0; + let mut num = 0; + let mut offset = 0; + let mut found = false; + let mut length = 0; + + for no in 0..self.items.len() { + { + let chunk = &self.items[no]; + for (pos, ch) in chunk.iter().enumerate() { + if *ch == line[idx] { + idx += 1; + if idx == line.len() { + num = no; + offset = pos + 1; + length += pos + 1; + found = true; + break; + } + } else { + idx = 0 + } + } + if !found { + length += chunk.len() + } + } + + if found { + let mut buf = BytesMut::with_capacity(length); + if num > 0 { + for _ in 0..num { + buf.extend_from_slice(&self.items.pop_front().unwrap()); + } + } + if offset > 0 { + let mut chunk = self.items.pop_front().unwrap(); + buf.extend_from_slice(&chunk.split_to(offset)); + if !chunk.is_empty() { + self.items.push_front(chunk) + } + } + self.len -= length; + return Ok(Async::Ready(Some(buf.freeze()))); + } + } + + match self.poll_stream()? { + Async::Ready(true) => self.read_until(line), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + + /// Read bytes until new line delimiter + pub fn readline(&mut self) -> Poll, PayloadError> { + self.read_until(b"\n") + } + + /// Put unprocessed data back to the buffer + pub fn unprocessed(&mut self, data: Bytes) { + self.len += data.len(); + self.items.push_front(data); + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use futures::unsync::mpsc; + + use super::*; + use crate::http::header::{DispositionParam, DispositionType}; + use crate::test::run_on; + + #[test] + fn test_boundary() { + let headers = HeaderMap::new(); + match Multipart::boundary(&headers) { + Err(MultipartError::NoContentType) => (), + _ => unreachable!("should not happen"), + } + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("test"), + ); + + match Multipart::boundary(&headers) { + Err(MultipartError::ParseContentType) => (), + _ => unreachable!("should not happen"), + } + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("multipart/mixed"), + ); + match Multipart::boundary(&headers) { + Err(MultipartError::Boundary) => (), + _ => unreachable!("should not happen"), + } + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"", + ), + ); + + assert_eq!( + Multipart::boundary(&headers).unwrap(), + "5c02368e880e436dab70ed54e1c58209" + ); + } + + fn create_stream() -> ( + mpsc::UnboundedSender>, + impl Stream, + ) { + let (tx, rx) = mpsc::unbounded(); + + (tx, rx.map_err(|_| panic!()).and_then(|res| res)) + } + + #[test] + fn test_multipart() { + run_on(|| { + let (sender, payload) = create_stream(); + + let bytes = Bytes::from( + "testasdadsad\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + test\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + data\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n", + ); + sender.unbounded_send(Ok(bytes)).unwrap(); + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"", + ), + ); + + let mut multipart = Multipart::new(&headers, payload); + match multipart.poll() { + Ok(Async::Ready(Some(item))) => match item { + MultipartItem::Field(mut field) => { + { + let cd = field.content_disposition().unwrap(); + assert_eq!(cd.disposition, DispositionType::FormData); + assert_eq!( + cd.parameters[0], + DispositionParam::Name("file".into()) + ); + } + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + match field.poll() { + Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "test"), + _ => unreachable!(), + } + match field.poll() { + Ok(Async::Ready(None)) => (), + _ => unreachable!(), + } + } + _ => unreachable!(), + }, + _ => unreachable!(), + } + + match multipart.poll() { + Ok(Async::Ready(Some(item))) => match item { + MultipartItem::Field(mut field) => { + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + match field.poll() { + Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"), + _ => unreachable!(), + } + match field.poll() { + Ok(Async::Ready(None)) => (), + _ => unreachable!(), + } + } + _ => unreachable!(), + }, + _ => unreachable!(), + } + + match multipart.poll() { + Ok(Async::Ready(None)) => (), + _ => unreachable!(), + } + }); + } + + #[test] + fn test_basic() { + run_on(|| { + let (_sender, payload) = create_stream(); + { + let mut payload = PayloadBuffer::new(payload); + assert_eq!(payload.len, 0); + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); + } + }); + } + + #[test] + fn test_eof() { + run_on(|| { + let (sender, payload) = create_stream(); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); + sender.unbounded_send(Ok(Bytes::from("data"))).unwrap(); + drop(sender); + + assert_eq!( + Async::Ready(Some(Bytes::from("data"))), + payload.readany().ok().unwrap() + ); + assert_eq!(payload.len, 0); + assert_eq!(Async::Ready(None), payload.readany().ok().unwrap()); + }); + } + + #[test] + fn test_err() { + run_on(|| { + let (sender, payload) = create_stream(); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); + + sender + .unbounded_send(Err(PayloadError::Incomplete(None))) + .unwrap(); + payload.readany().err().unwrap(); + }); + } + + #[test] + fn test_readany() { + run_on(|| { + let (sender, payload) = create_stream(); + let mut payload = PayloadBuffer::new(payload); + + sender.unbounded_send(Ok(Bytes::from("line1"))).unwrap(); + sender.unbounded_send(Ok(Bytes::from("line2"))).unwrap(); + + assert_eq!( + Async::Ready(Some(Bytes::from("line1"))), + payload.readany().ok().unwrap() + ); + assert_eq!(payload.len, 0); + + assert_eq!( + Async::Ready(Some(Bytes::from("line2"))), + payload.readany().ok().unwrap() + ); + assert_eq!(payload.len, 0); + }); + } + + #[test] + fn test_readexactly() { + run_on(|| { + let (sender, payload) = create_stream(); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.read_exact(2).ok().unwrap()); + + sender.unbounded_send(Ok(Bytes::from("line1"))).unwrap(); + sender.unbounded_send(Ok(Bytes::from("line2"))).unwrap(); + + assert_eq!( + Async::Ready(Some(Bytes::from_static(b"li"))), + payload.read_exact(2).ok().unwrap() + ); + assert_eq!(payload.len, 3); + + assert_eq!( + Async::Ready(Some(Bytes::from_static(b"ne1l"))), + payload.read_exact(4).ok().unwrap() + ); + assert_eq!(payload.len, 4); + + sender + .unbounded_send(Err(PayloadError::Incomplete(None))) + .unwrap(); + payload.read_exact(10).err().unwrap(); + }); + } + + #[test] + fn test_readuntil() { + run_on(|| { + let (sender, payload) = create_stream(); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.read_until(b"ne").ok().unwrap()); + + sender.unbounded_send(Ok(Bytes::from("line1"))).unwrap(); + sender.unbounded_send(Ok(Bytes::from("line2"))).unwrap(); + + assert_eq!( + Async::Ready(Some(Bytes::from("line"))), + payload.read_until(b"ne").ok().unwrap() + ); + assert_eq!(payload.len, 1); + + assert_eq!( + Async::Ready(Some(Bytes::from("1line2"))), + payload.read_until(b"2").ok().unwrap() + ); + assert_eq!(payload.len, 0); + + sender + .unbounded_send(Err(PayloadError::Incomplete(None))) + .unwrap(); + payload.read_until(b"b").err().unwrap(); + }); + } +}