From e558414867c6b29dd0ba11cbaea23f1f7d445102 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 8 Nov 2017 16:44:23 -0800 Subject: [PATCH] add response content encoding --- Cargo.toml | 4 +- README.md | 1 + examples/basic.rs | 2 +- src/body.rs | 8 + src/date.rs | 11 +- src/encoding.rs | 532 +++++++++++++++++++++++++++++++++++++++++-- src/h1writer.rs | 265 ++++++--------------- src/h2writer.rs | 172 ++++---------- tests/test_server.rs | 1 - 9 files changed, 635 insertions(+), 361 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 763d646da..5d6b90ea5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.3.0" authors = ["Nikolay Kim "] description = "Actix web framework" readme = "README.md" -keywords = ["actix", "actor", "http", "web", "async", "tokio", "futures", "web"] +keywords = ["http", "web", "async", "tokio", "futures"] homepage = "https://github.com/actix/actix-web" repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-web/" @@ -16,7 +16,7 @@ build = "build.rs" [badges] travis-ci = { repository = "actix/actix-web", branch = "master" } -appveyor = { repository = "fafhrd91/actix-web-hdy9d" } +appveyor = { repository = "actix/actix-web" } codecov = { repository = "actix/actix-web", branch = "master", service = "github" } [lib] diff --git a/README.md b/README.md index 9098cf0a0..15d7b2f53 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Actix web is licensed under the [Apache-2.0 license](http://opensource.org/licen * Streaming and pipelining * Keep-alive and slow requests handling * [WebSockets](https://actix.github.io/actix-web/actix_web/ws/index.html) + * Transparent content compression/decompression * Configurable request routing * Multipart streams * Middlewares diff --git a/examples/basic.rs b/examples/basic.rs index 991b75ddc..191ea683b 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -30,7 +30,7 @@ fn index_async(req: &mut HttpRequest, _payload: Payload, state: &()) -> Once HandlerResult { diff --git a/src/body.rs b/src/body.rs index 3454a6cfc..8b1d18e7e 100644 --- a/src/body.rs +++ b/src/body.rs @@ -50,6 +50,14 @@ impl Body { } } + /// Is this binary body. + pub fn is_binary(&self) -> bool { + match *self { + Body::Binary(_) => true, + _ => false + } + } + /// Create body from slice (copy) pub fn from_slice(s: &[u8]) -> Body { Body::Binary(BinaryBody::Bytes(Bytes::from(s))) diff --git a/src/date.rs b/src/date.rs index a5f456f96..27ae1db22 100644 --- a/src/date.rs +++ b/src/date.rs @@ -1,21 +1,20 @@ use std::cell::RefCell; use std::fmt::{self, Write}; use std::str; - use time::{self, Duration}; -use bytes::BytesMut; // "Sun, 06 Nov 1994 08:49:37 GMT".len() pub const DATE_VALUE_LENGTH: usize = 29; -pub fn extend(dst: &mut BytesMut) { +pub fn extend(dst: &mut [u8]) { CACHED.with(|cache| { let mut cache = cache.borrow_mut(); let now = time::get_time(); if now > cache.next_update { cache.update(now); } - dst.extend_from_slice(cache.buffer()); + + dst.copy_from_slice(cache.buffer()); }) } @@ -61,9 +60,9 @@ fn test_date_len() { #[test] fn test_date() { - let mut buf1 = BytesMut::new(); + let mut buf1 = [0u8; 29]; extend(&mut buf1); - let mut buf2 = BytesMut::new(); + let mut buf2 = [0u8; 29]; extend(&mut buf2); assert_eq!(buf1, buf2); } diff --git a/src/encoding.rs b/src/encoding.rs index e4545b418..2d6d95b10 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -1,15 +1,24 @@ -use std::{io, cmp}; +use std::{io, cmp, mem}; use std::io::{Read, Write}; +use std::fmt::Write as FmtWrite; +use std::str::FromStr; -use http::header::{HeaderMap, CONTENT_ENCODING}; +use http::Version; +use http::header::{HeaderMap, HeaderValue, + ACCEPT_ENCODING, CONNECTION, + CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; +use flate2::Compression; use flate2::read::{GzDecoder}; -use flate2::write::{DeflateDecoder}; -use brotli2::write::BrotliDecoder; +use flate2::write::{GzEncoder, DeflateDecoder, DeflateEncoder}; +use brotli2::write::{BrotliDecoder, BrotliEncoder}; use bytes::{Bytes, BytesMut, BufMut, Writer}; +use body::Body; +use httprequest::HttpRequest; +use httpresponse::HttpResponse; use payload::{PayloadSender, PayloadWriter, PayloadError}; -/// Represents various types of connection +/// Represents supported types of content encodings #[derive(Copy, Clone, PartialEq, Debug)] pub enum ContentEncoding { /// Automatically select encoding based on encoding negotiation @@ -24,6 +33,17 @@ pub enum ContentEncoding { Identity, } +impl ContentEncoding { + fn as_str(&self) -> &'static str { + match *self { + ContentEncoding::Br => "br", + ContentEncoding::Gzip => "gzip", + ContentEncoding::Deflate => "deflate", + ContentEncoding::Identity | ContentEncoding::Auto => "identity", + } + } +} + impl<'a> From<&'a str> for ContentEncoding { fn from(s: &'a str) -> ContentEncoding { match s.trim().to_lowercase().as_ref() { @@ -136,6 +156,7 @@ impl io::Write for BytesWriter { } } +/// Payload wrapper with content decompression support pub(crate) struct EncodedPayload { inner: PayloadSender, decoder: Decoder, @@ -246,14 +267,12 @@ impl PayloadWriter for EncodedPayload { } match self.decoder { Decoder::Br(ref mut decoder) => { - if decoder.write(&data).is_ok() { - if decoder.flush().is_ok() { - let b = decoder.get_mut().buf.take().freeze(); - if !b.is_empty() { - self.inner.feed_data(b); - } - return + if decoder.write(&data).is_ok() && decoder.flush().is_ok() { + let b = decoder.get_mut().buf.take().freeze(); + if !b.is_empty() { + self.inner.feed_data(b); } + return } trace!("Error decoding br encoding"); } @@ -292,14 +311,12 @@ impl PayloadWriter for EncodedPayload { } Decoder::Deflate(ref mut decoder) => { - if decoder.write(&data).is_ok() { - if decoder.flush().is_ok() { - let b = decoder.get_mut().buf.take().freeze(); - if !b.is_empty() { - self.inner.feed_data(b); - } - return + if decoder.write(&data).is_ok() && decoder.flush().is_ok() { + let b = decoder.get_mut().buf.take().freeze(); + if !b.is_empty() { + self.inner.feed_data(b); } + return } trace!("Error decoding deflate encoding"); } @@ -318,3 +335,480 @@ impl PayloadWriter for EncodedPayload { self.inner.capacity() } } + +pub(crate) struct PayloadEncoder(ContentEncoder); + +impl Default for PayloadEncoder { + fn default() -> PayloadEncoder { + PayloadEncoder(ContentEncoder::Identity(TransferEncoding::eof())) + } +} + +impl PayloadEncoder { + + pub fn new(req: &HttpRequest, resp: &mut HttpResponse) -> PayloadEncoder { + let version = resp.version().unwrap_or_else(|| req.version()); + let body = resp.replace_body(Body::Empty); + let has_body = if let Body::Empty = body { false } else { true }; + + // Enable content encoding only if response does not contain Content-Encoding header + let encoding = if has_body && !resp.headers.contains_key(CONTENT_ENCODING) { + let encoding = match *resp.content_encoding() { + ContentEncoding::Auto => { + // negotiate content-encoding + if let Some(val) = req.headers().get(ACCEPT_ENCODING) { + if let Ok(enc) = val.to_str() { + AcceptEncoding::parse(enc) + } else { + ContentEncoding::Identity + } + } else { + ContentEncoding::Identity + } + } + encoding => encoding, + }; + resp.headers.insert(CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str())); + encoding + } else { + ContentEncoding::Identity + }; + + // in general case it is very expensive to get compressed payload length, + // just switch to chunked encoding + let compression = encoding != ContentEncoding::Identity; + + let transfer = match body { + Body::Empty => { + if resp.chunked() { + error!("Chunked transfer is enabled but body is set to Empty"); + } + resp.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); + resp.headers.remove(TRANSFER_ENCODING); + TransferEncoding::length(0) + }, + Body::Length(n) => { + if resp.chunked() { + error!("Chunked transfer is enabled but body with specific length is specified"); + } + if compression { + resp.headers.remove(CONTENT_LENGTH); + if version == Version::HTTP_2 { + resp.headers.remove(TRANSFER_ENCODING); + TransferEncoding::eof() + } else { + resp.headers.insert( + TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + TransferEncoding::chunked() + } + } else { + resp.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); + resp.headers.remove(TRANSFER_ENCODING); + TransferEncoding::length(n) + } + }, + Body::Binary(ref bytes) => { + if compression { + resp.headers.remove(CONTENT_LENGTH); + if version == Version::HTTP_2 { + resp.headers.remove(TRANSFER_ENCODING); + TransferEncoding::eof() + } else { + resp.headers.insert( + TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + TransferEncoding::chunked() + } + } else { + resp.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); + resp.headers.remove(TRANSFER_ENCODING); + TransferEncoding::length(bytes.len() as u64) + } + } + Body::Streaming => { + if resp.chunked() { + resp.headers.remove(CONTENT_LENGTH); + if version != Version::HTTP_11 { + error!("Chunked transfer encoding is forbidden for {:?}", version); + } + if version == Version::HTTP_2 { + resp.headers.remove(TRANSFER_ENCODING); + TransferEncoding::eof() + } else { + resp.headers.insert( + TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + TransferEncoding::chunked() + } + } else { + TransferEncoding::eof() + } + } + Body::Upgrade => { + if version == Version::HTTP_2 { + error!("Connection upgrade is forbidden for HTTP/2"); + } else { + resp.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); + } + TransferEncoding::eof() + } + }; + resp.replace_body(body); + + PayloadEncoder( + match encoding { + ContentEncoding::Deflate => ContentEncoder::Deflate( + DeflateEncoder::new(transfer, Compression::Default)), + ContentEncoding::Gzip => ContentEncoder::Gzip( + GzEncoder::new(transfer, Compression::Default)), + ContentEncoding::Br => ContentEncoder::Br( + BrotliEncoder::new(transfer, 6)), + ContentEncoding::Identity => ContentEncoder::Identity(transfer), + ContentEncoding::Auto => + unreachable!() + } + ) + } +} + +impl PayloadEncoder { + + pub fn len(&self) -> usize { + self.0.get_ref().len() + } + + pub fn get_mut(&mut self) -> &mut BytesMut { + self.0.get_mut() + } + + pub fn is_eof(&self) -> bool { + self.0.is_eof() + } + + pub fn write(&mut self, payload: &[u8]) -> Result<(), io::Error> { + self.0.write(payload) + } + + pub fn write_eof(&mut self) -> Result<(), io::Error> { + self.0.write_eof() + } +} + +enum ContentEncoder { + Deflate(DeflateEncoder), + Gzip(GzEncoder), + Br(BrotliEncoder), + Identity(TransferEncoding), +} + +impl ContentEncoder { + + pub fn is_eof(&self) -> bool { + match *self { + ContentEncoder::Br(ref encoder) => + encoder.get_ref().is_eof(), + ContentEncoder::Deflate(ref encoder) => + encoder.get_ref().is_eof(), + ContentEncoder::Gzip(ref encoder) => + encoder.get_ref().is_eof(), + ContentEncoder::Identity(ref encoder) => + encoder.is_eof(), + } + } + + pub fn get_ref(&self) -> &BytesMut { + match *self { + ContentEncoder::Br(ref encoder) => + &encoder.get_ref().buffer, + ContentEncoder::Deflate(ref encoder) => + &encoder.get_ref().buffer, + ContentEncoder::Gzip(ref encoder) => + &encoder.get_ref().buffer, + ContentEncoder::Identity(ref encoder) => + &encoder.buffer, + } + } + + pub fn get_mut(&mut self) -> &mut BytesMut { + match *self { + ContentEncoder::Br(ref mut encoder) => + &mut encoder.get_mut().buffer, + ContentEncoder::Deflate(ref mut encoder) => + &mut encoder.get_mut().buffer, + ContentEncoder::Gzip(ref mut encoder) => + &mut encoder.get_mut().buffer, + ContentEncoder::Identity(ref mut encoder) => + &mut encoder.buffer, + } + } + + pub fn write_eof(&mut self) -> Result<(), io::Error> { + let encoder = mem::replace(self, ContentEncoder::Identity(TransferEncoding::eof())); + + match encoder { + ContentEncoder::Br(encoder) => { + match encoder.finish() { + Ok(mut writer) => { + writer.encode_eof(); + *self = ContentEncoder::Identity(writer); + Ok(()) + }, + Err(err) => Err(err), + } + } + ContentEncoder::Gzip(encoder) => { + match encoder.finish() { + Ok(mut writer) => { + writer.encode_eof(); + *self = ContentEncoder::Identity(writer); + Ok(()) + }, + Err(err) => Err(err), + } + }, + ContentEncoder::Deflate(encoder) => { + match encoder.finish() { + Ok(mut writer) => { + writer.encode_eof(); + *self = ContentEncoder::Identity(writer); + Ok(()) + }, + Err(err) => Err(err), + } + }, + ContentEncoder::Identity(mut writer) => { + writer.encode_eof(); + *self = ContentEncoder::Identity(writer); + Ok(()) + } + } + } + + pub fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { + match *self { + ContentEncoder::Br(ref mut encoder) => { + match encoder.write(data) { + Ok(_) => { + encoder.flush() + }, + Err(err) => { + trace!("Error decoding br encoding: {}", err); + Err(err) + }, + } + }, + ContentEncoder::Gzip(ref mut encoder) => { + match encoder.write(data) { + Ok(_) => { + encoder.flush() + }, + Err(err) => { + trace!("Error decoding br encoding: {}", err); + Err(err) + }, + } + } + ContentEncoder::Deflate(ref mut encoder) => { + match encoder.write(data) { + Ok(_) => { + encoder.flush() + }, + Err(err) => { + trace!("Error decoding deflate encoding: {}", err); + Err(err) + }, + } + } + ContentEncoder::Identity(ref mut encoder) => { + encoder.write_all(data)?; + Ok(()) + } + } + } +} + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone)] +pub(crate) struct TransferEncoding { + kind: TransferEncodingKind, + buffer: BytesMut, +} + +#[derive(Debug, PartialEq, Clone)] +enum TransferEncodingKind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked(bool), + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when Content-Length is not known. + /// + /// Appliction decides when to stop writing. + Eof, +} + +impl TransferEncoding { + + pub fn eof() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Eof, + buffer: BytesMut::new(), + } + } + + pub fn chunked() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Chunked(false), + buffer: BytesMut::new(), + } + } + + pub fn length(len: u64) -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Length(len), + buffer: BytesMut::new(), + } + } + + pub fn is_eof(&self) -> bool { + match self.kind { + TransferEncodingKind::Eof => true, + TransferEncodingKind::Chunked(ref eof) => + *eof, + TransferEncodingKind::Length(ref remaining) => + *remaining == 0, + } + } + + /// Encode message. Return `EOF` state of encoder + pub fn encode(&mut self, msg: &[u8]) -> bool { + match self.kind { + TransferEncodingKind::Eof => { + self.buffer.extend(msg); + msg.is_empty() + }, + TransferEncodingKind::Chunked(ref mut eof) => { + if *eof { + return true; + } + + if msg.is_empty() { + *eof = true; + self.buffer.extend(b"0\r\n\r\n"); + } else { + write!(self.buffer, "{:X}\r\n", msg.len()).unwrap(); + self.buffer.extend(msg); + self.buffer.extend(b"\r\n"); + } + *eof + }, + TransferEncodingKind::Length(ref mut remaining) => { + if msg.is_empty() { + return *remaining == 0 + } + let max = cmp::min(*remaining, msg.len() as u64); + trace!("sized write = {}", max); + self.buffer.extend(msg[..max as usize].as_ref()); + + *remaining -= max as u64; + trace!("encoded {} bytes, remaining = {}", max, remaining); + *remaining == 0 + }, + } + } + + /// Encode eof. Return `EOF` state of encoder + pub fn encode_eof(&mut self) { + match self.kind { + TransferEncodingKind::Eof | TransferEncodingKind::Length(_) => (), + TransferEncodingKind::Chunked(ref mut eof) => { + if !*eof { + *eof = true; + self.buffer.extend(b"0\r\n\r\n"); + } + }, + } + } +} + +impl io::Write for TransferEncoding { + + fn write(&mut self, buf: &[u8]) -> io::Result { + self.encode(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + + +struct AcceptEncoding { + encoding: ContentEncoding, + quality: f64, +} + +impl Eq for AcceptEncoding {} + +impl Ord for AcceptEncoding { + fn cmp(&self, other: &AcceptEncoding) -> cmp::Ordering { + if self.quality > other.quality { + cmp::Ordering::Less + } else if self.quality < other.quality { + cmp::Ordering::Greater + } else { + cmp::Ordering::Equal + } + } +} + +impl PartialOrd for AcceptEncoding { + fn partial_cmp(&self, other: &AcceptEncoding) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for AcceptEncoding { + fn eq(&self, other: &AcceptEncoding) -> bool { + self.quality == other.quality + } +} + +impl AcceptEncoding { + fn new(tag: &str) -> Option { + let parts: Vec<&str> = tag.split(';').collect(); + let encoding = match parts.len() { + 0 => return None, + _ => ContentEncoding::from(parts[0]), + }; + let quality = match parts.len() { + 1 => 1.0, + _ => match f64::from_str(parts[1]) { + Ok(q) => q, + Err(_) => 0.0, + } + }; + Some(AcceptEncoding { + encoding: encoding, + quality: quality, + }) + } + + /// Parse a raw Accept-Encoding header value into an ordered list. + pub fn parse(raw: &str) -> ContentEncoding { + let mut encodings: Vec<_> = + raw.replace(' ', "").split(',').map(|l| AcceptEncoding::new(l)).collect(); + encodings.sort(); + + for enc in encodings { + if let Some(enc) = enc { + return enc.encoding + } + } + ContentEncoding::Identity + } +} diff --git a/src/h1writer.rs b/src/h1writer.rs index 7a7d08001..dccf07033 100644 --- a/src/h1writer.rs +++ b/src/h1writer.rs @@ -1,14 +1,13 @@ -use std::{cmp, io}; +use std::io; use std::fmt::Write; -use bytes::BytesMut; use futures::{Async, Poll}; use tokio_io::AsyncWrite; use http::{Version, StatusCode}; -use http::header::{HeaderValue, - CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE}; +use http::header::{HeaderValue, CONNECTION, CONTENT_TYPE, DATE}; use date; use body::Body; +use encoding::PayloadEncoder; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -37,9 +36,8 @@ pub(crate) trait Writer { pub(crate) struct H1Writer { stream: Option, - buffer: BytesMut, started: bool, - encoder: Encoder, + encoder: PayloadEncoder, upgrade: bool, keepalive: bool, disconnected: bool, @@ -50,9 +48,8 @@ impl H1Writer { pub fn new(stream: T) -> H1Writer { H1Writer { stream: Some(stream), - buffer: BytesMut::new(), started: false, - encoder: Encoder::length(0), + encoder: PayloadEncoder::default(), upgrade: false, keepalive: false, disconnected: false, @@ -68,8 +65,7 @@ impl H1Writer { } pub fn disconnected(&mut self) { - let len = self.buffer.len(); - self.buffer.split_to(len); + self.encoder.get_mut().take(); } pub fn keepalive(&self) -> bool { @@ -77,14 +73,16 @@ impl H1Writer { } fn write_to_stream(&mut self) -> Result { + let buffer = self.encoder.get_mut(); + if let Some(ref mut stream) = self.stream { - while !self.buffer.is_empty() { - match stream.write(self.buffer.as_ref()) { + while !buffer.is_empty() { + match stream.write(buffer.as_ref()) { Ok(n) => { - self.buffer.split_to(n); + buffer.split_to(n); }, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + if buffer.len() > MAX_WRITE_BUFFER_SIZE { return Ok(WriterState::Pause) } else { return Ok(WriterState::Done) @@ -106,58 +104,12 @@ impl Writer for H1Writer { trace!("Prepare message status={:?}", msg.status); // prepare task - let mut extra = 0; - let body = msg.replace_body(Body::Empty); - let version = msg.version().unwrap_or_else(|| req.version()); self.started = true; + self.encoder = PayloadEncoder::new(req, msg); self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive()); - match body { - Body::Empty => { - if msg.chunked() { - error!("Chunked transfer is enabled but body is set to Empty"); - } - msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); - msg.headers.remove(TRANSFER_ENCODING); - self.encoder = Encoder::length(0); - }, - Body::Length(n) => { - if msg.chunked() { - error!("Chunked transfer is enabled but body with specific length is specified"); - } - msg.headers.insert( - CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); - msg.headers.remove(TRANSFER_ENCODING); - self.encoder = Encoder::length(n); - }, - Body::Binary(ref bytes) => { - extra = bytes.len(); - msg.headers.insert( - CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); - msg.headers.remove(TRANSFER_ENCODING); - self.encoder = Encoder::length(0); - } - Body::Streaming => { - if msg.chunked() { - if version < Version::HTTP_11 { - error!("Chunked transfer encoding is forbidden for {:?}", version); - } - msg.headers.remove(CONTENT_LENGTH); - msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - self.encoder = Encoder::chunked(); - } else { - self.encoder = Encoder::eof(); - } - } - Body::Upgrade => { - msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); - self.encoder = Encoder::eof(); - } - } - // Connection upgrade + let version = msg.version().unwrap_or_else(|| req.version()); if msg.upgrade() { msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); } @@ -171,44 +123,54 @@ impl Writer for H1Writer { } // render message - let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra; - self.buffer.reserve(init_cap); + { + let buffer = self.encoder.get_mut(); + if let Body::Binary(ref bytes) = *msg.body() { + buffer.reserve(100 + msg.headers.len() * AVERAGE_HEADER_SIZE + bytes.len()); + } else { + buffer.reserve(100 + msg.headers.len() * AVERAGE_HEADER_SIZE); + } - if version == Version::HTTP_11 && msg.status == StatusCode::OK { - self.buffer.extend(b"HTTP/1.1 200 OK\r\n"); - } else { - let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status); - } - for (key, value) in &msg.headers { - let t: &[u8] = key.as_ref(); - self.buffer.extend(t); - self.buffer.extend(b": "); - self.buffer.extend(value.as_ref()); - self.buffer.extend(b"\r\n"); + if version == Version::HTTP_11 && msg.status == StatusCode::OK { + buffer.extend(b"HTTP/1.1 200 OK\r\n"); + } else { + let _ = write!(buffer, "{:?} {}\r\n", version, msg.status); + } + for (key, value) in &msg.headers { + let t: &[u8] = key.as_ref(); + buffer.extend(t); + buffer.extend(b": "); + buffer.extend(value.as_ref()); + buffer.extend(b"\r\n"); + } + + // using http::h1::date is quite a lot faster than generating + // a unique Date header each time like req/s goes up about 10% + if !msg.headers.contains_key(DATE) { + buffer.reserve(date::DATE_VALUE_LENGTH + 8); + buffer.extend(b"Date: "); + let mut bytes = [0u8; 29]; + date::extend(&mut bytes[..]); + buffer.extend(&bytes); + buffer.extend(b"\r\n"); + } + + // default content-type + if !msg.headers.contains_key(CONTENT_TYPE) { + buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref()); + } + + // msg eof + buffer.extend(b"\r\n"); } - // using http::h1::date is quite a lot faster than generating - // a unique Date header each time like req/s goes up about 10% - if !msg.headers.contains_key(DATE) { - self.buffer.reserve(date::DATE_VALUE_LENGTH + 8); - self.buffer.extend(b"Date: "); - date::extend(&mut self.buffer); - self.buffer.extend(b"\r\n"); + if msg.body().is_binary() { + let body = msg.replace_body(Body::Empty); + if let Body::Binary(bytes) = body { + self.encoder.write(bytes.as_ref())?; + return Ok(WriterState::Done) + } } - - // default content-type - if !msg.headers.contains_key(CONTENT_TYPE) { - self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref()); - } - - self.buffer.extend(b"\r\n"); - - if let Body::Binary(ref bytes) = body { - self.buffer.extend_from_slice(bytes.as_ref()); - return Ok(WriterState::Done) - } - msg.replace_body(body); - Ok(WriterState::Done) } @@ -216,14 +178,14 @@ impl Writer for H1Writer { if !self.disconnected { if self.started { // TODO: add warning, write after EOF - self.encoder.encode(&mut self.buffer, payload); + self.encoder.write(payload)?; } else { - // might be response for EXCEPT - self.buffer.extend_from_slice(payload) + // might be response to EXCEPT + self.encoder.get_mut().extend_from_slice(payload) } } - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + if self.encoder.len() > MAX_WRITE_BUFFER_SIZE { Ok(WriterState::Pause) } else { Ok(WriterState::Done) @@ -231,11 +193,13 @@ impl Writer for H1Writer { } fn write_eof(&mut self) -> Result { - if !self.encoder.encode_eof(&mut self.buffer) { + self.encoder.write_eof()?; + + if !self.encoder.is_eof() { //debug!("last payload item, but it is not EOF "); Err(io::Error::new(io::ErrorKind::Other, "Last payload item, but eof is not reached")) - } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + } else if self.encoder.len() > MAX_WRITE_BUFFER_SIZE { Ok(WriterState::Pause) } else { Ok(WriterState::Done) @@ -250,100 +214,3 @@ impl Writer for H1Writer { } } } - -/// Encoders to handle different Transfer-Encodings. -#[derive(Debug, Clone)] -pub(crate) struct Encoder { - kind: Kind, -} - -#[derive(Debug, PartialEq, Clone)] -enum Kind { - /// An Encoder for when Transfer-Encoding includes `chunked`. - Chunked(bool), - /// An Encoder for when Content-Length is set. - /// - /// Enforces that the body is not longer than the Content-Length header. - Length(u64), - /// An Encoder for when Content-Length is not known. - /// - /// Appliction decides when to stop writing. - Eof, -} - -impl Encoder { - - pub fn eof() -> Encoder { - Encoder { - kind: Kind::Eof, - } - } - - pub fn chunked() -> Encoder { - Encoder { - kind: Kind::Chunked(false), - } - } - - pub fn length(len: u64) -> Encoder { - Encoder { - kind: Kind::Length(len), - } - } - - /// Encode message. Return `EOF` state of encoder - pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool { - match self.kind { - Kind::Eof => { - dst.extend(msg); - msg.is_empty() - }, - Kind::Chunked(ref mut eof) => { - if *eof { - return true; - } - - if msg.is_empty() { - *eof = true; - dst.extend(b"0\r\n\r\n"); - } else { - write!(dst, "{:X}\r\n", msg.len()).unwrap(); - dst.extend(msg); - dst.extend(b"\r\n"); - } - *eof - }, - Kind::Length(ref mut remaining) => { - if msg.is_empty() { - return *remaining == 0 - } - let max = cmp::min(*remaining, msg.len() as u64); - trace!("sized write = {}", max); - dst.extend(msg[..max as usize].as_ref()); - - *remaining -= max as u64; - trace!("encoded {} bytes, remaining = {}", max, remaining); - *remaining == 0 - }, - } - } - - /// Encode eof. Return `EOF` state of encoder - pub fn encode_eof(&mut self, dst: &mut BytesMut) -> bool { - match self.kind { - Kind::Eof => true, - Kind::Chunked(ref mut eof) => { - if *eof { - return true; - } - - *eof = true; - dst.extend(b"0\r\n\r\n"); - true - }, - Kind::Length(ref mut remaining) => { - *remaining == 0 - }, - } - } -} diff --git a/src/h2writer.rs b/src/h2writer.rs index d97f0a542..a4adba227 100644 --- a/src/h2writer.rs +++ b/src/h2writer.rs @@ -1,14 +1,14 @@ use std::{io, cmp}; -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use futures::{Async, Poll}; use http2::{Reason, SendStream}; use http2::server::Respond; use http::{Version, HttpTryFrom, Response}; -use http::header::{HeaderValue, CONNECTION, CONTENT_TYPE, - CONTENT_LENGTH, TRANSFER_ENCODING, DATE}; +use http::header::{HeaderValue, CONNECTION, CONTENT_TYPE, TRANSFER_ENCODING, DATE}; use date; use body::Body; +use encoding::PayloadEncoder; use httprequest::HttpRequest; use httpresponse::HttpResponse; use h1writer::{Writer, WriterState}; @@ -20,9 +20,8 @@ const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k pub(crate) struct H2Writer { respond: Respond, stream: Option>, - buffer: BytesMut, started: bool, - encoder: Encoder, + encoder: PayloadEncoder, disconnected: bool, eof: bool, } @@ -33,9 +32,8 @@ impl H2Writer { H2Writer { respond: respond, stream: None, - buffer: BytesMut::new(), started: false, - encoder: Encoder::length(0), + encoder: PayloadEncoder::default(), disconnected: false, eof: true, } @@ -53,7 +51,9 @@ impl H2Writer { } if let Some(ref mut stream) = self.stream { - if self.buffer.is_empty() { + let buffer = self.encoder.get_mut(); + + if buffer.is_empty() { if self.eof { let _ = stream.send_data(Bytes::new(), true); } @@ -63,7 +63,7 @@ impl H2Writer { loop { match stream.poll_capacity() { Ok(Async::NotReady) => { - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + if buffer.len() > MAX_WRITE_BUFFER_SIZE { return Ok(WriterState::Pause) } else { return Ok(WriterState::Done) @@ -73,14 +73,14 @@ impl H2Writer { return Ok(WriterState::Done) } Ok(Async::Ready(Some(cap))) => { - let len = self.buffer.len(); - let bytes = self.buffer.split_to(cmp::min(cap, len)); - let eof = self.buffer.is_empty() && self.eof; + let len = buffer.len(); + let bytes = buffer.split_to(cmp::min(cap, len)); + let eof = buffer.is_empty() && self.eof; if let Err(err) = stream.send_data(bytes.freeze(), eof) { return Err(io::Error::new(io::ErrorKind::Other, err)) - } else if !self.buffer.is_empty() { - let cap = cmp::min(self.buffer.len(), CHUNK_SIZE); + } else if !buffer.is_empty() { + let cap = cmp::min(buffer.len(), CHUNK_SIZE); stream.reserve_capacity(cap); } else { return Ok(WriterState::Done) @@ -98,57 +98,26 @@ impl H2Writer { impl Writer for H2Writer { - fn start(&mut self, _: &mut HttpRequest, msg: &mut HttpResponse) + fn start(&mut self, req: &mut HttpRequest, msg: &mut HttpResponse) -> Result { trace!("Prepare message status={:?}", msg); // prepare response self.started = true; - let body = msg.replace_body(Body::Empty); + self.encoder = PayloadEncoder::new(req, msg); + self.eof = if let Body::Empty = *msg.body() { true } else { false }; // http2 specific msg.headers.remove(CONNECTION); msg.headers.remove(TRANSFER_ENCODING); - match body { - Body::Empty => { - if msg.chunked() { - error!("Chunked transfer is enabled but body is set to Empty"); - } - msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); - self.encoder = Encoder::length(0); - }, - Body::Length(n) => { - if msg.chunked() { - error!("Chunked transfer is enabled but body with specific length is specified"); - } - self.eof = false; - msg.headers.insert( - CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); - self.encoder = Encoder::length(n); - }, - Body::Binary(ref bytes) => { - self.eof = false; - msg.headers.insert( - CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); - self.encoder = Encoder::length(0); - } - _ => { - msg.headers.remove(CONTENT_LENGTH); - self.eof = false; - self.encoder = Encoder::eof(); - } - } - // using http::h1::date is quite a lot faster than generating // a unique Date header each time like req/s goes up about 10% if !msg.headers.contains_key(DATE) { - let mut bytes = BytesMut::with_capacity(29); - date::extend(&mut bytes); - msg.headers.insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); + let mut bytes = [0u8; 29]; + date::extend(&mut bytes[..]); + msg.headers.insert(DATE, HeaderValue::try_from(&bytes[..]).unwrap()); } // default content-type @@ -165,23 +134,22 @@ impl Writer for H2Writer { } match self.respond.send_response(resp, self.eof) { - Ok(stream) => { - self.stream = Some(stream); - } - Err(_) => { - return Err(io::Error::new(io::ErrorKind::Other, "err")) - } + Ok(stream) => + self.stream = Some(stream), + Err(_) => + return Err(io::Error::new(io::ErrorKind::Other, "err")), } - if let Body::Binary(ref bytes) = body { - self.eof = true; - self.buffer.extend_from_slice(bytes.as_ref()); - if let Some(ref mut stream) = self.stream { - stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE)); + if msg.body().is_binary() { + if let Body::Binary(bytes) = msg.replace_body(Body::Empty) { + self.eof = true; + self.encoder.write(bytes.as_ref())?; + if let Some(ref mut stream) = self.stream { + stream.reserve_capacity(cmp::min(self.encoder.len(), CHUNK_SIZE)); + } + return Ok(WriterState::Done) } - return Ok(WriterState::Done) } - msg.replace_body(body); Ok(WriterState::Done) } @@ -190,14 +158,14 @@ impl Writer for H2Writer { if !self.disconnected { if self.started { // TODO: add warning, write after EOF - self.encoder.encode(&mut self.buffer, payload); + self.encoder.write(payload)?; } else { // might be response for EXCEPT - self.buffer.extend_from_slice(payload) + self.encoder.get_mut().extend_from_slice(payload) } } - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + if self.encoder.len() > MAX_WRITE_BUFFER_SIZE { Ok(WriterState::Pause) } else { Ok(WriterState::Done) @@ -205,11 +173,13 @@ impl Writer for H2Writer { } fn write_eof(&mut self) -> Result { + self.encoder.write_eof()?; + self.eof = true; - if !self.encoder.encode_eof(&mut self.buffer) { + if !self.encoder.is_eof() { Err(io::Error::new(io::ErrorKind::Other, "Last payload item, but eof is not reached")) - } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + } else if self.encoder.len() > MAX_WRITE_BUFFER_SIZE { Ok(WriterState::Pause) } else { Ok(WriterState::Done) @@ -224,67 +194,3 @@ impl Writer for H2Writer { } } } - - -/// Encoders to handle different Transfer-Encodings. -#[derive(Debug, Clone)] -pub(crate) struct Encoder { - kind: Kind, -} - -#[derive(Debug, PartialEq, Clone)] -enum Kind { - /// An Encoder for when Content-Length is set. - /// - /// Enforces that the body is not longer than the Content-Length header. - Length(u64), - /// An Encoder for when Content-Length is not known. - /// - /// Appliction decides when to stop writing. - Eof, -} - -impl Encoder { - - pub fn eof() -> Encoder { - Encoder { - kind: Kind::Eof, - } - } - - pub fn length(len: u64) -> Encoder { - Encoder { - kind: Kind::Length(len), - } - } - - /// Encode message. Return `EOF` state of encoder - pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool { - match self.kind { - Kind::Eof => { - dst.extend(msg); - msg.is_empty() - }, - Kind::Length(ref mut remaining) => { - if msg.is_empty() { - return *remaining == 0 - } - let max = cmp::min(*remaining, msg.len() as u64); - trace!("sized write = {}", max); - dst.extend(msg[..max as usize].as_ref()); - - *remaining -= max as u64; - trace!("encoded {} bytes, remaining = {}", max, remaining); - *remaining == 0 - }, - } - } - - /// Encode eof. Return `EOF` state of encoder - pub fn encode_eof(&mut self, _dst: &mut BytesMut) -> bool { - match self.kind { - Kind::Eof => true, - Kind::Length(ref mut remaining) => *remaining == 0 - } - } -} diff --git a/tests/test_server.rs b/tests/test_server.rs index 085576897..b489fd8c4 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -47,7 +47,6 @@ fn test_serve_incoming() { let tcp = TcpListener::from_listener(tcp, &addr2, Arbiter::handle()).unwrap(); srv.serve_incoming::<_, ()>(tcp.incoming()).unwrap(); sys.run(); - }); assert!(reqwest::get(&format!("http://{}/", addr1))