mirror of
https://github.com/actix/actix-extras.git
synced 2024-12-01 02:44:37 +01:00
refactor http/1 dispatcher
This commit is contained in:
parent
9c4a55c95c
commit
13193a0721
15
src/error.rs
15
src/error.rs
@ -341,15 +341,6 @@ pub enum PayloadError {
|
||||
/// A payload length is unknown.
|
||||
#[fail(display = "A payload length is unknown.")]
|
||||
UnknownLength,
|
||||
/// Io error
|
||||
#[fail(display = "{}", _0)]
|
||||
Io(#[cause] IoError),
|
||||
}
|
||||
|
||||
impl From<IoError> for PayloadError {
|
||||
fn from(err: IoError) -> PayloadError {
|
||||
PayloadError::Io(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// `PayloadError` returns two possible results:
|
||||
@ -374,7 +365,7 @@ impl ResponseError for cookie::ParseError {
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A set of errors that can occur during dispatching http requests
|
||||
pub enum DispatchError<E: fmt::Display + fmt::Debug> {
|
||||
pub enum DispatchError<E: fmt::Debug> {
|
||||
/// Service error
|
||||
// #[fail(display = "Application specific error: {}", _0)]
|
||||
Service(E),
|
||||
@ -413,13 +404,13 @@ pub enum DispatchError<E: fmt::Display + fmt::Debug> {
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl<E: fmt::Display + fmt::Debug> From<ParseError> for DispatchError<E> {
|
||||
impl<E: fmt::Debug> From<ParseError> for DispatchError<E> {
|
||||
fn from(err: ParseError) -> Self {
|
||||
DispatchError::Parse(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: fmt::Display + fmt::Debug> From<io::Error> for DispatchError<E> {
|
||||
impl<E: fmt::Debug> From<io::Error> for DispatchError<E> {
|
||||
fn from(err: io::Error) -> Self {
|
||||
DispatchError::Io(err)
|
||||
}
|
||||
|
@ -219,9 +219,7 @@ impl PayloadDecoder {
|
||||
}
|
||||
|
||||
pub fn eof() -> PayloadDecoder {
|
||||
PayloadDecoder {
|
||||
kind: Kind::Eof(false),
|
||||
}
|
||||
PayloadDecoder { kind: Kind::Eof }
|
||||
}
|
||||
}
|
||||
|
||||
@ -246,7 +244,7 @@ enum Kind {
|
||||
/// > the final encoding, the message body length cannot be determined
|
||||
/// > reliably; the server MUST respond with the 400 (Bad Request)
|
||||
/// > status code and then close the connection.
|
||||
Eof(bool),
|
||||
Eof,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
@ -309,13 +307,11 @@ impl Decoder for PayloadDecoder {
|
||||
}
|
||||
}
|
||||
}
|
||||
Kind::Eof(ref mut is_eof) => {
|
||||
if *is_eof {
|
||||
Ok(Some(PayloadItem::Eof))
|
||||
} else if !src.is_empty() {
|
||||
Ok(Some(PayloadItem::Chunk(src.take().freeze())))
|
||||
} else {
|
||||
Kind::Eof => {
|
||||
if src.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(PayloadItem::Chunk(src.take().freeze())))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::fmt::Debug;
|
||||
use std::time::Instant;
|
||||
|
||||
use actix_net::codec::Framed;
|
||||
@ -27,18 +27,17 @@ bitflags! {
|
||||
const STARTED = 0b0000_0001;
|
||||
const KEEPALIVE_ENABLED = 0b0000_0010;
|
||||
const KEEPALIVE = 0b0000_0100;
|
||||
const SHUTDOWN = 0b0000_1000;
|
||||
const READ_DISCONNECTED = 0b0001_0000;
|
||||
const WRITE_DISCONNECTED = 0b0010_0000;
|
||||
const POLLED = 0b0100_0000;
|
||||
const FLUSHED = 0b1000_0000;
|
||||
const POLLED = 0b0000_1000;
|
||||
const FLUSHED = 0b0001_0000;
|
||||
const SHUTDOWN = 0b0010_0000;
|
||||
const DISCONNECTED = 0b0100_0000;
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatcher for HTTP/1.1 protocol
|
||||
pub struct Dispatcher<T, S: Service>
|
||||
where
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
service: S,
|
||||
flags: Flags,
|
||||
@ -81,7 +80,7 @@ impl<T, S> Dispatcher<T, S>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
S: Service<Request = Request, Response = Response>,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
/// Create http/1 dispatcher.
|
||||
pub fn new(stream: T, config: ServiceConfig, service: S) -> Self {
|
||||
@ -122,9 +121,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_read(&self) -> bool {
|
||||
if self.flags.contains(Flags::READ_DISCONNECTED) {
|
||||
if self.flags.contains(Flags::DISCONNECTED) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -137,7 +135,7 @@ where
|
||||
|
||||
// if checked is set to true, delay disconnect until all tasks have finished.
|
||||
fn client_disconnected(&mut self) {
|
||||
self.flags.insert(Flags::READ_DISCONNECTED);
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
if let Some(mut payload) = self.payload.take() {
|
||||
payload.set_error(PayloadError::Incomplete);
|
||||
}
|
||||
@ -145,12 +143,11 @@ where
|
||||
|
||||
/// Flush stream
|
||||
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
|
||||
if self.flags.contains(Flags::STARTED) && !self.flags.contains(Flags::FLUSHED) {
|
||||
if !self.flags.contains(Flags::FLUSHED) {
|
||||
match self.framed.poll_complete() {
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(err) => {
|
||||
debug!("Error sending data: {}", err);
|
||||
self.client_disconnected();
|
||||
Err(err.into())
|
||||
}
|
||||
Ok(Async::Ready(_)) => {
|
||||
@ -167,8 +164,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub(self) fn poll_handler(&mut self) -> Result<(), DispatchError<S::Error>> {
|
||||
self.poll_io()?;
|
||||
fn poll_response(&mut self) -> Result<(), DispatchError<S::Error>> {
|
||||
let mut retry = self.can_read();
|
||||
|
||||
// process
|
||||
@ -221,7 +217,6 @@ where
|
||||
return Ok(());
|
||||
}
|
||||
Err(err) => {
|
||||
self.flags.insert(Flags::READ_DISCONNECTED);
|
||||
if let Some(mut payload) = self.payload.take() {
|
||||
payload.set_error(PayloadError::Incomplete);
|
||||
}
|
||||
@ -246,7 +241,6 @@ where
|
||||
return Ok(());
|
||||
}
|
||||
Err(err) => {
|
||||
self.flags.insert(Flags::READ_DISCONNECTED);
|
||||
if let Some(mut payload) = self.payload.take() {
|
||||
payload.set_error(PayloadError::Incomplete);
|
||||
}
|
||||
@ -261,7 +255,7 @@ where
|
||||
None => {
|
||||
// if read-backpressure is enabled and we consumed some data.
|
||||
// we may read more dataand retry
|
||||
if !retry && self.can_read() && self.poll_io()? {
|
||||
if !retry && self.can_read() && self.poll_request()? {
|
||||
retry = self.can_read();
|
||||
continue;
|
||||
}
|
||||
@ -319,7 +313,7 @@ where
|
||||
payload.feed_data(chunk);
|
||||
} else {
|
||||
error!("Internal server error: unexpected payload chunk");
|
||||
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
self.messages.push_back(Message::Error(
|
||||
Response::InternalServerError().finish(),
|
||||
));
|
||||
@ -331,7 +325,7 @@ where
|
||||
payload.feed_eof();
|
||||
} else {
|
||||
error!("Internal server error: unexpected eof");
|
||||
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
self.messages.push_back(Message::Error(
|
||||
Response::InternalServerError().finish(),
|
||||
));
|
||||
@ -343,7 +337,7 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(self) fn poll_io(&mut self) -> Result<bool, DispatchError<S::Error>> {
|
||||
pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError<S::Error>> {
|
||||
let mut updated = false;
|
||||
|
||||
if self.messages.len() < MAX_PIPELINED_MESSAGES {
|
||||
@ -354,26 +348,25 @@ where
|
||||
self.one_message(msg)?;
|
||||
}
|
||||
Ok(Async::Ready(None)) => {
|
||||
if self.flags.contains(Flags::READ_DISCONNECTED) {
|
||||
self.client_disconnected();
|
||||
}
|
||||
self.client_disconnected();
|
||||
break;
|
||||
}
|
||||
Ok(Async::NotReady) => break,
|
||||
Err(ParseError::Io(e)) => {
|
||||
self.client_disconnected();
|
||||
self.error = Some(DispatchError::Io(e));
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(mut payload) = self.payload.take() {
|
||||
let e = match e {
|
||||
ParseError::Io(e) => PayloadError::Io(e),
|
||||
_ => PayloadError::EncodingCorrupted,
|
||||
};
|
||||
payload.set_error(e);
|
||||
payload.set_error(PayloadError::EncodingCorrupted);
|
||||
}
|
||||
|
||||
// Malformed requests should be responded with 400
|
||||
self.messages
|
||||
.push_back(Message::Error(Response::BadRequest().finish()));
|
||||
self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED);
|
||||
self.error = Some(DispatchError::MalformedRequest);
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
self.error = Some(e.into());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -402,8 +395,7 @@ where
|
||||
} else if !self.flags.contains(Flags::STARTED) {
|
||||
// timeout on first request (slow request) return 408
|
||||
trace!("Slow request timeout");
|
||||
self.flags
|
||||
.insert(Flags::STARTED | Flags::READ_DISCONNECTED);
|
||||
self.flags.insert(Flags::STARTED | Flags::DISCONNECTED);
|
||||
self.state =
|
||||
State::SendResponse(Some(OutMessage::Response(
|
||||
Response::RequestTimeout().finish(),
|
||||
@ -444,54 +436,43 @@ impl<T, S> Future for Dispatcher<T, S>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
S: Service<Request = Request, Response = Response>,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
type Item = ();
|
||||
type Error = DispatchError<S::Error>;
|
||||
|
||||
#[inline]
|
||||
fn poll(&mut self) -> Poll<(), Self::Error> {
|
||||
self.poll_keepalive()?;
|
||||
|
||||
// shutdown
|
||||
if self.flags.contains(Flags::SHUTDOWN) {
|
||||
if self.flags.contains(Flags::WRITE_DISCONNECTED) {
|
||||
return Ok(Async::Ready(()));
|
||||
}
|
||||
self.poll_keepalive()?;
|
||||
try_ready!(self.poll_flush());
|
||||
return Ok(AsyncWrite::shutdown(self.framed.get_mut())?);
|
||||
}
|
||||
|
||||
// process incoming requests
|
||||
if !self.flags.contains(Flags::WRITE_DISCONNECTED) {
|
||||
self.poll_handler()?;
|
||||
|
||||
// flush stream
|
||||
Ok(AsyncWrite::shutdown(self.framed.get_mut())?)
|
||||
} else {
|
||||
self.poll_keepalive()?;
|
||||
self.poll_request()?;
|
||||
self.poll_response()?;
|
||||
self.poll_flush()?;
|
||||
|
||||
// deal with keep-alive and stream eof (client-side write shutdown)
|
||||
// keep-alive and stream errors
|
||||
if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) {
|
||||
// handle stream eof
|
||||
if self
|
||||
.flags
|
||||
.intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED)
|
||||
{
|
||||
return Ok(Async::Ready(()));
|
||||
if let Some(err) = self.error.take() {
|
||||
Err(err)
|
||||
} else if self.flags.contains(Flags::DISCONNECTED) {
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
// no keep-alive
|
||||
if self.flags.contains(Flags::STARTED)
|
||||
&& (!self.flags.contains(Flags::KEEPALIVE_ENABLED)
|
||||
|| !self.flags.contains(Flags::KEEPALIVE))
|
||||
// disconnect if keep-alive is not enabled
|
||||
else if self.flags.contains(Flags::STARTED) && !self
|
||||
.flags
|
||||
.intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED)
|
||||
{
|
||||
self.flags.insert(Flags::SHUTDOWN);
|
||||
return self.poll();
|
||||
self.poll()
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
Ok(Async::NotReady)
|
||||
} else if let Some(err) = self.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::net;
|
||||
|
||||
@ -26,7 +26,7 @@ impl<T, S> H1Service<T, S>
|
||||
where
|
||||
S: NewService,
|
||||
S::Service: Clone,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
/// Create new `HttpService` instance.
|
||||
pub fn new<F: IntoNewService<S>>(service: F) -> Self {
|
||||
@ -50,7 +50,7 @@ where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
S: NewService<Request = Request, Response = Response> + Clone,
|
||||
S::Service: Clone,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
type Request = T;
|
||||
type Response = ();
|
||||
@ -86,7 +86,7 @@ impl<T, S> H1ServiceBuilder<T, S>
|
||||
where
|
||||
S: NewService,
|
||||
S::Service: Clone,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
/// Create instance of `ServiceConfigBuilder`
|
||||
pub fn new() -> H1ServiceBuilder<T, S> {
|
||||
@ -203,7 +203,7 @@ where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
S: NewService<Request = Request, Response = Response>,
|
||||
S::Service: Clone,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
type Item = H1ServiceHandler<T, S::Service>;
|
||||
type Error = S::InitError;
|
||||
@ -227,7 +227,7 @@ pub struct H1ServiceHandler<T, S> {
|
||||
impl<T, S> H1ServiceHandler<T, S>
|
||||
where
|
||||
S: Service<Request = Request, Response = Response> + Clone,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler<T, S> {
|
||||
H1ServiceHandler {
|
||||
@ -242,7 +242,7 @@ impl<T, S> Service for H1ServiceHandler<T, S>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
S: Service<Request = Request, Response = Response> + Clone,
|
||||
S::Error: Debug + Display,
|
||||
S::Error: Debug,
|
||||
{
|
||||
type Request = T;
|
||||
type Response = ();
|
||||
|
@ -522,18 +522,11 @@ where
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use failure::Fail;
|
||||
use futures::future::{lazy, result};
|
||||
use std::io;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
|
||||
#[test]
|
||||
fn test_error() {
|
||||
let err: PayloadError =
|
||||
io::Error::new(io::ErrorKind::Other, "ParseError").into();
|
||||
assert_eq!(format!("{}", err), "ParseError");
|
||||
assert_eq!(format!("{}", err.cause().unwrap()), "ParseError");
|
||||
|
||||
let err = PayloadError::Incomplete;
|
||||
assert_eq!(
|
||||
format!("{}", err),
|
||||
|
@ -11,7 +11,7 @@ use actix_net::server::Server;
|
||||
use actix_web::{client, test, HttpMessage};
|
||||
use futures::future;
|
||||
|
||||
use actix_http::{h1, Error, KeepAlive, Request, Response};
|
||||
use actix_http::{h1, KeepAlive, Request, Response};
|
||||
|
||||
#[test]
|
||||
fn test_h1_v2() {
|
||||
@ -25,10 +25,11 @@ fn test_h1_v2() {
|
||||
.client_disconnect(1000)
|
||||
.server_hostname("localhost")
|
||||
.server_address(addr)
|
||||
.finish(|_| future::ok::<_, Error>(Response::Ok().finish()))
|
||||
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
thread::sleep(time::Duration::from_millis(100));
|
||||
|
||||
let mut sys = System::new("test");
|
||||
{
|
||||
@ -48,7 +49,7 @@ fn test_slow_request() {
|
||||
.bind("test", addr, move || {
|
||||
h1::H1Service::build()
|
||||
.client_timeout(100)
|
||||
.finish(|_| future::ok::<_, Error>(Response::Ok().finish()))
|
||||
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
@ -69,7 +70,7 @@ fn test_malformed_request() {
|
||||
.bind("test", addr, move || {
|
||||
h1::H1Service::build()
|
||||
.client_timeout(100)
|
||||
.finish(|_| future::ok::<_, Error>(Response::Ok().finish()))
|
||||
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
@ -103,7 +104,7 @@ fn test_content_length() {
|
||||
StatusCode::OK,
|
||||
StatusCode::NOT_FOUND,
|
||||
];
|
||||
future::ok::<_, Error>(Response::new(statuses[indx]))
|
||||
future::ok::<_, ()>(Response::new(statuses[indx]))
|
||||
})
|
||||
}).unwrap()
|
||||
.run();
|
||||
|
Loading…
Reference in New Issue
Block a user