use std::collections::VecDeque; use std::fmt::Debug; use std::time::Instant; use actix_net::codec::Framed; use actix_net::service::Service; use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::Delay; use error::{ParseError, PayloadError}; use payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; use body::{Body, BodyStream}; use config::ServiceConfig; use error::DispatchError; use request::Request; use response::Response; use super::codec::Codec; use super::{H1ServiceResult, Message, MessageType}; const MAX_PIPELINED_MESSAGES: usize = 16; bitflags! { pub struct Flags: u8 { const STARTED = 0b0000_0001; const KEEPALIVE_ENABLED = 0b0000_0010; const KEEPALIVE = 0b0000_0100; const POLLED = 0b0000_1000; const FLUSHED = 0b0001_0000; const SHUTDOWN = 0b0010_0000; const DISCONNECTED = 0b0100_0000; } } /// Dispatcher for HTTP/1.1 protocol pub struct Dispatcher where S::Error: Debug, { service: S, flags: Flags, framed: Option>, error: Option>, config: ServiceConfig, state: State, payload: Option, messages: VecDeque, unhandled: Option, ka_expire: Instant, ka_timer: Option, } enum DispatcherMessage { Item(Request), Error(Response), } enum State { None, ServiceCall(S::Future), SendResponse(Option<(Message, Body)>), SendPayload(BodyStream), } impl State { fn is_empty(&self) -> bool { if let State::None = self { true } else { false } } } impl Dispatcher where T: AsyncRead + AsyncWrite, S: Service, S::Error: Debug, { /// Create http/1 dispatcher. pub fn new(stream: T, config: ServiceConfig, service: S) -> Self { Dispatcher::with_timeout(stream, config, None, service) } /// Create http/1 dispatcher with slow request timeout. pub fn with_timeout( stream: T, config: ServiceConfig, timeout: Option, service: S, ) -> Self { let keepalive = config.keep_alive_enabled(); let flags = if keepalive { Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED } else { Flags::FLUSHED }; let framed = Framed::new(stream, Codec::new(config.clone())); // keep-alive timer let (ka_expire, ka_timer) = if let Some(delay) = timeout { (delay.deadline(), Some(delay)) } else if let Some(delay) = config.keep_alive_timer() { (delay.deadline(), Some(delay)) } else { (config.now(), None) }; Dispatcher { payload: None, state: State::None, error: None, messages: VecDeque::new(), framed: Some(framed), unhandled: None, service, flags, config, ka_expire, ka_timer, } } fn can_read(&self) -> bool { if self.flags.contains(Flags::DISCONNECTED) { return false; } if let Some(ref info) = self.payload { info.need_read() == PayloadStatus::Read } else { true } } // if checked is set to true, delay disconnect until all tasks have finished. fn client_disconnected(&mut self) { self.flags.insert(Flags::DISCONNECTED); if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::Incomplete(None)); } } /// Flush stream fn poll_flush(&mut self) -> Poll<(), DispatchError> { if !self.flags.contains(Flags::FLUSHED) { match self.framed.as_mut().unwrap().poll_complete() { Ok(Async::NotReady) => Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); Err(err.into()) } Ok(Async::Ready(_)) => { // if payload is not consumed we can not use connection if self.payload.is_some() && self.state.is_empty() { return Err(DispatchError::PayloadIsNotConsumed); } self.flags.insert(Flags::FLUSHED); Ok(Async::Ready(())) } } } else { Ok(Async::Ready(())) } } fn poll_response(&mut self) -> Result<(), DispatchError> { let mut retry = self.can_read(); // process loop { let state = match self.state { State::None => if let Some(msg) = self.messages.pop_front() { match msg { DispatcherMessage::Item(req) => Some(self.handle_request(req)?), DispatcherMessage::Error(res) => Some(State::SendResponse( Some((Message::Item(res), Body::Empty)), )), } } else { None }, // call inner service State::ServiceCall(ref mut fut) => { match fut.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { self.framed .as_mut() .unwrap() .get_codec_mut() .prepare_te(&mut res); let body = res.replace_body(Body::Empty); Some(State::SendResponse(Some((Message::Item(res), body)))) } Async::NotReady => None, } } // send respons State::SendResponse(ref mut item) => { let (msg, body) = item.take().expect("SendResponse is empty"); let framed = self.framed.as_mut().unwrap(); match framed.start_send(msg) { Ok(AsyncSink::Ready) => { self.flags .set(Flags::KEEPALIVE, framed.get_codec().keepalive()); self.flags.remove(Flags::FLUSHED); match body { Body::Empty => Some(State::None), Body::Streaming(stream) => { Some(State::SendPayload(stream)) } Body::Binary(mut bin) => { self.flags.remove(Flags::FLUSHED); framed .force_send(Message::Chunk(Some(bin.take())))?; framed.force_send(Message::Chunk(None))?; Some(State::None) } } } Ok(AsyncSink::NotReady(msg)) => { *item = Some((msg, body)); return Ok(()); } Err(err) => { if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::Incomplete(None)); } return Err(DispatchError::Io(err)); } } } // Send payload State::SendPayload(ref mut stream) => { let mut framed = self.framed.as_mut().unwrap(); loop { if !framed.is_write_buf_full() { match stream.poll().map_err(|_| DispatchError::Unknown)? { Async::Ready(Some(item)) => { self.flags.remove(Flags::FLUSHED); framed.force_send(Message::Chunk(Some(item)))?; continue; } Async::Ready(None) => { self.flags.remove(Flags::FLUSHED); framed.force_send(Message::Chunk(None))?; } Async::NotReady => return Ok(()), } } else { return Ok(()); } break; } None } }; match state { Some(state) => self.state = state, None => { // if read-backpressure is enabled and we consumed some data. // we may read more dataand retry if !retry && self.can_read() && self.poll_request()? { retry = self.can_read(); continue; } break; } } } Ok(()) } fn handle_request( &mut self, req: Request, ) -> Result, DispatchError> { let mut task = self.service.call(req); match task.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { self.framed .as_mut() .unwrap() .get_codec_mut() .prepare_te(&mut res); let body = res.replace_body(Body::Empty); Ok(State::SendResponse(Some((Message::Item(res), body)))) } Async::NotReady => Ok(State::ServiceCall(task)), } } /// Process one incoming requests pub(self) fn poll_request(&mut self) -> Result> { // limit a mount of non processed requests if self.messages.len() >= MAX_PIPELINED_MESSAGES { return Ok(false); } let mut updated = false; loop { match self.framed.as_mut().unwrap().poll() { Ok(Async::Ready(Some(msg))) => { updated = true; self.flags.insert(Flags::STARTED); match msg { Message::Item(req) => { match self .framed .as_ref() .unwrap() .get_codec() .message_type() { MessageType::Payload => { let (ps, pl) = Payload::new(false); *req.inner.payload.borrow_mut() = Some(pl); self.payload = Some(ps); } MessageType::Stream => { self.unhandled = Some(req); return Ok(updated); } _ => (), } // handle request early if self.state.is_empty() { self.state = self.handle_request(req)?; } else { self.messages.push_back(DispatcherMessage::Item(req)); } } Message::Chunk(Some(chunk)) => { if let Some(ref mut payload) = self.payload { payload.feed_data(chunk); } else { error!( "Internal server error: unexpected payload chunk" ); self.flags.insert(Flags::DISCONNECTED); self.messages.push_back(DispatcherMessage::Error( Response::InternalServerError().finish(), )); self.error = Some(DispatchError::InternalError); break; } } Message::Chunk(None) => { if let Some(mut payload) = self.payload.take() { payload.feed_eof(); } else { error!("Internal server error: unexpected eof"); self.flags.insert(Flags::DISCONNECTED); self.messages.push_back(DispatcherMessage::Error( Response::InternalServerError().finish(), )); self.error = Some(DispatchError::InternalError); break; } } } } Ok(Async::Ready(None)) => { self.client_disconnected(); break; } Ok(Async::NotReady) => break, Err(ParseError::Io(e)) => { self.client_disconnected(); self.error = Some(DispatchError::Io(e)); break; } Err(e) => { if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::EncodingCorrupted); } // Malformed requests should be responded with 400 self.messages.push_back(DispatcherMessage::Error( Response::BadRequest().finish(), )); self.flags.insert(Flags::DISCONNECTED); self.error = Some(e.into()); break; } } } if self.ka_timer.is_some() && updated { if let Some(expire) = self.config.keep_alive_expire() { self.ka_expire = expire; } } Ok(updated) } /// keep-alive timer fn poll_keepalive(&mut self) -> Result<(), DispatchError> { if let Some(ref mut timer) = self.ka_timer { match timer.poll() { Ok(Async::Ready(_)) => { // if we get timer during shutdown, just drop connection if self.flags.contains(Flags::SHUTDOWN) { return Err(DispatchError::DisconnectTimeout); } else if timer.deadline() >= self.ka_expire { // check for any outstanding response processing if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { if self.flags.contains(Flags::STARTED) { trace!("Keep-alive timeout, close connection"); self.flags.insert(Flags::SHUTDOWN); // start shutdown timer if let Some(deadline) = self.config.client_disconnect_timer() { timer.reset(deadline); let _ = timer.poll(); } else { return Ok(()); } } else { // timeout on first request (slow request) return 408 trace!("Slow request timeout"); self.flags.insert(Flags::STARTED | Flags::DISCONNECTED); self.state = State::SendResponse(Some(( Message::Item(Response::RequestTimeout().finish()), Body::Empty, ))); } } else if let Some(deadline) = self.config.keep_alive_expire() { timer.reset(deadline); let _ = timer.poll(); } } else { timer.reset(self.ka_expire); let _ = timer.poll(); } } Ok(Async::NotReady) => (), Err(e) => { error!("Timer error {:?}", e); return Err(DispatchError::Unknown); } } } Ok(()) } } impl Future for Dispatcher where T: AsyncRead + AsyncWrite, S: Service, S::Error: Debug, { type Item = H1ServiceResult; type Error = DispatchError; #[inline] fn poll(&mut self) -> Poll { if self.flags.contains(Flags::SHUTDOWN) { self.poll_keepalive()?; try_ready!(self.poll_flush()); let io = self.framed.take().unwrap().into_inner(); Ok(Async::Ready(H1ServiceResult::Shutdown(io))) } else { self.poll_keepalive()?; self.poll_request()?; self.poll_response()?; self.poll_flush()?; // keep-alive and stream errors if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { if let Some(err) = self.error.take() { Err(err) } else if self.flags.contains(Flags::DISCONNECTED) { Ok(Async::Ready(H1ServiceResult::Disconnected)) } // unhandled request (upgrade or connect) else if self.unhandled.is_some() { let req = self.unhandled.take().unwrap(); let framed = self.framed.take().unwrap(); Ok(Async::Ready(H1ServiceResult::Unhandled(req, framed))) } // disconnect if keep-alive is not enabled else if self.flags.contains(Flags::STARTED) && !self .flags .intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED) { let io = self.framed.take().unwrap().into_inner(); Ok(Async::Ready(H1ServiceResult::Shutdown(io))) } else { Ok(Async::NotReady) } } else { Ok(Async::NotReady) } } } }