1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-27 17:22:57 +01:00

flush stream on drain

This commit is contained in:
Nikolay Kim 2018-01-11 16:22:27 -08:00
parent 0a41ecd01d
commit 0707dfe5bb
6 changed files with 91 additions and 9 deletions

View File

@ -33,6 +33,8 @@ pub trait Writer {
fn write_eof(&mut self) -> Result<WriterState, io::Error>;
fn flush(&mut self) -> Poll<(), io::Error>;
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error>;
}
@ -112,10 +114,25 @@ impl<T: AsyncWrite> H1Writer<T> {
impl<T: AsyncWrite> Writer for H1Writer<T> {
#[inline]
fn written(&self) -> u64 {
self.written
}
#[inline]
fn flush(&mut self) -> Poll<(), io::Error> {
match self.stream.flush() {
Ok(_) => Ok(Async::Ready(())),
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
Ok(Async::NotReady)
} else {
Err(e)
}
}
}
}
fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse)
-> Result<WriterState, io::Error>
{
@ -226,6 +243,7 @@ impl<T: AsyncWrite> Writer for H1Writer<T> {
}
}
#[inline]
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
match self.write_to_stream() {
Ok(WriterState::Done) => {

View File

@ -44,7 +44,7 @@ pub(crate) struct Http2<T, H>
enum State<T: AsyncRead + AsyncWrite> {
Handshake(Handshake<T, Bytes>),
Server(Connection<T, Bytes>),
Connection(Connection<T, Bytes>),
Empty,
}
@ -76,7 +76,7 @@ impl<T, H> Http2<T, H>
pub fn poll(&mut self) -> Poll<(), ()> {
// server
if let State::Server(ref mut server) = self.state {
if let State::Connection(ref mut conn) = self.state {
// keep-alive timer
if let Some(ref mut timeout) = self.keepalive_timer {
match timeout.poll() {
@ -144,7 +144,7 @@ impl<T, H> Http2<T, H>
// get request
if !self.flags.contains(Flags::DISCONNECTED) {
match server.poll() {
match conn.poll() {
Ok(Async::Ready(None)) => {
not_ready = false;
self.flags.insert(Flags::DISCONNECTED);
@ -178,7 +178,8 @@ impl<T, H> Http2<T, H>
}
} else {
// keep-alive disable, drop connection
return Ok(Async::Ready(()))
return conn.poll_close().map_err(
|e| error!("Error during connection close: {}", e))
}
} else {
// keep-alive unset, rely on operating system
@ -198,7 +199,8 @@ impl<T, H> Http2<T, H>
if not_ready {
if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) {
return Ok(Async::Ready(()))
return conn.poll_close().map_err(
|e| error!("Error during connection close: {}", e))
} else {
return Ok(Async::NotReady)
}
@ -209,8 +211,8 @@ impl<T, H> Http2<T, H>
// handshake
self.state = if let State::Handshake(ref mut handshake) = self.state {
match handshake.poll() {
Ok(Async::Ready(srv)) => {
State::Server(srv)
Ok(Async::Ready(conn)) => {
State::Connection(conn)
},
Ok(Async::NotReady) =>
return Ok(Async::NotReady),

View File

@ -111,6 +111,11 @@ impl Writer for H2Writer {
self.written
}
#[inline]
fn flush(&mut self) -> Poll<(), io::Error> {
Ok(Async::Ready(()))
}
fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse)
-> Result<WriterState, io::Error>
{

View File

@ -1,11 +1,12 @@
//! Pieces pertaining to the HTTP response.
use std::{mem, str, fmt};
use std::io::Write;
use std::cell::RefCell;
use std::convert::Into;
use std::collections::VecDeque;
use cookie::CookieJar;
use bytes::{Bytes, BytesMut};
use bytes::{Bytes, BytesMut, BufMut};
use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use serde_json;
@ -347,6 +348,14 @@ impl HttpResponseBuilder {
self
}
/// Set content length
#[inline]
pub fn content_length(&mut self, len: u64) -> &mut Self {
let mut wrt = BytesMut::new().writer();
let _ = write!(wrt, "{}", len);
self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze())
}
/// Set a cookie
///
/// ```rust

View File

@ -568,6 +568,16 @@ impl<S: 'static, H> ProcessResponse<S, H> {
if self.running == RunningState::Paused || self.drain.is_some() {
match io.poll_completed(false) {
Ok(Async::Ready(_)) => {
match io.flush() {
Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => return Err(PipelineState::Response(self)),
Err(err) => {
debug!("Error sending data: {}", err);
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
}
}
self.running.resume();
// resolve drain futures

View File

@ -131,6 +131,44 @@ fn test_body_streaming_implicit() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_body_br_streaming() {
let srv = test::TestServer::new(
|app| app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
httpcodes::HTTPOk.build()
.content_encoding(headers::ContentEncoding::Br)
.body(Body::Streaming(Box::new(body)))}));
let mut res = reqwest::get(&srv.url("/")).unwrap();
assert!(res.status().is_success());
let mut bytes = BytesMut::with_capacity(2048).writer();
let _ = res.copy_to(&mut bytes);
let bytes = bytes.into_inner();
let mut e = BrotliDecoder::new(Vec::with_capacity(2048));
e.write_all(bytes.as_ref()).unwrap();
let dec = e.finish().unwrap();
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_body_length() {
let srv = test::TestServer::new(
|app| app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
httpcodes::HTTPOk.build()
.content_length(STR.len() as u64)
.body(Body::Streaming(Box::new(body)))}));
let mut res = reqwest::get(&srv.url("/")).unwrap();
assert!(res.status().is_success());
let mut bytes = BytesMut::with_capacity(2048).writer();
let _ = res.copy_to(&mut bytes);
let bytes = bytes.into_inner();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_body_streaming_explicit() {
let srv = test::TestServer::new(
@ -304,7 +342,7 @@ fn test_h2() {
})
});
let _res = core.run(tcp);
// assert_eq!(res, Bytes::from_static(STR.as_ref()));
// assert_eq!(_res.unwrap(), Bytes::from_static(STR.as_ref()));
}
#[test]