diff --git a/CHANGES.md b/CHANGES.md index b9ee0470..bfd86a1a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -23,6 +23,8 @@ * Panic during access without routing being set #452 +* Fixed http/2 error handling + ### Deprecated * `HttpServer::no_http2()` is deprecated, use `OpensslAcceptor::with_flags()` or diff --git a/src/pipeline.rs b/src/pipeline.rs index 7c277a58..ca6e974d 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -42,13 +42,6 @@ enum PipelineState { } impl> PipelineState { - fn is_response(&self) -> bool { - match *self { - PipelineState::Response(_) => true, - _ => false, - } - } - fn poll( &mut self, info: &mut PipelineInfo, mws: &[Box>], ) -> Option> { @@ -58,7 +51,8 @@ impl> PipelineState { PipelineState::RunMiddlewares(ref mut state) => state.poll(info, mws), PipelineState::Finishing(ref mut state) => state.poll(info, mws), PipelineState::Completed(ref mut state) => state.poll(info), - PipelineState::Response(_) | PipelineState::None | PipelineState::Error => { + PipelineState::Response(ref mut state) => state.poll(info, mws), + PipelineState::None | PipelineState::Error => { None } } @@ -130,22 +124,20 @@ impl> HttpHandlerTask for Pipeline { let mut state = mem::replace(&mut self.1, PipelineState::None); loop { - if state.is_response() { - if let PipelineState::Response(st) = state { - match st.poll_io(io, &mut self.0, &self.2) { - Ok(state) => { - self.1 = state; - if let Some(error) = self.0.error.take() { - return Err(error); - } else { - return Ok(Async::Ready(self.is_done())); - } - } - Err(state) => { - self.1 = state; - return Ok(Async::NotReady); + if let PipelineState::Response(st) = state { + match st.poll_io(io, &mut self.0, &self.2) { + Ok(state) => { + self.1 = state; + if let Some(error) = self.0.error.take() { + return Err(error); + } else { + return Ok(Async::Ready(self.is_done())); } } + Err(state) => { + self.1 = state; + return Ok(Async::NotReady); + } } } match state { @@ -401,7 +393,7 @@ impl RunMiddlewares { } struct ProcessResponse { - resp: HttpResponse, + resp: Option, iostate: IOState, running: RunningState, drain: Option>, @@ -442,7 +434,7 @@ impl ProcessResponse { #[inline] fn init(resp: HttpResponse) -> PipelineState { PipelineState::Response(ProcessResponse { - resp, + resp: Some(resp), iostate: IOState::Response, running: RunningState::Running, drain: None, @@ -451,6 +443,59 @@ impl ProcessResponse { }) } + fn poll( + &mut self, info: &mut PipelineInfo, mws: &[Box>], + ) -> Option> { + println!("POLL"); + // connection is dead at this point + match mem::replace(&mut self.iostate, IOState::Done) { + IOState::Response => + Some(FinishingMiddlewares::init(info, mws, self.resp.take().unwrap())), + IOState::Payload(_) => + Some(FinishingMiddlewares::init(info, mws, self.resp.take().unwrap())), + IOState::Actor(mut ctx) => { + if info.disconnected.take().is_some() { + ctx.disconnected(); + } + loop { + match ctx.poll() { + Ok(Async::Ready(Some(vec))) => { + if vec.is_empty() { + continue; + } + for frame in vec { + match frame { + Frame::Chunk(None) => { + info.context = Some(ctx); + return Some(FinishingMiddlewares::init( + info, mws, self.resp.take().unwrap(), + )) + } + Frame::Chunk(Some(_)) => (), + Frame::Drain(fut) => {let _ = fut.send(());}, + } + } + } + Ok(Async::Ready(None)) => + return Some(FinishingMiddlewares::init( + info, mws, self.resp.take().unwrap(), + )), + Ok(Async::NotReady) => { + self.iostate = IOState::Actor(ctx); + return None; + } + Err(err) => { + info.context = Some(ctx); + info.error = Some(err); + return Some(FinishingMiddlewares::init(info, mws, self.resp.take().unwrap())); + } + } + } + } + IOState::Done => Some(FinishingMiddlewares::init(info, mws, self.resp.take().unwrap())) + } + } + fn poll_io( mut self, io: &mut Writer, info: &mut PipelineInfo, mws: &[Box>], @@ -462,24 +507,24 @@ impl ProcessResponse { let result = match mem::replace(&mut self.iostate, IOState::Done) { IOState::Response => { let encoding = - self.resp.content_encoding().unwrap_or(info.encoding); + self.resp.as_ref().unwrap().content_encoding().unwrap_or(info.encoding); let result = - match io.start(&info.req, &mut self.resp, encoding) { + match io.start(&info.req, self.resp.as_mut().unwrap(), encoding) { Ok(res) => res, Err(err) => { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, mws, self.resp.take().unwrap(), )); } }; - if let Some(err) = self.resp.error() { - if self.resp.status().is_server_error() { + if let Some(err) = self.resp.as_ref().unwrap().error() { + if self.resp.as_ref().unwrap().status().is_server_error() { error!( "Error occured during request handling, status: {} {}", - self.resp.status(), err + self.resp.as_ref().unwrap().status(), err ); } else { warn!( @@ -493,7 +538,7 @@ impl ProcessResponse { } // always poll stream or actor for the first time - match self.resp.replace_body(Body::Empty) { + match self.resp.as_mut().unwrap().replace_body(Body::Empty) { Body::Streaming(stream) => { self.iostate = IOState::Payload(stream); continue 'inner; @@ -512,7 +557,7 @@ impl ProcessResponse { if let Err(err) = io.write_eof() { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, mws, self.resp.take().unwrap(), )); } break; @@ -523,7 +568,7 @@ impl ProcessResponse { Err(err) => { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, mws, self.resp.take().unwrap(), )); } Ok(result) => result, @@ -536,7 +581,7 @@ impl ProcessResponse { Err(err) => { info.error = Some(err); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, mws, self.resp.take().unwrap(), )); } }, @@ -559,7 +604,7 @@ impl ProcessResponse { info.error = Some(err.into()); return Ok( FinishingMiddlewares::init( - info, mws, self.resp, + info, mws, self.resp.take().unwrap(), ), ); } @@ -572,7 +617,7 @@ impl ProcessResponse { info.error = Some(err.into()); return Ok( FinishingMiddlewares::init( - info, mws, self.resp, + info, mws, self.resp.take().unwrap(), ), ); } @@ -598,7 +643,7 @@ impl ProcessResponse { info.context = Some(ctx); info.error = Some(err); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, mws, self.resp.take().unwrap(), )); } } @@ -638,7 +683,7 @@ impl ProcessResponse { info.context = Some(ctx); } info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, mws, self.resp)); + return Ok(FinishingMiddlewares::init(info, mws, self.resp.take().unwrap())); } } } @@ -652,11 +697,11 @@ impl ProcessResponse { Ok(_) => (), Err(err) => { info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, mws, self.resp)); + return Ok(FinishingMiddlewares::init(info, mws, self.resp.take().unwrap())); } } - self.resp.set_response_size(io.written()); - Ok(FinishingMiddlewares::init(info, mws, self.resp)) + self.resp.as_mut().unwrap().set_response_size(io.written()); + Ok(FinishingMiddlewares::init(info, mws, self.resp.take().unwrap())) } _ => Err(PipelineState::Response(self)), } diff --git a/src/server/h2.rs b/src/server/h2.rs index 9f072502..d52dc74f 100644 --- a/src/server/h2.rs +++ b/src/server/h2.rs @@ -102,13 +102,19 @@ where loop { let mut not_ready = true; + let disconnected = self.flags.contains(Flags::DISCONNECTED); // check in-flight connections for item in &mut self.tasks { // read payload - item.poll_payload(); + if !disconnected { + item.poll_payload(); + } if !item.flags.contains(EntryFlags::EOF) { + if disconnected { + item.flags.insert(EntryFlags::EOF); + } else { let retry = item.payload.need_read() == PayloadStatus::Read; loop { match item.task.poll_io(&mut item.stream) { @@ -141,12 +147,14 @@ where } break; } - } else if !item.flags.contains(EntryFlags::FINISHED) { + } + } + + if item.flags.contains(EntryFlags::EOF) && !item.flags.contains(EntryFlags::FINISHED) { match item.task.poll_completed() { Ok(Async::NotReady) => (), Ok(Async::Ready(_)) => { - not_ready = false; - item.flags.insert(EntryFlags::FINISHED); + item.flags.insert(EntryFlags::FINISHED | EntryFlags::WRITE_DONE); } Err(err) => { item.flags.insert( @@ -161,6 +169,7 @@ where if item.flags.contains(EntryFlags::FINISHED) && !item.flags.contains(EntryFlags::WRITE_DONE) + && !disconnected { match item.stream.poll_completed(false) { Ok(Async::NotReady) => (), @@ -168,7 +177,7 @@ where not_ready = false; item.flags.insert(EntryFlags::WRITE_DONE); } - Err(_err) => { + Err(_) => { item.flags.insert(EntryFlags::ERROR); } } @@ -177,7 +186,7 @@ where // cleanup finished tasks while !self.tasks.is_empty() { - if self.tasks[0].flags.contains(EntryFlags::EOF) + if self.tasks[0].flags.contains(EntryFlags::FINISHED) && self.tasks[0].flags.contains(EntryFlags::WRITE_DONE) || self.tasks[0].flags.contains(EntryFlags::ERROR) { @@ -397,6 +406,7 @@ impl Entry { } Ok(Async::NotReady) => break, Err(err) => { + println!("POLL-PAYLOAD error: {:?}", err); self.payload.set_error(PayloadError::Http2(err)); break; } diff --git a/src/server/h2writer.rs b/src/server/h2writer.rs index 511929fa..ce61b3ed 100644 --- a/src/server/h2writer.rs +++ b/src/server/h2writer.rs @@ -167,7 +167,6 @@ impl Writer for H2Writer { Ok(WriterState::Done) } else { self.flags.insert(Flags::EOF); - self.written = bytes.len() as u64; self.buffer.write(bytes.as_ref())?; if let Some(ref mut stream) = self.stream { self.flags.insert(Flags::RESERVED); @@ -183,8 +182,6 @@ impl Writer for H2Writer { } fn write(&mut self, payload: &Binary) -> io::Result { - self.written = payload.len() as u64; - if !self.flags.contains(Flags::DISCONNECTED) { if self.flags.contains(Flags::STARTED) { // TODO: add warning, write after EOF @@ -253,7 +250,9 @@ impl Writer for H2Writer { return Ok(Async::Ready(())); } } - Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + Err(e) => { + return Err(io::Error::new(io::ErrorKind::Other, e)) + } } } } diff --git a/src/test.rs b/src/test.rs index 244c079a..70de5a16 100644 --- a/src/test.rs +++ b/src/test.rs @@ -15,8 +15,10 @@ use tokio::runtime::current_thread::Runtime; #[cfg(feature = "alpn")] use openssl::ssl::SslAcceptorBuilder; -#[cfg(all(feature = "rust-tls"))] +#[cfg(feature = "rust-tls")] use rustls::ServerConfig; +#[cfg(feature = "rust-tls")] +use server::RustlsAcceptor; use application::{App, HttpApplication}; use body::Binary; @@ -342,7 +344,7 @@ impl TestServerBuilder { let ssl = self.rust_ssl.take(); if let Some(ssl) = ssl { let tcp = net::TcpListener::bind(addr).unwrap(); - srv = srv.listen_rustls(tcp, ssl).unwrap(); + srv = srv.listen_with(tcp, RustlsAcceptor::new(ssl)).unwrap(); } } if !has_ssl { diff --git a/tests/test_ws.rs b/tests/test_ws.rs index aa57faf6..752e88b5 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -210,7 +210,7 @@ impl Ws2 { ctx.drain() .and_then(|_, act, ctx| { act.count += 1; - if act.count != 10_000 { + if act.count != 1_000 { act.send(ctx); } actix::fut::ok(()) @@ -248,7 +248,7 @@ fn test_server_send_text() { }); let (mut reader, _writer) = srv.ws().unwrap(); - for _ in 0..10_000 { + for _ in 0..1_000 { let (item, r) = srv.execute(reader.into_future()).unwrap(); reader = r; assert_eq!(item, data); @@ -272,7 +272,7 @@ fn test_server_send_bin() { }); let (mut reader, _writer) = srv.ws().unwrap(); - for _ in 0..10_000 { + for _ in 0..1_000 { let (item, r) = srv.execute(reader.into_future()).unwrap(); reader = r; assert_eq!(item, data); @@ -308,7 +308,7 @@ fn test_ws_server_ssl() { let (mut reader, _writer) = srv.ws().unwrap(); let data = Some(ws::Message::Text("0".repeat(65_536))); - for _ in 0..10_000 { + for _ in 0..1_000 { let (item, r) = srv.execute(reader.into_future()).unwrap(); reader = r; assert_eq!(item, data); @@ -347,7 +347,7 @@ fn test_ws_server_rust_tls() { let (mut reader, _writer) = srv.ws().unwrap(); let data = Some(ws::Message::Text("0".repeat(65_536))); - for _ in 0..10_000 { + for _ in 0..1_000 { let (item, r) = srv.execute(reader.into_future()).unwrap(); reader = r; assert_eq!(item, data);