1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-27 17:52:56 +01:00

fix stream flushing

This commit is contained in:
Nikolay Kim 2018-10-02 10:43:23 -07:00
parent f8b176de9e
commit 61c7534e03
5 changed files with 131 additions and 85 deletions

View File

@ -44,6 +44,10 @@ pub enum HttpDispatchError {
#[fail(display = "HTTP2 error: {}", _0)] #[fail(display = "HTTP2 error: {}", _0)]
Http2(http2::Error), Http2(http2::Error),
/// Payload is not consumed
#[fail(display = "Task is completed but request's payload is not consumed")]
PayloadIsNotConsumed,
/// Malformed request /// Malformed request
#[fail(display = "Malformed request")] #[fail(display = "Malformed request")]
MalformedRequest, MalformedRequest,

View File

@ -30,7 +30,7 @@ bitflags! {
const READ_DISCONNECTED = 0b0001_0000; const READ_DISCONNECTED = 0b0001_0000;
const WRITE_DISCONNECTED = 0b0010_0000; const WRITE_DISCONNECTED = 0b0010_0000;
const POLLED = 0b0100_0000; const POLLED = 0b0100_0000;
const FLUSHED = 0b1000_0000;
} }
} }
@ -99,9 +99,9 @@ where
}; };
let flags = if is_eof { let flags = if is_eof {
Flags::READ_DISCONNECTED Flags::READ_DISCONNECTED | Flags::FLUSHED
} else if settings.keep_alive_enabled() { } else if settings.keep_alive_enabled() {
Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED
} else { } else {
Flags::empty() Flags::empty()
}; };
@ -130,7 +130,7 @@ where
} }
let mut disp = Http1Dispatcher { let mut disp = Http1Dispatcher {
flags: Flags::STARTED | Flags::READ_DISCONNECTED, flags: Flags::STARTED | Flags::READ_DISCONNECTED | Flags::FLUSHED,
stream: H1Writer::new(stream, settings.clone()), stream: H1Writer::new(stream, settings.clone()),
decoder: H1Decoder::new(), decoder: H1Decoder::new(),
payload: None, payload: None,
@ -177,7 +177,8 @@ where
} }
if !checked || self.tasks.is_empty() { if !checked || self.tasks.is_empty() {
self.flags.insert(Flags::WRITE_DISCONNECTED); self.flags
.insert(Flags::WRITE_DISCONNECTED | Flags::FLUSHED);
self.stream.disconnected(); self.stream.disconnected();
// notify all tasks // notify all tasks
@ -205,40 +206,28 @@ where
// shutdown // shutdown
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
if self.flags.intersects(Flags::WRITE_DISCONNECTED) {
return Ok(Async::Ready(()));
}
return self.poll_flush(true);
}
// process incoming requests
if !self.flags.contains(Flags::WRITE_DISCONNECTED) {
self.poll_handler()?;
// flush stream
self.poll_flush(false)?;
// deal with keep-alive and stream eof (client-side write shutdown)
if self.tasks.is_empty() && self.flags.intersects(Flags::FLUSHED) {
// handle stream eof
if self if self
.flags .flags
.intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) .intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED)
{ {
return Ok(Async::Ready(())); return Ok(Async::Ready(()));
} }
match self.stream.poll_completed(true) {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(_)) => return Ok(Async::Ready(())),
Err(err) => {
debug!("Error sending data: {}", err);
return Err(err.into());
}
}
}
self.poll_io()?;
if !self.flags.contains(Flags::WRITE_DISCONNECTED) {
match self.poll_handler()? {
Async::Ready(true) => self.poll(),
Async::Ready(false) => {
self.flags.insert(Flags::SHUTDOWN);
self.poll()
}
Async::NotReady => {
// deal with keep-alive and steam eof (client-side write shutdown)
if self.tasks.is_empty() {
// handle stream eof
if self.flags.intersects(
Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED,
) {
return Ok(Async::Ready(()));
}
// no keep-alive // no keep-alive
if self.flags.contains(Flags::STARTED) if self.flags.contains(Flags::STARTED)
&& (!self.flags.contains(Flags::KEEPALIVE_ENABLED) && (!self.flags.contains(Flags::KEEPALIVE_ENABLED)
@ -249,8 +238,6 @@ where
} }
} }
Ok(Async::NotReady) Ok(Async::NotReady)
}
}
} else if let Some(err) = self.error.take() { } else if let Some(err) = self.error.take() {
Err(err) Err(err)
} else { } else {
@ -258,6 +245,36 @@ where
} }
} }
/// Flush stream
fn poll_flush(&mut self, shutdown: bool) -> Poll<(), HttpDispatchError> {
if shutdown || self.flags.contains(Flags::STARTED) {
match self.stream.poll_completed(shutdown) {
Ok(Async::NotReady) => {
// mark stream
if !self.stream.flushed() {
self.flags.remove(Flags::FLUSHED);
}
Ok(Async::NotReady)
}
Err(err) => {
debug!("Error sending data: {}", err);
self.client_disconnected(false);
return Err(err.into());
}
Ok(Async::Ready(_)) => {
// if payload is not consumed we can not use connection
if self.payload.is_some() && self.tasks.is_empty() {
return Err(HttpDispatchError::PayloadIsNotConsumed);
}
self.flags.insert(Flags::FLUSHED);
Ok(Async::Ready(()))
}
}
} else {
Ok(Async::Ready(()))
}
}
/// keep-alive timer. returns `true` is keep-alive, otherwise drop /// keep-alive timer. returns `true` is keep-alive, otherwise drop
fn poll_keep_alive(&mut self) -> Result<(), HttpDispatchError> { fn poll_keep_alive(&mut self) -> Result<(), HttpDispatchError> {
if let Some(ref mut timer) = self.ka_timer { if let Some(ref mut timer) = self.ka_timer {
@ -317,20 +334,23 @@ where
} }
#[inline] #[inline]
/// read data from stream /// read data from the stream
pub(self) fn poll_io(&mut self) -> Result<(), HttpDispatchError> { pub(self) fn poll_io(&mut self) -> Result<bool, HttpDispatchError> {
if !self.flags.contains(Flags::POLLED) { if !self.flags.contains(Flags::POLLED) {
self.parse()?; let updated = self.parse()?;
self.flags.insert(Flags::POLLED); self.flags.insert(Flags::POLLED);
return Ok(()); return Ok(updated);
} }
// read io from socket // read io from socket
let mut updated = false;
if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES { if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES {
match self.stream.get_mut().read_available(&mut self.buf) { match self.stream.get_mut().read_available(&mut self.buf) {
Ok(Async::Ready((read_some, disconnected))) => { Ok(Async::Ready((read_some, disconnected))) => {
if read_some { if read_some {
self.parse()?; if self.parse()? {
updated = true;
}
} }
if disconnected { if disconnected {
self.client_disconnected(true); self.client_disconnected(true);
@ -343,13 +363,14 @@ where
} }
} }
} }
Ok(()) Ok(updated)
} }
pub(self) fn poll_handler(&mut self) -> Poll<bool, HttpDispatchError> { pub(self) fn poll_handler(&mut self) -> Result<(), HttpDispatchError> {
let retry = self.can_read(); self.poll_io()?;
let mut retry = self.can_read();
// process first pipelined response, only one task can do io operation in http/1 // process first pipelined response, only first task can do io operation in http/1
while !self.tasks.is_empty() { while !self.tasks.is_empty() {
match self.tasks[0].poll_io(&mut self.stream) { match self.tasks[0].poll_io(&mut self.stream) {
Ok(Async::Ready(ready)) => { Ok(Async::Ready(ready)) => {
@ -375,9 +396,12 @@ where
} }
// if read-backpressure is enabled and we consumed some data. // if read-backpressure is enabled and we consumed some data.
// we may read more data // we may read more dataand retry
if !retry && self.can_read() { if !retry && self.can_read() {
return Ok(Async::Ready(true)); if self.poll_io()? {
retry = self.can_read();
continue;
}
} }
break; break;
} }
@ -431,25 +455,7 @@ where
} }
} }
// flush stream Ok(())
if self.flags.contains(Flags::STARTED) {
match self.stream.poll_completed(false) {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => {
debug!("Error sending data: {}", err);
self.client_disconnected(false);
return Err(err.into());
}
Ok(Async::Ready(_)) => {
// if payload is not consumed we can not use connection
if self.payload.is_some() && self.tasks.is_empty() {
return Ok(Async::Ready(false));
}
}
}
}
Ok(Async::NotReady)
} }
fn push_response_entry(&mut self, status: StatusCode) { fn push_response_entry(&mut self, status: StatusCode) {
@ -457,7 +463,7 @@ where
.push_back(Entry::Error(ServerError::err(Version::HTTP_11, status))); .push_back(Entry::Error(ServerError::err(Version::HTTP_11, status)));
} }
pub(self) fn parse(&mut self) -> Result<(), HttpDispatchError> { pub(self) fn parse(&mut self) -> Result<bool, HttpDispatchError> {
let mut updated = false; let mut updated = false;
'outer: loop { 'outer: loop {
@ -524,7 +530,7 @@ where
payload.feed_data(chunk); payload.feed_data(chunk);
} else { } else {
error!("Internal server error: unexpected payload chunk"); error!("Internal server error: unexpected payload chunk");
self.flags.insert(Flags::READ_DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR);
self.error = Some(HttpDispatchError::InternalError); self.error = Some(HttpDispatchError::InternalError);
break; break;
@ -536,7 +542,7 @@ where
payload.feed_eof(); payload.feed_eof();
} else { } else {
error!("Internal server error: unexpected eof"); error!("Internal server error: unexpected eof");
self.flags.insert(Flags::READ_DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR);
self.error = Some(HttpDispatchError::InternalError); self.error = Some(HttpDispatchError::InternalError);
break; break;
@ -559,7 +565,7 @@ where
// Malformed requests should be responded with 400 // Malformed requests should be responded with 400
self.push_response_entry(StatusCode::BAD_REQUEST); self.push_response_entry(StatusCode::BAD_REQUEST);
self.flags.insert(Flags::READ_DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.error = Some(HttpDispatchError::MalformedRequest); self.error = Some(HttpDispatchError::MalformedRequest);
break; break;
} }
@ -571,7 +577,7 @@ where
self.ka_expire = expire; self.ka_expire = expire;
} }
} }
Ok(()) Ok(updated)
} }
} }

View File

@ -62,6 +62,10 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
self.flags = Flags::KEEPALIVE; self.flags = Flags::KEEPALIVE;
} }
pub fn flushed(&mut self) -> bool {
self.buffer.is_empty()
}
pub fn disconnected(&mut self) { pub fn disconnected(&mut self) {
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::DISCONNECTED);
} }

View File

@ -1094,3 +1094,35 @@ fn test_slow_request() {
sys.stop(); sys.stop();
} }
#[test]
fn test_malformed_request() {
use actix::System;
use std::net;
use std::sync::mpsc;
let (tx, rx) = mpsc::channel();
let addr = test::TestServer::unused_addr();
thread::spawn(move || {
System::run(move || {
let srv = server::new(|| {
vec![App::new().resource("/", |r| {
r.method(http::Method::GET).f(|_| HttpResponse::Ok())
})]
});
let _ = srv.bind(addr).unwrap().start();
let _ = tx.send(System::current());
});
});
let sys = rx.recv().unwrap();
thread::sleep(time::Duration::from_millis(200));
let mut stream = net::TcpStream::connect(addr).unwrap();
let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n");
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 400 Bad Request"));
sys.stop();
}

View File

@ -7,7 +7,7 @@ extern crate rand;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::{thread, time};
use bytes::Bytes; use bytes::Bytes;
use futures::Stream; use futures::Stream;
@ -380,17 +380,17 @@ fn test_ws_stopped() {
let num = Arc::new(AtomicUsize::new(0)); let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone(); let num2 = num.clone();
let _ = thread::spawn(move || {
let num3 = num2.clone();
let mut srv = test::TestServer::new(move |app| { let mut srv = test::TestServer::new(move |app| {
let num4 = num3.clone(); let num3 = num2.clone();
app.handler(move |req| ws::start(req, WsStopped(num4.clone()))) app.handler(move |req| ws::start(req, WsStopped(num3.clone())))
}); });
{
let (reader, mut writer) = srv.ws().unwrap(); let (reader, mut writer) = srv.ws().unwrap();
writer.text("text"); writer.text("text");
let (item, _) = srv.execute(reader.into_future()).unwrap(); let (item, _) = srv.execute(reader.into_future()).unwrap();
assert_eq!(item, Some(ws::Message::Text("text".to_owned()))); assert_eq!(item, Some(ws::Message::Text("text".to_owned())));
}).join(); }
thread::sleep(time::Duration::from_millis(1000));
assert_eq!(num.load(Ordering::Relaxed), 1); assert_eq!(num.load(Ordering::Relaxed), 1);
} }