1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-23 15:24:36 +01:00

refactor http/1 dispatcher

This commit is contained in:
Nikolay Kim 2018-10-07 09:48:53 -07:00
parent 9c4a55c95c
commit 13193a0721
6 changed files with 68 additions and 106 deletions

View File

@ -341,15 +341,6 @@ pub enum PayloadError {
/// A payload length is unknown.
#[fail(display = "A payload length is unknown.")]
UnknownLength,
/// Io error
#[fail(display = "{}", _0)]
Io(#[cause] IoError),
}
impl From<IoError> for PayloadError {
fn from(err: IoError) -> PayloadError {
PayloadError::Io(err)
}
}
/// `PayloadError` returns two possible results:
@ -374,7 +365,7 @@ impl ResponseError for cookie::ParseError {
#[derive(Debug)]
/// A set of errors that can occur during dispatching http requests
pub enum DispatchError<E: fmt::Display + fmt::Debug> {
pub enum DispatchError<E: fmt::Debug> {
/// Service error
// #[fail(display = "Application specific error: {}", _0)]
Service(E),
@ -413,13 +404,13 @@ pub enum DispatchError<E: fmt::Display + fmt::Debug> {
Unknown,
}
impl<E: fmt::Display + fmt::Debug> From<ParseError> for DispatchError<E> {
impl<E: fmt::Debug> From<ParseError> for DispatchError<E> {
fn from(err: ParseError) -> Self {
DispatchError::Parse(err)
}
}
impl<E: fmt::Display + fmt::Debug> From<io::Error> for DispatchError<E> {
impl<E: fmt::Debug> From<io::Error> for DispatchError<E> {
fn from(err: io::Error) -> Self {
DispatchError::Io(err)
}

View File

@ -219,9 +219,7 @@ impl PayloadDecoder {
}
pub fn eof() -> PayloadDecoder {
PayloadDecoder {
kind: Kind::Eof(false),
}
PayloadDecoder { kind: Kind::Eof }
}
}
@ -246,7 +244,7 @@ enum Kind {
/// > the final encoding, the message body length cannot be determined
/// > reliably; the server MUST respond with the 400 (Bad Request)
/// > status code and then close the connection.
Eof(bool),
Eof,
}
#[derive(Debug, PartialEq, Clone)]
@ -309,13 +307,11 @@ impl Decoder for PayloadDecoder {
}
}
}
Kind::Eof(ref mut is_eof) => {
if *is_eof {
Ok(Some(PayloadItem::Eof))
} else if !src.is_empty() {
Ok(Some(PayloadItem::Chunk(src.take().freeze())))
} else {
Kind::Eof => {
if src.is_empty() {
Ok(None)
} else {
Ok(Some(PayloadItem::Chunk(src.take().freeze())))
}
}
}

View File

@ -1,5 +1,5 @@
use std::collections::VecDeque;
use std::fmt::{Debug, Display};
use std::fmt::Debug;
use std::time::Instant;
use actix_net::codec::Framed;
@ -27,18 +27,17 @@ bitflags! {
const STARTED = 0b0000_0001;
const KEEPALIVE_ENABLED = 0b0000_0010;
const KEEPALIVE = 0b0000_0100;
const SHUTDOWN = 0b0000_1000;
const READ_DISCONNECTED = 0b0001_0000;
const WRITE_DISCONNECTED = 0b0010_0000;
const POLLED = 0b0100_0000;
const FLUSHED = 0b1000_0000;
const POLLED = 0b0000_1000;
const FLUSHED = 0b0001_0000;
const SHUTDOWN = 0b0010_0000;
const DISCONNECTED = 0b0100_0000;
}
}
/// Dispatcher for HTTP/1.1 protocol
pub struct Dispatcher<T, S: Service>
where
S::Error: Debug + Display,
S::Error: Debug,
{
service: S,
flags: Flags,
@ -81,7 +80,7 @@ impl<T, S> Dispatcher<T, S>
where
T: AsyncRead + AsyncWrite,
S: Service<Request = Request, Response = Response>,
S::Error: Debug + Display,
S::Error: Debug,
{
/// Create http/1 dispatcher.
pub fn new(stream: T, config: ServiceConfig, service: S) -> Self {
@ -122,9 +121,8 @@ where
}
}
#[inline]
fn can_read(&self) -> bool {
if self.flags.contains(Flags::READ_DISCONNECTED) {
if self.flags.contains(Flags::DISCONNECTED) {
return false;
}
@ -137,7 +135,7 @@ where
// if checked is set to true, delay disconnect until all tasks have finished.
fn client_disconnected(&mut self) {
self.flags.insert(Flags::READ_DISCONNECTED);
self.flags.insert(Flags::DISCONNECTED);
if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete);
}
@ -145,12 +143,11 @@ where
/// Flush stream
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
if self.flags.contains(Flags::STARTED) && !self.flags.contains(Flags::FLUSHED) {
if !self.flags.contains(Flags::FLUSHED) {
match self.framed.poll_complete() {
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
debug!("Error sending data: {}", err);
self.client_disconnected();
Err(err.into())
}
Ok(Async::Ready(_)) => {
@ -167,8 +164,7 @@ where
}
}
pub(self) fn poll_handler(&mut self) -> Result<(), DispatchError<S::Error>> {
self.poll_io()?;
fn poll_response(&mut self) -> Result<(), DispatchError<S::Error>> {
let mut retry = self.can_read();
// process
@ -221,7 +217,6 @@ where
return Ok(());
}
Err(err) => {
self.flags.insert(Flags::READ_DISCONNECTED);
if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete);
}
@ -246,7 +241,6 @@ where
return Ok(());
}
Err(err) => {
self.flags.insert(Flags::READ_DISCONNECTED);
if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete);
}
@ -261,7 +255,7 @@ where
None => {
// if read-backpressure is enabled and we consumed some data.
// we may read more dataand retry
if !retry && self.can_read() && self.poll_io()? {
if !retry && self.can_read() && self.poll_request()? {
retry = self.can_read();
continue;
}
@ -319,7 +313,7 @@ where
payload.feed_data(chunk);
} else {
error!("Internal server error: unexpected payload chunk");
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.flags.insert(Flags::DISCONNECTED);
self.messages.push_back(Message::Error(
Response::InternalServerError().finish(),
));
@ -331,7 +325,7 @@ where
payload.feed_eof();
} else {
error!("Internal server error: unexpected eof");
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.flags.insert(Flags::DISCONNECTED);
self.messages.push_back(Message::Error(
Response::InternalServerError().finish(),
));
@ -343,7 +337,7 @@ where
Ok(())
}
pub(self) fn poll_io(&mut self) -> Result<bool, DispatchError<S::Error>> {
pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError<S::Error>> {
let mut updated = false;
if self.messages.len() < MAX_PIPELINED_MESSAGES {
@ -354,26 +348,25 @@ where
self.one_message(msg)?;
}
Ok(Async::Ready(None)) => {
if self.flags.contains(Flags::READ_DISCONNECTED) {
self.client_disconnected();
}
self.client_disconnected();
break;
}
Ok(Async::NotReady) => break,
Err(ParseError::Io(e)) => {
self.client_disconnected();
self.error = Some(DispatchError::Io(e));
break;
}
Err(e) => {
if let Some(mut payload) = self.payload.take() {
let e = match e {
ParseError::Io(e) => PayloadError::Io(e),
_ => PayloadError::EncodingCorrupted,
};
payload.set_error(e);
payload.set_error(PayloadError::EncodingCorrupted);
}
// Malformed requests should be responded with 400
self.messages
.push_back(Message::Error(Response::BadRequest().finish()));
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
self.error = Some(DispatchError::MalformedRequest);
self.flags.insert(Flags::DISCONNECTED);
self.error = Some(e.into());
break;
}
}
@ -402,8 +395,7 @@ where
} else if !self.flags.contains(Flags::STARTED) {
// timeout on first request (slow request) return 408
trace!("Slow request timeout");
self.flags
.insert(Flags::STARTED | Flags::READ_DISCONNECTED);
self.flags.insert(Flags::STARTED | Flags::DISCONNECTED);
self.state =
State::SendResponse(Some(OutMessage::Response(
Response::RequestTimeout().finish(),
@ -444,54 +436,43 @@ impl<T, S> Future for Dispatcher<T, S>
where
T: AsyncRead + AsyncWrite,
S: Service<Request = Request, Response = Response>,
S::Error: Debug + Display,
S::Error: Debug,
{
type Item = ();
type Error = DispatchError<S::Error>;
#[inline]
fn poll(&mut self) -> Poll<(), Self::Error> {
self.poll_keepalive()?;
// shutdown
if self.flags.contains(Flags::SHUTDOWN) {
if self.flags.contains(Flags::WRITE_DISCONNECTED) {
return Ok(Async::Ready(()));
}
self.poll_keepalive()?;
try_ready!(self.poll_flush());
return Ok(AsyncWrite::shutdown(self.framed.get_mut())?);
}
// process incoming requests
if !self.flags.contains(Flags::WRITE_DISCONNECTED) {
self.poll_handler()?;
// flush stream
Ok(AsyncWrite::shutdown(self.framed.get_mut())?)
} else {
self.poll_keepalive()?;
self.poll_request()?;
self.poll_response()?;
self.poll_flush()?;
// deal with keep-alive and stream eof (client-side write shutdown)
// keep-alive and stream errors
if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) {
// handle stream eof
if self
.flags
.intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED)
{
return Ok(Async::Ready(()));
if let Some(err) = self.error.take() {
Err(err)
} else if self.flags.contains(Flags::DISCONNECTED) {
Ok(Async::Ready(()))
}
// no keep-alive
if self.flags.contains(Flags::STARTED)
&& (!self.flags.contains(Flags::KEEPALIVE_ENABLED)
|| !self.flags.contains(Flags::KEEPALIVE))
// disconnect if keep-alive is not enabled
else if self.flags.contains(Flags::STARTED) && !self
.flags
.intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED)
{
self.flags.insert(Flags::SHUTDOWN);
return self.poll();
self.poll()
} else {
Ok(Async::NotReady)
}
} else {
Ok(Async::NotReady)
}
Ok(Async::NotReady)
} else if let Some(err) = self.error.take() {
Err(err)
} else {
Ok(Async::Ready(()))
}
}
}

View File

@ -1,4 +1,4 @@
use std::fmt::{Debug, Display};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::net;
@ -26,7 +26,7 @@ impl<T, S> H1Service<T, S>
where
S: NewService,
S::Service: Clone,
S::Error: Debug + Display,
S::Error: Debug,
{
/// Create new `HttpService` instance.
pub fn new<F: IntoNewService<S>>(service: F) -> Self {
@ -50,7 +50,7 @@ where
T: AsyncRead + AsyncWrite,
S: NewService<Request = Request, Response = Response> + Clone,
S::Service: Clone,
S::Error: Debug + Display,
S::Error: Debug,
{
type Request = T;
type Response = ();
@ -86,7 +86,7 @@ impl<T, S> H1ServiceBuilder<T, S>
where
S: NewService,
S::Service: Clone,
S::Error: Debug + Display,
S::Error: Debug,
{
/// Create instance of `ServiceConfigBuilder`
pub fn new() -> H1ServiceBuilder<T, S> {
@ -203,7 +203,7 @@ where
T: AsyncRead + AsyncWrite,
S: NewService<Request = Request, Response = Response>,
S::Service: Clone,
S::Error: Debug + Display,
S::Error: Debug,
{
type Item = H1ServiceHandler<T, S::Service>;
type Error = S::InitError;
@ -227,7 +227,7 @@ pub struct H1ServiceHandler<T, S> {
impl<T, S> H1ServiceHandler<T, S>
where
S: Service<Request = Request, Response = Response> + Clone,
S::Error: Debug + Display,
S::Error: Debug,
{
fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler<T, S> {
H1ServiceHandler {
@ -242,7 +242,7 @@ impl<T, S> Service for H1ServiceHandler<T, S>
where
T: AsyncRead + AsyncWrite,
S: Service<Request = Request, Response = Response> + Clone,
S::Error: Debug + Display,
S::Error: Debug,
{
type Request = T;
type Response = ();

View File

@ -522,18 +522,11 @@ where
#[cfg(test)]
mod tests {
use super::*;
use failure::Fail;
use futures::future::{lazy, result};
use std::io;
use tokio::runtime::current_thread::Runtime;
#[test]
fn test_error() {
let err: PayloadError =
io::Error::new(io::ErrorKind::Other, "ParseError").into();
assert_eq!(format!("{}", err), "ParseError");
assert_eq!(format!("{}", err.cause().unwrap()), "ParseError");
let err = PayloadError::Incomplete;
assert_eq!(
format!("{}", err),

View File

@ -11,7 +11,7 @@ use actix_net::server::Server;
use actix_web::{client, test, HttpMessage};
use futures::future;
use actix_http::{h1, Error, KeepAlive, Request, Response};
use actix_http::{h1, KeepAlive, Request, Response};
#[test]
fn test_h1_v2() {
@ -25,10 +25,11 @@ fn test_h1_v2() {
.client_disconnect(1000)
.server_hostname("localhost")
.server_address(addr)
.finish(|_| future::ok::<_, Error>(Response::Ok().finish()))
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
}).unwrap()
.run();
});
thread::sleep(time::Duration::from_millis(100));
let mut sys = System::new("test");
{
@ -48,7 +49,7 @@ fn test_slow_request() {
.bind("test", addr, move || {
h1::H1Service::build()
.client_timeout(100)
.finish(|_| future::ok::<_, Error>(Response::Ok().finish()))
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
}).unwrap()
.run();
});
@ -69,7 +70,7 @@ fn test_malformed_request() {
.bind("test", addr, move || {
h1::H1Service::build()
.client_timeout(100)
.finish(|_| future::ok::<_, Error>(Response::Ok().finish()))
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
}).unwrap()
.run();
});
@ -103,7 +104,7 @@ fn test_content_length() {
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, Error>(Response::new(statuses[indx]))
future::ok::<_, ()>(Response::new(statuses[indx]))
})
}).unwrap()
.run();