1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-23 15:24:36 +01:00

refactor content encoder

This commit is contained in:
Nikolay Kim 2018-06-24 08:54:01 +06:00
parent 348491b18c
commit 45682c04a8
6 changed files with 278 additions and 163 deletions

View File

@ -21,7 +21,8 @@ use tokio_io::AsyncWrite;
use body::{Binary, Body};
use header::ContentEncoding;
use server::encoding::{ContentEncoder, TransferEncoding};
use server::encoding::{ContentEncoder, Output, TransferEncoding};
use server::shared::SharedBytes;
use server::WriterState;
use client::ClientRequest;
@ -41,21 +42,18 @@ pub(crate) struct HttpClientWriter {
flags: Flags,
written: u64,
headers_size: u32,
buffer: Box<BytesMut>,
buffer: Output,
buffer_capacity: usize,
encoder: ContentEncoder,
}
impl HttpClientWriter {
pub fn new() -> HttpClientWriter {
let encoder = ContentEncoder::Identity(TransferEncoding::eof());
HttpClientWriter {
flags: Flags::empty(),
written: 0,
headers_size: 0,
buffer_capacity: 0,
buffer: Box::new(BytesMut::new()),
encoder,
buffer: Output::Buffer(SharedBytes::empty()),
}
}
@ -75,7 +73,7 @@ impl HttpClientWriter {
&mut self, stream: &mut T,
) -> io::Result<WriterState> {
while !self.buffer.is_empty() {
match stream.write(self.buffer.as_ref()) {
match stream.write(self.buffer.as_ref().as_ref()) {
Ok(0) => {
self.disconnected();
return Ok(WriterState::Done);
@ -113,16 +111,18 @@ impl HttpClientWriter {
pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> {
// prepare task
self.flags.insert(Flags::STARTED);
self.encoder = content_encoder(self.buffer.as_mut(), msg);
if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
}
// render message
{
// output buffer
let buffer = self.buffer.get_mut();
// status line
writeln!(
Writer(&mut self.buffer),
Writer(buffer),
"{} {} {:?}\r",
msg.method(),
msg.uri()
@ -134,41 +134,41 @@ impl HttpClientWriter {
// write headers
if let Body::Binary(ref bytes) = *msg.body() {
self.buffer
.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len());
buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len());
} else {
self.buffer
.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE);
buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE);
}
for (key, value) in msg.headers() {
let v = value.as_ref();
let k = key.as_str().as_bytes();
self.buffer.reserve(k.len() + v.len() + 4);
self.buffer.put_slice(k);
self.buffer.put_slice(b": ");
self.buffer.put_slice(v);
self.buffer.put_slice(b"\r\n");
buffer.reserve(k.len() + v.len() + 4);
buffer.put_slice(k);
buffer.put_slice(b": ");
buffer.put_slice(v);
buffer.put_slice(b"\r\n");
}
// set date header
if !msg.headers().contains_key(DATE) {
self.buffer.extend_from_slice(b"date: ");
set_date(&mut self.buffer);
self.buffer.extend_from_slice(b"\r\n\r\n");
buffer.extend_from_slice(b"date: ");
set_date(buffer);
buffer.extend_from_slice(b"\r\n\r\n");
} else {
self.buffer.extend_from_slice(b"\r\n");
buffer.extend_from_slice(b"\r\n");
}
self.headers_size = self.buffer.len() as u32;
}
self.headers_size = self.buffer.len() as u32;
if msg.body().is_binary() {
if let Body::Binary(bytes) = msg.replace_body(Body::Empty) {
self.written += bytes.len() as u64;
self.encoder.write(bytes.as_ref())?;
}
} else {
self.buffer_capacity = msg.write_buffer_capacity();
self.buffer = content_encoder(self.buffer.take(), msg);
if msg.body().is_binary() {
if let Body::Binary(bytes) = msg.replace_body(Body::Empty) {
self.written += bytes.len() as u64;
self.buffer.write(bytes.as_ref())?;
}
} else {
self.buffer_capacity = msg.write_buffer_capacity();
}
Ok(())
}
@ -176,11 +176,7 @@ impl HttpClientWriter {
pub fn write(&mut self, payload: &[u8]) -> io::Result<WriterState> {
self.written += payload.len() as u64;
if !self.flags.contains(Flags::DISCONNECTED) {
if self.flags.contains(Flags::UPGRADE) {
self.buffer.extend(payload);
} else {
self.encoder.write(payload)?;
}
self.buffer.write(payload)?;
}
if self.buffer.len() > self.buffer_capacity {
@ -191,9 +187,7 @@ impl HttpClientWriter {
}
pub fn write_eof(&mut self) -> io::Result<()> {
self.encoder.write_eof()?;
if self.encoder.is_eof() {
if self.buffer.write_eof()? {
Ok(())
} else {
Err(io::Error::new(
@ -221,21 +215,20 @@ impl HttpClientWriter {
}
}
fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncoder {
fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> Output {
let version = req.version();
let mut body = req.replace_body(Body::Empty);
let mut encoding = req.content_encoding();
let mut transfer = match body {
let transfer = match body {
Body::Empty => {
req.headers_mut().remove(CONTENT_LENGTH);
TransferEncoding::length(0)
TransferEncoding::length(0, buf)
}
Body::Binary(ref mut bytes) => {
if encoding.is_compression() {
let mut tmp = BytesMut::new();
let mut transfer = TransferEncoding::eof();
transfer.set_buffer(&mut tmp);
let mut tmp = SharedBytes::empty();
let mut transfer = TransferEncoding::eof(tmp);
let mut enc = match encoding {
#[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate(
@ -256,7 +249,7 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
// TODO return error!
let _ = enc.write(bytes.as_ref());
let _ = enc.write_eof();
*bytes = Binary::from(tmp.take());
*bytes = Binary::from(enc.buf_mut().take());
req.headers_mut().insert(
CONTENT_ENCODING,
@ -268,7 +261,7 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
let _ = write!(b, "{}", bytes.len());
req.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap());
TransferEncoding::eof()
TransferEncoding::eof(buf)
}
Body::Streaming(_) | Body::Actor(_) => {
if req.upgrade() {
@ -282,9 +275,9 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
encoding = ContentEncoding::Identity;
req.headers_mut().remove(CONTENT_ENCODING);
}
TransferEncoding::eof()
TransferEncoding::eof(buf)
} else {
streaming_encoding(version, req)
streaming_encoding(buf, version, req)
}
}
};
@ -295,10 +288,9 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
HeaderValue::from_static(encoding.as_str()),
);
}
transfer.set_buffer(buf);
req.replace_body(body);
match encoding {
let enc = match encoding {
#[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
transfer,
@ -310,23 +302,24 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode
}
#[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)),
ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer)
}
}
ContentEncoding::Identity | ContentEncoding::Auto => return Output::TE(transfer),
};
Output::Encoder(enc)
}
fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEncoding {
fn streaming_encoding(
buf: SharedBytes, version: Version, req: &mut ClientRequest,
) -> TransferEncoding {
if req.chunked() {
// Enable transfer encoding
req.headers_mut().remove(CONTENT_LENGTH);
if version == Version::HTTP_2 {
req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof()
TransferEncoding::eof(buf)
} else {
req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked()
TransferEncoding::chunked(buf)
}
} else {
// if Content-Length is specified, then use it as length hint
@ -349,9 +342,9 @@ fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEnco
if !chunked {
if let Some(len) = len {
TransferEncoding::length(len)
TransferEncoding::length(len, buf)
} else {
TransferEncoding::eof()
TransferEncoding::eof(buf)
}
} else {
// Enable transfer encoding
@ -359,11 +352,11 @@ fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEnco
Version::HTTP_11 => {
req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked()
TransferEncoding::chunked(buf)
}
_ => {
req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof()
TransferEncoding::eof(buf)
}
}
}

View File

@ -1127,6 +1127,7 @@ mod tests {
let response = srv.execute(request.send()).unwrap();
println!("RESP: {:?}", response);
let te = response
.headers()
.get(header::TRANSFER_ENCODING)

View File

@ -1,7 +1,7 @@
use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::str::FromStr;
use std::{cmp, io, mem, ptr};
use std::{cmp, io, mem};
#[cfg(feature = "brotli")]
use brotli2::write::{BrotliDecoder, BrotliEncoder};
@ -25,6 +25,8 @@ use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse;
use payload::{PayloadSender, PayloadStatus, PayloadWriter};
use super::shared::SharedBytes;
pub(crate) enum PayloadType {
Sender(PayloadSender),
Encoding(Box<EncodedPayload>),
@ -368,6 +370,83 @@ impl PayloadStream {
}
}
pub(crate) enum Output {
Buffer(SharedBytes),
Encoder(ContentEncoder),
TE(TransferEncoding),
Empty,
}
impl Output {
pub fn take(&mut self) -> SharedBytes {
match mem::replace(self, Output::Empty) {
Output::Buffer(bytes) => bytes,
_ => panic!(),
}
}
pub fn as_ref(&mut self) -> &SharedBytes {
match self {
Output::Buffer(ref mut bytes) => bytes,
Output::Encoder(ref mut enc) => enc.buf_ref(),
Output::TE(ref mut te) => te.buf_ref(),
Output::Empty => panic!(),
}
}
pub fn get_mut(&mut self) -> &mut BytesMut {
match self {
Output::Buffer(ref mut bytes) => bytes.get_mut(),
_ => panic!(),
}
}
pub fn split_to(&mut self, cap: usize) -> BytesMut {
match self {
Output::Buffer(ref mut bytes) => bytes.split_to(cap),
Output::Encoder(ref mut enc) => enc.buf_mut().split_to(cap),
Output::TE(ref mut te) => te.buf_mut().split_to(cap),
Output::Empty => BytesMut::new(),
}
}
pub fn len(&self) -> usize {
match self {
Output::Buffer(ref bytes) => bytes.len(),
Output::Encoder(ref enc) => enc.len(),
Output::TE(ref te) => te.len(),
Output::Empty => 0,
}
}
pub fn is_empty(&self) -> bool {
match self {
Output::Buffer(ref bytes) => bytes.is_empty(),
Output::Encoder(ref enc) => enc.is_empty(),
Output::TE(ref te) => te.is_empty(),
Output::Empty => true,
}
}
pub fn write(&mut self, data: &[u8]) -> Result<(), io::Error> {
match self {
Output::Buffer(ref mut bytes) => {
bytes.extend_from_slice(data);
Ok(())
}
Output::Encoder(ref mut enc) => enc.write(data),
Output::TE(ref mut te) => te.encode(data).map(|_| ()),
Output::Empty => Ok(()),
}
}
pub fn write_eof(&mut self) -> Result<bool, io::Error> {
match self {
Output::Buffer(_) => Ok(true),
Output::Encoder(ref mut enc) => enc.write_eof(),
Output::TE(ref mut te) => Ok(te.encode_eof()),
Output::Empty => Ok(true),
}
}
}
pub(crate) enum ContentEncoder {
#[cfg(feature = "flate2")]
Deflate(DeflateEncoder<TransferEncoding>),
@ -379,14 +458,10 @@ pub(crate) enum ContentEncoder {
}
impl ContentEncoder {
pub fn empty() -> ContentEncoder {
ContentEncoder::Identity(TransferEncoding::eof())
}
pub fn for_server(
buf: &mut BytesMut, req: &HttpInnerMessage, resp: &mut HttpResponse,
buf: SharedBytes, req: &HttpInnerMessage, resp: &mut HttpResponse,
response_encoding: ContentEncoding,
) -> ContentEncoder {
) -> Output {
let version = resp.version().unwrap_or_else(|| req.version);
let is_head = req.method == Method::HEAD;
let mut len = 0;
@ -439,7 +514,7 @@ impl ContentEncoder {
if req.method != Method::HEAD {
resp.headers_mut().remove(CONTENT_LENGTH);
}
TransferEncoding::length(0)
TransferEncoding::length(0, buf)
}
&Body::Binary(_) => {
#[cfg(any(feature = "brotli", feature = "flate2"))]
@ -447,9 +522,8 @@ impl ContentEncoder {
if !(encoding == ContentEncoding::Identity
|| encoding == ContentEncoding::Auto)
{
let mut tmp = BytesMut::default();
let mut transfer = TransferEncoding::eof();
transfer.set_buffer(&mut tmp);
let mut tmp = SharedBytes::empty();
let mut transfer = TransferEncoding::eof(tmp);
let mut enc = match encoding {
#[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate(
@ -473,7 +547,7 @@ impl ContentEncoder {
// TODO return error!
let _ = enc.write(bin.as_ref());
let _ = enc.write_eof();
let body = tmp.take();
let body = enc.buf_mut().take();
len = body.len();
encoding = ContentEncoding::Identity;
@ -491,7 +565,7 @@ impl ContentEncoder {
} else {
// resp.headers_mut().remove(CONTENT_LENGTH);
}
TransferEncoding::eof()
TransferEncoding::eof(buf)
}
&Body::Streaming(_) | &Body::Actor(_) => {
if resp.upgrade() {
@ -502,14 +576,14 @@ impl ContentEncoder {
encoding = ContentEncoding::Identity;
resp.headers_mut().remove(CONTENT_ENCODING);
}
TransferEncoding::eof()
TransferEncoding::eof(buf)
} else {
if !(encoding == ContentEncoding::Identity
|| encoding == ContentEncoding::Auto)
{
resp.headers_mut().remove(CONTENT_LENGTH);
}
ContentEncoder::streaming_encoding(version, resp)
ContentEncoder::streaming_encoding(buf, version, resp)
}
}
};
@ -518,9 +592,8 @@ impl ContentEncoder {
resp.set_body(Body::Empty);
transfer.kind = TransferEncodingKind::Length(0);
}
transfer.set_buffer(buf);
match encoding {
let enc = match encoding {
#[cfg(feature = "flate2")]
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
transfer,
@ -533,13 +606,14 @@ impl ContentEncoder {
#[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 3)),
ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer)
return Output::TE(transfer)
}
}
};
Output::Encoder(enc)
}
fn streaming_encoding(
version: Version, resp: &mut HttpResponse,
buf: SharedBytes, version: Version, resp: &mut HttpResponse,
) -> TransferEncoding {
match resp.chunked() {
Some(true) => {
@ -547,14 +621,14 @@ impl ContentEncoder {
resp.headers_mut().remove(CONTENT_LENGTH);
if version == Version::HTTP_2 {
resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof()
TransferEncoding::eof(buf)
} else {
resp.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked()
TransferEncoding::chunked(buf)
}
}
Some(false) => TransferEncoding::eof(),
Some(false) => TransferEncoding::eof(buf),
None => {
// if Content-Length is specified, then use it as length hint
let (len, chunked) =
@ -577,9 +651,9 @@ impl ContentEncoder {
if !chunked {
if let Some(len) = len {
TransferEncoding::length(len)
TransferEncoding::length(len, buf)
} else {
TransferEncoding::eof()
TransferEncoding::eof(buf)
}
} else {
// Enable transfer encoding
@ -589,11 +663,11 @@ impl ContentEncoder {
TRANSFER_ENCODING,
HeaderValue::from_static("chunked"),
);
TransferEncoding::chunked()
TransferEncoding::chunked(buf)
}
_ => {
resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof()
TransferEncoding::eof(buf)
}
}
}
@ -604,15 +678,54 @@ impl ContentEncoder {
impl ContentEncoder {
#[inline]
pub fn is_eof(&self) -> bool {
pub fn len(&self) -> usize {
match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref encoder) => encoder.get_ref().is_eof(),
ContentEncoder::Br(ref encoder) => encoder.get_ref().len(),
#[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_eof(),
ContentEncoder::Deflate(ref encoder) => encoder.get_ref().len(),
#[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_eof(),
ContentEncoder::Identity(ref encoder) => encoder.is_eof(),
ContentEncoder::Gzip(ref encoder) => encoder.get_ref().len(),
ContentEncoder::Identity(ref encoder) => encoder.len(),
}
}
#[inline]
pub fn is_empty(&self) -> bool {
match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref encoder) => encoder.get_ref().is_empty(),
#[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_empty(),
#[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_empty(),
ContentEncoder::Identity(ref encoder) => encoder.is_empty(),
}
}
#[inline]
pub(crate) fn buf_mut(&mut self) -> &mut BytesMut {
match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_mut(),
#[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_mut(),
#[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_mut(),
ContentEncoder::Identity(ref mut encoder) => encoder.buf_mut(),
}
}
#[inline]
pub(crate) fn buf_ref(&mut self) -> &SharedBytes {
match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_ref(),
#[cfg(feature = "flate2")]
ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_ref(),
#[cfg(feature = "flate2")]
ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_ref(),
ContentEncoder::Identity(ref mut encoder) => encoder.buf_ref(),
}
}
@ -620,7 +733,7 @@ impl ContentEncoder {
#[inline(always)]
pub fn write_eof(&mut self) -> Result<bool, io::Error> {
let encoder =
mem::replace(self, ContentEncoder::Identity(TransferEncoding::eof()));
mem::replace(self, ContentEncoder::Identity(TransferEncoding::empty()));
match encoder {
#[cfg(feature = "brotli")]
@ -695,9 +808,9 @@ impl ContentEncoder {
}
/// Encoders to handle different Transfer-Encodings.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct TransferEncoding {
buf: *mut BytesMut,
buf: Option<SharedBytes>,
kind: TransferEncodingKind,
}
@ -716,40 +829,55 @@ enum TransferEncodingKind {
}
impl TransferEncoding {
pub(crate) fn set_buffer(&mut self, buf: *mut BytesMut) {
self.buf = buf;
fn take(self) -> SharedBytes {
self.buf.unwrap()
}
fn buf_ref(&mut self) -> &SharedBytes {
self.buf.as_ref().unwrap()
}
fn len(&self) -> usize {
self.buf.as_ref().unwrap().len()
}
fn is_empty(&self) -> bool {
self.buf.as_ref().unwrap().is_empty()
}
fn buf_mut(&mut self) -> &mut BytesMut {
self.buf.as_mut().unwrap().get_mut()
}
#[inline]
pub fn eof() -> TransferEncoding {
pub fn empty() -> TransferEncoding {
TransferEncoding {
buf: None,
kind: TransferEncodingKind::Eof,
buf: ptr::null_mut(),
}
}
#[inline]
pub fn chunked() -> TransferEncoding {
pub fn eof(buf: SharedBytes) -> TransferEncoding {
TransferEncoding {
buf: Some(buf),
kind: TransferEncodingKind::Eof,
}
}
#[inline]
pub fn chunked(buf: SharedBytes) -> TransferEncoding {
TransferEncoding {
buf: Some(buf),
kind: TransferEncodingKind::Chunked(false),
buf: ptr::null_mut(),
}
}
#[inline]
pub fn length(len: u64) -> TransferEncoding {
pub fn length(len: u64, buf: SharedBytes) -> TransferEncoding {
TransferEncoding {
buf: Some(buf),
kind: TransferEncodingKind::Length(len),
buf: ptr::null_mut(),
}
}
#[inline]
pub fn is_eof(&self) -> bool {
match self.kind {
TransferEncodingKind::Eof => true,
TransferEncodingKind::Chunked(ref eof) => *eof,
TransferEncodingKind::Length(ref remaining) => *remaining == 0,
}
}
@ -759,9 +887,7 @@ impl TransferEncoding {
match self.kind {
TransferEncodingKind::Eof => {
let eof = msg.is_empty();
debug_assert!(!self.buf.is_null());
let buf = unsafe { &mut *self.buf };
buf.extend(msg);
self.buf.as_mut().unwrap().extend_from_slice(msg);
Ok(eof)
}
TransferEncodingKind::Chunked(ref mut eof) => {
@ -771,16 +897,13 @@ impl TransferEncoding {
if msg.is_empty() {
*eof = true;
debug_assert!(!self.buf.is_null());
let buf = unsafe { &mut *self.buf };
buf.extend_from_slice(b"0\r\n\r\n");
self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n");
} else {
let mut buf = BytesMut::new();
writeln!(&mut buf, "{:X}\r", msg.len())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
debug_assert!(!self.buf.is_null());
let b = unsafe { &mut *self.buf };
let b = self.buf.as_mut().unwrap();
b.reserve(buf.len() + msg.len() + 2);
b.extend_from_slice(buf.as_ref());
b.extend_from_slice(msg);
@ -795,8 +918,10 @@ impl TransferEncoding {
}
let len = cmp::min(*remaining, msg.len() as u64);
debug_assert!(!self.buf.is_null());
unsafe { &mut *self.buf }.extend(&msg[..len as usize]);
self.buf
.as_mut()
.unwrap()
.extend_from_slice(&msg[..len as usize]);
*remaining -= len as u64;
Ok(*remaining == 0)
@ -816,10 +941,7 @@ impl TransferEncoding {
TransferEncodingKind::Chunked(ref mut eof) => {
if !*eof {
*eof = true;
debug_assert!(!self.buf.is_null());
let buf = unsafe { &mut *self.buf };
buf.extend_from_slice(b"0\r\n\r\n");
self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n");
}
true
}
@ -912,15 +1034,14 @@ mod tests {
#[test]
fn test_chunked_te() {
let mut bytes = BytesMut::new();
let mut enc = TransferEncoding::chunked();
let bytes = SharedBytes::empty();
let mut enc = TransferEncoding::chunked(bytes);
{
enc.set_buffer(&mut bytes);
assert!(!enc.encode(b"test").ok().unwrap());
assert!(enc.encode(b"").ok().unwrap());
}
assert_eq!(
bytes.take().freeze(),
enc.buf_mut().take().freeze(),
Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
);
}

View File

@ -6,7 +6,7 @@ use std::io;
use std::rc::Rc;
use tokio_io::AsyncWrite;
use super::encoding::ContentEncoder;
use super::encoding::{ContentEncoder, Output};
use super::helpers;
use super::settings::WorkerSettings;
use super::shared::SharedBytes;
@ -32,10 +32,9 @@ bitflags! {
pub(crate) struct H1Writer<T: AsyncWrite, H: 'static> {
flags: Flags,
stream: T,
encoder: ContentEncoder,
written: u64,
headers_size: u32,
buffer: SharedBytes,
buffer: Output,
buffer_capacity: usize,
settings: Rc<WorkerSettings<H>>,
}
@ -46,10 +45,9 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
) -> H1Writer<T, H> {
H1Writer {
flags: Flags::KEEPALIVE,
encoder: ContentEncoder::empty(),
written: 0,
headers_size: 0,
buffer: buf,
buffer: Output::Buffer(buf),
buffer_capacity: 0,
stream,
settings,
@ -66,7 +64,7 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
}
pub fn disconnected(&mut self) {
self.buffer.take();
self.buffer = Output::Empty;
}
pub fn keepalive(&self) -> bool {
@ -106,7 +104,8 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
#[inline]
fn buffer(&mut self) -> &mut BytesMut {
self.buffer.get_mut()
//self.buffer.get_mut()
unimplemented!()
}
fn start(
@ -114,8 +113,6 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
encoding: ContentEncoding,
) -> io::Result<WriterState> {
// prepare task
self.encoder =
ContentEncoder::for_server(self.buffer.get_mut(), req, msg, encoding);
if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) {
self.flags = Flags::STARTED | Flags::KEEPALIVE;
} else {
@ -143,7 +140,9 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
// render message
{
// output buffer
let mut buffer = self.buffer.get_mut();
let reason = msg.reason().as_bytes();
let mut is_bin = if let Body::Binary(ref bytes) = body {
buffer.reserve(
@ -222,9 +221,12 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
self.headers_size = buffer.len() as u32;
}
// output encoding
self.buffer = ContentEncoder::for_server(self.buffer.take(), req, msg, encoding);
if let Body::Binary(bytes) = body {
self.written = bytes.len() as u64;
self.encoder.write(bytes.as_ref())?;
self.buffer.write(bytes.as_ref())?;
} else {
// capacity, makes sense only for streaming or actor
self.buffer_capacity = msg.write_buffer_capacity();
@ -253,19 +255,19 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
Ok(val) => val,
};
if n < pl.len() {
self.buffer.extend_from_slice(&pl[n..]);
self.buffer.write(&pl[n..]);
return Ok(WriterState::Done);
}
} else {
self.buffer.extend(payload);
self.buffer.write(payload.as_ref());
}
} else {
// TODO: add warning, write after EOF
self.encoder.write(payload.as_ref())?;
self.buffer.write(payload.as_ref())?;
}
} else {
// could be response to EXCEPT header
self.buffer.extend_from_slice(payload.as_ref())
self.buffer.write(payload.as_ref());
}
}
@ -277,7 +279,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
}
fn write_eof(&mut self) -> io::Result<WriterState> {
if !self.encoder.write_eof()? {
if !self.buffer.write_eof()? {
Err(io::Error::new(
io::ErrorKind::Other,
"Last payload item, but eof is not reached",
@ -293,7 +295,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
if !self.buffer.is_empty() {
let written = {
match Self::write_data(&mut self.stream, self.buffer.as_ref()) {
match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) {
Err(err) => {
if err.kind() == io::ErrorKind::WriteZero {
self.disconnected();

View File

@ -11,7 +11,7 @@ use std::{cmp, io};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{HttpTryFrom, Version};
use super::encoding::ContentEncoder;
use super::encoding::{ContentEncoder, Output};
use super::helpers;
use super::settings::WorkerSettings;
use super::shared::SharedBytes;
@ -35,10 +35,9 @@ bitflags! {
pub(crate) struct H2Writer<H: 'static> {
respond: SendResponse<Bytes>,
stream: Option<SendStream<Bytes>>,
encoder: ContentEncoder,
flags: Flags,
written: u64,
buffer: SharedBytes,
buffer: Output,
buffer_capacity: usize,
settings: Rc<WorkerSettings<H>>,
}
@ -51,10 +50,9 @@ impl<H: 'static> H2Writer<H> {
respond,
settings,
stream: None,
encoder: ContentEncoder::empty(),
flags: Flags::empty(),
written: 0,
buffer: buf,
buffer: Output::Buffer(buf),
buffer_capacity: 0,
}
}
@ -87,8 +85,7 @@ impl<H: 'static> Writer for H2Writer<H> {
) -> io::Result<WriterState> {
// prepare response
self.flags.insert(Flags::STARTED);
self.encoder =
ContentEncoder::for_server(self.buffer.get_mut(), req, msg, encoding);
self.buffer = ContentEncoder::for_server(self.buffer.take(), req, msg, encoding);
// http2 specific
msg.headers_mut().remove(CONNECTION);
@ -150,7 +147,7 @@ impl<H: 'static> Writer for H2Writer<H> {
} else {
self.flags.insert(Flags::EOF);
self.written = bytes.len() as u64;
self.encoder.write(bytes.as_ref())?;
self.buffer.write(bytes.as_ref())?;
if let Some(ref mut stream) = self.stream {
self.flags.insert(Flags::RESERVED);
stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE));
@ -170,10 +167,10 @@ impl<H: 'static> Writer for H2Writer<H> {
if !self.flags.contains(Flags::DISCONNECTED) {
if self.flags.contains(Flags::STARTED) {
// TODO: add warning, write after EOF
self.encoder.write(payload.as_ref())?;
self.buffer.write(payload.as_ref())?;
} else {
// might be response for EXCEPT
self.buffer.extend_from_slice(payload.as_ref())
error!("Not supported");
}
}
@ -186,7 +183,7 @@ impl<H: 'static> Writer for H2Writer<H> {
fn write_eof(&mut self) -> io::Result<WriterState> {
self.flags.insert(Flags::EOF);
if !self.encoder.write_eof()? {
if !self.buffer.write_eof()? {
Err(io::Error::new(
io::ErrorKind::Other,
"Last payload item, but eof is not reached",

View File

@ -5,8 +5,6 @@ use std::rc::Rc;
use bytes::BytesMut;
use body::Binary;
#[derive(Debug)]
pub(crate) struct SharedBytesPool(RefCell<VecDeque<BytesMut>>);
@ -50,6 +48,10 @@ impl SharedBytes {
SharedBytes(Some(bytes), Some(pool))
}
pub fn empty() -> SharedBytes {
SharedBytes(Some(BytesMut::new()), None)
}
#[inline]
pub(crate) fn get_mut(&mut self) -> &mut BytesMut {
self.0.as_mut().unwrap()
@ -79,9 +81,8 @@ impl SharedBytes {
}
#[inline]
pub fn extend(&mut self, data: &Binary) {
let buf = self.get_mut();
buf.extend_from_slice(data.as_ref());
pub fn reserve(&mut self, cap: usize) {
self.get_mut().reserve(cap);
}
#[inline]