diff --git a/src/client/parser.rs b/src/client/parser.rs index 1638d8eb..4668f58a 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -40,8 +40,10 @@ impl HttpResponseParser { // if buf is empty parse_message will always return NotReady, let's avoid that if buf.is_empty() { match io.read_available(buf) { - Ok(Async::Ready(0)) => return Err(HttpResponseParserError::Disconnect), - Ok(Async::Ready(_)) => (), + Ok(Async::Ready(true)) => { + return Err(HttpResponseParserError::Disconnect) + } + Ok(Async::Ready(false)) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => return Err(HttpResponseParserError::Error(err.into())), } @@ -60,10 +62,10 @@ impl HttpResponseParser { return Err(HttpResponseParserError::Error(ParseError::TooLarge)); } match io.read_available(buf) { - Ok(Async::Ready(0)) => { + Ok(Async::Ready(true)) => { return Err(HttpResponseParserError::Disconnect) } - Ok(Async::Ready(_)) => (), + Ok(Async::Ready(false)) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { return Err(HttpResponseParserError::Error(err.into())) @@ -84,10 +86,10 @@ impl HttpResponseParser { loop { // read payload let (not_ready, stream_finished) = match io.read_available(buf) { - Ok(Async::Ready(0)) => (false, true), - Err(err) => return Err(err.into()), + Ok(Async::Ready(true)) => (false, true), + Ok(Async::Ready(false)) => (false, false), Ok(Async::NotReady) => (true, false), - _ => (false, false), + Err(err) => return Err(err.into()), }; match self.decoder.as_mut().unwrap().decode(buf) { diff --git a/src/server/channel.rs b/src/server/channel.rs index 26061352..1439ddcb 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -140,7 +140,7 @@ where ref mut buf, )) => { match io.read_available(buf) { - Ok(Async::Ready(0)) | Err(_) => { + Ok(Async::Ready(true)) | Err(_) => { debug!("Ignored premature client disconnection"); settings.remove_channel(); if let Some(n) = self.node.as_mut() { diff --git a/src/server/h1.rs b/src/server/h1.rs index ababda6b..87eeccb0 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -1,10 +1,9 @@ use std::collections::VecDeque; -use std::io; use std::net::SocketAddr; use std::rc::Rc; use std::time::{Duration, Instant}; -use bytes::{BufMut, BytesMut}; +use bytes::BytesMut; use futures::{Async, Future, Poll}; use tokio_timer::Delay; @@ -22,8 +21,6 @@ use super::Writer; use super::{HttpHandler, HttpHandlerTask, IoStream}; const MAX_PIPELINED_MESSAGES: usize = 16; -const LW_BUFFER_SIZE: usize = 4096; -const HW_BUFFER_SIZE: usize = 32_768; bitflags! { struct Flags: u8 { @@ -32,6 +29,7 @@ bitflags! { const KEEPALIVE = 0b0000_0100; const SHUTDOWN = 0b0000_1000; const DISCONNECTED = 0b0001_0000; + const POLLED = 0b0010_0000; } } @@ -173,29 +171,58 @@ where #[inline] /// read data from stream pub fn poll_io(&mut self) { + if !self.flags.contains(Flags::POLLED) { + self.parse(); + self.flags.insert(Flags::POLLED); + return; + } // read io from socket if !self.flags.intersects(Flags::ERROR) && self.tasks.len() < MAX_PIPELINED_MESSAGES && self.can_read() { - if self.read() { - // notify all tasks - self.stream.disconnected(); - for entry in &mut self.tasks { - entry.pipe.disconnected() - } - // kill keepalive - self.keepalive_timer.take(); + let res = self.stream.get_mut().read_available(&mut self.buf); + match res { + //self.stream.get_mut().read_available(&mut self.buf) { + Ok(Async::Ready(disconnected)) => { + if disconnected { + // notify all tasks + self.stream.disconnected(); + for entry in &mut self.tasks { + entry.pipe.disconnected() + } + // kill keepalive + self.keepalive_timer.take(); - // on parse error, stop reading stream but tasks need to be - // completed - self.flags.insert(Flags::ERROR); + // on parse error, stop reading stream but tasks need to be + // completed + self.flags.insert(Flags::ERROR); - if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete); + } + } else { + self.parse(); + } + } + Ok(Async::NotReady) => (), + Err(_) => { + // notify all tasks + self.stream.disconnected(); + for entry in &mut self.tasks { + entry.pipe.disconnected() + } + // kill keepalive + self.keepalive_timer.take(); + + // on parse error, stop reading stream but tasks need to be + // completed + self.flags.insert(Flags::ERROR); + + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete); + } } - } else { - self.parse(); } } } @@ -434,35 +461,12 @@ where } } } - - #[inline] - fn read(&mut self) -> bool { - loop { - unsafe { - if self.buf.remaining_mut() < LW_BUFFER_SIZE { - self.buf.reserve(HW_BUFFER_SIZE); - } - match self.stream.get_mut().read(self.buf.bytes_mut()) { - Ok(n) => { - if n == 0 { - return true; - } else { - self.buf.advance_mut(n); - } - } - Err(e) => { - return e.kind() != io::ErrorKind::WouldBlock; - } - } - } - } - } } #[cfg(test)] mod tests { use std::net::Shutdown; - use std::{cmp, time}; + use std::{cmp, io, time}; use bytes::{Buf, Bytes, BytesMut}; use http::{Method, Version}; @@ -606,7 +610,7 @@ mod tests { let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); h1.poll_io(); - h1.parse(); + h1.poll_io(); assert_eq!(h1.tasks.len(), 1); } @@ -621,7 +625,7 @@ mod tests { let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); h1.poll_io(); - h1.parse(); + h1.poll_io(); assert!(h1.flags.contains(Flags::ERROR)); } diff --git a/src/server/mod.rs b/src/server/mod.rs index c98579d0..6ecc75d1 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -216,21 +216,32 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static { fn set_linger(&mut self, dur: Option) -> io::Result<()>; - fn read_available(&mut self, buf: &mut BytesMut) -> Poll { - unsafe { - if buf.remaining_mut() < LW_BUFFER_SIZE { - buf.reserve(HW_BUFFER_SIZE); - } - match self.read(buf.bytes_mut()) { - Ok(n) => { - buf.advance_mut(n); - Ok(Async::Ready(n)) + fn read_available(&mut self, buf: &mut BytesMut) -> Poll { + let mut read_some = false; + loop { + unsafe { + if buf.remaining_mut() < LW_BUFFER_SIZE { + buf.reserve(HW_BUFFER_SIZE); } - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - Ok(Async::NotReady) - } else { - Err(e) + match self.read(buf.bytes_mut()) { + Ok(n) => { + if n == 0 { + return Ok(Async::Ready(!read_some)); + } else { + read_some = true; + buf.advance_mut(n); + } + } + Err(e) => { + return if e.kind() == io::ErrorKind::WouldBlock { + if read_some { + Ok(Async::Ready(false)) + } else { + Ok(Async::NotReady) + } + } else { + Err(e) + }; } } }