From 3516f02e4fdb4f91f100d2d3e3e22f6485937e00 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 13 Oct 2017 16:33:23 -0700 Subject: [PATCH] keep-alive support --- README.md | 3 +- src/context.rs | 6 +-- src/error.rs | 41 +++++++++---------- src/httpmessage.rs | 4 +- src/lib.rs | 2 +- src/main.rs | 3 ++ src/payload.rs | 45 ++++++++++++++++++--- src/reader.rs | 70 +++++++++++++++++++++++++------- src/server.rs | 99 +++++++++++++++++++++++++++++++++++++++------- src/task.rs | 26 +++++++++--- src/ws.rs | 53 +++++++++++++++++-------- 11 files changed, 264 insertions(+), 88 deletions(-) diff --git a/README.md b/README.md index f9d43475..022bd00e 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Actix http is licensed under the [Apache-2.0 license](http://opensource.org/lice * HTTP 1.1 and 1.0 support * Streaming and pipelining support + * Keep-alive and slow requests support * [WebSockets support](https://fafhrd91.github.io/actix-http/actix_http/ws/index.html) * [Configurable request routing](https://fafhrd91.github.io/actix-http/actix_http/struct.RoutingMap.html) @@ -50,7 +51,7 @@ impl Route for MyRoute { fn request(req: HttpRequest, payload: Payload, ctx: &mut HttpContext) -> Reply { - Reply::reply(req, httpcodes::HTTPOk) + Reply::reply(httpcodes::HTTPOk) } } diff --git a/src/context.rs b/src/context.rs index 7a82bf84..46a68fd6 100644 --- a/src/context.rs +++ b/src/context.rs @@ -33,6 +33,7 @@ impl ActorContext for HttpContext where A: Actor + Route if self.state == ActorState::Running { self.state = ActorState::Stopping; } + self.write_eof(); } /// Terminate actor execution @@ -151,9 +152,8 @@ impl Stream for HttpContext where A: Actor + Route if self.wait.is_some() && self.act.is_some() { if let Some(ref mut act) = self.act { if let Some(ref mut fut) = self.wait { - match fut.poll(act, ctx) { - Ok(Async::NotReady) => return Ok(Async::NotReady), - _ => (), + if let Ok(Async::NotReady) = fut.poll(act, ctx) { + return Ok(Async::NotReady); } } } diff --git a/src/error.rs b/src/error.rs index 57badbf2..145bd34d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,7 +11,7 @@ use http::{StatusCode, Error as HttpError}; use httpmessage::{Body, HttpResponse}; -/// A set of errors that can occur parsing HTTP streams. +/// A set of errors that can occur during parsing HTTP streams. #[derive(Debug)] pub enum ParseError { /// An invalid `Method`, such as `GE,T`. @@ -31,8 +31,6 @@ pub enum ParseError { /// A timeout occurred waiting for an IO event. #[allow(dead_code)] Timeout, - /// Unexpected EOF during parsing - Eof, /// An `io::Error` that occurred while trying to read or write to a network stream. Io(IoError), /// Parsing a field as string failed @@ -60,7 +58,6 @@ impl StdError for ParseError { ParseError::Incomplete => "Message is incomplete", ParseError::Timeout => "Timeout", ParseError::Uri => "Uri error", - ParseError::Eof => "Unexpected eof during parse", ParseError::Io(ref e) => e.description(), ParseError::Utf8(ref e) => e.description(), } @@ -107,7 +104,7 @@ impl From for ParseError { } } -/// Return BadRequest for ParseError +/// Return `BadRequest` for `ParseError` impl From for HttpResponse { fn from(err: ParseError) -> Self { HttpResponse::new(StatusCode::BAD_REQUEST, @@ -115,8 +112,8 @@ impl From for HttpResponse { } } -/// Return InternalServerError for HttpError, -/// Response generation can return HttpError, so it is internal error +/// Return `InternalServerError` for `HttpError`, +/// Response generation can return `HttpError`, so it is internal error impl From for HttpResponse { fn from(err: HttpError) -> Self { HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR, @@ -124,7 +121,7 @@ impl From for HttpResponse { } } -/// Return BadRequest for cookie::ParseError +/// Return `BadRequest` for `cookie::ParseError` impl From for HttpResponse { fn from(err: cookie::ParseError) -> Self { HttpResponse::new(StatusCode::BAD_REQUEST, @@ -137,20 +134,19 @@ mod tests { use std::error::Error as StdError; use std::io; use httparse; - use super::Error; - use super::Error::*; + use super::ParseError; #[test] fn test_cause() { let orig = io::Error::new(io::ErrorKind::Other, "other"); let desc = orig.description().to_owned(); - let e = Io(orig); + let e = ParseError::Io(orig); assert_eq!(e.cause().unwrap().description(), desc); } macro_rules! from { ($from:expr => $error:pat) => { - match Error::from($from) { + match ParseError::from($from) { e @ $error => { assert!(e.description().len() >= 5); } , @@ -161,7 +157,7 @@ mod tests { macro_rules! from_and_cause { ($from:expr => $error:pat) => { - match Error::from($from) { + match ParseError::from($from) { e @ $error => { let desc = e.cause().unwrap().description(); assert_eq!(desc, $from.description().to_owned()); @@ -174,16 +170,15 @@ mod tests { #[test] fn test_from() { + from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => ParseError::Io(..)); - from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => Io(..)); - - from!(httparse::Error::HeaderName => Header); - from!(httparse::Error::HeaderName => Header); - from!(httparse::Error::HeaderValue => Header); - from!(httparse::Error::NewLine => Header); - from!(httparse::Error::Status => Status); - from!(httparse::Error::Token => Header); - from!(httparse::Error::TooManyHeaders => TooLarge); - from!(httparse::Error::Version => Version); + from!(httparse::Error::HeaderName => ParseError::Header); + from!(httparse::Error::HeaderName => ParseError::Header); + from!(httparse::Error::HeaderValue => ParseError::Header); + from!(httparse::Error::NewLine => ParseError::Header); + from!(httparse::Error::Status => ParseError::Status); + from!(httparse::Error::Token => ParseError::Header); + from!(httparse::Error::TooManyHeaders => ParseError::TooLarge); + from!(httparse::Error::Version => ParseError::Version); } } diff --git a/src/httpmessage.rs b/src/httpmessage.rs index ad2a09ac..5b76eb02 100644 --- a/src/httpmessage.rs +++ b/src/httpmessage.rs @@ -82,8 +82,8 @@ impl HttpRequest { pub fn cookie(&self) -> Result, cookie::ParseError> { if let Some(val) = self.headers.get(header::COOKIE) { let s = str::from_utf8(val.as_bytes()) - .map_err(|e| cookie::ParseError::from(e))?; - cookie::Cookie::parse(s).map(|c| Some(c)) + .map_err(cookie::ParseError::from)?; + cookie::Cookie::parse(s).map(Some) } else { Ok(None) } diff --git a/src/lib.rs b/src/lib.rs index 47d12b7a..956346c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,7 @@ pub mod dev; pub mod httpcodes; pub use application::Application; pub use httpmessage::{Body, Builder, HttpRequest, HttpResponse}; -pub use payload::{Payload, PayloadItem}; +pub use payload::{Payload, PayloadItem, PayloadError}; pub use router::RoutingMap; pub use resource::{Reply, Resource}; pub use route::{Route, RouteFactory, RouteHandler}; diff --git a/src/main.rs b/src/main.rs index 32153b8e..538961b6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -84,6 +84,9 @@ impl Handler for MyWS { ws::Message::Ping(msg) => ws::WsWriter::pong(ctx, msg), ws::Message::Text(text) => ws::WsWriter::text(ctx, text), ws::Message::Binary(bin) => ws::WsWriter::binary(ctx, bin), + ws::Message::Closed | ws::Message::Error => { + ctx.stop(); + } _ => (), } Self::empty() diff --git a/src/payload.rs b/src/payload.rs index d4de1859..3f56da10 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -1,15 +1,31 @@ use std::rc::{Rc, Weak}; use std::cell::RefCell; +use std::convert::From; use std::collections::VecDeque; +use std::io::Error as IoError; use bytes::Bytes; use futures::{Async, Poll, Stream}; use futures::task::{Task, current as current_task}; -/// Just Bytes object -pub type PayloadItem = Bytes; - const MAX_PAYLOAD_SIZE: usize = 65_536; // max buffer size 64k +/// Just Bytes object +pub type PayloadItem = Result; + +#[derive(Debug)] +/// A set of error that can occur during payload parsing. +pub enum PayloadError { + /// A payload reached EOF, but is not complete. + Incomplete, + /// Parse error + ParseError(IoError), +} + +impl From for PayloadError { + fn from(err: IoError) -> PayloadError { + PayloadError::ParseError(err) + } +} /// Stream of byte chunks /// @@ -55,7 +71,7 @@ impl Payload { } /// Put unused data back to payload - pub fn unread_data(&mut self, data: PayloadItem) { + pub fn unread_data(&mut self, data: Bytes) { self.inner.borrow_mut().unread_data(data); } } @@ -75,6 +91,12 @@ pub(crate) struct PayloadSender { } impl PayloadSender { + pub(crate) fn set_error(&mut self, err: PayloadError) { + if let Some(shared) = self.inner.upgrade() { + shared.borrow_mut().set_error(err) + } + } + pub(crate) fn feed_eof(&mut self) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().feed_eof() @@ -112,6 +134,7 @@ struct Inner { len: usize, eof: bool, paused: bool, + err: Option, task: Option, items: VecDeque, } @@ -123,6 +146,7 @@ impl Inner { len: 0, eof: eof, paused: false, + err: None, task: None, items: VecDeque::new(), } @@ -140,6 +164,13 @@ impl Inner { self.paused = false; } + fn set_error(&mut self, err: PayloadError) { + self.err = Some(err); + if let Some(task) = self.task.take() { + task.notify() + } + } + fn feed_eof(&mut self) { self.eof = true; if let Some(task) = self.task.take() { @@ -163,12 +194,14 @@ impl Inner { self.len } - fn readany(&mut self) -> Async> { + fn readany(&mut self) -> Async> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); - Async::Ready(Some(data)) + Async::Ready(Some(Ok(data))) } else if self.eof { Async::Ready(None) + } else if let Some(err) = self.err.take() { + Async::Ready(Some(Err(err))) } else { self.task = Some(current_task()); Async::NotReady diff --git a/src/reader.rs b/src/reader.rs index ea06b3ba..50ef8d0f 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -10,7 +10,7 @@ use tokio_io::AsyncRead; use error::ParseError; use decode::Decoder; use httpmessage::HttpRequest; -use payload::{Payload, PayloadSender}; +use payload::{Payload, PayloadError, PayloadSender}; const MAX_HEADERS: usize = 100; const INIT_BUFFER_SIZE: usize = 8192; @@ -21,7 +21,7 @@ struct PayloadInfo { decoder: Decoder, } -pub struct Reader { +pub(crate) struct Reader { read_buf: BytesMut, payload: Option, } @@ -32,6 +32,11 @@ enum Decoding { NotReady, } +pub(crate) enum ReaderError { + Payload, + Error(ParseError), +} + impl Reader { pub fn new() -> Reader { Reader { @@ -53,7 +58,7 @@ impl Reader { } } - fn decode(&mut self) -> std::result::Result + fn decode(&mut self) -> std::result::Result { if let Some(ref mut payload) = self.payload { if payload.tx.maybe_paused() { @@ -69,7 +74,10 @@ impl Reader { return Ok(Decoding::Ready) }, Ok(Async::NotReady) => return Ok(Decoding::NotReady), - Err(_) => return Err(ParseError::Incomplete), + Err(err) => { + payload.tx.set_error(err.into()); + return Err(ReaderError::Payload) + } } } } else { @@ -77,7 +85,7 @@ impl Reader { } } - pub fn parse(&mut self, io: &mut T) -> Poll<(HttpRequest, Payload), ParseError> + pub fn parse(&mut self, io: &mut T) -> Poll<(HttpRequest, Payload), ReaderError> where T: AsyncRead { loop { @@ -88,15 +96,32 @@ impl Reader { break }, Decoding::NotReady => { - if 0 == try_ready!(self.read_from_io(io)) { - return Err(ParseError::Eof) + match self.read_from_io(io) { + Ok(Async::Ready(0)) => { + if let Some(ref mut payload) = self.payload { + payload.tx.set_error(PayloadError::Incomplete); + } + // http channel should deal with payload errors + return Err(ReaderError::Payload) + } + Ok(Async::Ready(_)) => { + continue + } + Ok(Async::NotReady) => break, + Err(err) => { + if let Some(ref mut payload) = self.payload { + payload.tx.set_error(err.into()); + } + // http channel should deal with payload errors + return Err(ReaderError::Payload) + } } } } } loop { - match try!(parse(&mut self.read_buf)) { + match try!(parse(&mut self.read_buf).map_err(ReaderError::Error)) { Some((msg, decoder)) => { let payload = if let Some(decoder) = decoder { let (tx, rx) = Payload::new(false); @@ -118,13 +143,23 @@ impl Reader { match self.read_from_io(io) { Ok(Async::Ready(0)) => { trace!("parse eof"); - return Err(ParseError::Eof); + if let Some(ref mut payload) = self.payload { + payload.tx.set_error(PayloadError::Incomplete); + } + // http channel should deal with payload errors + return Err(ReaderError::Payload) } Ok(Async::Ready(_)) => { continue } Ok(Async::NotReady) => break, - Err(err) => return Err(err.into()), + Err(err) => { + if let Some(ref mut payload) = self.payload { + payload.tx.set_error(err.into()); + } + // http channel should deal with payload errors + return Err(ReaderError::Payload) + } } } } @@ -139,13 +174,20 @@ impl Reader { None => { if self.read_buf.capacity() >= MAX_BUFFER_SIZE { debug!("MAX_BUFFER_SIZE reached, closing"); - return Err(ParseError::TooLarge); + return Err(ReaderError::Error(ParseError::TooLarge)); } }, } - if 0 == try_ready!(self.read_from_io(io)) { - trace!("parse eof"); - return Err(ParseError::Eof); + match self.read_from_io(io) { + Ok(Async::Ready(0)) => { + trace!("parse eof"); + return Err(ReaderError::Error(ParseError::Incomplete)); + }, + Ok(Async::Ready(_)) => (), + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(err) => + return Err(ReaderError::Error(err.into())) } } } diff --git a/src/server.rs b/src/server.rs index b2243488..c0ec2f3e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,13 +1,15 @@ use std::{io, net, mem}; use std::rc::Rc; +use std::time::Duration; use std::collections::VecDeque; use actix::dev::*; use futures::{Future, Poll, Async}; +use tokio_core::reactor::Timeout; use tokio_core::net::{TcpListener, TcpStream}; use task::{Task, RequestInfo}; -use reader::Reader; +use reader::{Reader, ReaderError}; use router::{Router, RoutingMap}; /// An HTTP Server @@ -58,6 +60,8 @@ impl Handler<(TcpStream, net::SocketAddr), io::Error> for HttpServer { error: false, items: VecDeque::new(), inactive: Vec::new(), + keepalive: true, + keepalive_timer: None, }); Self::empty() } @@ -72,6 +76,9 @@ struct Entry { finished: bool, } +const KEEPALIVE_PERIOD: u64 = 15; // seconds +const MAX_PIPELINED_MESSAGES: usize = 16; + pub struct HttpChannel { router: Rc, #[allow(dead_code)] @@ -81,6 +88,14 @@ pub struct HttpChannel { error: bool, items: VecDeque, inactive: Vec, + keepalive: bool, + keepalive_timer: Option, +} + +impl Drop for HttpChannel { + fn drop(&mut self) { + println!("Drop http channel"); + } } impl Actor for HttpChannel { @@ -92,6 +107,16 @@ impl Future for HttpChannel { type Error = (); fn poll(&mut self) -> Poll { + // keep-alive timer + if let Some(ref mut timeout) = self.keepalive_timer { + match timeout.poll() { + Ok(Async::Ready(_)) => + return Ok(Async::Ready(())), + Ok(Async::NotReady) => (), + Err(_) => unreachable!(), + } + } + loop { // check in-flight messages let mut idx = 0; @@ -109,10 +134,20 @@ impl Future for HttpChannel { { Ok(Async::Ready(val)) => { let mut item = self.items.pop_front().unwrap(); + + // overide keep-alive state + if self.keepalive { + self.keepalive = item.task.keepalive(); + } if !val { item.eof = true; self.inactive.push(item); } + + // no keep-alive + if !self.keepalive && self.items.is_empty() { + return Ok(Async::Ready(())) + } continue }, Ok(Async::NotReady) => (), @@ -134,15 +169,14 @@ impl Future for HttpChannel { idx += 1; } - // check for parse error - if self.items.is_empty() && self.error { - - } - // read incoming data - if !self.error { + if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES { match self.reader.parse(&mut self.stream) { Ok(Async::Ready((req, payload))) => { + // stop keepalive timer + self.keepalive_timer.take(); + + // start request processing let info = RequestInfo::new(&req); self.items.push_back( Entry {task: self.router.call(req, payload), @@ -151,16 +185,51 @@ impl Future for HttpChannel { error: false, finished: false}); } - Ok(Async::NotReady) => - return Ok(Async::NotReady), - Err(err) => return Err(()) - //self.items.push_back( - // Entry {task: Task::reply(err), - // eof: false, - // error: false, - // finished: false}) + Err(err) => { + // kill keepalive + self.keepalive = false; + self.keepalive_timer.take(); + + // on parse error, stop reading stream but + // complete tasks + self.error = true; + + if let ReaderError::Error(err) = err { + self.items.push_back( + Entry {task: Task::reply(err), + req: RequestInfo::for_error(), + eof: false, + error: false, + finished: false}); + } + } + Ok(Async::NotReady) => { + // start keep-alive timer, this is also slow request timeout + if self.items.is_empty() { + if self.keepalive { + if self.keepalive_timer.is_none() { + trace!("Start keep-alive timer"); + let mut timeout = Timeout::new( + Duration::new(KEEPALIVE_PERIOD, 0), + Arbiter::handle()).unwrap(); + // register timeout + let _ = timeout.poll(); + self.keepalive_timer = Some(timeout); + } + } else { + // keep-alive disable, drop connection + return Ok(Async::Ready(())) + } + } + return Ok(Async::NotReady) + } } } + + // check for parse error + if self.items.is_empty() && self.error { + return Ok(Async::Ready(())) + } } } } diff --git a/src/task.rs b/src/task.rs index 846f01da..d2291428 100644 --- a/src/task.rs +++ b/src/task.rs @@ -56,6 +56,12 @@ impl RequestInfo { keep_alive: req.keep_alive(), } } + pub fn for_error() -> Self { + RequestInfo { + version: Version::HTTP_11, + keep_alive: false, + } + } } pub struct Task { @@ -65,7 +71,8 @@ pub struct Task { stream: Option>, encoder: Encoder, buffer: BytesMut, - upgraded: bool, + upgrade: bool, + keepalive: bool, } impl Task { @@ -82,7 +89,8 @@ impl Task { stream: None, encoder: Encoder::length(0), buffer: BytesMut::new(), - upgraded: false, + upgrade: false, + keepalive: false, } } @@ -96,10 +104,15 @@ impl Task { stream: Some(Box::new(stream)), encoder: Encoder::length(0), buffer: BytesMut::new(), - upgraded: false, + upgrade: false, + keepalive: false, } } + pub(crate) fn keepalive(&self) -> bool { + self.keepalive && !self.upgrade + } + fn prepare(&mut self, req: &RequestInfo, mut msg: HttpResponse) { trace!("Prepare message status={:?}", msg.status); @@ -107,6 +120,7 @@ impl Task { let mut extra = 0; let body = msg.replace_body(Body::Empty); let version = msg.version().unwrap_or_else(|| req.version); + self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive); match body { Body::Empty => { @@ -158,7 +172,7 @@ impl Task { msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); } // keep-alive - else if msg.keep_alive().unwrap_or_else(|| req.keep_alive) { + else if self.keepalive { if version < Version::HTTP_11 { msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive")); } @@ -296,8 +310,8 @@ impl Future for Task { error!("Non expected frame {:?}", frame); return Err(()) } - self.upgraded = msg.upgrade(); - if self.upgraded || msg.body().has_body() { + self.upgrade = msg.upgrade(); + if self.upgrade || msg.body().has_body() { self.iostate = TaskIOState::ReadingPayload; } else { self.iostate = TaskIOState::Done; diff --git a/src/ws.rs b/src/ws.rs index f125865e..5698e496 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -27,14 +27,14 @@ //! match ws::handshake(&req) { //! Ok(resp) => { //! // Send handshake response to peer -//! ctx.start(req, resp); +//! ctx.start(resp); //! // Map Payload into WsStream //! ctx.add_stream(ws::WsStream::new(payload)); //! // Start ws messages processing //! Reply::stream(WsRoute) //! }, //! Err(err) => -//! Reply::reply(req, err) +//! Reply::reply(err) //! } //! } //! } @@ -172,11 +172,13 @@ pub fn handshake(req: &HttpRequest) -> Result { pub struct WsStream { rx: Payload, buf: BytesMut, + closed: bool, + error_sent: bool, } impl WsStream { pub fn new(rx: Payload) -> WsStream { - WsStream { rx: rx, buf: BytesMut::new() } + WsStream { rx: rx, buf: BytesMut::new(), closed: false, error_sent: false } } } @@ -187,15 +189,20 @@ impl Stream for WsStream { fn poll(&mut self) -> Poll, Self::Error> { let mut done = false; - loop { - match self.rx.readany() { - Async::Ready(Some(chunk)) => { - self.buf.extend(chunk) + if !self.closed { + loop { + match self.rx.readany() { + Async::Ready(Some(Ok(chunk))) => { + self.buf.extend(chunk) + } + Async::Ready(Some(Err(_))) => { + self.closed = true; + } + Async::Ready(None) => { + done = true; + } + Async::NotReady => break, } - Async::Ready(None) => { - done = true; - } - Async::NotReady => break, } } @@ -229,13 +236,25 @@ impl Stream for WsStream { } } } - Ok(None) => if done { - return Ok(Async::Ready(None)) - } else { - return Ok(Async::NotReady) + Ok(None) => { + if done { + return Ok(Async::Ready(None)) + } else if self.closed { + if !self.error_sent { + self.error_sent = true; + return Ok(Async::Ready(Some(Message::Closed))) + } else { + return Ok(Async::Ready(None)) + } + } else { + return Ok(Async::NotReady) + } }, - Err(_) => - return Err(()), + Err(_) => { + self.closed = true; + self.error_sent = true; + return Ok(Async::Ready(Some(Message::Error))); + } } } }