1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 09:42:40 +01:00

Merge pull request #483 from Neopallium/master

Fix bug with client disconnect immediately after receiving http request.
This commit is contained in:
Nikolay Kim 2018-08-26 10:15:25 -07:00 committed by GitHub
commit 5906971b6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 48 deletions

View File

@ -41,10 +41,10 @@ impl HttpResponseParser {
// if buf is empty parse_message will always return NotReady, let's avoid that // if buf is empty parse_message will always return NotReady, let's avoid that
if buf.is_empty() { if buf.is_empty() {
match io.read_available(buf) { match io.read_available(buf) {
Ok(Async::Ready(true)) => { Ok(Async::Ready((_, true))) => {
return Err(HttpResponseParserError::Disconnect) return Err(HttpResponseParserError::Disconnect)
} }
Ok(Async::Ready(false)) => (), Ok(Async::Ready((_, false))) => (),
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(HttpResponseParserError::Error(err.into())), Err(err) => return Err(HttpResponseParserError::Error(err.into())),
} }
@ -63,10 +63,10 @@ impl HttpResponseParser {
return Err(HttpResponseParserError::Error(ParseError::TooLarge)); return Err(HttpResponseParserError::Error(ParseError::TooLarge));
} }
match io.read_available(buf) { match io.read_available(buf) {
Ok(Async::Ready(true)) => { Ok(Async::Ready((_, true))) => {
return Err(HttpResponseParserError::Disconnect) return Err(HttpResponseParserError::Disconnect)
} }
Ok(Async::Ready(false)) => (), Ok(Async::Ready((_, false))) => (),
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => { Err(err) => {
return Err(HttpResponseParserError::Error(err.into())) return Err(HttpResponseParserError::Error(err.into()))
@ -87,8 +87,8 @@ impl HttpResponseParser {
loop { loop {
// read payload // read payload
let (not_ready, stream_finished) = match io.read_available(buf) { let (not_ready, stream_finished) = match io.read_available(buf) {
Ok(Async::Ready(true)) => (false, true), Ok(Async::Ready((_, true))) => (false, true),
Ok(Async::Ready(false)) => (false, false), Ok(Async::Ready((_, false))) => (false, false),
Ok(Async::NotReady) => (true, false), Ok(Async::NotReady) => (true, false),
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
}; };

View File

@ -94,6 +94,7 @@ where
}; };
} }
let mut is_eof = false;
let kind = match self.proto { let kind = match self.proto {
Some(HttpProtocol::H1(ref mut h1)) => { Some(HttpProtocol::H1(ref mut h1)) => {
let result = h1.poll(); let result = h1.poll();
@ -120,16 +121,27 @@ where
return result; return result;
} }
Some(HttpProtocol::Unknown(_, _, ref mut io, ref mut buf)) => { Some(HttpProtocol::Unknown(_, _, ref mut io, ref mut buf)) => {
let mut disconnect = false;
match io.read_available(buf) { match io.read_available(buf) {
Ok(Async::Ready(true)) | Err(_) => { Ok(Async::Ready((read_some, stream_closed))) => {
debug!("Ignored premature client disconnection"); is_eof = stream_closed;
if let Some(n) = self.node.as_mut() { // Only disconnect if no data was read.
n.remove() if is_eof && !read_some {
}; disconnect = true;
return Err(()); }
}
Err(_) => {
disconnect = true;
} }
_ => (), _ => (),
} }
if disconnect {
debug!("Ignored premature client disconnection");
if let Some(n) = self.node.as_mut() {
n.remove()
};
return Err(());
}
if buf.len() >= 14 { if buf.len() >= 14 {
if buf[..14] == HTTP2_PREFACE[..] { if buf[..14] == HTTP2_PREFACE[..] {
@ -149,7 +161,7 @@ where
match kind { match kind {
ProtocolKind::Http1 => { ProtocolKind::Http1 => {
self.proto = self.proto =
Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf))); Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf, is_eof)));
return self.poll(); return self.poll();
} }
ProtocolKind::Http2 => { ProtocolKind::Http2 => {

View File

@ -90,10 +90,10 @@ where
{ {
pub fn new( pub fn new(
settings: Rc<WorkerSettings<H>>, stream: T, addr: Option<SocketAddr>, settings: Rc<WorkerSettings<H>>, stream: T, addr: Option<SocketAddr>,
buf: BytesMut, buf: BytesMut, is_eof: bool,
) -> Self { ) -> Self {
Http1 { Http1 {
flags: Flags::KEEPALIVE, flags: Flags::KEEPALIVE | if is_eof { Flags::DISCONNECTED } else { Flags::empty() },
stream: H1Writer::new(stream, Rc::clone(&settings)), stream: H1Writer::new(stream, Rc::clone(&settings)),
decoder: H1Decoder::new(), decoder: H1Decoder::new(),
payload: None, payload: None,
@ -132,6 +132,21 @@ where
} }
} }
fn client_disconnect(&mut self) {
// notify all tasks
self.notify_disconnect();
// kill keepalive
self.keepalive_timer.take();
// on parse error, stop reading stream but tasks need to be
// completed
self.flags.insert(Flags::ERROR);
if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete);
}
}
#[inline] #[inline]
pub fn poll(&mut self) -> Poll<(), ()> { pub fn poll(&mut self) -> Poll<(), ()> {
// keep-alive timer // keep-alive timer
@ -188,38 +203,21 @@ where
&& self.can_read() && self.can_read()
{ {
match self.stream.get_mut().read_available(&mut self.buf) { match self.stream.get_mut().read_available(&mut self.buf) {
Ok(Async::Ready(disconnected)) => { Ok(Async::Ready((read_some, disconnected))) => {
if disconnected { if read_some {
// notify all tasks
self.notify_disconnect();
// kill keepalive
self.keepalive_timer.take();
// on parse error, stop reading stream but tasks need to be
// completed
self.flags.insert(Flags::ERROR);
if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete);
}
} else {
self.parse(); self.parse();
} }
if disconnected {
// delay disconnect until all tasks have finished.
self.flags.insert(Flags::DISCONNECTED);
if self.tasks.is_empty() {
self.client_disconnect();
}
}
} }
Ok(Async::NotReady) => (), Ok(Async::NotReady) => (),
Err(_) => { Err(_) => {
// notify all tasks self.client_disconnect();
self.notify_disconnect();
// kill keepalive
self.keepalive_timer.take();
// on parse error, stop reading stream but tasks need to be
// completed
self.flags.insert(Flags::ERROR);
if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete);
}
} }
} }
} }
@ -331,8 +329,13 @@ where
} }
} }
// deal with keep-alive // deal with keep-alive and steam eof (client-side write shutdown)
if self.tasks.is_empty() { if self.tasks.is_empty() {
// handle stream eof
if self.flags.contains(Flags::DISCONNECTED) {
self.client_disconnect();
return Ok(Async::Ready(false));
}
// no keep-alive // no keep-alive
if self.flags.contains(Flags::ERROR) if self.flags.contains(Flags::ERROR)
|| (!self.flags.contains(Flags::KEEPALIVE) || (!self.flags.contains(Flags::KEEPALIVE)
@ -608,7 +611,7 @@ mod tests {
let readbuf = BytesMut::new(); let readbuf = BytesMut::new();
let settings = Rc::new(wrk_settings()); let settings = Rc::new(wrk_settings());
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true);
h1.poll_io(); h1.poll_io();
h1.poll_io(); h1.poll_io();
assert_eq!(h1.tasks.len(), 1); assert_eq!(h1.tasks.len(), 1);
@ -620,7 +623,7 @@ mod tests {
let readbuf = BytesMut::new(); let readbuf = BytesMut::new();
let settings = Rc::new(wrk_settings()); let settings = Rc::new(wrk_settings());
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true);
h1.poll_io(); h1.poll_io();
h1.poll_io(); h1.poll_io();
assert!(h1.flags.contains(Flags::ERROR)); assert!(h1.flags.contains(Flags::ERROR));

View File

@ -390,7 +390,7 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static {
fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()>; fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()>;
fn read_available(&mut self, buf: &mut BytesMut) -> Poll<bool, io::Error> { fn read_available(&mut self, buf: &mut BytesMut) -> Poll<(bool, bool), io::Error> {
let mut read_some = false; let mut read_some = false;
loop { loop {
if buf.remaining_mut() < LW_BUFFER_SIZE { if buf.remaining_mut() < LW_BUFFER_SIZE {
@ -400,7 +400,7 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static {
match self.read(buf.bytes_mut()) { match self.read(buf.bytes_mut()) {
Ok(n) => { Ok(n) => {
if n == 0 { if n == 0 {
return Ok(Async::Ready(!read_some)); return Ok(Async::Ready((read_some, true)));
} else { } else {
read_some = true; read_some = true;
buf.advance_mut(n); buf.advance_mut(n);
@ -409,7 +409,7 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static {
Err(e) => { Err(e) => {
return if e.kind() == io::ErrorKind::WouldBlock { return if e.kind() == io::ErrorKind::WouldBlock {
if read_some { if read_some {
Ok(Async::Ready(false)) Ok(Async::Ready((read_some, false)))
} else { } else {
Ok(Async::NotReady) Ok(Async::NotReady)
} }