1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-30 18:34:36 +01:00

refactor http/1 dispatcher

This commit is contained in:
Nikolay Kim 2018-10-07 09:48:53 -07:00
parent 9c4a55c95c
commit 13193a0721
6 changed files with 68 additions and 106 deletions

View File

@ -341,15 +341,6 @@ pub enum PayloadError {
/// A payload length is unknown. /// A payload length is unknown.
#[fail(display = "A payload length is unknown.")] #[fail(display = "A payload length is unknown.")]
UnknownLength, UnknownLength,
/// Io error
#[fail(display = "{}", _0)]
Io(#[cause] IoError),
}
impl From<IoError> for PayloadError {
fn from(err: IoError) -> PayloadError {
PayloadError::Io(err)
}
} }
/// `PayloadError` returns two possible results: /// `PayloadError` returns two possible results:
@ -374,7 +365,7 @@ impl ResponseError for cookie::ParseError {
#[derive(Debug)] #[derive(Debug)]
/// A set of errors that can occur during dispatching http requests /// A set of errors that can occur during dispatching http requests
pub enum DispatchError<E: fmt::Display + fmt::Debug> { pub enum DispatchError<E: fmt::Debug> {
/// Service error /// Service error
// #[fail(display = "Application specific error: {}", _0)] // #[fail(display = "Application specific error: {}", _0)]
Service(E), Service(E),
@ -413,13 +404,13 @@ pub enum DispatchError<E: fmt::Display + fmt::Debug> {
Unknown, Unknown,
} }
impl<E: fmt::Display + fmt::Debug> From<ParseError> for DispatchError<E> { impl<E: fmt::Debug> From<ParseError> for DispatchError<E> {
fn from(err: ParseError) -> Self { fn from(err: ParseError) -> Self {
DispatchError::Parse(err) DispatchError::Parse(err)
} }
} }
impl<E: fmt::Display + fmt::Debug> From<io::Error> for DispatchError<E> { impl<E: fmt::Debug> From<io::Error> for DispatchError<E> {
fn from(err: io::Error) -> Self { fn from(err: io::Error) -> Self {
DispatchError::Io(err) DispatchError::Io(err)
} }

View File

@ -219,9 +219,7 @@ impl PayloadDecoder {
} }
pub fn eof() -> PayloadDecoder { pub fn eof() -> PayloadDecoder {
PayloadDecoder { PayloadDecoder { kind: Kind::Eof }
kind: Kind::Eof(false),
}
} }
} }
@ -246,7 +244,7 @@ enum Kind {
/// > the final encoding, the message body length cannot be determined /// > the final encoding, the message body length cannot be determined
/// > reliably; the server MUST respond with the 400 (Bad Request) /// > reliably; the server MUST respond with the 400 (Bad Request)
/// > status code and then close the connection. /// > status code and then close the connection.
Eof(bool), Eof,
} }
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
@ -309,13 +307,11 @@ impl Decoder for PayloadDecoder {
} }
} }
} }
Kind::Eof(ref mut is_eof) => { Kind::Eof => {
if *is_eof { if src.is_empty() {
Ok(Some(PayloadItem::Eof))
} else if !src.is_empty() {
Ok(Some(PayloadItem::Chunk(src.take().freeze())))
} else {
Ok(None) Ok(None)
} else {
Ok(Some(PayloadItem::Chunk(src.take().freeze())))
} }
} }
} }

View File

@ -1,5 +1,5 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::fmt::{Debug, Display}; use std::fmt::Debug;
use std::time::Instant; use std::time::Instant;
use actix_net::codec::Framed; use actix_net::codec::Framed;
@ -27,18 +27,17 @@ bitflags! {
const STARTED = 0b0000_0001; const STARTED = 0b0000_0001;
const KEEPALIVE_ENABLED = 0b0000_0010; const KEEPALIVE_ENABLED = 0b0000_0010;
const KEEPALIVE = 0b0000_0100; const KEEPALIVE = 0b0000_0100;
const SHUTDOWN = 0b0000_1000; const POLLED = 0b0000_1000;
const READ_DISCONNECTED = 0b0001_0000; const FLUSHED = 0b0001_0000;
const WRITE_DISCONNECTED = 0b0010_0000; const SHUTDOWN = 0b0010_0000;
const POLLED = 0b0100_0000; const DISCONNECTED = 0b0100_0000;
const FLUSHED = 0b1000_0000;
} }
} }
/// Dispatcher for HTTP/1.1 protocol /// Dispatcher for HTTP/1.1 protocol
pub struct Dispatcher<T, S: Service> pub struct Dispatcher<T, S: Service>
where where
S::Error: Debug + Display, S::Error: Debug,
{ {
service: S, service: S,
flags: Flags, flags: Flags,
@ -81,7 +80,7 @@ impl<T, S> Dispatcher<T, S>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: Service<Request = Request, Response = Response>, S: Service<Request = Request, Response = Response>,
S::Error: Debug + Display, S::Error: Debug,
{ {
/// Create http/1 dispatcher. /// Create http/1 dispatcher.
pub fn new(stream: T, config: ServiceConfig, service: S) -> Self { pub fn new(stream: T, config: ServiceConfig, service: S) -> Self {
@ -122,9 +121,8 @@ where
} }
} }
#[inline]
fn can_read(&self) -> bool { fn can_read(&self) -> bool {
if self.flags.contains(Flags::READ_DISCONNECTED) { if self.flags.contains(Flags::DISCONNECTED) {
return false; return false;
} }
@ -137,7 +135,7 @@ where
// if checked is set to true, delay disconnect until all tasks have finished. // if checked is set to true, delay disconnect until all tasks have finished.
fn client_disconnected(&mut self) { fn client_disconnected(&mut self) {
self.flags.insert(Flags::READ_DISCONNECTED); self.flags.insert(Flags::DISCONNECTED);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete); payload.set_error(PayloadError::Incomplete);
} }
@ -145,12 +143,11 @@ where
/// Flush stream /// Flush stream
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> { fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
if self.flags.contains(Flags::STARTED) && !self.flags.contains(Flags::FLUSHED) { if !self.flags.contains(Flags::FLUSHED) {
match self.framed.poll_complete() { match self.framed.poll_complete() {
Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => { Err(err) => {
debug!("Error sending data: {}", err); debug!("Error sending data: {}", err);
self.client_disconnected();
Err(err.into()) Err(err.into())
} }
Ok(Async::Ready(_)) => { Ok(Async::Ready(_)) => {
@ -167,8 +164,7 @@ where
} }
} }
pub(self) fn poll_handler(&mut self) -> Result<(), DispatchError<S::Error>> { fn poll_response(&mut self) -> Result<(), DispatchError<S::Error>> {
self.poll_io()?;
let mut retry = self.can_read(); let mut retry = self.can_read();
// process // process
@ -221,7 +217,6 @@ where
return Ok(()); return Ok(());
} }
Err(err) => { Err(err) => {
self.flags.insert(Flags::READ_DISCONNECTED);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete); payload.set_error(PayloadError::Incomplete);
} }
@ -246,7 +241,6 @@ where
return Ok(()); return Ok(());
} }
Err(err) => { Err(err) => {
self.flags.insert(Flags::READ_DISCONNECTED);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete); payload.set_error(PayloadError::Incomplete);
} }
@ -261,7 +255,7 @@ where
None => { None => {
// if read-backpressure is enabled and we consumed some data. // if read-backpressure is enabled and we consumed some data.
// we may read more dataand retry // we may read more dataand retry
if !retry && self.can_read() && self.poll_io()? { if !retry && self.can_read() && self.poll_request()? {
retry = self.can_read(); retry = self.can_read();
continue; continue;
} }
@ -319,7 +313,7 @@ where
payload.feed_data(chunk); payload.feed_data(chunk);
} else { } else {
error!("Internal server error: unexpected payload chunk"); error!("Internal server error: unexpected payload chunk");
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); self.flags.insert(Flags::DISCONNECTED);
self.messages.push_back(Message::Error( self.messages.push_back(Message::Error(
Response::InternalServerError().finish(), Response::InternalServerError().finish(),
)); ));
@ -331,7 +325,7 @@ where
payload.feed_eof(); payload.feed_eof();
} else { } else {
error!("Internal server error: unexpected eof"); error!("Internal server error: unexpected eof");
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); self.flags.insert(Flags::DISCONNECTED);
self.messages.push_back(Message::Error( self.messages.push_back(Message::Error(
Response::InternalServerError().finish(), Response::InternalServerError().finish(),
)); ));
@ -343,7 +337,7 @@ where
Ok(()) Ok(())
} }
pub(self) fn poll_io(&mut self) -> Result<bool, DispatchError<S::Error>> { pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError<S::Error>> {
let mut updated = false; let mut updated = false;
if self.messages.len() < MAX_PIPELINED_MESSAGES { if self.messages.len() < MAX_PIPELINED_MESSAGES {
@ -354,26 +348,25 @@ where
self.one_message(msg)?; self.one_message(msg)?;
} }
Ok(Async::Ready(None)) => { Ok(Async::Ready(None)) => {
if self.flags.contains(Flags::READ_DISCONNECTED) { self.client_disconnected();
self.client_disconnected();
}
break; break;
} }
Ok(Async::NotReady) => break, Ok(Async::NotReady) => break,
Err(ParseError::Io(e)) => {
self.client_disconnected();
self.error = Some(DispatchError::Io(e));
break;
}
Err(e) => { Err(e) => {
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
let e = match e { payload.set_error(PayloadError::EncodingCorrupted);
ParseError::Io(e) => PayloadError::Io(e),
_ => PayloadError::EncodingCorrupted,
};
payload.set_error(e);
} }
// Malformed requests should be responded with 400 // Malformed requests should be responded with 400
self.messages self.messages
.push_back(Message::Error(Response::BadRequest().finish())); .push_back(Message::Error(Response::BadRequest().finish()));
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); self.flags.insert(Flags::DISCONNECTED);
self.error = Some(DispatchError::MalformedRequest); self.error = Some(e.into());
break; break;
} }
} }
@ -402,8 +395,7 @@ where
} else if !self.flags.contains(Flags::STARTED) { } else if !self.flags.contains(Flags::STARTED) {
// timeout on first request (slow request) return 408 // timeout on first request (slow request) return 408
trace!("Slow request timeout"); trace!("Slow request timeout");
self.flags self.flags.insert(Flags::STARTED | Flags::DISCONNECTED);
.insert(Flags::STARTED | Flags::READ_DISCONNECTED);
self.state = self.state =
State::SendResponse(Some(OutMessage::Response( State::SendResponse(Some(OutMessage::Response(
Response::RequestTimeout().finish(), Response::RequestTimeout().finish(),
@ -444,54 +436,43 @@ impl<T, S> Future for Dispatcher<T, S>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: Service<Request = Request, Response = Response>, S: Service<Request = Request, Response = Response>,
S::Error: Debug + Display, S::Error: Debug,
{ {
type Item = (); type Item = ();
type Error = DispatchError<S::Error>; type Error = DispatchError<S::Error>;
#[inline] #[inline]
fn poll(&mut self) -> Poll<(), Self::Error> { fn poll(&mut self) -> Poll<(), Self::Error> {
self.poll_keepalive()?;
// shutdown
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
if self.flags.contains(Flags::WRITE_DISCONNECTED) { self.poll_keepalive()?;
return Ok(Async::Ready(()));
}
try_ready!(self.poll_flush()); try_ready!(self.poll_flush());
return Ok(AsyncWrite::shutdown(self.framed.get_mut())?); Ok(AsyncWrite::shutdown(self.framed.get_mut())?)
} } else {
self.poll_keepalive()?;
// process incoming requests self.poll_request()?;
if !self.flags.contains(Flags::WRITE_DISCONNECTED) { self.poll_response()?;
self.poll_handler()?;
// flush stream
self.poll_flush()?; self.poll_flush()?;
// deal with keep-alive and stream eof (client-side write shutdown) // keep-alive and stream errors
if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) {
// handle stream eof if let Some(err) = self.error.take() {
if self Err(err)
.flags } else if self.flags.contains(Flags::DISCONNECTED) {
.intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) Ok(Async::Ready(()))
{
return Ok(Async::Ready(()));
} }
// no keep-alive // disconnect if keep-alive is not enabled
if self.flags.contains(Flags::STARTED) else if self.flags.contains(Flags::STARTED) && !self
&& (!self.flags.contains(Flags::KEEPALIVE_ENABLED) .flags
|| !self.flags.contains(Flags::KEEPALIVE)) .intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED)
{ {
self.flags.insert(Flags::SHUTDOWN); self.flags.insert(Flags::SHUTDOWN);
return self.poll(); self.poll()
} else {
Ok(Async::NotReady)
} }
} else {
Ok(Async::NotReady)
} }
Ok(Async::NotReady)
} else if let Some(err) = self.error.take() {
Err(err)
} else {
Ok(Async::Ready(()))
} }
} }
} }

View File

@ -1,4 +1,4 @@
use std::fmt::{Debug, Display}; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::net; use std::net;
@ -26,7 +26,7 @@ impl<T, S> H1Service<T, S>
where where
S: NewService, S: NewService,
S::Service: Clone, S::Service: Clone,
S::Error: Debug + Display, S::Error: Debug,
{ {
/// Create new `HttpService` instance. /// Create new `HttpService` instance.
pub fn new<F: IntoNewService<S>>(service: F) -> Self { pub fn new<F: IntoNewService<S>>(service: F) -> Self {
@ -50,7 +50,7 @@ where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: NewService<Request = Request, Response = Response> + Clone, S: NewService<Request = Request, Response = Response> + Clone,
S::Service: Clone, S::Service: Clone,
S::Error: Debug + Display, S::Error: Debug,
{ {
type Request = T; type Request = T;
type Response = (); type Response = ();
@ -86,7 +86,7 @@ impl<T, S> H1ServiceBuilder<T, S>
where where
S: NewService, S: NewService,
S::Service: Clone, S::Service: Clone,
S::Error: Debug + Display, S::Error: Debug,
{ {
/// Create instance of `ServiceConfigBuilder` /// Create instance of `ServiceConfigBuilder`
pub fn new() -> H1ServiceBuilder<T, S> { pub fn new() -> H1ServiceBuilder<T, S> {
@ -203,7 +203,7 @@ where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: NewService<Request = Request, Response = Response>, S: NewService<Request = Request, Response = Response>,
S::Service: Clone, S::Service: Clone,
S::Error: Debug + Display, S::Error: Debug,
{ {
type Item = H1ServiceHandler<T, S::Service>; type Item = H1ServiceHandler<T, S::Service>;
type Error = S::InitError; type Error = S::InitError;
@ -227,7 +227,7 @@ pub struct H1ServiceHandler<T, S> {
impl<T, S> H1ServiceHandler<T, S> impl<T, S> H1ServiceHandler<T, S>
where where
S: Service<Request = Request, Response = Response> + Clone, S: Service<Request = Request, Response = Response> + Clone,
S::Error: Debug + Display, S::Error: Debug,
{ {
fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler<T, S> { fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler<T, S> {
H1ServiceHandler { H1ServiceHandler {
@ -242,7 +242,7 @@ impl<T, S> Service for H1ServiceHandler<T, S>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: Service<Request = Request, Response = Response> + Clone, S: Service<Request = Request, Response = Response> + Clone,
S::Error: Debug + Display, S::Error: Debug,
{ {
type Request = T; type Request = T;
type Response = (); type Response = ();

View File

@ -522,18 +522,11 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use failure::Fail;
use futures::future::{lazy, result}; use futures::future::{lazy, result};
use std::io;
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
#[test] #[test]
fn test_error() { fn test_error() {
let err: PayloadError =
io::Error::new(io::ErrorKind::Other, "ParseError").into();
assert_eq!(format!("{}", err), "ParseError");
assert_eq!(format!("{}", err.cause().unwrap()), "ParseError");
let err = PayloadError::Incomplete; let err = PayloadError::Incomplete;
assert_eq!( assert_eq!(
format!("{}", err), format!("{}", err),

View File

@ -11,7 +11,7 @@ use actix_net::server::Server;
use actix_web::{client, test, HttpMessage}; use actix_web::{client, test, HttpMessage};
use futures::future; use futures::future;
use actix_http::{h1, Error, KeepAlive, Request, Response}; use actix_http::{h1, KeepAlive, Request, Response};
#[test] #[test]
fn test_h1_v2() { fn test_h1_v2() {
@ -25,10 +25,11 @@ fn test_h1_v2() {
.client_disconnect(1000) .client_disconnect(1000)
.server_hostname("localhost") .server_hostname("localhost")
.server_address(addr) .server_address(addr)
.finish(|_| future::ok::<_, Error>(Response::Ok().finish())) .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
}).unwrap() }).unwrap()
.run(); .run();
}); });
thread::sleep(time::Duration::from_millis(100));
let mut sys = System::new("test"); let mut sys = System::new("test");
{ {
@ -48,7 +49,7 @@ fn test_slow_request() {
.bind("test", addr, move || { .bind("test", addr, move || {
h1::H1Service::build() h1::H1Service::build()
.client_timeout(100) .client_timeout(100)
.finish(|_| future::ok::<_, Error>(Response::Ok().finish())) .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
}).unwrap() }).unwrap()
.run(); .run();
}); });
@ -69,7 +70,7 @@ fn test_malformed_request() {
.bind("test", addr, move || { .bind("test", addr, move || {
h1::H1Service::build() h1::H1Service::build()
.client_timeout(100) .client_timeout(100)
.finish(|_| future::ok::<_, Error>(Response::Ok().finish())) .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
}).unwrap() }).unwrap()
.run(); .run();
}); });
@ -103,7 +104,7 @@ fn test_content_length() {
StatusCode::OK, StatusCode::OK,
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
]; ];
future::ok::<_, Error>(Response::new(statuses[indx])) future::ok::<_, ()>(Response::new(statuses[indx]))
}) })
}).unwrap() }).unwrap()
.run(); .run();