1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 07:53:00 +01:00

refactor content encoder

This commit is contained in:
Nikolay Kim 2018-06-24 08:54:01 +06:00
parent 348491b18c
commit 45682c04a8
6 changed files with 278 additions and 163 deletions

View File

@ -21,7 +21,8 @@ use tokio_io::AsyncWrite;
use body::{Binary, Body}; use body::{Binary, Body};
use header::ContentEncoding; use header::ContentEncoding;
use server::encoding::{ContentEncoder, TransferEncoding}; use server::encoding::{ContentEncoder, Output, TransferEncoding};
use server::shared::SharedBytes;
use server::WriterState; use server::WriterState;
use client::ClientRequest; use client::ClientRequest;
@ -41,21 +42,18 @@ pub(crate) struct HttpClientWriter {
flags: Flags, flags: Flags,
written: u64, written: u64,
headers_size: u32, headers_size: u32,
buffer: Box<BytesMut>, buffer: Output,
buffer_capacity: usize, buffer_capacity: usize,
encoder: ContentEncoder,
} }
impl HttpClientWriter { impl HttpClientWriter {
pub fn new() -> HttpClientWriter { pub fn new() -> HttpClientWriter {
let encoder = ContentEncoder::Identity(TransferEncoding::eof());
HttpClientWriter { HttpClientWriter {
flags: Flags::empty(), flags: Flags::empty(),
written: 0, written: 0,
headers_size: 0, headers_size: 0,
buffer_capacity: 0, buffer_capacity: 0,
buffer: Box::new(BytesMut::new()), buffer: Output::Buffer(SharedBytes::empty()),
encoder,
} }
} }
@ -75,7 +73,7 @@ impl HttpClientWriter {
&mut self, stream: &mut T, &mut self, stream: &mut T,
) -> io::Result<WriterState> { ) -> io::Result<WriterState> {
while !self.buffer.is_empty() { while !self.buffer.is_empty() {
match stream.write(self.buffer.as_ref()) { match stream.write(self.buffer.as_ref().as_ref()) {
Ok(0) => { Ok(0) => {
self.disconnected(); self.disconnected();
return Ok(WriterState::Done); return Ok(WriterState::Done);
@ -113,16 +111,18 @@ impl HttpClientWriter {
pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> { pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> {
// prepare task // prepare task
self.flags.insert(Flags::STARTED); self.flags.insert(Flags::STARTED);
self.encoder = content_encoder(self.buffer.as_mut(), msg);
if msg.upgrade() { if msg.upgrade() {
self.flags.insert(Flags::UPGRADE); self.flags.insert(Flags::UPGRADE);
} }
// render message // render message
{ {
// output buffer
let buffer = self.buffer.get_mut();
// status line // status line
writeln!( writeln!(
Writer(&mut self.buffer), Writer(buffer),
"{} {} {:?}\r", "{} {} {:?}\r",
msg.method(), msg.method(),
msg.uri() msg.uri()
@ -134,53 +134,49 @@ impl HttpClientWriter {
// write headers // write headers
if let Body::Binary(ref bytes) = *msg.body() { if let Body::Binary(ref bytes) = *msg.body() {
self.buffer buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len());
.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len());
} else { } else {
self.buffer buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE);
.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE);
} }
for (key, value) in msg.headers() { for (key, value) in msg.headers() {
let v = value.as_ref(); let v = value.as_ref();
let k = key.as_str().as_bytes(); let k = key.as_str().as_bytes();
self.buffer.reserve(k.len() + v.len() + 4); buffer.reserve(k.len() + v.len() + 4);
self.buffer.put_slice(k); buffer.put_slice(k);
self.buffer.put_slice(b": "); buffer.put_slice(b": ");
self.buffer.put_slice(v); buffer.put_slice(v);
self.buffer.put_slice(b"\r\n"); buffer.put_slice(b"\r\n");
} }
// set date header // set date header
if !msg.headers().contains_key(DATE) { if !msg.headers().contains_key(DATE) {
self.buffer.extend_from_slice(b"date: "); buffer.extend_from_slice(b"date: ");
set_date(&mut self.buffer); set_date(buffer);
self.buffer.extend_from_slice(b"\r\n\r\n"); buffer.extend_from_slice(b"\r\n\r\n");
} else { } else {
self.buffer.extend_from_slice(b"\r\n"); buffer.extend_from_slice(b"\r\n");
}
} }
self.headers_size = self.buffer.len() as u32; self.headers_size = self.buffer.len() as u32;
self.buffer = content_encoder(self.buffer.take(), msg);
if msg.body().is_binary() { if msg.body().is_binary() {
if let Body::Binary(bytes) = msg.replace_body(Body::Empty) { if let Body::Binary(bytes) = msg.replace_body(Body::Empty) {
self.written += bytes.len() as u64; self.written += bytes.len() as u64;
self.encoder.write(bytes.as_ref())?; self.buffer.write(bytes.as_ref())?;
} }
} else { } else {
self.buffer_capacity = msg.write_buffer_capacity(); self.buffer_capacity = msg.write_buffer_capacity();
} }
}
Ok(()) Ok(())
} }
pub fn write(&mut self, payload: &[u8]) -> io::Result<WriterState> { pub fn write(&mut self, payload: &[u8]) -> io::Result<WriterState> {
self.written += payload.len() as u64; self.written += payload.len() as u64;
if !self.flags.contains(Flags::DISCONNECTED) { if !self.flags.contains(Flags::DISCONNECTED) {
if self.flags.contains(Flags::UPGRADE) { self.buffer.write(payload)?;
self.buffer.extend(payload);
} else {
self.encoder.write(payload)?;
}
} }
if self.buffer.len() > self.buffer_capacity { if self.buffer.len() > self.buffer_capacity {
@ -191,9 +187,7 @@ impl HttpClientWriter {
} }
pub fn write_eof(&mut self) -> io::Result<()> { pub fn write_eof(&mut self) -> io::Result<()> {
self.encoder.write_eof()?; if self.buffer.write_eof()? {
if self.encoder.is_eof() {
Ok(()) Ok(())
} else { } else {
Err(io::Error::new( Err(io::Error::new(
@ -221,21 +215,20 @@ impl HttpClientWriter {
} }
} }
fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncoder { fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> Output {
let version = req.version(); let version = req.version();
let mut body = req.replace_body(Body::Empty); let mut body = req.replace_body(Body::Empty);
let mut encoding = req.content_encoding(); let mut encoding = req.content_encoding();
let mut transfer = match body { let transfer = match body {
Body::Empty => { Body::Empty => {
req.headers_mut().remove(CONTENT_LENGTH); req.headers_mut().remove(CONTENT_LENGTH);
TransferEncoding::length(0) TransferEncoding::length(0, buf)
} }
Body::Binary(ref mut bytes) => { Body::Binary(ref mut bytes) => {
if encoding.is_compression() { if encoding.is_compression() {
let mut tmp = BytesMut::new(); let mut tmp = SharedBytes::empty();
let mut transfer = TransferEncoding::eof(); let mut transfer = TransferEncoding::eof(tmp);
transfer.set_buffer(&mut tmp);
let mut enc = match encoding { let mut enc = match encoding {
#[cfg(feature = "flate2")] #[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate( ContentEncoding::Deflate => ContentEncoder::Deflate(
@ -256,7 +249,7 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
// TODO return error! // TODO return error!
let _ = enc.write(bytes.as_ref()); let _ = enc.write(bytes.as_ref());
let _ = enc.write_eof(); let _ = enc.write_eof();
*bytes = Binary::from(tmp.take()); *bytes = Binary::from(enc.buf_mut().take());
req.headers_mut().insert( req.headers_mut().insert(
CONTENT_ENCODING, CONTENT_ENCODING,
@ -268,7 +261,7 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
let _ = write!(b, "{}", bytes.len()); let _ = write!(b, "{}", bytes.len());
req.headers_mut() req.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); .insert(CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap());
TransferEncoding::eof() TransferEncoding::eof(buf)
} }
Body::Streaming(_) | Body::Actor(_) => { Body::Streaming(_) | Body::Actor(_) => {
if req.upgrade() { if req.upgrade() {
@ -282,9 +275,9 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
encoding = ContentEncoding::Identity; encoding = ContentEncoding::Identity;
req.headers_mut().remove(CONTENT_ENCODING); req.headers_mut().remove(CONTENT_ENCODING);
} }
TransferEncoding::eof() TransferEncoding::eof(buf)
} else { } else {
streaming_encoding(version, req) streaming_encoding(buf, version, req)
} }
} }
}; };
@ -295,10 +288,9 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
HeaderValue::from_static(encoding.as_str()), HeaderValue::from_static(encoding.as_str()),
); );
} }
transfer.set_buffer(buf);
req.replace_body(body); req.replace_body(body);
match encoding { let enc = match encoding {
#[cfg(feature = "flate2")] #[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new( ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
transfer, transfer,
@ -310,23 +302,24 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
} }
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)), ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)),
ContentEncoding::Identity | ContentEncoding::Auto => { ContentEncoding::Identity | ContentEncoding::Auto => return Output::TE(transfer),
ContentEncoder::Identity(transfer) };
} Output::Encoder(enc)
}
} }
fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEncoding { fn streaming_encoding(
buf: SharedBytes, version: Version, req: &mut ClientRequest,
) -> TransferEncoding {
if req.chunked() { if req.chunked() {
// Enable transfer encoding // Enable transfer encoding
req.headers_mut().remove(CONTENT_LENGTH); req.headers_mut().remove(CONTENT_LENGTH);
if version == Version::HTTP_2 { if version == Version::HTTP_2 {
req.headers_mut().remove(TRANSFER_ENCODING); req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof() TransferEncoding::eof(buf)
} else { } else {
req.headers_mut() req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked() TransferEncoding::chunked(buf)
} }
} else { } else {
// if Content-Length is specified, then use it as length hint // if Content-Length is specified, then use it as length hint
@ -349,9 +342,9 @@ fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEnco
if !chunked { if !chunked {
if let Some(len) = len { if let Some(len) = len {
TransferEncoding::length(len) TransferEncoding::length(len, buf)
} else { } else {
TransferEncoding::eof() TransferEncoding::eof(buf)
} }
} else { } else {
// Enable transfer encoding // Enable transfer encoding
@ -359,11 +352,11 @@ fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEnco
Version::HTTP_11 => { Version::HTTP_11 => {
req.headers_mut() req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked() TransferEncoding::chunked(buf)
} }
_ => { _ => {
req.headers_mut().remove(TRANSFER_ENCODING); req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof() TransferEncoding::eof(buf)
} }
} }
} }

View File

@ -1127,6 +1127,7 @@ mod tests {
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
println!("RESP: {:?}", response);
let te = response let te = response
.headers() .headers()
.get(header::TRANSFER_ENCODING) .get(header::TRANSFER_ENCODING)

View File

@ -1,7 +1,7 @@
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::str::FromStr; use std::str::FromStr;
use std::{cmp, io, mem, ptr}; use std::{cmp, io, mem};
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
use brotli2::write::{BrotliDecoder, BrotliEncoder}; use brotli2::write::{BrotliDecoder, BrotliEncoder};
@ -25,6 +25,8 @@ use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use payload::{PayloadSender, PayloadStatus, PayloadWriter}; use payload::{PayloadSender, PayloadStatus, PayloadWriter};
use super::shared::SharedBytes;
pub(crate) enum PayloadType { pub(crate) enum PayloadType {
Sender(PayloadSender), Sender(PayloadSender),
Encoding(Box<EncodedPayload>), Encoding(Box<EncodedPayload>),
@ -368,6 +370,83 @@ impl PayloadStream {
} }
} }
pub(crate) enum Output {
Buffer(SharedBytes),
Encoder(ContentEncoder),
TE(TransferEncoding),
Empty,
}
impl Output {
pub fn take(&mut self) -> SharedBytes {
match mem::replace(self, Output::Empty) {
Output::Buffer(bytes) => bytes,
_ => panic!(),
}
}
pub fn as_ref(&mut self) -> &SharedBytes {
match self {
Output::Buffer(ref mut bytes) => bytes,
Output::Encoder(ref mut enc) => enc.buf_ref(),
Output::TE(ref mut te) => te.buf_ref(),
Output::Empty => panic!(),
}
}
pub fn get_mut(&mut self) -> &mut BytesMut {
match self {
Output::Buffer(ref mut bytes) => bytes.get_mut(),
_ => panic!(),
}
}
pub fn split_to(&mut self, cap: usize) -> BytesMut {
match self {
Output::Buffer(ref mut bytes) => bytes.split_to(cap),
Output::Encoder(ref mut enc) => enc.buf_mut().split_to(cap),
Output::TE(ref mut te) => te.buf_mut().split_to(cap),
Output::Empty => BytesMut::new(),
}
}
pub fn len(&self) -> usize {
match self {
Output::Buffer(ref bytes) => bytes.len(),
Output::Encoder(ref enc) => enc.len(),
Output::TE(ref te) => te.len(),
Output::Empty => 0,
}
}
pub fn is_empty(&self) -> bool {
match self {
Output::Buffer(ref bytes) => bytes.is_empty(),
Output::Encoder(ref enc) => enc.is_empty(),
Output::TE(ref te) => te.is_empty(),
Output::Empty => true,
}
}
pub fn write(&mut self, data: &[u8]) -> Result<(), io::Error> {
match self {
Output::Buffer(ref mut bytes) => {
bytes.extend_from_slice(data);
Ok(())
}
Output::Encoder(ref mut enc) => enc.write(data),
Output::TE(ref mut te) => te.encode(data).map(|_| ()),
Output::Empty => Ok(()),
}
}
pub fn write_eof(&mut self) -> Result<bool, io::Error> {
match self {
Output::Buffer(_) => Ok(true),
Output::Encoder(ref mut enc) => enc.write_eof(),
Output::TE(ref mut te) => Ok(te.encode_eof()),
Output::Empty => Ok(true),
}
}
}
pub(crate) enum ContentEncoder { pub(crate) enum ContentEncoder {
#[cfg(feature = "flate2")] #[cfg(feature = "flate2")]
Deflate(DeflateEncoder<TransferEncoding>), Deflate(DeflateEncoder<TransferEncoding>),
@ -379,14 +458,10 @@ pub(crate) enum ContentEncoder {
} }
impl ContentEncoder { impl ContentEncoder {
pub fn empty() -> ContentEncoder {
ContentEncoder::Identity(TransferEncoding::eof())
}
pub fn for_server( pub fn for_server(
buf: &mut BytesMut, req: &HttpInnerMessage, resp: &mut HttpResponse, buf: SharedBytes, req: &HttpInnerMessage, resp: &mut HttpResponse,
response_encoding: ContentEncoding, response_encoding: ContentEncoding,
) -> ContentEncoder { ) -> Output {
let version = resp.version().unwrap_or_else(|| req.version); let version = resp.version().unwrap_or_else(|| req.version);
let is_head = req.method == Method::HEAD; let is_head = req.method == Method::HEAD;
let mut len = 0; let mut len = 0;
@ -439,7 +514,7 @@ impl ContentEncoder {
if req.method != Method::HEAD { if req.method != Method::HEAD {
resp.headers_mut().remove(CONTENT_LENGTH); resp.headers_mut().remove(CONTENT_LENGTH);
} }
TransferEncoding::length(0) TransferEncoding::length(0, buf)
} }
&Body::Binary(_) => { &Body::Binary(_) => {
#[cfg(any(feature = "brotli", feature = "flate2"))] #[cfg(any(feature = "brotli", feature = "flate2"))]
@ -447,9 +522,8 @@ impl ContentEncoder {
if !(encoding == ContentEncoding::Identity if !(encoding == ContentEncoding::Identity
|| encoding == ContentEncoding::Auto) || encoding == ContentEncoding::Auto)
{ {
let mut tmp = BytesMut::default(); let mut tmp = SharedBytes::empty();
let mut transfer = TransferEncoding::eof(); let mut transfer = TransferEncoding::eof(tmp);
transfer.set_buffer(&mut tmp);
let mut enc = match encoding { let mut enc = match encoding {
#[cfg(feature = "flate2")] #[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate( ContentEncoding::Deflate => ContentEncoder::Deflate(
@ -473,7 +547,7 @@ impl ContentEncoder {
// TODO return error! // TODO return error!
let _ = enc.write(bin.as_ref()); let _ = enc.write(bin.as_ref());
let _ = enc.write_eof(); let _ = enc.write_eof();
let body = tmp.take(); let body = enc.buf_mut().take();
len = body.len(); len = body.len();
encoding = ContentEncoding::Identity; encoding = ContentEncoding::Identity;
@ -491,7 +565,7 @@ impl ContentEncoder {
} else { } else {
// resp.headers_mut().remove(CONTENT_LENGTH); // resp.headers_mut().remove(CONTENT_LENGTH);
} }
TransferEncoding::eof() TransferEncoding::eof(buf)
} }
&Body::Streaming(_) | &Body::Actor(_) => { &Body::Streaming(_) | &Body::Actor(_) => {
if resp.upgrade() { if resp.upgrade() {
@ -502,14 +576,14 @@ impl ContentEncoder {
encoding = ContentEncoding::Identity; encoding = ContentEncoding::Identity;
resp.headers_mut().remove(CONTENT_ENCODING); resp.headers_mut().remove(CONTENT_ENCODING);
} }
TransferEncoding::eof() TransferEncoding::eof(buf)
} else { } else {
if !(encoding == ContentEncoding::Identity if !(encoding == ContentEncoding::Identity
|| encoding == ContentEncoding::Auto) || encoding == ContentEncoding::Auto)
{ {
resp.headers_mut().remove(CONTENT_LENGTH); resp.headers_mut().remove(CONTENT_LENGTH);
} }
ContentEncoder::streaming_encoding(version, resp) ContentEncoder::streaming_encoding(buf, version, resp)
} }
} }
}; };
@ -518,9 +592,8 @@ impl ContentEncoder {
resp.set_body(Body::Empty); resp.set_body(Body::Empty);
transfer.kind = TransferEncodingKind::Length(0); transfer.kind = TransferEncodingKind::Length(0);
} }
transfer.set_buffer(buf);
match encoding { let enc = match encoding {
#[cfg(feature = "flate2")] #[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new( ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
transfer, transfer,
@ -533,13 +606,14 @@ impl ContentEncoder {
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 3)), ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 3)),
ContentEncoding::Identity | ContentEncoding::Auto => { ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer) return Output::TE(transfer)
}
} }
};
Output::Encoder(enc)
} }
fn streaming_encoding( fn streaming_encoding(
version: Version, resp: &mut HttpResponse, buf: SharedBytes, version: Version, resp: &mut HttpResponse,
) -> TransferEncoding { ) -> TransferEncoding {
match resp.chunked() { match resp.chunked() {
Some(true) => { Some(true) => {
@ -547,14 +621,14 @@ impl ContentEncoder {
resp.headers_mut().remove(CONTENT_LENGTH); resp.headers_mut().remove(CONTENT_LENGTH);
if version == Version::HTTP_2 { if version == Version::HTTP_2 {
resp.headers_mut().remove(TRANSFER_ENCODING); resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof() TransferEncoding::eof(buf)
} else { } else {
resp.headers_mut() resp.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked() TransferEncoding::chunked(buf)
} }
} }
Some(false) => TransferEncoding::eof(), Some(false) => TransferEncoding::eof(buf),
None => { None => {
// if Content-Length is specified, then use it as length hint // if Content-Length is specified, then use it as length hint
let (len, chunked) = let (len, chunked) =
@ -577,9 +651,9 @@ impl ContentEncoder {
if !chunked { if !chunked {
if let Some(len) = len { if let Some(len) = len {
TransferEncoding::length(len) TransferEncoding::length(len, buf)
} else { } else {
TransferEncoding::eof() TransferEncoding::eof(buf)
} }
} else { } else {
// Enable transfer encoding // Enable transfer encoding
@ -589,11 +663,11 @@ impl ContentEncoder {
TRANSFER_ENCODING, TRANSFER_ENCODING,
HeaderValue::from_static("chunked"), HeaderValue::from_static("chunked"),
); );
TransferEncoding::chunked() TransferEncoding::chunked(buf)
} }
_ => { _ => {
resp.headers_mut().remove(TRANSFER_ENCODING); resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof() TransferEncoding::eof(buf)
} }
} }
} }
@ -604,15 +678,54 @@ impl ContentEncoder {
impl ContentEncoder { impl ContentEncoder {
#[inline] #[inline]
pub fn is_eof(&self) -> bool { pub fn len(&self) -> usize {
match *self { match *self {
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
ContentEncoder::Br(ref encoder) => encoder.get_ref().is_eof(), ContentEncoder::Br(ref encoder) => encoder.get_ref().len(),
#[cfg(feature = "flate2")] #[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_eof(), ContentEncoder::Deflate(ref encoder) => encoder.get_ref().len(),
#[cfg(feature = "flate2")] #[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_eof(), ContentEncoder::Gzip(ref encoder) => encoder.get_ref().len(),
ContentEncoder::Identity(ref encoder) => encoder.is_eof(), ContentEncoder::Identity(ref encoder) => encoder.len(),
}
}
#[inline]
pub fn is_empty(&self) -> bool {
match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref encoder) => encoder.get_ref().is_empty(),
#[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_empty(),
#[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_empty(),
ContentEncoder::Identity(ref encoder) => encoder.is_empty(),
}
}
#[inline]
pub(crate) fn buf_mut(&mut self) -> &mut BytesMut {
match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_mut(),
#[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_mut(),
#[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_mut(),
ContentEncoder::Identity(ref mut encoder) => encoder.buf_mut(),
}
}
#[inline]
pub(crate) fn buf_ref(&mut self) -> &SharedBytes {
match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_ref(),
#[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_ref(),
#[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_ref(),
ContentEncoder::Identity(ref mut encoder) => encoder.buf_ref(),
} }
} }
@ -620,7 +733,7 @@ impl ContentEncoder {
#[inline(always)] #[inline(always)]
pub fn write_eof(&mut self) -> Result<bool, io::Error> { pub fn write_eof(&mut self) -> Result<bool, io::Error> {
let encoder = let encoder =
mem::replace(self, ContentEncoder::Identity(TransferEncoding::eof())); mem::replace(self, ContentEncoder::Identity(TransferEncoding::empty()));
match encoder { match encoder {
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
@ -695,9 +808,9 @@ impl ContentEncoder {
} }
/// Encoders to handle different Transfer-Encodings. /// Encoders to handle different Transfer-Encodings.
#[derive(Debug, Clone)] #[derive(Debug)]
pub(crate) struct TransferEncoding { pub(crate) struct TransferEncoding {
buf: *mut BytesMut, buf: Option<SharedBytes>,
kind: TransferEncodingKind, kind: TransferEncodingKind,
} }
@ -716,40 +829,55 @@ enum TransferEncodingKind {
} }
impl TransferEncoding { impl TransferEncoding {
pub(crate) fn set_buffer(&mut self, buf: *mut BytesMut) { fn take(self) -> SharedBytes {
self.buf = buf; self.buf.unwrap()
}
fn buf_ref(&mut self) -> &SharedBytes {
self.buf.as_ref().unwrap()
}
fn len(&self) -> usize {
self.buf.as_ref().unwrap().len()
}
fn is_empty(&self) -> bool {
self.buf.as_ref().unwrap().is_empty()
}
fn buf_mut(&mut self) -> &mut BytesMut {
self.buf.as_mut().unwrap().get_mut()
} }
#[inline] #[inline]
pub fn eof() -> TransferEncoding { pub fn empty() -> TransferEncoding {
TransferEncoding { TransferEncoding {
buf: None,
kind: TransferEncodingKind::Eof, kind: TransferEncodingKind::Eof,
buf: ptr::null_mut(),
} }
} }
#[inline] #[inline]
pub fn chunked() -> TransferEncoding { pub fn eof(buf: SharedBytes) -> TransferEncoding {
TransferEncoding { TransferEncoding {
buf: Some(buf),
kind: TransferEncodingKind::Eof,
}
}
#[inline]
pub fn chunked(buf: SharedBytes) -> TransferEncoding {
TransferEncoding {
buf: Some(buf),
kind: TransferEncodingKind::Chunked(false), kind: TransferEncodingKind::Chunked(false),
buf: ptr::null_mut(),
} }
} }
#[inline] #[inline]
pub fn length(len: u64) -> TransferEncoding { pub fn length(len: u64, buf: SharedBytes) -> TransferEncoding {
TransferEncoding { TransferEncoding {
buf: Some(buf),
kind: TransferEncodingKind::Length(len), kind: TransferEncodingKind::Length(len),
buf: ptr::null_mut(),
}
}
#[inline]
pub fn is_eof(&self) -> bool {
match self.kind {
TransferEncodingKind::Eof => true,
TransferEncodingKind::Chunked(ref eof) => *eof,
TransferEncodingKind::Length(ref remaining) => *remaining == 0,
} }
} }
@ -759,9 +887,7 @@ impl TransferEncoding {
match self.kind { match self.kind {
TransferEncodingKind::Eof => { TransferEncodingKind::Eof => {
let eof = msg.is_empty(); let eof = msg.is_empty();
debug_assert!(!self.buf.is_null()); self.buf.as_mut().unwrap().extend_from_slice(msg);
let buf = unsafe { &mut *self.buf };
buf.extend(msg);
Ok(eof) Ok(eof)
} }
TransferEncodingKind::Chunked(ref mut eof) => { TransferEncodingKind::Chunked(ref mut eof) => {
@ -771,16 +897,13 @@ impl TransferEncoding {
if msg.is_empty() { if msg.is_empty() {
*eof = true; *eof = true;
debug_assert!(!self.buf.is_null()); self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n");
let buf = unsafe { &mut *self.buf };
buf.extend_from_slice(b"0\r\n\r\n");
} else { } else {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
writeln!(&mut buf, "{:X}\r", msg.len()) writeln!(&mut buf, "{:X}\r", msg.len())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
debug_assert!(!self.buf.is_null()); let b = self.buf.as_mut().unwrap();
let b = unsafe { &mut *self.buf };
b.reserve(buf.len() + msg.len() + 2); b.reserve(buf.len() + msg.len() + 2);
b.extend_from_slice(buf.as_ref()); b.extend_from_slice(buf.as_ref());
b.extend_from_slice(msg); b.extend_from_slice(msg);
@ -795,8 +918,10 @@ impl TransferEncoding {
} }
let len = cmp::min(*remaining, msg.len() as u64); let len = cmp::min(*remaining, msg.len() as u64);
debug_assert!(!self.buf.is_null()); self.buf
unsafe { &mut *self.buf }.extend(&msg[..len as usize]); .as_mut()
.unwrap()
.extend_from_slice(&msg[..len as usize]);
*remaining -= len as u64; *remaining -= len as u64;
Ok(*remaining == 0) Ok(*remaining == 0)
@ -816,10 +941,7 @@ impl TransferEncoding {
TransferEncodingKind::Chunked(ref mut eof) => { TransferEncodingKind::Chunked(ref mut eof) => {
if !*eof { if !*eof {
*eof = true; *eof = true;
self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n");
debug_assert!(!self.buf.is_null());
let buf = unsafe { &mut *self.buf };
buf.extend_from_slice(b"0\r\n\r\n");
} }
true true
} }
@ -912,15 +1034,14 @@ mod tests {
#[test] #[test]
fn test_chunked_te() { fn test_chunked_te() {
let mut bytes = BytesMut::new(); let bytes = SharedBytes::empty();
let mut enc = TransferEncoding::chunked(); let mut enc = TransferEncoding::chunked(bytes);
{ {
enc.set_buffer(&mut bytes);
assert!(!enc.encode(b"test").ok().unwrap()); assert!(!enc.encode(b"test").ok().unwrap());
assert!(enc.encode(b"").ok().unwrap()); assert!(enc.encode(b"").ok().unwrap());
} }
assert_eq!( assert_eq!(
bytes.take().freeze(), enc.buf_mut().take().freeze(),
Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n") Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
); );
} }

View File

@ -6,7 +6,7 @@ use std::io;
use std::rc::Rc; use std::rc::Rc;
use tokio_io::AsyncWrite; use tokio_io::AsyncWrite;
use super::encoding::ContentEncoder; use super::encoding::{ContentEncoder, Output};
use super::helpers; use super::helpers;
use super::settings::WorkerSettings; use super::settings::WorkerSettings;
use super::shared::SharedBytes; use super::shared::SharedBytes;
@ -32,10 +32,9 @@ bitflags! {
pub(crate) struct H1Writer<T: AsyncWrite, H: 'static> { pub(crate) struct H1Writer<T: AsyncWrite, H: 'static> {
flags: Flags, flags: Flags,
stream: T, stream: T,
encoder: ContentEncoder,
written: u64, written: u64,
headers_size: u32, headers_size: u32,
buffer: SharedBytes, buffer: Output,
buffer_capacity: usize, buffer_capacity: usize,
settings: Rc<WorkerSettings<H>>, settings: Rc<WorkerSettings<H>>,
} }
@ -46,10 +45,9 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
) -> H1Writer<T, H> { ) -> H1Writer<T, H> {
H1Writer { H1Writer {
flags: Flags::KEEPALIVE, flags: Flags::KEEPALIVE,
encoder: ContentEncoder::empty(),
written: 0, written: 0,
headers_size: 0, headers_size: 0,
buffer: buf, buffer: Output::Buffer(buf),
buffer_capacity: 0, buffer_capacity: 0,
stream, stream,
settings, settings,
@ -66,7 +64,7 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
} }
pub fn disconnected(&mut self) { pub fn disconnected(&mut self) {
self.buffer.take(); self.buffer = Output::Empty;
} }
pub fn keepalive(&self) -> bool { pub fn keepalive(&self) -> bool {
@ -106,7 +104,8 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
#[inline] #[inline]
fn buffer(&mut self) -> &mut BytesMut { fn buffer(&mut self) -> &mut BytesMut {
self.buffer.get_mut() //self.buffer.get_mut()
unimplemented!()
} }
fn start( fn start(
@ -114,8 +113,6 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
encoding: ContentEncoding, encoding: ContentEncoding,
) -> io::Result<WriterState> { ) -> io::Result<WriterState> {
// prepare task // prepare task
self.encoder =
ContentEncoder::for_server(self.buffer.get_mut(), req, msg, encoding);
if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) {
self.flags = Flags::STARTED | Flags::KEEPALIVE; self.flags = Flags::STARTED | Flags::KEEPALIVE;
} else { } else {
@ -143,7 +140,9 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
// render message // render message
{ {
// output buffer
let mut buffer = self.buffer.get_mut(); let mut buffer = self.buffer.get_mut();
let reason = msg.reason().as_bytes(); let reason = msg.reason().as_bytes();
let mut is_bin = if let Body::Binary(ref bytes) = body { let mut is_bin = if let Body::Binary(ref bytes) = body {
buffer.reserve( buffer.reserve(
@ -222,9 +221,12 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
self.headers_size = buffer.len() as u32; self.headers_size = buffer.len() as u32;
} }
// output encoding
self.buffer = ContentEncoder::for_server(self.buffer.take(), req, msg, encoding);
if let Body::Binary(bytes) = body { if let Body::Binary(bytes) = body {
self.written = bytes.len() as u64; self.written = bytes.len() as u64;
self.encoder.write(bytes.as_ref())?; self.buffer.write(bytes.as_ref())?;
} else { } else {
// capacity, makes sense only for streaming or actor // capacity, makes sense only for streaming or actor
self.buffer_capacity = msg.write_buffer_capacity(); self.buffer_capacity = msg.write_buffer_capacity();
@ -253,19 +255,19 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
Ok(val) => val, Ok(val) => val,
}; };
if n < pl.len() { if n < pl.len() {
self.buffer.extend_from_slice(&pl[n..]); self.buffer.write(&pl[n..]);
return Ok(WriterState::Done); return Ok(WriterState::Done);
} }
} else { } else {
self.buffer.extend(payload); self.buffer.write(payload.as_ref());
} }
} else { } else {
// TODO: add warning, write after EOF // TODO: add warning, write after EOF
self.encoder.write(payload.as_ref())?; self.buffer.write(payload.as_ref())?;
} }
} else { } else {
// could be response to EXCEPT header // could be response to EXCEPT header
self.buffer.extend_from_slice(payload.as_ref()) self.buffer.write(payload.as_ref());
} }
} }
@ -277,7 +279,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
} }
fn write_eof(&mut self) -> io::Result<WriterState> { fn write_eof(&mut self) -> io::Result<WriterState> {
if !self.encoder.write_eof()? { if !self.buffer.write_eof()? {
Err(io::Error::new( Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"Last payload item, but eof is not reached", "Last payload item, but eof is not reached",
@ -293,7 +295,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
if !self.buffer.is_empty() { if !self.buffer.is_empty() {
let written = { let written = {
match Self::write_data(&mut self.stream, self.buffer.as_ref()) { match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) {
Err(err) => { Err(err) => {
if err.kind() == io::ErrorKind::WriteZero { if err.kind() == io::ErrorKind::WriteZero {
self.disconnected(); self.disconnected();

View File

@ -11,7 +11,7 @@ use std::{cmp, io};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{HttpTryFrom, Version}; use http::{HttpTryFrom, Version};
use super::encoding::ContentEncoder; use super::encoding::{ContentEncoder, Output};
use super::helpers; use super::helpers;
use super::settings::WorkerSettings; use super::settings::WorkerSettings;
use super::shared::SharedBytes; use super::shared::SharedBytes;
@ -35,10 +35,9 @@ bitflags! {
pub(crate) struct H2Writer<H: 'static> { pub(crate) struct H2Writer<H: 'static> {
respond: SendResponse<Bytes>, respond: SendResponse<Bytes>,
stream: Option<SendStream<Bytes>>, stream: Option<SendStream<Bytes>>,
encoder: ContentEncoder,
flags: Flags, flags: Flags,
written: u64, written: u64,
buffer: SharedBytes, buffer: Output,
buffer_capacity: usize, buffer_capacity: usize,
settings: Rc<WorkerSettings<H>>, settings: Rc<WorkerSettings<H>>,
} }
@ -51,10 +50,9 @@ impl<H: 'static> H2Writer<H> {
respond, respond,
settings, settings,
stream: None, stream: None,
encoder: ContentEncoder::empty(),
flags: Flags::empty(), flags: Flags::empty(),
written: 0, written: 0,
buffer: buf, buffer: Output::Buffer(buf),
buffer_capacity: 0, buffer_capacity: 0,
} }
} }
@ -87,8 +85,7 @@ impl<H: 'static> Writer for H2Writer<H> {
) -> io::Result<WriterState> { ) -> io::Result<WriterState> {
// prepare response // prepare response
self.flags.insert(Flags::STARTED); self.flags.insert(Flags::STARTED);
self.encoder = self.buffer = ContentEncoder::for_server(self.buffer.take(), req, msg, encoding);
ContentEncoder::for_server(self.buffer.get_mut(), req, msg, encoding);
// http2 specific // http2 specific
msg.headers_mut().remove(CONNECTION); msg.headers_mut().remove(CONNECTION);
@ -150,7 +147,7 @@ impl<H: 'static> Writer for H2Writer<H> {
} else { } else {
self.flags.insert(Flags::EOF); self.flags.insert(Flags::EOF);
self.written = bytes.len() as u64; self.written = bytes.len() as u64;
self.encoder.write(bytes.as_ref())?; self.buffer.write(bytes.as_ref())?;
if let Some(ref mut stream) = self.stream { if let Some(ref mut stream) = self.stream {
self.flags.insert(Flags::RESERVED); self.flags.insert(Flags::RESERVED);
stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE)); stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE));
@ -170,10 +167,10 @@ impl<H: 'static> Writer for H2Writer<H> {
if !self.flags.contains(Flags::DISCONNECTED) { if !self.flags.contains(Flags::DISCONNECTED) {
if self.flags.contains(Flags::STARTED) { if self.flags.contains(Flags::STARTED) {
// TODO: add warning, write after EOF // TODO: add warning, write after EOF
self.encoder.write(payload.as_ref())?; self.buffer.write(payload.as_ref())?;
} else { } else {
// might be response for EXCEPT // might be response for EXCEPT
self.buffer.extend_from_slice(payload.as_ref()) error!("Not supported");
} }
} }
@ -186,7 +183,7 @@ impl<H: 'static> Writer for H2Writer<H> {
fn write_eof(&mut self) -> io::Result<WriterState> { fn write_eof(&mut self) -> io::Result<WriterState> {
self.flags.insert(Flags::EOF); self.flags.insert(Flags::EOF);
if !self.encoder.write_eof()? { if !self.buffer.write_eof()? {
Err(io::Error::new( Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"Last payload item, but eof is not reached", "Last payload item, but eof is not reached",

View File

@ -5,8 +5,6 @@ use std::rc::Rc;
use bytes::BytesMut; use bytes::BytesMut;
use body::Binary;
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct SharedBytesPool(RefCell<VecDeque<BytesMut>>); pub(crate) struct SharedBytesPool(RefCell<VecDeque<BytesMut>>);
@ -50,6 +48,10 @@ impl SharedBytes {
SharedBytes(Some(bytes), Some(pool)) SharedBytes(Some(bytes), Some(pool))
} }
pub fn empty() -> SharedBytes {
SharedBytes(Some(BytesMut::new()), None)
}
#[inline] #[inline]
pub(crate) fn get_mut(&mut self) -> &mut BytesMut { pub(crate) fn get_mut(&mut self) -> &mut BytesMut {
self.0.as_mut().unwrap() self.0.as_mut().unwrap()
@ -79,9 +81,8 @@ impl SharedBytes {
} }
#[inline] #[inline]
pub fn extend(&mut self, data: &Binary) { pub fn reserve(&mut self, cap: usize) {
let buf = self.get_mut(); self.get_mut().reserve(cap);
buf.extend_from_slice(data.as_ref());
} }
#[inline] #[inline]