diff --git a/src/client/writer.rs b/src/client/writer.rs index f9961a79..bf626513 100644 --- a/src/client/writer.rs +++ b/src/client/writer.rs @@ -21,7 +21,8 @@ use tokio_io::AsyncWrite; use body::{Binary, Body}; use header::ContentEncoding; -use server::encoding::{ContentEncoder, TransferEncoding}; +use server::encoding::{ContentEncoder, Output, TransferEncoding}; +use server::shared::SharedBytes; use server::WriterState; use client::ClientRequest; @@ -41,21 +42,18 @@ pub(crate) struct HttpClientWriter { flags: Flags, written: u64, headers_size: u32, - buffer: Box, + buffer: Output, buffer_capacity: usize, - encoder: ContentEncoder, } impl HttpClientWriter { pub fn new() -> HttpClientWriter { - let encoder = ContentEncoder::Identity(TransferEncoding::eof()); HttpClientWriter { flags: Flags::empty(), written: 0, headers_size: 0, buffer_capacity: 0, - buffer: Box::new(BytesMut::new()), - encoder, + buffer: Output::Buffer(SharedBytes::empty()), } } @@ -75,7 +73,7 @@ impl HttpClientWriter { &mut self, stream: &mut T, ) -> io::Result { while !self.buffer.is_empty() { - match stream.write(self.buffer.as_ref()) { + match stream.write(self.buffer.as_ref().as_ref()) { Ok(0) => { self.disconnected(); return Ok(WriterState::Done); @@ -113,16 +111,18 @@ impl HttpClientWriter { pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> { // prepare task self.flags.insert(Flags::STARTED); - self.encoder = content_encoder(self.buffer.as_mut(), msg); if msg.upgrade() { self.flags.insert(Flags::UPGRADE); } // render message { + // output buffer + let buffer = self.buffer.get_mut(); + // status line writeln!( - Writer(&mut self.buffer), + Writer(buffer), "{} {} {:?}\r", msg.method(), msg.uri() @@ -134,41 +134,41 @@ impl HttpClientWriter { // write headers if let Body::Binary(ref bytes) = *msg.body() { - self.buffer - .reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len()); + buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len()); } else { - self.buffer - .reserve(msg.headers().len() * AVERAGE_HEADER_SIZE); + buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE); } for (key, value) in msg.headers() { let v = value.as_ref(); let k = key.as_str().as_bytes(); - self.buffer.reserve(k.len() + v.len() + 4); - self.buffer.put_slice(k); - self.buffer.put_slice(b": "); - self.buffer.put_slice(v); - self.buffer.put_slice(b"\r\n"); + buffer.reserve(k.len() + v.len() + 4); + buffer.put_slice(k); + buffer.put_slice(b": "); + buffer.put_slice(v); + buffer.put_slice(b"\r\n"); } // set date header if !msg.headers().contains_key(DATE) { - self.buffer.extend_from_slice(b"date: "); - set_date(&mut self.buffer); - self.buffer.extend_from_slice(b"\r\n\r\n"); + buffer.extend_from_slice(b"date: "); + set_date(buffer); + buffer.extend_from_slice(b"\r\n\r\n"); } else { - self.buffer.extend_from_slice(b"\r\n"); + buffer.extend_from_slice(b"\r\n"); } - self.headers_size = self.buffer.len() as u32; + } + self.headers_size = self.buffer.len() as u32; - if msg.body().is_binary() { - if let Body::Binary(bytes) = msg.replace_body(Body::Empty) { - self.written += bytes.len() as u64; - self.encoder.write(bytes.as_ref())?; - } - } else { - self.buffer_capacity = msg.write_buffer_capacity(); + self.buffer = content_encoder(self.buffer.take(), msg); + + if msg.body().is_binary() { + if let Body::Binary(bytes) = msg.replace_body(Body::Empty) { + self.written += bytes.len() as u64; + self.buffer.write(bytes.as_ref())?; } + } else { + self.buffer_capacity = msg.write_buffer_capacity(); } Ok(()) } @@ -176,11 +176,7 @@ impl HttpClientWriter { pub fn write(&mut self, payload: &[u8]) -> io::Result { self.written += payload.len() as u64; if !self.flags.contains(Flags::DISCONNECTED) { - if self.flags.contains(Flags::UPGRADE) { - self.buffer.extend(payload); - } else { - self.encoder.write(payload)?; - } + self.buffer.write(payload)?; } if self.buffer.len() > self.buffer_capacity { @@ -191,9 +187,7 @@ impl HttpClientWriter { } pub fn write_eof(&mut self) -> io::Result<()> { - self.encoder.write_eof()?; - - if self.encoder.is_eof() { + if self.buffer.write_eof()? { Ok(()) } else { Err(io::Error::new( @@ -221,21 +215,20 @@ impl HttpClientWriter { } } -fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncoder { +fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> Output { let version = req.version(); let mut body = req.replace_body(Body::Empty); let mut encoding = req.content_encoding(); - let mut transfer = match body { + let transfer = match body { Body::Empty => { req.headers_mut().remove(CONTENT_LENGTH); - TransferEncoding::length(0) + TransferEncoding::length(0, buf) } Body::Binary(ref mut bytes) => { if encoding.is_compression() { - let mut tmp = BytesMut::new(); - let mut transfer = TransferEncoding::eof(); - transfer.set_buffer(&mut tmp); + let mut tmp = SharedBytes::empty(); + let mut transfer = TransferEncoding::eof(tmp); let mut enc = match encoding { #[cfg(feature = "flate2")] ContentEncoding::Deflate => ContentEncoder::Deflate( @@ -256,7 +249,7 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode // TODO return error! let _ = enc.write(bytes.as_ref()); let _ = enc.write_eof(); - *bytes = Binary::from(tmp.take()); + *bytes = Binary::from(enc.buf_mut().take()); req.headers_mut().insert( CONTENT_ENCODING, @@ -268,7 +261,7 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode let _ = write!(b, "{}", bytes.len()); req.headers_mut() .insert(CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); - TransferEncoding::eof() + TransferEncoding::eof(buf) } Body::Streaming(_) | Body::Actor(_) => { if req.upgrade() { @@ -282,9 +275,9 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode encoding = ContentEncoding::Identity; req.headers_mut().remove(CONTENT_ENCODING); } - TransferEncoding::eof() + TransferEncoding::eof(buf) } else { - streaming_encoding(version, req) + streaming_encoding(buf, version, req) } } }; @@ -295,10 +288,9 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode HeaderValue::from_static(encoding.as_str()), ); } - transfer.set_buffer(buf); req.replace_body(body); - match encoding { + let enc = match encoding { #[cfg(feature = "flate2")] ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new( transfer, @@ -310,23 +302,24 @@ fn content_encoder(buf: &mut BytesMut, req: &mut ClientRequest) -> ContentEncode } #[cfg(feature = "brotli")] ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)), - ContentEncoding::Identity | ContentEncoding::Auto => { - ContentEncoder::Identity(transfer) - } - } + ContentEncoding::Identity | ContentEncoding::Auto => return Output::TE(transfer), + }; + Output::Encoder(enc) } -fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEncoding { +fn streaming_encoding( + buf: SharedBytes, version: Version, req: &mut ClientRequest, +) -> TransferEncoding { if req.chunked() { // Enable transfer encoding req.headers_mut().remove(CONTENT_LENGTH); if version == Version::HTTP_2 { req.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof() + TransferEncoding::eof(buf) } else { req.headers_mut() .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked() + TransferEncoding::chunked(buf) } } else { // if Content-Length is specified, then use it as length hint @@ -349,9 +342,9 @@ fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEnco if !chunked { if let Some(len) = len { - TransferEncoding::length(len) + TransferEncoding::length(len, buf) } else { - TransferEncoding::eof() + TransferEncoding::eof(buf) } } else { // Enable transfer encoding @@ -359,11 +352,11 @@ fn streaming_encoding(version: Version, req: &mut ClientRequest) -> TransferEnco Version::HTTP_11 => { req.headers_mut() .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked() + TransferEncoding::chunked(buf) } _ => { req.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof() + TransferEncoding::eof(buf) } } } diff --git a/src/fs.rs b/src/fs.rs index c5a7de61..bf9079cc 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -1127,6 +1127,7 @@ mod tests { let response = srv.execute(request.send()).unwrap(); + println!("RESP: {:?}", response); let te = response .headers() .get(header::TRANSFER_ENCODING) diff --git a/src/server/encoding.rs b/src/server/encoding.rs index db0fd0c3..e7dc7a1a 100644 --- a/src/server/encoding.rs +++ b/src/server/encoding.rs @@ -1,7 +1,7 @@ use std::fmt::Write as FmtWrite; use std::io::{Read, Write}; use std::str::FromStr; -use std::{cmp, io, mem, ptr}; +use std::{cmp, io, mem}; #[cfg(feature = "brotli")] use brotli2::write::{BrotliDecoder, BrotliEncoder}; @@ -25,6 +25,8 @@ use httprequest::HttpInnerMessage; use httpresponse::HttpResponse; use payload::{PayloadSender, PayloadStatus, PayloadWriter}; +use super::shared::SharedBytes; + pub(crate) enum PayloadType { Sender(PayloadSender), Encoding(Box), @@ -368,6 +370,83 @@ impl PayloadStream { } } +pub(crate) enum Output { + Buffer(SharedBytes), + Encoder(ContentEncoder), + TE(TransferEncoding), + Empty, +} + +impl Output { + pub fn take(&mut self) -> SharedBytes { + match mem::replace(self, Output::Empty) { + Output::Buffer(bytes) => bytes, + _ => panic!(), + } + } + pub fn as_ref(&mut self) -> &SharedBytes { + match self { + Output::Buffer(ref mut bytes) => bytes, + Output::Encoder(ref mut enc) => enc.buf_ref(), + Output::TE(ref mut te) => te.buf_ref(), + Output::Empty => panic!(), + } + } + pub fn get_mut(&mut self) -> &mut BytesMut { + match self { + Output::Buffer(ref mut bytes) => bytes.get_mut(), + _ => panic!(), + } + } + pub fn split_to(&mut self, cap: usize) -> BytesMut { + match self { + Output::Buffer(ref mut bytes) => bytes.split_to(cap), + Output::Encoder(ref mut enc) => enc.buf_mut().split_to(cap), + Output::TE(ref mut te) => te.buf_mut().split_to(cap), + Output::Empty => BytesMut::new(), + } + } + + pub fn len(&self) -> usize { + match self { + Output::Buffer(ref bytes) => bytes.len(), + Output::Encoder(ref enc) => enc.len(), + Output::TE(ref te) => te.len(), + Output::Empty => 0, + } + } + + pub fn is_empty(&self) -> bool { + match self { + Output::Buffer(ref bytes) => bytes.is_empty(), + Output::Encoder(ref enc) => enc.is_empty(), + Output::TE(ref te) => te.is_empty(), + Output::Empty => true, + } + } + + pub fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { + match self { + Output::Buffer(ref mut bytes) => { + bytes.extend_from_slice(data); + Ok(()) + } + Output::Encoder(ref mut enc) => enc.write(data), + Output::TE(ref mut te) => te.encode(data).map(|_| ()), + Output::Empty => Ok(()), + } + } + + pub fn write_eof(&mut self) -> Result { + match self { + Output::Buffer(_) => Ok(true), + Output::Encoder(ref mut enc) => enc.write_eof(), + Output::TE(ref mut te) => Ok(te.encode_eof()), + Output::Empty => Ok(true), + } + } +} + pub(crate) enum ContentEncoder { #[cfg(feature = "flate2")] Deflate(DeflateEncoder), @@ -379,14 +458,10 @@ pub(crate) enum ContentEncoder { } impl ContentEncoder { - pub fn empty() -> ContentEncoder { - ContentEncoder::Identity(TransferEncoding::eof()) - } - pub fn for_server( - buf: &mut BytesMut, req: &HttpInnerMessage, resp: &mut HttpResponse, + buf: SharedBytes, req: &HttpInnerMessage, resp: &mut HttpResponse, response_encoding: ContentEncoding, - ) -> ContentEncoder { + ) -> Output { let version = resp.version().unwrap_or_else(|| req.version); let is_head = req.method == Method::HEAD; let mut len = 0; @@ -439,7 +514,7 @@ impl ContentEncoder { if req.method != Method::HEAD { resp.headers_mut().remove(CONTENT_LENGTH); } - TransferEncoding::length(0) + TransferEncoding::length(0, buf) } &Body::Binary(_) => { #[cfg(any(feature = "brotli", feature = "flate2"))] @@ -447,9 +522,8 @@ impl ContentEncoder { if !(encoding == ContentEncoding::Identity || encoding == ContentEncoding::Auto) { - let mut tmp = BytesMut::default(); - let mut transfer = TransferEncoding::eof(); - transfer.set_buffer(&mut tmp); + let mut tmp = SharedBytes::empty(); + let mut transfer = TransferEncoding::eof(tmp); let mut enc = match encoding { #[cfg(feature = "flate2")] ContentEncoding::Deflate => ContentEncoder::Deflate( @@ -473,7 +547,7 @@ impl ContentEncoder { // TODO return error! let _ = enc.write(bin.as_ref()); let _ = enc.write_eof(); - let body = tmp.take(); + let body = enc.buf_mut().take(); len = body.len(); encoding = ContentEncoding::Identity; @@ -491,7 +565,7 @@ impl ContentEncoder { } else { // resp.headers_mut().remove(CONTENT_LENGTH); } - TransferEncoding::eof() + TransferEncoding::eof(buf) } &Body::Streaming(_) | &Body::Actor(_) => { if resp.upgrade() { @@ -502,14 +576,14 @@ impl ContentEncoder { encoding = ContentEncoding::Identity; resp.headers_mut().remove(CONTENT_ENCODING); } - TransferEncoding::eof() + TransferEncoding::eof(buf) } else { if !(encoding == ContentEncoding::Identity || encoding == ContentEncoding::Auto) { resp.headers_mut().remove(CONTENT_LENGTH); } - ContentEncoder::streaming_encoding(version, resp) + ContentEncoder::streaming_encoding(buf, version, resp) } } }; @@ -518,9 +592,8 @@ impl ContentEncoder { resp.set_body(Body::Empty); transfer.kind = TransferEncodingKind::Length(0); } - transfer.set_buffer(buf); - match encoding { + let enc = match encoding { #[cfg(feature = "flate2")] ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new( transfer, @@ -533,13 +606,14 @@ impl ContentEncoder { #[cfg(feature = "brotli")] ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 3)), ContentEncoding::Identity | ContentEncoding::Auto => { - ContentEncoder::Identity(transfer) + return Output::TE(transfer) } - } + }; + Output::Encoder(enc) } fn streaming_encoding( - version: Version, resp: &mut HttpResponse, + buf: SharedBytes, version: Version, resp: &mut HttpResponse, ) -> TransferEncoding { match resp.chunked() { Some(true) => { @@ -547,14 +621,14 @@ impl ContentEncoder { resp.headers_mut().remove(CONTENT_LENGTH); if version == Version::HTTP_2 { resp.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof() + TransferEncoding::eof(buf) } else { resp.headers_mut() .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked() + TransferEncoding::chunked(buf) } } - Some(false) => TransferEncoding::eof(), + Some(false) => TransferEncoding::eof(buf), None => { // if Content-Length is specified, then use it as length hint let (len, chunked) = @@ -577,9 +651,9 @@ impl ContentEncoder { if !chunked { if let Some(len) = len { - TransferEncoding::length(len) + TransferEncoding::length(len, buf) } else { - TransferEncoding::eof() + TransferEncoding::eof(buf) } } else { // Enable transfer encoding @@ -589,11 +663,11 @@ impl ContentEncoder { TRANSFER_ENCODING, HeaderValue::from_static("chunked"), ); - TransferEncoding::chunked() + TransferEncoding::chunked(buf) } _ => { resp.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof() + TransferEncoding::eof(buf) } } } @@ -604,15 +678,54 @@ impl ContentEncoder { impl ContentEncoder { #[inline] - pub fn is_eof(&self) -> bool { + pub fn len(&self) -> usize { match *self { #[cfg(feature = "brotli")] - ContentEncoder::Br(ref encoder) => encoder.get_ref().is_eof(), + ContentEncoder::Br(ref encoder) => encoder.get_ref().len(), #[cfg(feature = "flate2")] - ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_eof(), + ContentEncoder::Deflate(ref encoder) => encoder.get_ref().len(), #[cfg(feature = "flate2")] - ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_eof(), - ContentEncoder::Identity(ref encoder) => encoder.is_eof(), + ContentEncoder::Gzip(ref encoder) => encoder.get_ref().len(), + ContentEncoder::Identity(ref encoder) => encoder.len(), + } + } + + #[inline] + pub fn is_empty(&self) -> bool { + match *self { + #[cfg(feature = "brotli")] + ContentEncoder::Br(ref encoder) => encoder.get_ref().is_empty(), + #[cfg(feature = "flate2")] + ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_empty(), + #[cfg(feature = "flate2")] + ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_empty(), + ContentEncoder::Identity(ref encoder) => encoder.is_empty(), + } + } + + #[inline] + pub(crate) fn buf_mut(&mut self) -> &mut BytesMut { + match *self { + #[cfg(feature = "brotli")] + ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_mut(), + #[cfg(feature = "flate2")] + ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_mut(), + #[cfg(feature = "flate2")] + ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_mut(), + ContentEncoder::Identity(ref mut encoder) => encoder.buf_mut(), + } + } + + #[inline] + pub(crate) fn buf_ref(&mut self) -> &SharedBytes { + match *self { + #[cfg(feature = "brotli")] + ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_ref(), + #[cfg(feature = "flate2")] + ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_ref(), + #[cfg(feature = "flate2")] + ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_ref(), + ContentEncoder::Identity(ref mut encoder) => encoder.buf_ref(), } } @@ -620,7 +733,7 @@ impl ContentEncoder { #[inline(always)] pub fn write_eof(&mut self) -> Result { let encoder = - mem::replace(self, ContentEncoder::Identity(TransferEncoding::eof())); + mem::replace(self, ContentEncoder::Identity(TransferEncoding::empty())); match encoder { #[cfg(feature = "brotli")] @@ -695,9 +808,9 @@ impl ContentEncoder { } /// Encoders to handle different Transfer-Encodings. -#[derive(Debug, Clone)] +#[derive(Debug)] pub(crate) struct TransferEncoding { - buf: *mut BytesMut, + buf: Option, kind: TransferEncodingKind, } @@ -716,40 +829,55 @@ enum TransferEncodingKind { } impl TransferEncoding { - pub(crate) fn set_buffer(&mut self, buf: *mut BytesMut) { - self.buf = buf; + fn take(self) -> SharedBytes { + self.buf.unwrap() + } + + fn buf_ref(&mut self) -> &SharedBytes { + self.buf.as_ref().unwrap() + } + + fn len(&self) -> usize { + self.buf.as_ref().unwrap().len() + } + + fn is_empty(&self) -> bool { + self.buf.as_ref().unwrap().is_empty() + } + + fn buf_mut(&mut self) -> &mut BytesMut { + self.buf.as_mut().unwrap().get_mut() } #[inline] - pub fn eof() -> TransferEncoding { + pub fn empty() -> TransferEncoding { TransferEncoding { + buf: None, kind: TransferEncodingKind::Eof, - buf: ptr::null_mut(), } } #[inline] - pub fn chunked() -> TransferEncoding { + pub fn eof(buf: SharedBytes) -> TransferEncoding { TransferEncoding { + buf: Some(buf), + kind: TransferEncodingKind::Eof, + } + } + + #[inline] + pub fn chunked(buf: SharedBytes) -> TransferEncoding { + TransferEncoding { + buf: Some(buf), kind: TransferEncodingKind::Chunked(false), - buf: ptr::null_mut(), } } #[inline] - pub fn length(len: u64) -> TransferEncoding { + pub fn length(len: u64, buf: SharedBytes) -> TransferEncoding { TransferEncoding { + buf: Some(buf), kind: TransferEncodingKind::Length(len), - buf: ptr::null_mut(), - } - } - - #[inline] - pub fn is_eof(&self) -> bool { - match self.kind { - TransferEncodingKind::Eof => true, - TransferEncodingKind::Chunked(ref eof) => *eof, - TransferEncodingKind::Length(ref remaining) => *remaining == 0, } } @@ -759,9 +887,7 @@ impl TransferEncoding { match self.kind { TransferEncodingKind::Eof => { let eof = msg.is_empty(); - debug_assert!(!self.buf.is_null()); - let buf = unsafe { &mut *self.buf }; - buf.extend(msg); + self.buf.as_mut().unwrap().extend_from_slice(msg); Ok(eof) } TransferEncodingKind::Chunked(ref mut eof) => { @@ -771,16 +897,13 @@ impl TransferEncoding { if msg.is_empty() { *eof = true; - debug_assert!(!self.buf.is_null()); - let buf = unsafe { &mut *self.buf }; - buf.extend_from_slice(b"0\r\n\r\n"); + self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n"); } else { let mut buf = BytesMut::new(); writeln!(&mut buf, "{:X}\r", msg.len()) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - debug_assert!(!self.buf.is_null()); - let b = unsafe { &mut *self.buf }; + let b = self.buf.as_mut().unwrap(); b.reserve(buf.len() + msg.len() + 2); b.extend_from_slice(buf.as_ref()); b.extend_from_slice(msg); @@ -795,8 +918,10 @@ impl TransferEncoding { } let len = cmp::min(*remaining, msg.len() as u64); - debug_assert!(!self.buf.is_null()); - unsafe { &mut *self.buf }.extend(&msg[..len as usize]); + self.buf + .as_mut() + .unwrap() + .extend_from_slice(&msg[..len as usize]); *remaining -= len as u64; Ok(*remaining == 0) @@ -816,10 +941,7 @@ impl TransferEncoding { TransferEncodingKind::Chunked(ref mut eof) => { if !*eof { *eof = true; - - debug_assert!(!self.buf.is_null()); - let buf = unsafe { &mut *self.buf }; - buf.extend_from_slice(b"0\r\n\r\n"); + self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n"); } true } @@ -912,15 +1034,14 @@ mod tests { #[test] fn test_chunked_te() { - let mut bytes = BytesMut::new(); - let mut enc = TransferEncoding::chunked(); + let bytes = SharedBytes::empty(); + let mut enc = TransferEncoding::chunked(bytes); { - enc.set_buffer(&mut bytes); assert!(!enc.encode(b"test").ok().unwrap()); assert!(enc.encode(b"").ok().unwrap()); } assert_eq!( - bytes.take().freeze(), + enc.buf_mut().take().freeze(), Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n") ); } diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index bf91a030..36571d88 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -6,7 +6,7 @@ use std::io; use std::rc::Rc; use tokio_io::AsyncWrite; -use super::encoding::ContentEncoder; +use super::encoding::{ContentEncoder, Output}; use super::helpers; use super::settings::WorkerSettings; use super::shared::SharedBytes; @@ -32,10 +32,9 @@ bitflags! { pub(crate) struct H1Writer { flags: Flags, stream: T, - encoder: ContentEncoder, written: u64, headers_size: u32, - buffer: SharedBytes, + buffer: Output, buffer_capacity: usize, settings: Rc>, } @@ -46,10 +45,9 @@ impl H1Writer { ) -> H1Writer { H1Writer { flags: Flags::KEEPALIVE, - encoder: ContentEncoder::empty(), written: 0, headers_size: 0, - buffer: buf, + buffer: Output::Buffer(buf), buffer_capacity: 0, stream, settings, @@ -66,7 +64,7 @@ impl H1Writer { } pub fn disconnected(&mut self) { - self.buffer.take(); + self.buffer = Output::Empty; } pub fn keepalive(&self) -> bool { @@ -106,7 +104,8 @@ impl Writer for H1Writer { #[inline] fn buffer(&mut self) -> &mut BytesMut { - self.buffer.get_mut() + //self.buffer.get_mut() + unimplemented!() } fn start( @@ -114,8 +113,6 @@ impl Writer for H1Writer { encoding: ContentEncoding, ) -> io::Result { // prepare task - self.encoder = - ContentEncoder::for_server(self.buffer.get_mut(), req, msg, encoding); if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { self.flags = Flags::STARTED | Flags::KEEPALIVE; } else { @@ -143,7 +140,9 @@ impl Writer for H1Writer { // render message { + // output buffer let mut buffer = self.buffer.get_mut(); + let reason = msg.reason().as_bytes(); let mut is_bin = if let Body::Binary(ref bytes) = body { buffer.reserve( @@ -222,9 +221,12 @@ impl Writer for H1Writer { self.headers_size = buffer.len() as u32; } + // output encoding + self.buffer = ContentEncoder::for_server(self.buffer.take(), req, msg, encoding); + if let Body::Binary(bytes) = body { self.written = bytes.len() as u64; - self.encoder.write(bytes.as_ref())?; + self.buffer.write(bytes.as_ref())?; } else { // capacity, makes sense only for streaming or actor self.buffer_capacity = msg.write_buffer_capacity(); @@ -253,19 +255,19 @@ impl Writer for H1Writer { Ok(val) => val, }; if n < pl.len() { - self.buffer.extend_from_slice(&pl[n..]); + self.buffer.write(&pl[n..]); return Ok(WriterState::Done); } } else { - self.buffer.extend(payload); + self.buffer.write(payload.as_ref()); } } else { // TODO: add warning, write after EOF - self.encoder.write(payload.as_ref())?; + self.buffer.write(payload.as_ref())?; } } else { // could be response to EXCEPT header - self.buffer.extend_from_slice(payload.as_ref()) + self.buffer.write(payload.as_ref()); } } @@ -277,7 +279,7 @@ impl Writer for H1Writer { } fn write_eof(&mut self) -> io::Result { - if !self.encoder.write_eof()? { + if !self.buffer.write_eof()? { Err(io::Error::new( io::ErrorKind::Other, "Last payload item, but eof is not reached", @@ -293,7 +295,7 @@ impl Writer for H1Writer { fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { if !self.buffer.is_empty() { let written = { - match Self::write_data(&mut self.stream, self.buffer.as_ref()) { + match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) { Err(err) => { if err.kind() == io::ErrorKind::WriteZero { self.disconnected(); diff --git a/src/server/h2writer.rs b/src/server/h2writer.rs index c816c12a..e5f579b2 100644 --- a/src/server/h2writer.rs +++ b/src/server/h2writer.rs @@ -11,7 +11,7 @@ use std::{cmp, io}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::{HttpTryFrom, Version}; -use super::encoding::ContentEncoder; +use super::encoding::{ContentEncoder, Output}; use super::helpers; use super::settings::WorkerSettings; use super::shared::SharedBytes; @@ -35,10 +35,9 @@ bitflags! { pub(crate) struct H2Writer { respond: SendResponse, stream: Option>, - encoder: ContentEncoder, flags: Flags, written: u64, - buffer: SharedBytes, + buffer: Output, buffer_capacity: usize, settings: Rc>, } @@ -51,10 +50,9 @@ impl H2Writer { respond, settings, stream: None, - encoder: ContentEncoder::empty(), flags: Flags::empty(), written: 0, - buffer: buf, + buffer: Output::Buffer(buf), buffer_capacity: 0, } } @@ -87,8 +85,7 @@ impl Writer for H2Writer { ) -> io::Result { // prepare response self.flags.insert(Flags::STARTED); - self.encoder = - ContentEncoder::for_server(self.buffer.get_mut(), req, msg, encoding); + self.buffer = ContentEncoder::for_server(self.buffer.take(), req, msg, encoding); // http2 specific msg.headers_mut().remove(CONNECTION); @@ -150,7 +147,7 @@ impl Writer for H2Writer { } else { self.flags.insert(Flags::EOF); self.written = bytes.len() as u64; - self.encoder.write(bytes.as_ref())?; + self.buffer.write(bytes.as_ref())?; if let Some(ref mut stream) = self.stream { self.flags.insert(Flags::RESERVED); stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE)); @@ -170,10 +167,10 @@ impl Writer for H2Writer { if !self.flags.contains(Flags::DISCONNECTED) { if self.flags.contains(Flags::STARTED) { // TODO: add warning, write after EOF - self.encoder.write(payload.as_ref())?; + self.buffer.write(payload.as_ref())?; } else { // might be response for EXCEPT - self.buffer.extend_from_slice(payload.as_ref()) + error!("Not supported"); } } @@ -186,7 +183,7 @@ impl Writer for H2Writer { fn write_eof(&mut self) -> io::Result { self.flags.insert(Flags::EOF); - if !self.encoder.write_eof()? { + if !self.buffer.write_eof()? { Err(io::Error::new( io::ErrorKind::Other, "Last payload item, but eof is not reached", diff --git a/src/server/shared.rs b/src/server/shared.rs index 064130fb..4d6a59d8 100644 --- a/src/server/shared.rs +++ b/src/server/shared.rs @@ -5,8 +5,6 @@ use std::rc::Rc; use bytes::BytesMut; -use body::Binary; - #[derive(Debug)] pub(crate) struct SharedBytesPool(RefCell>); @@ -50,6 +48,10 @@ impl SharedBytes { SharedBytes(Some(bytes), Some(pool)) } + pub fn empty() -> SharedBytes { + SharedBytes(Some(BytesMut::new()), None) + } + #[inline] pub(crate) fn get_mut(&mut self) -> &mut BytesMut { self.0.as_mut().unwrap() @@ -79,9 +81,8 @@ impl SharedBytes { } #[inline] - pub fn extend(&mut self, data: &Binary) { - let buf = self.get_mut(); - buf.extend_from_slice(data.as_ref()); + pub fn reserve(&mut self, cap: usize) { + self.get_mut().reserve(cap); } #[inline]