From c2978a6eead1f0c0fa7e55dac80ef87578c78a72 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 6 Nov 2017 01:24:49 -0800 Subject: [PATCH] add content encoding decompression --- CHANGES.md | 8 + Cargo.toml | 2 + examples/basic.rs | 7 +- src/h1.rs | 60 +++++-- src/h2.rs | 5 +- src/httpresponse.rs | 46 +++++- src/lib.rs | 2 + src/multipart.rs | 2 +- src/payload.rs | 382 +++++++++++++++++++++++++++++++++++--------- 9 files changed, 421 insertions(+), 93 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 8821796a..5ee633eb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,13 @@ # Changes +## 0.3.0 (2017-xx-xx) + +* HTTP/2 Support + +* Content compression/decompression + + ## 0.2.1 (2017-11-03) * Allow to start tls server with `HttpServer::serve_tls` @@ -9,6 +16,7 @@ * Add conversion impl from `HttpResponse` and `BinaryBody` to a `Frame` + ## 0.2.0 (2017-10-30) * Do not use `http::Uri` as it can not parse some valid paths diff --git a/Cargo.toml b/Cargo.toml index 41a3568b..1efbe1ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,8 @@ cookie = { version="0.10", features=["percent-encode"] } regex = "0.2" sha1 = "0.2" url = "1.5" +flate2 = "0.2" +brotli2 = "0.3" percent-encoding = "1.0" # tokio diff --git a/examples/basic.rs b/examples/basic.rs index 0a6aecb5..991b75dd 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -8,8 +8,13 @@ use actix_web::*; use futures::stream::{once, Once}; /// somple handle -fn index(req: &mut HttpRequest, _payload: Payload, state: &()) -> HttpResponse { +fn index(req: &mut HttpRequest, mut _payload: Payload, state: &()) -> HttpResponse { println!("{:?}", req); + if let Ok(ch) = _payload.readany() { + if let futures::Async::Ready(Some(d)) = ch { + println!("{}", String::from_utf8_lossy(d.0.as_ref())); + } + } httpcodes::HTTPOk.into() } diff --git a/src/h1.rs b/src/h1.rs index 5d478f01..33a5313d 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -7,7 +7,7 @@ use std::collections::VecDeque; use actix::Arbiter; use httparse; use http::{Method, Version, HttpTryFrom, HeaderMap}; -use http::header::{self, HeaderName, HeaderValue}; +use http::header::{self, HeaderName, HeaderValue, CONTENT_ENCODING}; use bytes::{Bytes, BytesMut, BufMut}; use futures::{Future, Poll, Async}; use tokio_io::{AsyncRead, AsyncWrite}; @@ -17,10 +17,12 @@ use percent_encoding; use task::Task; use channel::HttpHandler; use error::ParseError; +use h1writer::H1Writer; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; -use payload::{Payload, PayloadError, PayloadSender}; -use h1writer::H1Writer; +use httpresponse::ContentEncoding; +use payload::{Payload, PayloadError, PayloadSender, + PayloadWriter, EncodedPayload, DEFAULT_BUFFER_SIZE}; const KEEPALIVE_PERIOD: u64 = 15; // seconds const INIT_BUFFER_SIZE: usize = 8192; @@ -284,10 +286,25 @@ enum Decoding { } struct PayloadInfo { - tx: PayloadSender, + tx: PayloadInfoItem, decoder: Decoder, } +enum PayloadInfoItem { + Sender(PayloadSender), + Encoding(EncodedPayload), +} + +impl PayloadInfo { + + fn as_mut(&mut self) -> &mut PayloadWriter { + match self.tx { + PayloadInfoItem::Sender(ref mut sender) => sender, + PayloadInfoItem::Encoding(ref mut enc) => enc, + } + } +} + #[derive(Debug)] enum ReaderError { Disconnect, @@ -313,21 +330,21 @@ impl Reader { fn decode(&mut self, buf: &mut BytesMut) -> std::result::Result { if let Some(ref mut payload) = self.payload { - if payload.tx.maybe_paused() { + if payload.as_mut().capacity() > DEFAULT_BUFFER_SIZE { return Ok(Decoding::Paused) } loop { match payload.decoder.decode(buf) { Ok(Async::Ready(Some(bytes))) => { - payload.tx.feed_data(bytes) + payload.as_mut().feed_data(bytes) }, Ok(Async::Ready(None)) => { - payload.tx.feed_eof(); + payload.as_mut().feed_eof(); return Ok(Decoding::Ready) }, Ok(Async::NotReady) => return Ok(Decoding::NotReady), Err(err) => { - payload.tx.set_error(err.into()); + payload.as_mut().set_error(err.into()); return Err(ReaderError::Payload) } } @@ -351,7 +368,7 @@ impl Reader { match self.read_from_io(io, buf) { Ok(Async::Ready(0)) => { if let Some(ref mut payload) = self.payload { - payload.tx.set_error(PayloadError::Incomplete); + payload.as_mut().set_error(PayloadError::Incomplete); } // http channel should not deal with payload errors return Err(ReaderError::Payload) @@ -362,7 +379,7 @@ impl Reader { Ok(Async::NotReady) => break, Err(err) => { if let Some(ref mut payload) = self.payload { - payload.tx.set_error(err.into()); + payload.as_mut().set_error(err.into()); } // http channel should not deal with payload errors return Err(ReaderError::Payload) @@ -377,6 +394,22 @@ impl Reader { Message::Http1(msg, decoder) => { let payload = if let Some(decoder) = decoder { let (tx, rx) = Payload::new(false); + + let enc = if let Some(enc) = msg.headers().get(CONTENT_ENCODING) { + if let Ok(enc) = enc.to_str() { + ContentEncoding::from(enc) + } else { + ContentEncoding::Auto + } + } else { + ContentEncoding::Auto + }; + + let tx = match enc { + ContentEncoding::Auto => PayloadInfoItem::Sender(tx), + _ => PayloadInfoItem::Encoding(EncodedPayload::new(tx, enc)), + }; + let payload = PayloadInfo { tx: tx, decoder: decoder, @@ -396,7 +429,8 @@ impl Reader { Ok(Async::Ready(0)) => { trace!("parse eof"); if let Some(ref mut payload) = self.payload { - payload.tx.set_error(PayloadError::Incomplete); + payload.as_mut().set_error( + PayloadError::Incomplete); } // http channel should deal with payload errors return Err(ReaderError::Payload) @@ -407,7 +441,7 @@ impl Reader { Ok(Async::NotReady) => break, Err(err) => { if let Some(ref mut payload) = self.payload { - payload.tx.set_error(err.into()); + payload.as_mut().set_error(err.into()); } // http channel should deal with payload errors return Err(ReaderError::Payload) @@ -964,7 +998,7 @@ mod tests { ReaderError::Error(_) => (), _ => panic!("Parse error expected"), }, - val => { + _ => { panic!("Error expected") } }} diff --git a/src/h2.rs b/src/h2.rs index 73235e98..51f03ef8 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -15,15 +15,14 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tokio_core::reactor::Timeout; use task::Task; +use h2writer::H2Writer; use channel::HttpHandler; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; -use payload::{Payload, PayloadError, PayloadSender}; -use h2writer::H2Writer; +use payload::{Payload, PayloadError, PayloadSender, PayloadWriter}; const KEEPALIVE_PERIOD: u64 = 15; // seconds - pub(crate) struct Http2 where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static { diff --git a/src/httpresponse.rs b/src/httpresponse.rs index 46074e89..97cff474 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -23,6 +23,33 @@ pub enum ConnectionType { Upgrade, } +/// Represents various types of connection +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum ContentEncoding { + /// Auto + Auto, + /// Brotli + Br, + /// Deflate + Deflate, + /// Gzip + Gzip, + /// Identity + Identity, +} + +impl<'a> From<&'a str> for ContentEncoding { + fn from(s: &'a str) -> ContentEncoding { + match s.trim().to_lowercase().as_ref() { + "br" => ContentEncoding::Br, + "gzip" => ContentEncoding::Gzip, + "deflate" => ContentEncoding::Deflate, + "identity" => ContentEncoding::Identity, + _ => ContentEncoding::Auto, + } + } +} + #[derive(Debug)] /// An HTTP Response pub struct HttpResponse { @@ -32,6 +59,7 @@ pub struct HttpResponse { reason: Option<&'static str>, body: Body, chunked: bool, + encoding: ContentEncoding, connection_type: Option, error: Option>, } @@ -56,7 +84,7 @@ impl HttpResponse { reason: None, body: body, chunked: false, - // compression: None, + encoding: ContentEncoding::Auto, connection_type: None, error: None, } @@ -72,7 +100,7 @@ impl HttpResponse { reason: None, body: Body::from_slice(error.description().as_ref()), chunked: false, - // compression: None, + encoding: ContentEncoding::Auto, connection_type: None, error: Some(Box::new(error)), } @@ -210,6 +238,7 @@ struct Parts { status: StatusCode, reason: Option<&'static str>, chunked: bool, + encoding: ContentEncoding, connection_type: Option, cookies: CookieJar, } @@ -222,6 +251,7 @@ impl Parts { status: status, reason: None, chunked: false, + encoding: ContentEncoding::Auto, connection_type: None, cookies: CookieJar::new(), } @@ -287,6 +317,17 @@ impl HttpResponseBuilder { self } + /// Set content encoding. + /// + /// By default `ContentEncoding::Auto` is used, which automatically + /// determine content encoding based on request `Accept-Encoding` headers. + pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self { + if let Some(parts) = parts(&mut self.parts, &self.err) { + parts.encoding = enc; + } + self + } + /// Set connection type pub fn connection_type(&mut self, conn: ConnectionType) -> &mut Self { if let Some(parts) = parts(&mut self.parts, &self.err) { @@ -384,6 +425,7 @@ impl HttpResponseBuilder { reason: parts.reason, body: body.into(), chunked: parts.chunked, + encoding: parts.encoding, connection_type: parts.connection_type, error: None, }) diff --git a/src/lib.rs b/src/lib.rs index a4f87599..d1e6450f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,8 @@ extern crate http_range; extern crate mime; extern crate mime_guess; extern crate url; +extern crate flate2; +extern crate brotli2; extern crate percent_encoding; extern crate actix; extern crate h2 as http2; diff --git a/src/multipart.rs b/src/multipart.rs index d0a6c128..be272777 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -707,7 +707,7 @@ mod tests { use bytes::Bytes; use futures::future::{lazy, result}; use tokio_core::reactor::Core; - use payload::Payload; + use payload::{Payload, PayloadWriter}; #[test] fn test_boundary() { diff --git a/src/payload.rs b/src/payload.rs index 86398494..0f539c8c 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -1,17 +1,21 @@ -use std::{fmt, cmp}; +use std::{io, fmt, cmp}; use std::rc::{Rc, Weak}; use std::cell::RefCell; use std::collections::VecDeque; use std::error::Error; -use std::io::Error as IoError; -use bytes::{Bytes, BytesMut}; +use std::io::{Read, Write, Error as IoError}; +use bytes::{Bytes, BytesMut, BufMut, Writer}; use http2::Error as Http2Error; use futures::{Async, Poll, Stream}; use futures::task::{Task, current as current_task}; +use flate2::{FlateReadExt, Flush, Decompress, Status as DecompressStatus}; +use flate2::read::GzDecoder; +use brotli2::write::BrotliDecoder; use actix::ResponseType; +use httpresponse::ContentEncoding; -const DEFAULT_BUFFER_SIZE: usize = 65_536; // max buffer size 64k +pub(crate) const DEFAULT_BUFFER_SIZE: usize = 65_536; // max buffer size 64k /// Just Bytes object pub struct PayloadItem(pub Bytes); @@ -21,11 +25,19 @@ impl ResponseType for PayloadItem { type Error = (); } +impl fmt::Debug for PayloadItem { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self.0, f) + } +} + #[derive(Debug)] /// A set of error that can occur during payload parsing. pub enum PayloadError { /// A payload reached EOF, but is not complete. Incomplete, + /// Content encoding stream corruption + EncodingCorrupted, /// Parse error ParseError(IoError), /// Http2 error @@ -45,6 +57,7 @@ impl Error for PayloadError { fn description(&self) -> &str { match *self { PayloadError::Incomplete => "A payload reached EOF, but is not complete.", + PayloadError::EncodingCorrupted => "Can not decode content-encoding.", PayloadError::ParseError(ref e) => e.description(), PayloadError::Http2(ref e) => e.description(), } @@ -80,12 +93,6 @@ impl Payload { (PayloadSender{inner: Rc::downgrade(&shared)}, Payload{inner: shared}) } - /// Indicates paused state of the payload. If payload data is not consumed - /// it get paused. Max size of not consumed data is 64k - pub fn paused(&self) -> bool { - self.inner.borrow().paused() - } - /// Indicates EOF of payload pub fn eof(&self) -> bool { self.inner.borrow().eof() @@ -146,7 +153,6 @@ impl Payload { } } - impl Stream for Payload { type Item = PayloadItem; type Error = PayloadError; @@ -156,50 +162,41 @@ impl Stream for Payload { } } +pub(crate) trait PayloadWriter { + fn set_error(&mut self, err: PayloadError); + + fn feed_eof(&mut self); + + fn feed_data(&mut self, data: Bytes); + + fn capacity(&self) -> usize; +} + pub(crate) struct PayloadSender { inner: Weak>, } -impl PayloadSender { - pub fn set_error(&mut self, err: PayloadError) { +impl PayloadWriter for PayloadSender { + + fn set_error(&mut self, err: PayloadError) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().set_error(err) } } - pub fn feed_eof(&mut self) { + fn feed_eof(&mut self) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().feed_eof() } } - pub fn feed_data(&mut self, data: Bytes) { + fn feed_data(&mut self, data: Bytes) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().feed_data(data) } } - pub fn maybe_paused(&self) -> bool { - match self.inner.upgrade() { - Some(shared) => { - let inner = shared.borrow(); - if inner.paused() && inner.len() < inner.buffer_size() { - drop(inner); - shared.borrow_mut().resume(); - false - } else if !inner.paused() && inner.len() > inner.buffer_size() { - drop(inner); - shared.borrow_mut().pause(); - true - } else { - inner.paused() - } - } - None => false, - } - } - - pub fn capacity(&self) -> usize { + fn capacity(&self) -> usize { if let Some(shared) = self.inner.upgrade() { shared.borrow().capacity() } else { @@ -208,11 +205,286 @@ impl PayloadSender { } } +enum Decoder { + Zlib(Decompress), + Gzip(Option>), + Br(Rc>, BrotliDecoder), + Identity, +} + +#[derive(Debug)] +struct Wrapper { + buf: BytesMut +} + +impl io::Read for Wrapper { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = cmp::min(buf.len(), self.buf.len()); + buf[..len].copy_from_slice(&self.buf[..len]); + self.buf.split_to(len); + Ok(len) + } +} + +#[derive(Debug)] +struct WrapperRc { + buf: Rc>, +} + +impl io::Write for WrapperRc { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.buf.borrow_mut().extend(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +pub(crate) struct EncodedPayload { + inner: PayloadSender, + decoder: Decoder, + dst: Writer, + buffer: BytesMut, + error: bool, +} + +impl EncodedPayload { + pub fn new(inner: PayloadSender, enc: ContentEncoding) -> EncodedPayload { + let dec = match enc { + ContentEncoding::Deflate => Decoder::Zlib(Decompress::new(false)), + ContentEncoding::Gzip => Decoder::Gzip(None), + ContentEncoding::Br => { + let buf = Rc::new(RefCell::new(BytesMut::new())); + let buf2 = Rc::clone(&buf); + Decoder::Br(buf, BrotliDecoder::new(WrapperRc{buf: buf2})) + } + _ => Decoder::Identity, + }; + EncodedPayload { + inner: inner, + decoder: dec, + error: false, + dst: BytesMut::new().writer(), + buffer: BytesMut::new(), + } + } +} + +impl PayloadWriter for EncodedPayload { + + fn set_error(&mut self, err: PayloadError) { + self.inner.set_error(err) + } + + fn feed_eof(&mut self) { + if self.error { + return + } + let err = match self.decoder { + Decoder::Br(ref mut buf, ref mut decoder) => { + match decoder.flush() { + Ok(_) => { + let b = buf.borrow_mut().take().freeze(); + if !b.is_empty() { + self.inner.feed_data(b); + } + self.inner.feed_eof(); + return + }, + Err(err) => Some(err), + } + } + + Decoder::Gzip(ref mut decoder) => { + if decoder.is_none() { + self.inner.feed_eof(); + return + } + loop { + let len = self.dst.get_ref().len(); + let len_buf = decoder.as_mut().unwrap().get_mut().buf.len(); + + if len < len_buf * 2 { + self.dst.get_mut().reserve(len_buf * 2 - len); + unsafe{self.dst.get_mut().set_len(len_buf * 2)}; + } + match decoder.as_mut().unwrap().read(&mut self.dst.get_mut()) { + Ok(n) => { + if n == 0 { + self.inner.feed_eof(); + return + } else { + self.inner.feed_data(self.dst.get_mut().split_to(n).freeze()); + } + } + Err(err) => break Some(err) + } + } + } + Decoder::Zlib(ref mut decoder) => { + let len = self.dst.get_ref().len(); + if len < self.buffer.len() * 2 { + self.dst.get_mut().reserve(self.buffer.len() * 2 - len); + unsafe{self.dst.get_mut().set_len(self.buffer.len() * 2)}; + } + + let len = self.dst.get_ref().len(); + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let ret = decoder.decompress( + self.buffer.as_ref(), &mut self.dst.get_mut()[..len], Flush::Finish); + let read = (decoder.total_out() - before_out) as usize; + let consumed = (decoder.total_in() - before_in) as usize; + + let ch = self.dst.get_mut().split_to(read).freeze(); + if !ch.is_empty() { + self.inner.feed_data(ch); + } + self.buffer.split_to(consumed); + + match ret { + Ok(DecompressStatus::Ok) | Ok(DecompressStatus::StreamEnd) => { + self.inner.feed_eof(); + return + }, + _ => None, + } + }, + Decoder::Identity => { + self.inner.feed_eof(); + return + } + }; + + self.error = true; + self.decoder = Decoder::Identity; + if let Some(err) = err { + self.set_error(PayloadError::ParseError(err)); + } else { + self.set_error(PayloadError::Incomplete); + } + } + + fn feed_data(&mut self, data: Bytes) { + if self.error { + return + } + match self.decoder { + Decoder::Br(ref mut buf, ref mut decoder) => { + match decoder.write(&data) { + Ok(_) => { + let b = buf.borrow_mut().take().freeze(); + if !b.is_empty() { + self.inner.feed_data(b); + } + return + }, + Err(err) => { + trace!("Error decoding br encoding: {}", err); + }, + } + } + + Decoder::Gzip(ref mut decoder) => { + if decoder.is_none() { + let mut buf = BytesMut::new(); + buf.extend(data); + *decoder = Some(Wrapper{buf: buf}.gz_decode().unwrap()); + } else { + decoder.as_mut().unwrap().get_mut().buf.extend(data); + } + + loop { + let len_buf = decoder.as_mut().unwrap().get_mut().buf.len(); + if len_buf == 0 { + return + } + + let len = self.dst.get_ref().len(); + if len < len_buf * 2 { + self.dst.get_mut().reserve(len_buf * 2 - len); + unsafe{self.dst.get_mut().set_len(len_buf * 2)}; + } + match decoder.as_mut().unwrap().read(&mut self.dst.get_mut()) { + Ok(n) => { + if n == 0 { + return + } else { + self.inner.feed_data(self.dst.get_mut().split_to(n).freeze()); + } + } + Err(_) => break + } + } + } + + Decoder::Zlib(ref mut decoder) => { + self.buffer.extend(data); + + loop { + if self.buffer.is_empty() { + return + } + + let ret = { + let len = self.dst.get_ref().len(); + if len < self.buffer.len() * 2 { + self.dst.get_mut().reserve(self.buffer.len() * 2 - len); + unsafe{self.dst.get_mut().set_len(self.buffer.len() * 2)}; + } + let before_out = decoder.total_out(); + let before_in = decoder.total_in(); + + let len = self.dst.get_ref().len(); + let ret = decoder.decompress( + self.buffer.as_ref(), &mut self.dst.get_mut()[..len], Flush::None); + let read = (decoder.total_out() - before_out) as usize; + let consumed = (decoder.total_in() - before_in) as usize; + + let ch = self.dst.get_mut().split_to(read).freeze(); + if !ch.is_empty() { + self.inner.feed_data(ch); + } + if self.buffer.len() > consumed { + self.buffer.split_to(consumed); + } + ret + }; + + match ret { + Ok(DecompressStatus::Ok) => continue, + _ => break, + } + } + } + Decoder::Identity => { + self.inner.feed_data(data); + return + } + }; + + self.error = true; + self.decoder = Decoder::Identity; + self.set_error(PayloadError::EncodingCorrupted); + } + + fn capacity(&self) -> usize { + match self.decoder { + Decoder::Br(ref buf, _) => { + buf.borrow().len() + self.inner.capacity() + } + _ => { + self.inner.capacity() + } + } + } +} + #[derive(Debug)] struct Inner { len: usize, eof: bool, - paused: bool, err: Option, task: Option, items: VecDeque, @@ -225,7 +497,6 @@ impl Inner { Inner { len: 0, eof: eof, - paused: false, err: None, task: None, items: VecDeque::new(), @@ -233,18 +504,6 @@ impl Inner { } } - fn paused(&self) -> bool { - self.paused - } - - fn pause(&mut self) { - self.paused = true; - } - - fn resume(&mut self) { - self.paused = false; - } - fn set_error(&mut self, err: PayloadError) { self.err = Some(err); if let Some(task) = self.task.take() { @@ -600,29 +859,6 @@ mod tests { })).unwrap(); } - #[test] - fn test_pause() { - Core::new().unwrap().run(lazy(|| { - let (mut sender, mut payload) = Payload::new(false); - - assert!(!payload.paused()); - assert!(!sender.maybe_paused()); - - for _ in 0..DEFAULT_BUFFER_SIZE+1 { - sender.feed_data(Bytes::from("1")); - } - assert!(sender.maybe_paused()); - assert!(payload.paused()); - - payload.readexactly(10).unwrap(); - assert!(!sender.maybe_paused()); - assert!(!payload.paused()); - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } - #[test] fn test_unread_data() { Core::new().unwrap().run(lazy(|| {