From 13193a0721110f0ad0dd046af76480322ae46430 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 7 Oct 2018 09:48:53 -0700 Subject: [PATCH] refactor http/1 dispatcher --- src/error.rs | 15 ++---- src/h1/decoder.rs | 16 +++---- src/h1/dispatcher.rs | 111 ++++++++++++++++++------------------------- src/h1/service.rs | 14 +++--- src/payload.rs | 7 --- tests/test_server.rs | 11 +++-- 6 files changed, 68 insertions(+), 106 deletions(-) diff --git a/src/error.rs b/src/error.rs index 1e60c3486..465b8ae0a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -341,15 +341,6 @@ pub enum PayloadError { /// A payload length is unknown. #[fail(display = "A payload length is unknown.")] UnknownLength, - /// Io error - #[fail(display = "{}", _0)] - Io(#[cause] IoError), -} - -impl From for PayloadError { - fn from(err: IoError) -> PayloadError { - PayloadError::Io(err) - } } /// `PayloadError` returns two possible results: @@ -374,7 +365,7 @@ impl ResponseError for cookie::ParseError { #[derive(Debug)] /// A set of errors that can occur during dispatching http requests -pub enum DispatchError { +pub enum DispatchError { /// Service error // #[fail(display = "Application specific error: {}", _0)] Service(E), @@ -413,13 +404,13 @@ pub enum DispatchError { Unknown, } -impl From for DispatchError { +impl From for DispatchError { fn from(err: ParseError) -> Self { DispatchError::Parse(err) } } -impl From for DispatchError { +impl From for DispatchError { fn from(err: io::Error) -> Self { DispatchError::Io(err) } diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index fb29f033c..d0c3fa048 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -219,9 +219,7 @@ impl PayloadDecoder { } pub fn eof() -> PayloadDecoder { - PayloadDecoder { - kind: Kind::Eof(false), - } + PayloadDecoder { kind: Kind::Eof } } } @@ -246,7 +244,7 @@ enum Kind { /// > the final encoding, the message body length cannot be determined /// > reliably; the server MUST respond with the 400 (Bad Request) /// > status code and then close the connection. - Eof(bool), + Eof, } #[derive(Debug, PartialEq, Clone)] @@ -309,13 +307,11 @@ impl Decoder for PayloadDecoder { } } } - Kind::Eof(ref mut is_eof) => { - if *is_eof { - Ok(Some(PayloadItem::Eof)) - } else if !src.is_empty() { - Ok(Some(PayloadItem::Chunk(src.take().freeze()))) - } else { + Kind::Eof => { + if src.is_empty() { Ok(None) + } else { + Ok(Some(PayloadItem::Chunk(src.take().freeze()))) } } } diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 3bf17a8b3..92bec3544 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -1,5 +1,5 @@ use std::collections::VecDeque; -use std::fmt::{Debug, Display}; +use std::fmt::Debug; use std::time::Instant; use actix_net::codec::Framed; @@ -27,18 +27,17 @@ bitflags! { const STARTED = 0b0000_0001; const KEEPALIVE_ENABLED = 0b0000_0010; const KEEPALIVE = 0b0000_0100; - const SHUTDOWN = 0b0000_1000; - const READ_DISCONNECTED = 0b0001_0000; - const WRITE_DISCONNECTED = 0b0010_0000; - const POLLED = 0b0100_0000; - const FLUSHED = 0b1000_0000; + const POLLED = 0b0000_1000; + const FLUSHED = 0b0001_0000; + const SHUTDOWN = 0b0010_0000; + const DISCONNECTED = 0b0100_0000; } } /// Dispatcher for HTTP/1.1 protocol pub struct Dispatcher where - S::Error: Debug + Display, + S::Error: Debug, { service: S, flags: Flags, @@ -81,7 +80,7 @@ impl Dispatcher where T: AsyncRead + AsyncWrite, S: Service, - S::Error: Debug + Display, + S::Error: Debug, { /// Create http/1 dispatcher. pub fn new(stream: T, config: ServiceConfig, service: S) -> Self { @@ -122,9 +121,8 @@ where } } - #[inline] fn can_read(&self) -> bool { - if self.flags.contains(Flags::READ_DISCONNECTED) { + if self.flags.contains(Flags::DISCONNECTED) { return false; } @@ -137,7 +135,7 @@ where // if checked is set to true, delay disconnect until all tasks have finished. fn client_disconnected(&mut self) { - self.flags.insert(Flags::READ_DISCONNECTED); + self.flags.insert(Flags::DISCONNECTED); if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::Incomplete); } @@ -145,12 +143,11 @@ where /// Flush stream fn poll_flush(&mut self) -> Poll<(), DispatchError> { - if self.flags.contains(Flags::STARTED) && !self.flags.contains(Flags::FLUSHED) { + if !self.flags.contains(Flags::FLUSHED) { match self.framed.poll_complete() { Ok(Async::NotReady) => Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); - self.client_disconnected(); Err(err.into()) } Ok(Async::Ready(_)) => { @@ -167,8 +164,7 @@ where } } - pub(self) fn poll_handler(&mut self) -> Result<(), DispatchError> { - self.poll_io()?; + fn poll_response(&mut self) -> Result<(), DispatchError> { let mut retry = self.can_read(); // process @@ -221,7 +217,6 @@ where return Ok(()); } Err(err) => { - self.flags.insert(Flags::READ_DISCONNECTED); if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::Incomplete); } @@ -246,7 +241,6 @@ where return Ok(()); } Err(err) => { - self.flags.insert(Flags::READ_DISCONNECTED); if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::Incomplete); } @@ -261,7 +255,7 @@ where None => { // if read-backpressure is enabled and we consumed some data. // we may read more dataand retry - if !retry && self.can_read() && self.poll_io()? { + if !retry && self.can_read() && self.poll_request()? { retry = self.can_read(); continue; } @@ -319,7 +313,7 @@ where payload.feed_data(chunk); } else { error!("Internal server error: unexpected payload chunk"); - self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + self.flags.insert(Flags::DISCONNECTED); self.messages.push_back(Message::Error( Response::InternalServerError().finish(), )); @@ -331,7 +325,7 @@ where payload.feed_eof(); } else { error!("Internal server error: unexpected eof"); - self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + self.flags.insert(Flags::DISCONNECTED); self.messages.push_back(Message::Error( Response::InternalServerError().finish(), )); @@ -343,7 +337,7 @@ where Ok(()) } - pub(self) fn poll_io(&mut self) -> Result> { + pub(self) fn poll_request(&mut self) -> Result> { let mut updated = false; if self.messages.len() < MAX_PIPELINED_MESSAGES { @@ -354,26 +348,25 @@ where self.one_message(msg)?; } Ok(Async::Ready(None)) => { - if self.flags.contains(Flags::READ_DISCONNECTED) { - self.client_disconnected(); - } + self.client_disconnected(); break; } Ok(Async::NotReady) => break, + Err(ParseError::Io(e)) => { + self.client_disconnected(); + self.error = Some(DispatchError::Io(e)); + break; + } Err(e) => { if let Some(mut payload) = self.payload.take() { - let e = match e { - ParseError::Io(e) => PayloadError::Io(e), - _ => PayloadError::EncodingCorrupted, - }; - payload.set_error(e); + payload.set_error(PayloadError::EncodingCorrupted); } // Malformed requests should be responded with 400 self.messages .push_back(Message::Error(Response::BadRequest().finish())); - self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); - self.error = Some(DispatchError::MalformedRequest); + self.flags.insert(Flags::DISCONNECTED); + self.error = Some(e.into()); break; } } @@ -402,8 +395,7 @@ where } else if !self.flags.contains(Flags::STARTED) { // timeout on first request (slow request) return 408 trace!("Slow request timeout"); - self.flags - .insert(Flags::STARTED | Flags::READ_DISCONNECTED); + self.flags.insert(Flags::STARTED | Flags::DISCONNECTED); self.state = State::SendResponse(Some(OutMessage::Response( Response::RequestTimeout().finish(), @@ -444,54 +436,43 @@ impl Future for Dispatcher where T: AsyncRead + AsyncWrite, S: Service, - S::Error: Debug + Display, + S::Error: Debug, { type Item = (); type Error = DispatchError; #[inline] fn poll(&mut self) -> Poll<(), Self::Error> { - self.poll_keepalive()?; - - // shutdown if self.flags.contains(Flags::SHUTDOWN) { - if self.flags.contains(Flags::WRITE_DISCONNECTED) { - return Ok(Async::Ready(())); - } + self.poll_keepalive()?; try_ready!(self.poll_flush()); - return Ok(AsyncWrite::shutdown(self.framed.get_mut())?); - } - - // process incoming requests - if !self.flags.contains(Flags::WRITE_DISCONNECTED) { - self.poll_handler()?; - - // flush stream + Ok(AsyncWrite::shutdown(self.framed.get_mut())?) + } else { + self.poll_keepalive()?; + self.poll_request()?; + self.poll_response()?; self.poll_flush()?; - // deal with keep-alive and stream eof (client-side write shutdown) + // keep-alive and stream errors if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { - // handle stream eof - if self - .flags - .intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) - { - return Ok(Async::Ready(())); + if let Some(err) = self.error.take() { + Err(err) + } else if self.flags.contains(Flags::DISCONNECTED) { + Ok(Async::Ready(())) } - // no keep-alive - if self.flags.contains(Flags::STARTED) - && (!self.flags.contains(Flags::KEEPALIVE_ENABLED) - || !self.flags.contains(Flags::KEEPALIVE)) + // disconnect if keep-alive is not enabled + else if self.flags.contains(Flags::STARTED) && !self + .flags + .intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED) { self.flags.insert(Flags::SHUTDOWN); - return self.poll(); + self.poll() + } else { + Ok(Async::NotReady) } + } else { + Ok(Async::NotReady) } - Ok(Async::NotReady) - } else if let Some(err) = self.error.take() { - Err(err) - } else { - Ok(Async::Ready(())) } } } diff --git a/src/h1/service.rs b/src/h1/service.rs index eea0e6d9f..de24f52c1 100644 --- a/src/h1/service.rs +++ b/src/h1/service.rs @@ -1,4 +1,4 @@ -use std::fmt::{Debug, Display}; +use std::fmt::Debug; use std::marker::PhantomData; use std::net; @@ -26,7 +26,7 @@ impl H1Service where S: NewService, S::Service: Clone, - S::Error: Debug + Display, + S::Error: Debug, { /// Create new `HttpService` instance. pub fn new>(service: F) -> Self { @@ -50,7 +50,7 @@ where T: AsyncRead + AsyncWrite, S: NewService + Clone, S::Service: Clone, - S::Error: Debug + Display, + S::Error: Debug, { type Request = T; type Response = (); @@ -86,7 +86,7 @@ impl H1ServiceBuilder where S: NewService, S::Service: Clone, - S::Error: Debug + Display, + S::Error: Debug, { /// Create instance of `ServiceConfigBuilder` pub fn new() -> H1ServiceBuilder { @@ -203,7 +203,7 @@ where T: AsyncRead + AsyncWrite, S: NewService, S::Service: Clone, - S::Error: Debug + Display, + S::Error: Debug, { type Item = H1ServiceHandler; type Error = S::InitError; @@ -227,7 +227,7 @@ pub struct H1ServiceHandler { impl H1ServiceHandler where S: Service + Clone, - S::Error: Debug + Display, + S::Error: Debug, { fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler { H1ServiceHandler { @@ -242,7 +242,7 @@ impl Service for H1ServiceHandler where T: AsyncRead + AsyncWrite, S: Service + Clone, - S::Error: Debug + Display, + S::Error: Debug, { type Request = T; type Response = (); diff --git a/src/payload.rs b/src/payload.rs index 3f51f6ec0..54539c408 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -522,18 +522,11 @@ where #[cfg(test)] mod tests { use super::*; - use failure::Fail; use futures::future::{lazy, result}; - use std::io; use tokio::runtime::current_thread::Runtime; #[test] fn test_error() { - let err: PayloadError = - io::Error::new(io::ErrorKind::Other, "ParseError").into(); - assert_eq!(format!("{}", err), "ParseError"); - assert_eq!(format!("{}", err.cause().unwrap()), "ParseError"); - let err = PayloadError::Incomplete; assert_eq!( format!("{}", err), diff --git a/tests/test_server.rs b/tests/test_server.rs index 8d682e120..43e3966df 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -11,7 +11,7 @@ use actix_net::server::Server; use actix_web::{client, test, HttpMessage}; use futures::future; -use actix_http::{h1, Error, KeepAlive, Request, Response}; +use actix_http::{h1, KeepAlive, Request, Response}; #[test] fn test_h1_v2() { @@ -25,10 +25,11 @@ fn test_h1_v2() { .client_disconnect(1000) .server_hostname("localhost") .server_address(addr) - .finish(|_| future::ok::<_, Error>(Response::Ok().finish())) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) }).unwrap() .run(); }); + thread::sleep(time::Duration::from_millis(100)); let mut sys = System::new("test"); { @@ -48,7 +49,7 @@ fn test_slow_request() { .bind("test", addr, move || { h1::H1Service::build() .client_timeout(100) - .finish(|_| future::ok::<_, Error>(Response::Ok().finish())) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) }).unwrap() .run(); }); @@ -69,7 +70,7 @@ fn test_malformed_request() { .bind("test", addr, move || { h1::H1Service::build() .client_timeout(100) - .finish(|_| future::ok::<_, Error>(Response::Ok().finish())) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) }).unwrap() .run(); }); @@ -103,7 +104,7 @@ fn test_content_length() { StatusCode::OK, StatusCode::NOT_FOUND, ]; - future::ok::<_, Error>(Response::new(statuses[indx])) + future::ok::<_, ()>(Response::new(statuses[indx])) }) }).unwrap() .run();