From 91af3ca148e7be9b48cd1d9bcaa316b442e2457c Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 1 Oct 2018 19:18:24 -0700 Subject: [PATCH] simplify h1 dispatcher --- src/lib.rs | 4 - src/server/error.rs | 12 ++ src/server/h1.rs | 425 ++++++++++++++++++---------------------- src/server/h1decoder.rs | 1 + src/server/handler.rs | 21 +- src/server/http.rs | 4 +- src/server/incoming.rs | 4 +- src/server/message.rs | 21 ++ 8 files changed, 249 insertions(+), 243 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 099b0b16c..df3c3817e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,10 +81,6 @@ specialization, // for impl ErrorResponse for std::error::Error extern_prelude, ))] -#![cfg_attr( - feature = "cargo-clippy", - allow(decimal_literal_representation, suspicious_arithmetic_impl) -)] #![warn(missing_docs)] #[macro_use] diff --git a/src/server/error.rs b/src/server/error.rs index 4396e6a2a..eb3e88478 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -36,10 +36,22 @@ pub enum HttpDispatchError { #[fail(display = "The first request did not complete within the specified timeout")] SlowRequestTimeout, + /// Shutdown timeout + #[fail(display = "Connection shutdown timeout")] + ShutdownTimeout, + /// HTTP2 error #[fail(display = "HTTP2 error: {}", _0)] Http2(http2::Error), + /// Malformed request + #[fail(display = "Malformed request")] + MalformedRequest, + + /// Internal error + #[fail(display = "Internal error")] + InternalError, + /// Unknown error #[fail(display = "Unknown error")] Unknown, diff --git a/src/server/h1.rs b/src/server/h1.rs index a1a6c0af4..f3c71e3c2 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -4,6 +4,7 @@ use std::time::Instant; use bytes::BytesMut; use futures::{Async, Future, Poll}; +use tokio_current_thread::spawn; use tokio_timer::Delay; use error::{Error, PayloadError}; @@ -13,17 +14,16 @@ use payload::{Payload, PayloadStatus, PayloadWriter}; use super::error::{HttpDispatchError, ServerError}; use super::h1decoder::{DecoderError, H1Decoder, Message}; use super::h1writer::H1Writer; +use super::handler::{HttpHandler, HttpHandlerTask, HttpHandlerTaskFut}; use super::input::PayloadType; use super::settings::WorkerSettings; -use super::Writer; -use super::{HttpHandler, HttpHandlerTask, IoStream}; +use super::{IoStream, Writer}; const MAX_PIPELINED_MESSAGES: usize = 16; bitflags! { pub struct Flags: u8 { const STARTED = 0b0000_0001; - const ERROR = 0b0000_0010; const KEEPALIVE = 0b0000_0100; const SHUTDOWN = 0b0000_1000; const READ_DISCONNECTED = 0b0001_0000; @@ -32,14 +32,6 @@ bitflags! { } } -bitflags! { - struct EntryFlags: u8 { - const EOF = 0b0000_0001; - const ERROR = 0b0000_0010; - const FINISHED = 0b0000_0100; - } -} - pub(crate) struct Http1 { flags: Flags, settings: WorkerSettings, @@ -49,39 +41,40 @@ pub(crate) struct Http1 { payload: Option, buf: BytesMut, tasks: VecDeque>, - error: Option, + error: Option, ka_enabled: bool, ka_expire: Instant, ka_timer: Option, } -struct Entry { - pipe: EntryPipe, - flags: EntryFlags, -} - -enum EntryPipe { +enum Entry { Task(H::Task), Error(Box), } -impl EntryPipe { +impl Entry { + fn into_task(self) -> H::Task { + match self { + Entry::Task(task) => task, + Entry::Error(_) => panic!(), + } + } fn disconnected(&mut self) { match *self { - EntryPipe::Task(ref mut task) => task.disconnected(), - EntryPipe::Error(ref mut task) => task.disconnected(), + Entry::Task(ref mut task) => task.disconnected(), + Entry::Error(ref mut task) => task.disconnected(), } } fn poll_io(&mut self, io: &mut Writer) -> Poll { match *self { - EntryPipe::Task(ref mut task) => task.poll_io(io), - EntryPipe::Error(ref mut task) => task.poll_io(io), + Entry::Task(ref mut task) => task.poll_io(io), + Entry::Error(ref mut task) => task.poll_io(io), } } fn poll_completed(&mut self) -> Poll<(), Error> { match *self { - EntryPipe::Task(ref mut task) => task.poll_completed(), - EntryPipe::Error(ref mut task) => task.poll_completed(), + Entry::Task(ref mut task) => task.poll_completed(), + Entry::Error(ref mut task) => task.poll_completed(), } } } @@ -136,10 +129,7 @@ where #[inline] fn can_read(&self) -> bool { - if self - .flags - .intersects(Flags::ERROR | Flags::READ_DISCONNECTED) - { + if self.flags.intersects(Flags::READ_DISCONNECTED) { return false; } @@ -150,41 +140,46 @@ where } } - fn write_disconnected(&mut self) { - self.flags.insert(Flags::WRITE_DISCONNECTED); - - // notify all tasks - self.stream.disconnected(); - for task in &mut self.tasks { - task.pipe.disconnected(); - } - } - - fn read_disconnected(&mut self) { - self.flags.insert( - Flags::READ_DISCONNECTED - // on parse error, stop reading stream but tasks need to be - // completed - | Flags::ERROR, - ); - + // if checked is set to true, delay disconnect until all tasks have finished. + fn client_disconnected(&mut self, checked: bool) { + self.flags.insert(Flags::READ_DISCONNECTED); if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::Incomplete); } + + if !checked || self.tasks.is_empty() { + self.flags.insert(Flags::WRITE_DISCONNECTED); + self.stream.disconnected(); + + // notify all tasks + for mut task in self.tasks.drain(..) { + task.disconnected(); + match task.poll_completed() { + Ok(Async::NotReady) => { + // spawn not completed task, it does not require access to io + // at this point + spawn(HttpHandlerTaskFut::new(task.into_task())); + } + Ok(Async::Ready(_)) => (), + Err(err) => { + error!("Unhandled application error: {}", err); + } + } + } + } } #[inline] pub fn poll(&mut self) -> Poll<(), HttpDispatchError> { // check connection keep-alive - if !self.poll_keep_alive() { - return Ok(Async::Ready(())); - } + self.poll_keep_alive()?; // shutdown if self.flags.contains(Flags::SHUTDOWN) { - if self.flags.intersects( - Flags::ERROR | Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED, - ) { + if self + .flags + .intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) + { return Ok(Async::Ready(())); } match self.stream.poll_completed(true) { @@ -197,44 +192,46 @@ where } } - self.poll_io(); + self.poll_io()?; - loop { + if !self.flags.contains(Flags::WRITE_DISCONNECTED) { match self.poll_handler()? { - Async::Ready(true) => { - self.poll_io(); - } + Async::Ready(true) => self.poll(), Async::Ready(false) => { self.flags.insert(Flags::SHUTDOWN); - return self.poll(); + self.poll() } Async::NotReady => { // deal with keep-alive and steam eof (client-side write shutdown) if self.tasks.is_empty() { // handle stream eof - if self.flags.contains(Flags::READ_DISCONNECTED) { - self.flags.insert(Flags::SHUTDOWN); - return self.poll(); + if self.flags.intersects( + Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED, + ) { + return Ok(Async::Ready(())); } // no keep-alive - if self.flags.contains(Flags::ERROR) - || (!self.flags.contains(Flags::KEEPALIVE) - || !self.ka_enabled) - && self.flags.contains(Flags::STARTED) + if self.flags.contains(Flags::STARTED) + && (!self.ka_enabled + || !self.flags.contains(Flags::KEEPALIVE)) { self.flags.insert(Flags::SHUTDOWN); return self.poll(); } } - return Ok(Async::NotReady); + Ok(Async::NotReady) } } + } else if let Some(err) = self.error.take() { + Err(err) + } else { + Ok(Async::Ready(())) } } /// keep-alive timer. returns `true` is keep-alive, otherwise drop - fn poll_keep_alive(&mut self) -> bool { - let timer = if let Some(ref mut timer) = self.ka_timer { + fn poll_keep_alive(&mut self) -> Result<(), HttpDispatchError> { + if let Some(ref mut timer) = self.ka_timer { match timer.poll() { Ok(Async::Ready(_)) => { if timer.deadline() >= self.ka_expire { @@ -242,43 +239,39 @@ where if self.tasks.is_empty() { // if we get timer during shutdown, just drop connection if self.flags.contains(Flags::SHUTDOWN) { - return false; + return Err(HttpDispatchError::ShutdownTimeout); } else { trace!("Keep-alive timeout, close connection"); self.flags.insert(Flags::SHUTDOWN); - None + // TODO: start shutdown timer + return Ok(()); } - } else { - self.settings.keep_alive_timer() + } else if let Some(deadline) = self.settings.keep_alive_expire() + { + timer.reset(deadline) } } else { - Some(Delay::new(self.ka_expire)) + timer.reset(self.ka_expire) } } - Ok(Async::NotReady) => None, + Ok(Async::NotReady) => (), Err(e) => { error!("Timer error {:?}", e); - return false; + return Err(HttpDispatchError::Unknown); } } - } else { - None - }; - - if let Some(mut timer) = timer { - let _ = timer.poll(); - self.ka_timer = Some(timer); } - true + + Ok(()) } #[inline] /// read data from stream - pub fn poll_io(&mut self) { + pub fn poll_io(&mut self) -> Result<(), HttpDispatchError> { if !self.flags.contains(Flags::POLLED) { - self.parse(); + self.parse()?; self.flags.insert(Flags::POLLED); - return; + return Ok(()); } // read io from socket @@ -286,136 +279,118 @@ where match self.stream.get_mut().read_available(&mut self.buf) { Ok(Async::Ready((read_some, disconnected))) => { if read_some { - self.parse(); + self.parse()?; } if disconnected { - self.read_disconnected(); - // delay disconnect until all tasks have finished. - if self.tasks.is_empty() { - self.write_disconnected(); - } + self.client_disconnected(true); } } Ok(Async::NotReady) => (), - Err(_) => { - self.read_disconnected(); - self.write_disconnected(); + Err(err) => { + self.client_disconnected(false); + return Err(err.into()); } } } + Ok(()) } pub fn poll_handler(&mut self) -> Poll { let retry = self.can_read(); - // check in-flight messages - let mut io = false; - let mut idx = 0; - while idx < self.tasks.len() { - // only one task can do io operation in http/1 - if !io - && !self.tasks[idx].flags.contains(EntryFlags::EOF) - && !self.flags.contains(Flags::WRITE_DISCONNECTED) - { - // io is corrupted, send buffer - if self.tasks[idx].flags.contains(EntryFlags::ERROR) { - if let Ok(Async::NotReady) = self.stream.poll_completed(true) { - return Ok(Async::NotReady); - } - self.flags.insert(Flags::ERROR); - return Err(self - .error - .take() - .map(|e| e.into()) - .unwrap_or(HttpDispatchError::Unknown)); - } - - match self.tasks[idx].pipe.poll_io(&mut self.stream) { - Ok(Async::Ready(ready)) => { - // override keep-alive state - if self.stream.keepalive() { - self.flags.insert(Flags::KEEPALIVE); - } else { - self.flags.remove(Flags::KEEPALIVE); - } - // prepare stream for next response - self.stream.reset(); - - if ready { - self.tasks[idx] - .flags - .insert(EntryFlags::EOF | EntryFlags::FINISHED); - } else { - self.tasks[idx].flags.insert(EntryFlags::EOF); - } - } - // no more IO for this iteration - Ok(Async::NotReady) => { - // check if we need timer - if self.ka_timer.is_some() && self.stream.upgrade() { - self.ka_timer.take(); - } - - // check if previously read backpressure was enabled - if self.can_read() && !retry { - return Ok(Async::Ready(true)); - } - io = true; - } - Err(err) => { - error!("Unhandled error1: {}", err); - // it is not possible to recover from error - // during pipe handling, so just drop connection - self.read_disconnected(); - self.write_disconnected(); - self.tasks[idx].flags.insert(EntryFlags::ERROR); - self.error = Some(err); - continue; - } - } - } else if !self.tasks[idx].flags.contains(EntryFlags::FINISHED) { - match self.tasks[idx].pipe.poll_completed() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => { - self.tasks[idx].flags.insert(EntryFlags::FINISHED) - } - Err(err) => { - error!("Unhandled error: {}", err); - self.read_disconnected(); - self.write_disconnected(); - self.tasks[idx].flags.insert(EntryFlags::ERROR); - self.error = Some(err); - continue; - } - } - } - idx += 1; - } - - // cleanup finished tasks + // process first pipelined response, only one task can do io operation in http/1 while !self.tasks.is_empty() { - if self.tasks[0] - .flags - .contains(EntryFlags::EOF | EntryFlags::FINISHED) - { - self.tasks.pop_front(); - } else { - break; + match self.tasks[0].poll_io(&mut self.stream) { + Ok(Async::Ready(ready)) => { + // override keep-alive state + if self.stream.keepalive() { + self.flags.insert(Flags::KEEPALIVE); + } else { + self.flags.remove(Flags::KEEPALIVE); + } + // prepare stream for next response + self.stream.reset(); + + let task = self.tasks.pop_front().unwrap(); + if !ready { + // task is done with io operations but still needs to do more work + spawn(HttpHandlerTaskFut::new(task.into_task())); + } + } + Ok(Async::NotReady) => { + // check if we need timer + if self.ka_timer.is_some() && self.stream.upgrade() { + self.ka_timer.take(); + } + + // if read-backpressure is enabled and we consumed some data. + // we may read more data + if !retry && self.can_read() { + return Ok(Async::Ready(true)); + } + break; + } + Err(err) => { + error!("Unhandled error1: {}", err); + // it is not possible to recover from error + // during pipe handling, so just drop connection + self.client_disconnected(false); + return Err(err.into()); + } } } - // check stream state + // check in-flight messages. all tasks must be alive, + // they need to produce response. if app returned error + // and we can not continue processing incoming requests. + let mut idx = 1; + while idx < self.tasks.len() { + let stop = match self.tasks[idx].poll_completed() { + Ok(Async::NotReady) => false, + Ok(Async::Ready(_)) => true, + Err(err) => { + self.error = Some(err.into()); + true + } + }; + if stop { + // error in task handling or task is completed, + // so no response for this task which means we can not read more requests + // because pipeline sequence is broken. + // but we can safely complete existing tasks + self.flags.insert(Flags::READ_DISCONNECTED); + + for mut task in self.tasks.drain(idx..) { + task.disconnected(); + match task.poll_completed() { + Ok(Async::NotReady) => { + // spawn not completed task, it does not require access to io + // at this point + spawn(HttpHandlerTaskFut::new(task.into_task())); + } + Ok(Async::Ready(_)) => (), + Err(err) => { + error!("Unhandled application error: {}", err); + } + } + } + break; + } else { + idx += 1; + } + } + + // flush stream if self.flags.contains(Flags::STARTED) { match self.stream.poll_completed(false) { Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); - self.read_disconnected(); - self.write_disconnected(); + self.client_disconnected(false); return Err(err.into()); } Ok(Async::Ready(_)) => { - // non consumed payload in that case close connection + // if payload is not consumed we can not use connection if self.payload.is_some() && self.tasks.is_empty() { return Ok(Async::Ready(false)); } @@ -427,13 +402,11 @@ where } fn push_response_entry(&mut self, status: StatusCode) { - self.tasks.push_back(Entry { - pipe: EntryPipe::Error(ServerError::err(Version::HTTP_11, status)), - flags: EntryFlags::empty(), - }); + self.tasks + .push_back(Entry::Error(ServerError::err(Version::HTTP_11, status))); } - pub fn parse(&mut self) { + pub fn parse(&mut self) -> Result<(), HttpDispatchError> { let mut updated = false; 'outer: loop { @@ -457,9 +430,9 @@ where // search handler for request match self.settings.handler().handle(msg) { - Ok(mut pipe) => { + Ok(mut task) => { if self.tasks.is_empty() { - match pipe.poll_io(&mut self.stream) { + match task.poll_io(&mut self.stream) { Ok(Async::Ready(ready)) => { // override keep-alive state if self.stream.keepalive() { @@ -471,42 +444,28 @@ where self.stream.reset(); if !ready { - let item = Entry { - pipe: EntryPipe::Task(pipe), - flags: EntryFlags::EOF, - }; - self.tasks.push_back(item); + // task is done with io operations + // but still needs to do more work + spawn(HttpHandlerTaskFut::new(task)); } continue 'outer; } Ok(Async::NotReady) => (), Err(err) => { error!("Unhandled error: {}", err); - self.flags.insert(Flags::ERROR); - return; + self.client_disconnected(false); + return Err(err.into()); } } } - self.tasks.push_back(Entry { - pipe: EntryPipe::Task(pipe), - flags: EntryFlags::empty(), - }); + self.tasks.push_back(Entry::Task(task)); continue 'outer; } Err(_) => { // handler is not found - self.tasks.push_back(Entry { - pipe: EntryPipe::Error(ServerError::err( - Version::HTTP_11, - StatusCode::NOT_FOUND, - )), - flags: EntryFlags::empty(), - }); + self.push_response_entry(StatusCode::NOT_FOUND); } } - - // handler is not found - self.push_response_entry(StatusCode::NOT_FOUND); } Ok(Some(Message::Chunk(chunk))) => { updated = true; @@ -514,8 +473,9 @@ where payload.feed_data(chunk); } else { error!("Internal server error: unexpected payload chunk"); - self.flags.insert(Flags::ERROR); + self.flags.insert(Flags::READ_DISCONNECTED); self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); + self.error = Some(HttpDispatchError::InternalError); break; } } @@ -525,23 +485,19 @@ where payload.feed_eof(); } else { error!("Internal server error: unexpected eof"); - self.flags.insert(Flags::ERROR); + self.flags.insert(Flags::READ_DISCONNECTED); self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); + self.error = Some(HttpDispatchError::InternalError); break; } } Ok(None) => { if self.flags.contains(Flags::READ_DISCONNECTED) { - self.read_disconnected(); - if self.tasks.is_empty() { - self.write_disconnected(); - } + self.client_disconnected(true); } break; } Err(e) => { - updated = false; - self.flags.insert(Flags::ERROR); if let Some(mut payload) = self.payload.take() { let e = match e { DecoderError::Io(e) => PayloadError::Io(e), @@ -550,8 +506,10 @@ where payload.set_error(e); } - //Malformed requests should be responded with 400 + // Malformed requests should be responded with 400 self.push_response_entry(StatusCode::BAD_REQUEST); + self.flags.insert(Flags::READ_DISCONNECTED); + self.error = Some(HttpDispatchError::MalformedRequest); break; } } @@ -562,6 +520,7 @@ where self.ka_expire = expire; } } + Ok(()) } } @@ -708,15 +667,15 @@ mod tests { #[test] fn test_req_parse_err() { let mut sys = System::new("test"); - sys.block_on(future::lazy(|| { + let _ = sys.block_on(future::lazy(|| { let buf = Buffer::new("GET /test HTTP/1\r\n\r\n"); let readbuf = BytesMut::new(); let settings = wrk_settings(); let mut h1 = Http1::new(settings.clone(), buf, None, readbuf, false, None); - h1.poll_io(); - h1.poll_io(); - assert!(h1.flags.contains(Flags::ERROR)); + assert!(h1.poll_io().is_ok()); + assert!(h1.poll_io().is_ok()); + assert!(h1.flags.contains(Flags::READ_DISCONNECTED)); assert_eq!(h1.tasks.len(), 1); future::ok::<_, ()>(()) })); diff --git a/src/server/h1decoder.rs b/src/server/h1decoder.rs index 084ae8b2f..a7531bbbd 100644 --- a/src/server/h1decoder.rs +++ b/src/server/h1decoder.rs @@ -18,6 +18,7 @@ pub(crate) struct H1Decoder { decoder: Option, } +#[derive(Debug)] pub(crate) enum Message { Message { msg: Request, payload: bool }, Chunk(Bytes), diff --git a/src/server/handler.rs b/src/server/handler.rs index 0700e1961..33e50ac34 100644 --- a/src/server/handler.rs +++ b/src/server/handler.rs @@ -1,4 +1,4 @@ -use futures::{Async, Poll}; +use futures::{Async, Future, Poll}; use super::message::Request; use super::Writer; @@ -42,6 +42,25 @@ impl HttpHandlerTask for Box { } } +pub(super) struct HttpHandlerTaskFut { + task: T, +} + +impl HttpHandlerTaskFut { + pub(crate) fn new(task: T) -> Self { + Self { task } + } +} + +impl Future for HttpHandlerTaskFut { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + self.task.poll_completed().map_err(|_| ()) + } +} + /// Conversion helper trait pub trait IntoHttpHandler { /// The associated type which is result of conversion. diff --git a/src/server/http.rs b/src/server/http.rs index 311c53cb2..511b1832e 100644 --- a/src/server/http.rs +++ b/src/server/http.rs @@ -485,7 +485,7 @@ impl H + Send + Clone> HttpServer { socket.lst, host, socket.addr, - self.keep_alive.clone(), + self.keep_alive, self.client_timeout, ); } @@ -531,7 +531,7 @@ impl H + Send + Clone> HttpServer { socket.lst, host, socket.addr, - self.keep_alive.clone(), + self.keep_alive, self.client_timeout, ); } diff --git a/src/server/incoming.rs b/src/server/incoming.rs index c77280084..a56ccb80f 100644 --- a/src/server/incoming.rs +++ b/src/server/incoming.rs @@ -41,9 +41,7 @@ where // start server HttpIncoming::create(move |ctx| { - ctx.add_message_stream( - stream.map_err(|_| ()).map(move |t| WrapperStream::new(t)), - ); + ctx.add_message_stream(stream.map_err(|_| ()).map(WrapperStream::new)); HttpIncoming { settings } }); } diff --git a/src/server/message.rs b/src/server/message.rs index 43f7e1425..9c4bc1ec4 100644 --- a/src/server/message.rs +++ b/src/server/message.rs @@ -1,5 +1,6 @@ use std::cell::{Cell, Ref, RefCell, RefMut}; use std::collections::VecDeque; +use std::fmt; use std::net::SocketAddr; use std::rc::Rc; @@ -220,6 +221,26 @@ impl Request { } } +impl fmt::Debug for Request { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nRequest {:?} {}:{}", + self.version(), + self.method(), + self.path() + )?; + if let Some(q) = self.uri().query().as_ref() { + writeln!(f, " query: ?{:?}", q)?; + } + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + pub(crate) struct RequestPool( RefCell>>, RefCell,