1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-22 23:05:56 +01:00

fix stream flushing

This commit is contained in:
Nikolay Kim 2018-10-02 10:43:23 -07:00
parent f8b176de9e
commit 61c7534e03
5 changed files with 131 additions and 85 deletions

View File

@ -44,6 +44,10 @@ pub enum HttpDispatchError {
#[fail(display = "HTTP2 error: {}", _0)]
Http2(http2::Error),
/// Payload is not consumed
#[fail(display = "Task is completed but request's payload is not consumed")]
PayloadIsNotConsumed,
/// Malformed request
#[fail(display = "Malformed request")]
MalformedRequest,

View File

@ -30,7 +30,7 @@ bitflags! {
const READ_DISCONNECTED = 0b0001_0000;
const WRITE_DISCONNECTED = 0b0010_0000;
const POLLED = 0b0100_0000;
const FLUSHED = 0b1000_0000;
}
}
@ -99,9 +99,9 @@ where
};
let flags = if is_eof {
Flags::READ_DISCONNECTED
Flags::READ_DISCONNECTED | Flags::FLUSHED
} else if settings.keep_alive_enabled() {
Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED
Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED
} else {
Flags::empty()
};
@ -130,7 +130,7 @@ where
}
let mut disp = Http1Dispatcher {
flags: Flags::STARTED | Flags::READ_DISCONNECTED,
flags: Flags::STARTED | Flags::READ_DISCONNECTED | Flags::FLUSHED,
stream: H1Writer::new(stream, settings.clone()),
decoder: H1Decoder::new(),
payload: None,
@ -177,7 +177,8 @@ where
}
if !checked || self.tasks.is_empty() {
self.flags.insert(Flags::WRITE_DISCONNECTED);
self.flags
.insert(Flags::WRITE_DISCONNECTED | Flags::FLUSHED);
self.stream.disconnected();
// notify all tasks
@ -205,54 +206,70 @@ where
// shutdown
if self.flags.contains(Flags::SHUTDOWN) {
if self
.flags
.intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED)
{
if self.flags.intersects(Flags::WRITE_DISCONNECTED) {
return Ok(Async::Ready(()));
}
match self.stream.poll_completed(true) {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(_)) => return Ok(Async::Ready(())),
Err(err) => {
debug!("Error sending data: {}", err);
return Err(err.into());
}
}
return self.poll_flush(true);
}
self.poll_io()?;
// process incoming requests
if !self.flags.contains(Flags::WRITE_DISCONNECTED) {
match self.poll_handler()? {
Async::Ready(true) => self.poll(),
Async::Ready(false) => {
self.flags.insert(Flags::SHUTDOWN);
self.poll()
self.poll_handler()?;
// flush stream
self.poll_flush(false)?;
// deal with keep-alive and stream eof (client-side write shutdown)
if self.tasks.is_empty() && self.flags.intersects(Flags::FLUSHED) {
// handle stream eof
if self
.flags
.intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED)
{
return Ok(Async::Ready(()));
}
Async::NotReady => {
// deal with keep-alive and steam eof (client-side write shutdown)
if self.tasks.is_empty() {
// handle stream eof
if self.flags.intersects(
Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED,
) {
return Ok(Async::Ready(()));
}
// no keep-alive
if self.flags.contains(Flags::STARTED)
&& (!self.flags.contains(Flags::KEEPALIVE_ENABLED)
|| !self.flags.contains(Flags::KEEPALIVE))
{
self.flags.insert(Flags::SHUTDOWN);
return self.poll();
}
// no keep-alive
if self.flags.contains(Flags::STARTED)
&& (!self.flags.contains(Flags::KEEPALIVE_ENABLED)
|| !self.flags.contains(Flags::KEEPALIVE))
{
self.flags.insert(Flags::SHUTDOWN);
return self.poll();
}
}
Ok(Async::NotReady)
} else if let Some(err) = self.error.take() {
Err(err)
} else {
Ok(Async::Ready(()))
}
}
/// Flush stream
fn poll_flush(&mut self, shutdown: bool) -> Poll<(), HttpDispatchError> {
if shutdown || self.flags.contains(Flags::STARTED) {
match self.stream.poll_completed(shutdown) {
Ok(Async::NotReady) => {
// mark stream
if !self.stream.flushed() {
self.flags.remove(Flags::FLUSHED);
}
Ok(Async::NotReady)
}
Err(err) => {
debug!("Error sending data: {}", err);
self.client_disconnected(false);
return Err(err.into());
}
Ok(Async::Ready(_)) => {
// if payload is not consumed we can not use connection
if self.payload.is_some() && self.tasks.is_empty() {
return Err(HttpDispatchError::PayloadIsNotConsumed);
}
self.flags.insert(Flags::FLUSHED);
Ok(Async::Ready(()))
}
}
} else if let Some(err) = self.error.take() {
Err(err)
} else {
Ok(Async::Ready(()))
}
@ -317,20 +334,23 @@ where
}
#[inline]
/// read data from stream
pub(self) fn poll_io(&mut self) -> Result<(), HttpDispatchError> {
/// read data from the stream
pub(self) fn poll_io(&mut self) -> Result<bool, HttpDispatchError> {
if !self.flags.contains(Flags::POLLED) {
self.parse()?;
let updated = self.parse()?;
self.flags.insert(Flags::POLLED);
return Ok(());
return Ok(updated);
}
// read io from socket
let mut updated = false;
if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES {
match self.stream.get_mut().read_available(&mut self.buf) {
Ok(Async::Ready((read_some, disconnected))) => {
if read_some {
self.parse()?;
if self.parse()? {
updated = true;
}
}
if disconnected {
self.client_disconnected(true);
@ -343,13 +363,14 @@ where
}
}
}
Ok(())
Ok(updated)
}
pub(self) fn poll_handler(&mut self) -> Poll<bool, HttpDispatchError> {
let retry = self.can_read();
pub(self) fn poll_handler(&mut self) -> Result<(), HttpDispatchError> {
self.poll_io()?;
let mut retry = self.can_read();
// process first pipelined response, only one task can do io operation in http/1
// process first pipelined response, only first task can do io operation in http/1
while !self.tasks.is_empty() {
match self.tasks[0].poll_io(&mut self.stream) {
Ok(Async::Ready(ready)) => {
@ -375,9 +396,12 @@ where
}
// if read-backpressure is enabled and we consumed some data.
// we may read more data
// we may read more dataand retry
if !retry && self.can_read() {
return Ok(Async::Ready(true));
if self.poll_io()? {
retry = self.can_read();
continue;
}
}
break;
}
@ -431,25 +455,7 @@ where
}
}
// flush stream
if self.flags.contains(Flags::STARTED) {
match self.stream.poll_completed(false) {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => {
debug!("Error sending data: {}", err);
self.client_disconnected(false);
return Err(err.into());
}
Ok(Async::Ready(_)) => {
// if payload is not consumed we can not use connection
if self.payload.is_some() && self.tasks.is_empty() {
return Ok(Async::Ready(false));
}
}
}
}
Ok(Async::NotReady)
Ok(())
}
fn push_response_entry(&mut self, status: StatusCode) {
@ -457,7 +463,7 @@ where
.push_back(Entry::Error(ServerError::err(Version::HTTP_11, status)));
}
pub(self) fn parse(&mut self) -> Result<(), HttpDispatchError> {
pub(self) fn parse(&mut self) -> Result<bool, HttpDispatchError> {
let mut updated = false;
'outer: loop {
@ -524,7 +530,7 @@ where
payload.feed_data(chunk);
} else {
error!("Internal server error: unexpected payload chunk");
self.flags.insert(Flags::READ_DISCONNECTED);
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR);
self.error = Some(HttpDispatchError::InternalError);
break;
@ -536,7 +542,7 @@ where
payload.feed_eof();
} else {
error!("Internal server error: unexpected eof");
self.flags.insert(Flags::READ_DISCONNECTED);
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR);
self.error = Some(HttpDispatchError::InternalError);
break;
@ -559,7 +565,7 @@ where
// Malformed requests should be responded with 400
self.push_response_entry(StatusCode::BAD_REQUEST);
self.flags.insert(Flags::READ_DISCONNECTED);
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.error = Some(HttpDispatchError::MalformedRequest);
break;
}
@ -571,7 +577,7 @@ where
self.ka_expire = expire;
}
}
Ok(())
Ok(updated)
}
}

View File

@ -62,6 +62,10 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
self.flags = Flags::KEEPALIVE;
}
pub fn flushed(&mut self) -> bool {
self.buffer.is_empty()
}
pub fn disconnected(&mut self) {
self.flags.insert(Flags::DISCONNECTED);
}

View File

@ -1094,3 +1094,35 @@ fn test_slow_request() {
sys.stop();
}
#[test]
fn test_malformed_request() {
use actix::System;
use std::net;
use std::sync::mpsc;
let (tx, rx) = mpsc::channel();
let addr = test::TestServer::unused_addr();
thread::spawn(move || {
System::run(move || {
let srv = server::new(|| {
vec![App::new().resource("/", |r| {
r.method(http::Method::GET).f(|_| HttpResponse::Ok())
})]
});
let _ = srv.bind(addr).unwrap().start();
let _ = tx.send(System::current());
});
});
let sys = rx.recv().unwrap();
thread::sleep(time::Duration::from_millis(200));
let mut stream = net::TcpStream::connect(addr).unwrap();
let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n");
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 400 Bad Request"));
sys.stop();
}

View File

@ -7,7 +7,7 @@ extern crate rand;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::{thread, time};
use bytes::Bytes;
use futures::Stream;
@ -380,17 +380,17 @@ fn test_ws_stopped() {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let _ = thread::spawn(move || {
let mut srv = test::TestServer::new(move |app| {
let num3 = num2.clone();
let mut srv = test::TestServer::new(move |app| {
let num4 = num3.clone();
app.handler(move |req| ws::start(req, WsStopped(num4.clone())))
});
app.handler(move |req| ws::start(req, WsStopped(num3.clone())))
});
{
let (reader, mut writer) = srv.ws().unwrap();
writer.text("text");
let (item, _) = srv.execute(reader.into_future()).unwrap();
assert_eq!(item, Some(ws::Message::Text("text".to_owned())));
}).join();
}
thread::sleep(time::Duration::from_millis(1000));
assert_eq!(num.load(Ordering::Relaxed), 1);
}