1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-30 18:34:36 +01:00

flush stream on drain

This commit is contained in:
Nikolay Kim 2018-01-11 16:22:27 -08:00
parent 0a41ecd01d
commit 0707dfe5bb
6 changed files with 91 additions and 9 deletions

View File

@ -33,6 +33,8 @@ pub trait Writer {
fn write_eof(&mut self) -> Result<WriterState, io::Error>; fn write_eof(&mut self) -> Result<WriterState, io::Error>;
fn flush(&mut self) -> Poll<(), io::Error>;
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error>; fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error>;
} }
@ -112,10 +114,25 @@ impl<T: AsyncWrite> H1Writer<T> {
impl<T: AsyncWrite> Writer for H1Writer<T> { impl<T: AsyncWrite> Writer for H1Writer<T> {
#[inline]
fn written(&self) -> u64 { fn written(&self) -> u64 {
self.written self.written
} }
#[inline]
fn flush(&mut self) -> Poll<(), io::Error> {
match self.stream.flush() {
Ok(_) => Ok(Async::Ready(())),
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
Ok(Async::NotReady)
} else {
Err(e)
}
}
}
}
fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse)
-> Result<WriterState, io::Error> -> Result<WriterState, io::Error>
{ {
@ -226,6 +243,7 @@ impl<T: AsyncWrite> Writer for H1Writer<T> {
} }
} }
#[inline]
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
match self.write_to_stream() { match self.write_to_stream() {
Ok(WriterState::Done) => { Ok(WriterState::Done) => {

View File

@ -44,7 +44,7 @@ pub(crate) struct Http2<T, H>
enum State<T: AsyncRead + AsyncWrite> { enum State<T: AsyncRead + AsyncWrite> {
Handshake(Handshake<T, Bytes>), Handshake(Handshake<T, Bytes>),
Server(Connection<T, Bytes>), Connection(Connection<T, Bytes>),
Empty, Empty,
} }
@ -76,7 +76,7 @@ impl<T, H> Http2<T, H>
pub fn poll(&mut self) -> Poll<(), ()> { pub fn poll(&mut self) -> Poll<(), ()> {
// server // server
if let State::Server(ref mut server) = self.state { if let State::Connection(ref mut conn) = self.state {
// keep-alive timer // keep-alive timer
if let Some(ref mut timeout) = self.keepalive_timer { if let Some(ref mut timeout) = self.keepalive_timer {
match timeout.poll() { match timeout.poll() {
@ -144,7 +144,7 @@ impl<T, H> Http2<T, H>
// get request // get request
if !self.flags.contains(Flags::DISCONNECTED) { if !self.flags.contains(Flags::DISCONNECTED) {
match server.poll() { match conn.poll() {
Ok(Async::Ready(None)) => { Ok(Async::Ready(None)) => {
not_ready = false; not_ready = false;
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::DISCONNECTED);
@ -178,7 +178,8 @@ impl<T, H> Http2<T, H>
} }
} else { } else {
// keep-alive disable, drop connection // keep-alive disable, drop connection
return Ok(Async::Ready(())) return conn.poll_close().map_err(
|e| error!("Error during connection close: {}", e))
} }
} else { } else {
// keep-alive unset, rely on operating system // keep-alive unset, rely on operating system
@ -198,7 +199,8 @@ impl<T, H> Http2<T, H>
if not_ready { if not_ready {
if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) { if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) {
return Ok(Async::Ready(())) return conn.poll_close().map_err(
|e| error!("Error during connection close: {}", e))
} else { } else {
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
@ -209,8 +211,8 @@ impl<T, H> Http2<T, H>
// handshake // handshake
self.state = if let State::Handshake(ref mut handshake) = self.state { self.state = if let State::Handshake(ref mut handshake) = self.state {
match handshake.poll() { match handshake.poll() {
Ok(Async::Ready(srv)) => { Ok(Async::Ready(conn)) => {
State::Server(srv) State::Connection(conn)
}, },
Ok(Async::NotReady) => Ok(Async::NotReady) =>
return Ok(Async::NotReady), return Ok(Async::NotReady),

View File

@ -111,6 +111,11 @@ impl Writer for H2Writer {
self.written self.written
} }
#[inline]
fn flush(&mut self) -> Poll<(), io::Error> {
Ok(Async::Ready(()))
}
fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse)
-> Result<WriterState, io::Error> -> Result<WriterState, io::Error>
{ {

View File

@ -1,11 +1,12 @@
//! Pieces pertaining to the HTTP response. //! Pieces pertaining to the HTTP response.
use std::{mem, str, fmt}; use std::{mem, str, fmt};
use std::io::Write;
use std::cell::RefCell; use std::cell::RefCell;
use std::convert::Into; use std::convert::Into;
use std::collections::VecDeque; use std::collections::VecDeque;
use cookie::CookieJar; use cookie::CookieJar;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut, BufMut};
use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError}; use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
use serde_json; use serde_json;
@ -347,6 +348,14 @@ impl HttpResponseBuilder {
self self
} }
/// Set content length
#[inline]
pub fn content_length(&mut self, len: u64) -> &mut Self {
let mut wrt = BytesMut::new().writer();
let _ = write!(wrt, "{}", len);
self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze())
}
/// Set a cookie /// Set a cookie
/// ///
/// ```rust /// ```rust

View File

@ -568,6 +568,16 @@ impl<S: 'static, H> ProcessResponse<S, H> {
if self.running == RunningState::Paused || self.drain.is_some() { if self.running == RunningState::Paused || self.drain.is_some() {
match io.poll_completed(false) { match io.poll_completed(false) {
Ok(Async::Ready(_)) => { Ok(Async::Ready(_)) => {
match io.flush() {
Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => return Err(PipelineState::Response(self)),
Err(err) => {
debug!("Error sending data: {}", err);
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
}
}
self.running.resume(); self.running.resume();
// resolve drain futures // resolve drain futures

View File

@ -131,6 +131,44 @@ fn test_body_streaming_implicit() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[test]
fn test_body_br_streaming() {
let srv = test::TestServer::new(
|app| app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
httpcodes::HTTPOk.build()
.content_encoding(headers::ContentEncoding::Br)
.body(Body::Streaming(Box::new(body)))}));
let mut res = reqwest::get(&srv.url("/")).unwrap();
assert!(res.status().is_success());
let mut bytes = BytesMut::with_capacity(2048).writer();
let _ = res.copy_to(&mut bytes);
let bytes = bytes.into_inner();
let mut e = BrotliDecoder::new(Vec::with_capacity(2048));
e.write_all(bytes.as_ref()).unwrap();
let dec = e.finish().unwrap();
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_body_length() {
let srv = test::TestServer::new(
|app| app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
httpcodes::HTTPOk.build()
.content_length(STR.len() as u64)
.body(Body::Streaming(Box::new(body)))}));
let mut res = reqwest::get(&srv.url("/")).unwrap();
assert!(res.status().is_success());
let mut bytes = BytesMut::with_capacity(2048).writer();
let _ = res.copy_to(&mut bytes);
let bytes = bytes.into_inner();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test] #[test]
fn test_body_streaming_explicit() { fn test_body_streaming_explicit() {
let srv = test::TestServer::new( let srv = test::TestServer::new(
@ -304,7 +342,7 @@ fn test_h2() {
}) })
}); });
let _res = core.run(tcp); let _res = core.run(tcp);
// assert_eq!(res, Bytes::from_static(STR.as_ref())); // assert_eq!(_res.unwrap(), Bytes::from_static(STR.as_ref()));
} }
#[test] #[test]