diff --git a/src/h1writer.rs b/src/h1writer.rs index 200ff0529..fd5551724 100644 --- a/src/h1writer.rs +++ b/src/h1writer.rs @@ -33,6 +33,8 @@ pub trait Writer { fn write_eof(&mut self) -> Result; + fn flush(&mut self) -> Poll<(), io::Error>; + fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error>; } @@ -112,10 +114,25 @@ impl H1Writer { impl Writer for H1Writer { + #[inline] fn written(&self) -> u64 { self.written } + #[inline] + fn flush(&mut self) -> Poll<(), io::Error> { + match self.stream.flush() { + Ok(_) => Ok(Async::Ready(())), + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + Ok(Async::NotReady) + } else { + Err(e) + } + } + } + } + fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) -> Result { @@ -226,6 +243,7 @@ impl Writer for H1Writer { } } + #[inline] fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { match self.write_to_stream() { Ok(WriterState::Done) => { diff --git a/src/h2.rs b/src/h2.rs index 446219727..e60d799d6 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -44,7 +44,7 @@ pub(crate) struct Http2 enum State { Handshake(Handshake), - Server(Connection), + Connection(Connection), Empty, } @@ -76,7 +76,7 @@ impl Http2 pub fn poll(&mut self) -> Poll<(), ()> { // server - if let State::Server(ref mut server) = self.state { + if let State::Connection(ref mut conn) = self.state { // keep-alive timer if let Some(ref mut timeout) = self.keepalive_timer { match timeout.poll() { @@ -144,7 +144,7 @@ impl Http2 // get request if !self.flags.contains(Flags::DISCONNECTED) { - match server.poll() { + match conn.poll() { Ok(Async::Ready(None)) => { not_ready = false; self.flags.insert(Flags::DISCONNECTED); @@ -178,7 +178,8 @@ impl Http2 } } else { // keep-alive disable, drop connection - return Ok(Async::Ready(())) + return conn.poll_close().map_err( + |e| error!("Error during connection close: {}", e)) } } else { // keep-alive unset, rely on operating system @@ -198,7 +199,8 @@ impl Http2 if not_ready { if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) { - return Ok(Async::Ready(())) + return conn.poll_close().map_err( + |e| error!("Error during connection close: {}", e)) } else { return Ok(Async::NotReady) } @@ -209,8 +211,8 @@ impl Http2 // handshake self.state = if let State::Handshake(ref mut handshake) = self.state { match handshake.poll() { - Ok(Async::Ready(srv)) => { - State::Server(srv) + Ok(Async::Ready(conn)) => { + State::Connection(conn) }, Ok(Async::NotReady) => return Ok(Async::NotReady), diff --git a/src/h2writer.rs b/src/h2writer.rs index 57c4bd357..4707b8ee5 100644 --- a/src/h2writer.rs +++ b/src/h2writer.rs @@ -111,6 +111,11 @@ impl Writer for H2Writer { self.written } + #[inline] + fn flush(&mut self) -> Poll<(), io::Error> { + Ok(Async::Ready(())) + } + fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) -> Result { diff --git a/src/httpresponse.rs b/src/httpresponse.rs index 63582aeb9..358d0daf9 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -1,11 +1,12 @@ //! Pieces pertaining to the HTTP response. use std::{mem, str, fmt}; +use std::io::Write; use std::cell::RefCell; use std::convert::Into; use std::collections::VecDeque; use cookie::CookieJar; -use bytes::{Bytes, BytesMut}; +use bytes::{Bytes, BytesMut, BufMut}; use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError}; use http::header::{self, HeaderName, HeaderValue}; use serde_json; @@ -347,6 +348,14 @@ impl HttpResponseBuilder { self } + /// Set content length + #[inline] + pub fn content_length(&mut self, len: u64) -> &mut Self { + let mut wrt = BytesMut::new().writer(); + let _ = write!(wrt, "{}", len); + self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + } + /// Set a cookie /// /// ```rust diff --git a/src/pipeline.rs b/src/pipeline.rs index 77ad05e04..195a12d9b 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -568,6 +568,16 @@ impl ProcessResponse { if self.running == RunningState::Paused || self.drain.is_some() { match io.poll_completed(false) { Ok(Async::Ready(_)) => { + match io.flush() { + Ok(Async::Ready(_)) => (), + Ok(Async::NotReady) => return Err(PipelineState::Response(self)), + Err(err) => { + debug!("Error sending data: {}", err); + info.error = Some(err.into()); + return Ok(FinishingMiddlewares::init(info, self.resp)) + } + } + self.running.resume(); // resolve drain futures diff --git a/tests/test_server.rs b/tests/test_server.rs index 72ee2fb4f..b88b25a43 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -131,6 +131,44 @@ fn test_body_streaming_implicit() { assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } +#[test] +fn test_body_br_streaming() { + let srv = test::TestServer::new( + |app| app.handler(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + httpcodes::HTTPOk.build() + .content_encoding(headers::ContentEncoding::Br) + .body(Body::Streaming(Box::new(body)))})); + + let mut res = reqwest::get(&srv.url("/")).unwrap(); + assert!(res.status().is_success()); + let mut bytes = BytesMut::with_capacity(2048).writer(); + let _ = res.copy_to(&mut bytes); + let bytes = bytes.into_inner(); + + let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_body_length() { + let srv = test::TestServer::new( + |app| app.handler(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + httpcodes::HTTPOk.build() + .content_length(STR.len() as u64) + .body(Body::Streaming(Box::new(body)))})); + + let mut res = reqwest::get(&srv.url("/")).unwrap(); + assert!(res.status().is_success()); + let mut bytes = BytesMut::with_capacity(2048).writer(); + let _ = res.copy_to(&mut bytes); + let bytes = bytes.into_inner(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + #[test] fn test_body_streaming_explicit() { let srv = test::TestServer::new( @@ -304,7 +342,7 @@ fn test_h2() { }) }); let _res = core.run(tcp); - // assert_eq!(res, Bytes::from_static(STR.as_ref())); + // assert_eq!(_res.unwrap(), Bytes::from_static(STR.as_ref())); } #[test]