diff --git a/src/server/h1.rs b/src/server/h1.rs index e55596635..de2a6e8c8 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -124,6 +124,14 @@ where } } + fn notify_disconnect(&mut self) { + // notify all tasks + self.stream.disconnected(); + for entry in &mut self.tasks { + entry.pipe.disconnected() + } + } + #[inline] pub fn poll(&mut self) -> Poll<(), ()> { // keep-alive timer @@ -183,10 +191,7 @@ where Ok(Async::Ready(disconnected)) => { if disconnected { // notify all tasks - self.stream.disconnected(); - for entry in &mut self.tasks { - entry.pipe.disconnected() - } + self.notify_disconnect(); // kill keepalive self.keepalive_timer.take(); @@ -204,10 +209,7 @@ where Ok(Async::NotReady) => (), Err(_) => { // notify all tasks - self.stream.disconnected(); - for entry in &mut self.tasks { - entry.pipe.disconnected() - } + self.notify_disconnect(); // kill keepalive self.keepalive_timer.take(); @@ -285,6 +287,7 @@ where Ok(Async::NotReady) => (), Ok(Async::Ready(_)) => item.flags.insert(EntryFlags::FINISHED), Err(err) => { + self.notify_disconnect(); item.flags.insert(EntryFlags::ERROR); error!("Unhandled error: {}", err); } @@ -316,6 +319,7 @@ where Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); + self.notify_disconnect(); return Err(()); } Ok(Async::Ready(_)) => { diff --git a/tests/test_ws.rs b/tests/test_ws.rs index dd65d4a58..96d97b824 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -9,6 +9,7 @@ use bytes::Bytes; use futures::Stream; use rand::distributions::Alphanumeric; use rand::Rng; +use std::time::Duration; #[cfg(feature = "alpn")] extern crate openssl; @@ -120,6 +121,31 @@ fn test_large_bin() { } } +#[test] +fn test_client_frame_size() { + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(131_072) + .collect::(); + + let mut srv = test::TestServer::new(|app| { + app.handler(|req| -> Result { + let mut resp = ws::handshake(req)?; + let stream = ws::WsStream::new(req.payload()).max_size(131_072); + + let body = ws::WebsocketContext::create(req.clone(), Ws, stream); + Ok(resp.body(body)) + }) + }); + let (reader, mut writer) = srv.ws().unwrap(); + + writer.binary(data.clone()); + match srv.execute(reader.into_future()).err().unwrap().0 { + ws::ProtocolError::Overflow => (), + _ => panic!(), + } +} + struct Ws2 { count: usize, bin: bool,