diff --git a/src/client/parser.rs b/src/client/parser.rs index f5390cc34..dd4e60bc5 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -41,10 +41,10 @@ impl HttpResponseParser { // if buf is empty parse_message will always return NotReady, let's avoid that if buf.is_empty() { match io.read_available(buf) { - Ok(Async::Ready(true)) => { + Ok(Async::Ready((_, true))) => { return Err(HttpResponseParserError::Disconnect) } - Ok(Async::Ready(false)) => (), + Ok(Async::Ready((_, false))) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => return Err(HttpResponseParserError::Error(err.into())), } @@ -63,10 +63,10 @@ impl HttpResponseParser { return Err(HttpResponseParserError::Error(ParseError::TooLarge)); } match io.read_available(buf) { - Ok(Async::Ready(true)) => { + Ok(Async::Ready((_, true))) => { return Err(HttpResponseParserError::Disconnect) } - Ok(Async::Ready(false)) => (), + Ok(Async::Ready((_, false))) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { return Err(HttpResponseParserError::Error(err.into())) @@ -87,8 +87,8 @@ impl HttpResponseParser { loop { // read payload let (not_ready, stream_finished) = match io.read_available(buf) { - Ok(Async::Ready(true)) => (false, true), - Ok(Async::Ready(false)) => (false, false), + Ok(Async::Ready((_, true))) => (false, true), + Ok(Async::Ready((_, false))) => (false, false), Ok(Async::NotReady) => (true, false), Err(err) => return Err(err.into()), }; diff --git a/src/server/channel.rs b/src/server/channel.rs index 7de561c6b..84f301513 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -94,6 +94,7 @@ where }; } + let mut is_eof = false; let kind = match self.proto { Some(HttpProtocol::H1(ref mut h1)) => { let result = h1.poll(); @@ -120,16 +121,27 @@ where return result; } Some(HttpProtocol::Unknown(_, _, ref mut io, ref mut buf)) => { + let mut disconnect = false; match io.read_available(buf) { - Ok(Async::Ready(true)) | Err(_) => { - debug!("Ignored premature client disconnection"); - if let Some(n) = self.node.as_mut() { - n.remove() - }; - return Err(()); + Ok(Async::Ready((read_some, stream_closed))) => { + is_eof = stream_closed; + // Only disconnect if no data was read. + if is_eof && !read_some { + disconnect = true; + } + } + Err(_) => { + disconnect = true; } _ => (), } + if disconnect { + debug!("Ignored premature client disconnection"); + if let Some(n) = self.node.as_mut() { + n.remove() + }; + return Err(()); + } if buf.len() >= 14 { if buf[..14] == HTTP2_PREFACE[..] { @@ -149,7 +161,7 @@ where match kind { ProtocolKind::Http1 => { self.proto = - Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf))); + Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf, is_eof))); return self.poll(); } ProtocolKind::Http2 => { diff --git a/src/server/h1.rs b/src/server/h1.rs index 808dc11a1..ae5dd4655 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -90,10 +90,10 @@ where { pub fn new( settings: Rc>, stream: T, addr: Option, - buf: BytesMut, + buf: BytesMut, is_eof: bool, ) -> Self { Http1 { - flags: Flags::KEEPALIVE, + flags: Flags::KEEPALIVE | if is_eof { Flags::DISCONNECTED } else { Flags::empty() }, stream: H1Writer::new(stream, Rc::clone(&settings)), decoder: H1Decoder::new(), payload: None, @@ -132,6 +132,21 @@ where } } + fn client_disconnect(&mut self) { + // notify all tasks + self.notify_disconnect(); + // kill keepalive + self.keepalive_timer.take(); + + // on parse error, stop reading stream but tasks need to be + // completed + self.flags.insert(Flags::ERROR); + + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete); + } + } + #[inline] pub fn poll(&mut self) -> Poll<(), ()> { // keep-alive timer @@ -188,38 +203,21 @@ where && self.can_read() { match self.stream.get_mut().read_available(&mut self.buf) { - Ok(Async::Ready(disconnected)) => { - if disconnected { - // notify all tasks - self.notify_disconnect(); - // kill keepalive - self.keepalive_timer.take(); - - // on parse error, stop reading stream but tasks need to be - // completed - self.flags.insert(Flags::ERROR); - - if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); - } - } else { + Ok(Async::Ready((read_some, disconnected))) => { + if read_some { self.parse(); } + if disconnected { + // delay disconnect until all tasks have finished. + self.flags.insert(Flags::DISCONNECTED); + if self.tasks.is_empty() { + self.client_disconnect(); + } + } } Ok(Async::NotReady) => (), Err(_) => { - // notify all tasks - self.notify_disconnect(); - // kill keepalive - self.keepalive_timer.take(); - - // on parse error, stop reading stream but tasks need to be - // completed - self.flags.insert(Flags::ERROR); - - if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); - } + self.client_disconnect(); } } } @@ -331,8 +329,13 @@ where } } - // deal with keep-alive + // deal with keep-alive and steam eof (client-side write shutdown) if self.tasks.is_empty() { + // handle stream eof + if self.flags.contains(Flags::DISCONNECTED) { + self.client_disconnect(); + return Ok(Async::Ready(false)); + } // no keep-alive if self.flags.contains(Flags::ERROR) || (!self.flags.contains(Flags::KEEPALIVE) @@ -608,7 +611,7 @@ mod tests { let readbuf = BytesMut::new(); let settings = Rc::new(wrk_settings()); - let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); + let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true); h1.poll_io(); h1.poll_io(); assert_eq!(h1.tasks.len(), 1); @@ -620,7 +623,7 @@ mod tests { let readbuf = BytesMut::new(); let settings = Rc::new(wrk_settings()); - let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); + let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true); h1.poll_io(); h1.poll_io(); assert!(h1.flags.contains(Flags::ERROR)); diff --git a/src/server/mod.rs b/src/server/mod.rs index 36d85a787..009e06ccd 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -390,7 +390,7 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static { fn set_linger(&mut self, dur: Option) -> io::Result<()>; - fn read_available(&mut self, buf: &mut BytesMut) -> Poll { + fn read_available(&mut self, buf: &mut BytesMut) -> Poll<(bool, bool), io::Error> { let mut read_some = false; loop { if buf.remaining_mut() < LW_BUFFER_SIZE { @@ -400,7 +400,7 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static { match self.read(buf.bytes_mut()) { Ok(n) => { if n == 0 { - return Ok(Async::Ready(!read_some)); + return Ok(Async::Ready((read_some, true))); } else { read_some = true; buf.advance_mut(n); @@ -409,7 +409,7 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static { Err(e) => { return if e.kind() == io::ErrorKind::WouldBlock { if read_some { - Ok(Async::Ready(false)) + Ok(Async::Ready((read_some, false))) } else { Ok(Async::NotReady) }