From 6e138bf373349496d617b5e7a9031df103183a41 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 30 Nov 2017 14:42:20 -0800 Subject: [PATCH] refactor streaming responses --- CHANGES.md | 2 + examples/basic.rs | 15 +- examples/websocket.rs | 2 +- guide/src/qs_3.md | 4 +- src/body.rs | 66 ++++-- src/context.rs | 63 +++--- src/encoding.rs | 38 ++-- src/httpresponse.rs | 14 -- src/lib.rs | 2 +- src/pipeline.rs | 8 +- src/resource.rs | 7 +- src/route.rs | 43 ++-- src/task.rs | 465 +++++++++++++++++++++++++++--------------- src/ws.rs | 2 +- 14 files changed, 436 insertions(+), 295 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 001367014..836b16b3c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,8 @@ * HTTP/2 Support +* Refactor streaming responses + * Refactor error handling * Asynchronous middlewares diff --git a/examples/basic.rs b/examples/basic.rs index 1b850c7b8..932d4fece 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -8,7 +8,7 @@ extern crate futures; use actix_web::*; use actix_web::middlewares::RequestSession; -use futures::stream::{once, Once}; +use futures::future::{FutureResult, result}; /// simple handler fn index(mut req: HttpRequest) -> Result { @@ -31,15 +31,14 @@ fn index(mut req: HttpRequest) -> Result { } /// async handler -fn index_async(req: HttpRequest) -> Once +fn index_async(req: HttpRequest) -> FutureResult { println!("{:?}", req); - once(Ok(HttpResponse::Ok() - .content_type("text/html") - .body(format!("Hello {}!", req.match_info().get("name").unwrap())) - .unwrap() - .into())) + result(HttpResponse::Ok() + .content_type("text/html") + .body(format!("Hello {}!", req.match_info().get("name").unwrap())) + .map_err(|e| e.into())) } /// handler with path parameters like `/user/{name}/` @@ -69,7 +68,7 @@ fn main() { )) // register simple handle r, handle all methods .handler("/index.html", index) - // with path parameters + // with path parameters .resource("/user/{name}/", |r| r.handler(Method::GET, with_param)) // async handler .resource("/async/{name}", |r| r.async(Method::GET, index_async)) diff --git a/examples/websocket.rs b/examples/websocket.rs index 31d439895..24a88d03b 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -56,7 +56,7 @@ impl Handler for MyWebSocket { } fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); + ::std::env::set_var("RUST_LOG", "actix_web=trace"); let _ = env_logger::init(); let sys = actix::System::new("ws-example"); diff --git a/guide/src/qs_3.md b/guide/src/qs_3.md index 89bef34c0..a34a46848 100644 --- a/guide/src/qs_3.md +++ b/guide/src/qs_3.md @@ -1,4 +1,4 @@ -# Overview +# [WIP] Overview Actix web provides some primitives to build web servers and applications with Rust. It provides routing, middlewares, pre-processing of requests, and post-processing of responses, @@ -69,7 +69,7 @@ fn main() { } ``` -## Handler +## [WIP] Handler A request handler can have different forms. diff --git a/src/body.rs b/src/body.rs index 56b1411b0..73bd8920c 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,24 +1,29 @@ +use std::fmt; use std::rc::Rc; use std::sync::Arc; use bytes::{Bytes, BytesMut}; +use futures::Stream; -use route::Frame; +use error::Error; + +pub(crate) type BodyStream = Box>; /// Represents various types of http message body. -#[derive(Debug, PartialEq)] pub enum Body { /// Empty response. `Content-Length` header is set to `0` Empty, /// Specific response body. Binary(Binary), - /// Streaming response body with specified length. - Length(u64), /// Unspecified streaming response. Developer is responsible for setting /// right `Content-Length` or `Transfer-Encoding` headers. - Streaming, + Streaming(BodyStream), /// Upgrade connection. - Upgrade, + Upgrade(BodyStream), + /// Special body type for actor streaming response. + StreamingContext, + /// Special body type for actor upgrade response. + UpgradeContext, } /// Represents various types of binary body. @@ -45,7 +50,8 @@ impl Body { /// Does this body streaming. pub fn is_streaming(&self) -> bool { match *self { - Body::Length(_) | Body::Streaming | Body::Upgrade => true, + Body::Streaming(_) | Body::StreamingContext + | Body::Upgrade(_) | Body::UpgradeContext => true, _ => false } } @@ -64,6 +70,43 @@ impl Body { } } +impl PartialEq for Body { + fn eq(&self, other: &Body) -> bool { + match *self { + Body::Empty => match *other { + Body::Empty => true, + _ => false, + }, + Body::Binary(ref b) => match *other { + Body::Binary(ref b2) => b == b2, + _ => false, + }, + Body::StreamingContext => match *other { + Body::StreamingContext => true, + _ => false, + }, + Body::UpgradeContext => match *other { + Body::UpgradeContext => true, + _ => false, + }, + Body::Streaming(_) | Body::Upgrade(_) => false, + } + } +} + +impl fmt::Debug for Body { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Body::Empty => write!(f, "Body::Empty"), + Body::Binary(ref b) => write!(f, "Body::Binary({:?})", b), + Body::Streaming(_) => write!(f, "Body::Streaming(_)"), + Body::Upgrade(_) => write!(f, "Body::Upgrade(_)"), + Body::StreamingContext => write!(f, "Body::StreamingContext"), + Body::UpgradeContext => write!(f, "Body::UpgradeContext"), + } + } +} + impl From for Body where T: Into{ fn from(b: T) -> Body { Body::Binary(b.into()) @@ -195,12 +238,6 @@ impl AsRef<[u8]> for Binary { } } -impl From for Frame { - fn from(b: Binary) -> Frame { - Frame::Payload(Some(b)) - } -} - #[cfg(test)] mod tests { use super::*; @@ -209,8 +246,7 @@ mod tests { fn test_body_is_streaming() { assert_eq!(Body::Empty.is_streaming(), false); assert_eq!(Body::Binary(Binary::from("")).is_streaming(), false); - assert_eq!(Body::Length(100).is_streaming(), true); - assert_eq!(Body::Streaming.is_streaming(), true); + // assert_eq!(Body::Streaming.is_streaming(), true); } #[test] diff --git a/src/context.rs b/src/context.rs index 8623c7137..ae6211d9e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use std::cell::RefCell; use std::collections::VecDeque; use std::marker::PhantomData; -use futures::{Async, Future, Stream, Poll}; +use futures::{Async, Future, Poll}; use futures::sync::oneshot::Sender; use actix::{Actor, ActorState, ActorContext, AsyncContext, @@ -13,13 +13,19 @@ use actix::dev::{AsyncContextApi, ActorAddressCell, ActorItemsCell, ActorWaitCel Envelope, ToEnvelope, RemoteEnvelope}; use task::{IoContext, DrainFut}; -use body::Binary; +use body::{Body, Binary}; use error::Error; -use route::Frame; use httprequest::HttpRequest; use httpresponse::HttpResponse; +#[derive(Debug)] +pub(crate) enum Frame { + Message(HttpResponse), + Payload(Option), + Drain(Rc>), +} + /// Http actor execution context pub struct HttpContext where A: Actor>, { @@ -31,25 +37,14 @@ pub struct HttpContext where A: Actor>, stream: VecDeque, wait: ActorWaitCell, request: HttpRequest, + streaming: bool, disconnected: bool, } -impl IoContext for HttpContext where A: Actor, S: 'static { - - fn disconnected(&mut self) { - self.items.stop(); - self.disconnected = true; - if self.state == ActorState::Running { - self.state = ActorState::Stopping; - } - } -} - impl ActorContext for HttpContext where A: Actor { /// Stop actor execution fn stop(&mut self) { - self.stream.push_back(Frame::Payload(None)); self.items.stop(); self.address.close(); if self.state == ActorState::Running { @@ -116,6 +111,7 @@ impl HttpContext where A: Actor { wait: ActorWaitCell::default(), stream: VecDeque::new(), request: req, + streaming: false, disconnected: false, } } @@ -133,20 +129,25 @@ impl HttpContext where A: Actor { &mut self.request } - - /// Start response processing + /// Send response to peer pub fn start>(&mut self, response: R) { - self.stream.push_back(Frame::Message(response.into())) + let resp = response.into(); + match *resp.body() { + Body::StreamingContext | Body::UpgradeContext => self.streaming = true, + _ => (), + } + self.stream.push_back(Frame::Message(resp)) } /// Write payload pub fn write>(&mut self, data: B) { - self.stream.push_back(Frame::Payload(Some(data.into()))) - } - - /// Indicate end of streamimng payload. Also this method calls `Self::close`. - pub fn write_eof(&mut self) { - self.stop(); + if self.streaming { + if !self.disconnected { + self.stream.push_back(Frame::Payload(Some(data.into()))) + } + } else { + warn!("Trying to write response body for non-streaming response"); + } } /// Returns drain future @@ -184,11 +185,15 @@ impl HttpContext where A: Actor { } } -#[doc(hidden)] -impl Stream for HttpContext where A: Actor -{ - type Item = Frame; - type Error = Error; +impl IoContext for HttpContext where A: Actor, S: 'static { + + fn disconnected(&mut self) { + self.items.stop(); + self.disconnected = true; + if self.state == ActorState::Running { + self.state = ActorState::Stopping; + } + } fn poll(&mut self) -> Poll, Error> { let act: &mut A = unsafe { diff --git a/src/encoding.rs b/src/encoding.rs index c6fffa8bc..4774b4c7a 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -381,28 +381,6 @@ impl PayloadEncoder { resp.headers.remove(TRANSFER_ENCODING); TransferEncoding::length(0) }, - Body::Length(n) => { - if resp.chunked() { - error!("Chunked transfer is enabled but body with specific length is specified"); - } - if compression { - resp.headers.remove(CONTENT_LENGTH); - if version == Version::HTTP_2 { - resp.headers.remove(TRANSFER_ENCODING); - TransferEncoding::eof() - } else { - resp.headers.insert( - TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked() - } - } else { - resp.headers.insert( - CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); - resp.headers.remove(TRANSFER_ENCODING); - TransferEncoding::length(n) - } - }, Body::Binary(ref mut bytes) => { if compression { let transfer = TransferEncoding::eof(); @@ -435,7 +413,7 @@ impl PayloadEncoder { TransferEncoding::length(bytes.len() as u64) } } - Body::Streaming => { + Body::Streaming(_) | Body::StreamingContext => { if resp.chunked() { resp.headers.remove(CONTENT_LENGTH); if version != Version::HTTP_11 { @@ -449,11 +427,23 @@ impl PayloadEncoder { TRANSFER_ENCODING, HeaderValue::from_static("chunked")); TransferEncoding::chunked() } + } else if let Some(len) = resp.headers().get(CONTENT_LENGTH) { + // Content-Length + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + TransferEncoding::length(len) + } else { + debug!("illegal Content-Length: {:?}", len); + TransferEncoding::eof() + } + } else { + TransferEncoding::eof() + } } else { TransferEncoding::eof() } } - Body::Upgrade => { + Body::Upgrade(_) | Body::UpgradeContext => { if version == Version::HTTP_2 { error!("Connection upgrade is forbidden for HTTP/2"); } else { diff --git a/src/httpresponse.rs b/src/httpresponse.rs index a25b59d9e..75f45c844 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -11,7 +11,6 @@ use serde::Serialize; use Cookie; use body::Body; -use route::Frame; use error::Error; use encoding::ContentEncoding; @@ -223,13 +222,6 @@ impl fmt::Debug for HttpResponse { } } -// TODO: remove -impl From for Frame { - fn from(resp: HttpResponse) -> Frame { - Frame::Message(resp) - } -} - #[derive(Debug)] struct Parts { version: Option, @@ -535,12 +527,6 @@ mod tests { assert_eq!(resp.status(), StatusCode::NO_CONTENT) } - #[test] - fn test_body() { - assert!(Body::Length(10).is_streaming()); - assert!(Body::Streaming.is_streaming()); - } - #[test] fn test_upgrade() { let resp = HttpResponse::build(StatusCode::OK) diff --git a/src/lib.rs b/src/lib.rs index c94e67897..53b30a374 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,7 +82,7 @@ pub use application::Application; pub use httprequest::{HttpRequest, UrlEncoded}; pub use httpresponse::HttpResponse; pub use payload::{Payload, PayloadItem}; -pub use route::{Frame, Reply}; +pub use route::Reply; pub use resource::Resource; pub use recognizer::Params; pub use server::HttpServer; diff --git a/src/pipeline.rs b/src/pipeline.rs index de230aa98..3b6eb7d4c 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -391,11 +391,7 @@ impl MiddlewaresResponse { } } - pub fn poll(&mut self, req: &mut HttpRequest) -> Poll, Error> { - if self.fut.is_none() { - return Ok(Async::Ready(None)) - } - + pub fn poll(&mut self, req: &mut HttpRequest) -> Poll { loop { // poll latest fut let mut resp = match self.fut.as_mut().unwrap().poll() { @@ -410,7 +406,7 @@ impl MiddlewaresResponse { loop { if self.idx == 0 { - return Ok(Async::Ready(Some(resp))) + return Ok(Async::Ready(resp)) } else { match self.middlewares[self.idx].response(req, resp) { Response::Err(err) => diff --git a/src/resource.rs b/src/resource.rs index 4fcd99227..c9dcdef7d 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -2,12 +2,13 @@ use std::marker::PhantomData; use std::collections::HashMap; use http::Method; -use futures::Stream; +use futures::Future; use task::Task; use error::Error; -use route::{Reply, RouteHandler, Frame, WrapHandler, Handler, StreamHandler}; +use route::{Reply, RouteHandler, WrapHandler, Handler, StreamHandler}; use httprequest::HttpRequest; +use httpresponse::HttpResponse; use httpcodes::{HTTPNotFound, HTTPMethodNotAllowed}; /// Http resource @@ -68,7 +69,7 @@ impl Resource where S: 'static { /// Register async handler for specified method. pub fn async(&mut self, method: Method, handler: F) where F: Fn(HttpRequest) -> R + 'static, - R: Stream + 'static, + R: Future + 'static, { self.routes.insert(method, Box::new(StreamHandler::new(handler))); } diff --git a/src/route.rs b/src/route.rs index d23376890..bd89f8b1d 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,31 +1,14 @@ -use std::rc::Rc; -use std::cell::RefCell; use std::marker::PhantomData; use std::result::Result as StdResult; use actix::Actor; -use futures::Stream; +use futures::Future; -use body::Binary; use error::Error; use context::HttpContext; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use task::{Task, DrainFut, IoContext}; - -#[doc(hidden)] -#[derive(Debug)] -pub enum Frame { - Message(HttpResponse), - Payload(Option), - Drain(Rc>), -} - -impl Frame { - pub fn eof() -> Frame { - Frame::Payload(None) - } -} +use task::{Task, IoContext}; /// Trait defines object that could be regestered as route handler #[allow(unused_variables)] @@ -55,8 +38,8 @@ pub struct Reply(ReplyItem); enum ReplyItem { Message(HttpResponse), - Actor(Box>), - Stream(Box>), + Actor(Box), + Future(Box>), } impl Reply { @@ -69,10 +52,10 @@ impl Reply { } /// Create async response - pub fn stream(stream: S) -> Reply - where S: Stream + 'static + pub fn async(fut: F) -> Reply + where F: Future + 'static { - Reply(ReplyItem::Stream(Box::new(stream))) + Reply(ReplyItem::Future(Box::new(fut))) } /// Send response @@ -89,8 +72,8 @@ impl Reply { ReplyItem::Actor(ctx) => { task.context(ctx) } - ReplyItem::Stream(stream) => { - task.stream(stream) + ReplyItem::Future(fut) => { + task.async(fut) } } } @@ -160,7 +143,7 @@ impl RouteHandler for WrapHandler pub(crate) struct StreamHandler where F: Fn(HttpRequest) -> R + 'static, - R: Stream + 'static, + R: Future + 'static, S: 'static, { f: Box, @@ -169,7 +152,7 @@ struct StreamHandler impl StreamHandler where F: Fn(HttpRequest) -> R + 'static, - R: Stream + 'static, + R: Future + 'static, S: 'static, { pub fn new(f: F) -> Self { @@ -179,10 +162,10 @@ impl StreamHandler impl RouteHandler for StreamHandler where F: Fn(HttpRequest) -> R + 'static, - R: Stream + 'static, + R: Future + 'static, S: 'static, { fn handle(&self, req: HttpRequest, task: &mut Task) { - task.stream((self.f)(req)) + task.async((self.f)(req)) } } diff --git a/src/task.rs b/src/task.rs index 2d0545a77..009ec52e4 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,20 +1,18 @@ -use std::mem; +use std::{fmt, mem}; use std::rc::Rc; use std::cell::RefCell; -use std::collections::VecDeque; -use futures::{Async, Future, Poll, Stream}; +use futures::{Async, Future, Poll}; use futures::task::{Task as FutureTask, current as current_task}; +use body::{Body, BodyStream, Binary}; +use context::Frame; use h1writer::{Writer, WriterState}; use error::{Error, UnexpectedTaskFrame}; -use route::Frame; use pipeline::MiddlewaresResponse; use httprequest::HttpRequest; use httpresponse::HttpResponse; -type FrameStream = Stream; - #[derive(PartialEq, Debug)] enum TaskRunningState { Paused, @@ -38,27 +36,69 @@ impl TaskRunningState { } } -#[derive(PartialEq, Debug)] -enum TaskIOState { - ReadingMessage, - ReadingPayload, - Done, +enum ResponseState { + Reading, + Ready(HttpResponse), + Middlewares(MiddlewaresResponse), + Prepared(Option), } -impl TaskIOState { - fn is_done(&self) -> bool { - *self == TaskIOState::Done - } +enum IOState { + Response, + Payload(BodyStream), + Context, + Done, } enum TaskStream { None, - Stream(Box), - Context(Box>), + Context(Box), + Response(Box>), } -pub(crate) trait IoContext: Stream + 'static { +impl IOState { + fn is_done(&self) -> bool { + match *self { + IOState::Done => true, + _ => false + } + } +} + +impl ResponseState { + fn is_reading(&self) -> bool { + match *self { + ResponseState::Reading => true, + _ => false + } + } +} + +impl fmt::Debug for ResponseState { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ResponseState::Reading => write!(f, "ResponseState::Reading"), + ResponseState::Ready(_) => write!(f, "ResponseState::Ready"), + ResponseState::Middlewares(_) => write!(f, "ResponseState::Middlewares"), + ResponseState::Prepared(_) => write!(f, "ResponseState::Prepared"), + } + } +} + +impl fmt::Debug for IOState { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + IOState::Response => write!(f, "IOState::Response"), + IOState::Payload(_) => write!(f, "IOState::Payload"), + IOState::Context => write!(f, "IOState::Context"), + IOState::Done => write!(f, "IOState::Done"), + } + } +} + +pub(crate) trait IoContext: 'static { fn disconnected(&mut self); + fn poll(&mut self) -> Poll, Error>; } /// Future that resolves when all buffered data get sent @@ -104,12 +144,11 @@ impl Future for DrainFut { } pub struct Task { - state: TaskRunningState, - iostate: TaskIOState, - frames: VecDeque, + running: TaskRunningState, + response: ResponseState, + iostate: IOState, stream: TaskStream, drain: Vec>>, - prepared: Option, disconnected: bool, middlewares: Option, } @@ -118,12 +157,11 @@ pub struct Task { impl Default for Task { fn default() -> Task { - Task { state: TaskRunningState::Running, - iostate: TaskIOState::ReadingMessage, - frames: VecDeque::new(), + Task { running: TaskRunningState::Running, + response: ResponseState::Reading, + iostate: IOState::Response, drain: Vec::new(), stream: TaskStream::None, - prepared: None, disconnected: false, middlewares: None } } @@ -132,16 +170,11 @@ impl Default for Task { impl Task { pub(crate) fn from_response>(response: R) -> Task { - let mut frames = VecDeque::new(); - frames.push_back(Frame::Message(response.into())); - frames.push_back(Frame::Payload(None)); - - Task { state: TaskRunningState::Running, - iostate: TaskIOState::Done, - frames: frames, + Task { running: TaskRunningState::Running, + response: ResponseState::Ready(response.into()), + iostate: IOState::Response, drain: Vec::new(), stream: TaskStream::None, - prepared: None, disconnected: false, middlewares: None } } @@ -151,27 +184,33 @@ impl Task { } pub fn reply>(&mut self, response: R) { - self.frames.push_back(Frame::Message(response.into())); - self.frames.push_back(Frame::Payload(None)); - self.iostate = TaskIOState::Done; + let state = &mut self.response; + match *state { + ResponseState::Reading => + *state = ResponseState::Ready(response.into()), + _ => panic!("Internal task state is broken"), + } } pub fn error>(&mut self, err: E) { self.reply(err.into()) } - pub(crate) fn context(&mut self, ctx: Box>) { + pub(crate) fn context(&mut self, ctx: Box) { self.stream = TaskStream::Context(ctx); } - pub fn stream(&mut self, stream: S) - where S: Stream + 'static + pub fn async(&mut self, fut: F) + where F: Future + 'static { - self.stream = TaskStream::Stream(Box::new(stream)); + self.stream = TaskStream::Response(Box::new(fut)); } pub(crate) fn response(&mut self) -> HttpResponse { - self.prepared.take().unwrap() + match self.response { + ResponseState::Prepared(ref mut state) => state.take().unwrap(), + _ => panic!("Internal state is broken"), + } } pub(crate) fn set_middlewares(&mut self, middlewares: MiddlewaresResponse) { @@ -188,97 +227,112 @@ impl Task { pub(crate) fn poll_io(&mut self, io: &mut T, req: &mut HttpRequest) -> Poll where T: Writer { - trace!("POLL-IO frames:{:?}", self.frames.len()); - - // response is completed - if self.frames.is_empty() && self.iostate.is_done() { - return Ok(Async::Ready(self.state.is_done())); - } else if self.drain.is_empty() { - // poll stream - if self.state == TaskRunningState::Running { - match self.poll()? { - Async::Ready(_) => - self.state = TaskRunningState::Done, - Async::NotReady => (), - } - } - - // process middlewares response - if let Some(mut middlewares) = self.middlewares.take() { - match middlewares.poll(req)? { - Async::NotReady => { - self.middlewares = Some(middlewares); - return Ok(Async::NotReady); - } - Async::Ready(None) => { - self.middlewares = Some(middlewares); - } - Async::Ready(Some(mut response)) => { - let result = io.start(req, &mut response)?; - self.prepared = Some(response); - match result { - WriterState::Pause => self.state.pause(), - WriterState::Done => self.state.resume(), - } - }, - } - } + trace!("POLL-IO frames resp: {:?}, io: {:?}, running: {:?}", + self.response, self.iostate, self.running); + if self.iostate.is_done() { // response is completed + return Ok(Async::Ready(self.running.is_done())); + } else if self.drain.is_empty() && self.running != TaskRunningState::Paused { // if task is paused, write buffer is probably full - if self.state != TaskRunningState::Paused { - // process exiting frames - while let Some(frame) = self.frames.pop_front() { - trace!("IO Frame: {:?}", frame); - let res = match frame { - Frame::Message(mut resp) => { - // run middlewares - if let Some(mut middlewares) = self.middlewares.take() { - match middlewares.response(req, resp) { - Ok(Some(mut resp)) => { - let result = io.start(req, &mut resp)?; - self.prepared = Some(resp); - result - } - Ok(None) => { - // middlewares need to run some futures - self.middlewares = Some(middlewares); - return self.poll_io(io, req) - } - Err(err) => return Err(err), - } - } else { + + loop { + let result = match mem::replace(&mut self.iostate, IOState::Done) { + IOState::Response => { + match self.poll_response(req) { + Ok(Async::Ready(mut resp)) => { let result = io.start(req, &mut resp)?; - self.prepared = Some(resp); + + match resp.replace_body(Body::Empty) { + Body::Streaming(stream) | Body::Upgrade(stream) => + self.iostate = IOState::Payload(stream), + Body::StreamingContext | Body::UpgradeContext => + self.iostate = IOState::Context, + _ => (), + } + self.response = ResponseState::Prepared(Some(resp)); + result + }, + Ok(Async::NotReady) => { + self.iostate = IOState::Response; + return Ok(Async::NotReady) + } + Err(err) => { + let mut resp = err.into(); + let result = io.start(req, &mut resp)?; + + match resp.replace_body(Body::Empty) { + Body::Streaming(stream) | Body::Upgrade(stream) => + self.iostate = IOState::Payload(stream), + _ => (), + } + self.response = ResponseState::Prepared(Some(resp)); result } } - Frame::Payload(Some(chunk)) => { - io.write(chunk.as_ref())? - }, - Frame::Payload(None) => { - self.iostate = TaskIOState::Done; - io.write_eof()? - }, - Frame::Drain(fut) => { - self.drain.push(fut); - break + }, + IOState::Payload(mut body) => { + // always poll stream + if self.running == TaskRunningState::Running { + match self.poll()? { + Async::Ready(_) => + self.running = TaskRunningState::Done, + Async::NotReady => (), + } } - }; - match res { - WriterState::Pause => { - self.state.pause(); - break + match body.poll() { + Ok(Async::Ready(None)) => { + self.iostate = IOState::Done; + io.write_eof()?; + break + }, + Ok(Async::Ready(Some(chunk))) => { + self.iostate = IOState::Payload(body); + io.write(chunk.as_ref())? + } + Ok(Async::NotReady) => { + self.iostate = IOState::Payload(body); + break + }, + Err(err) => return Err(err), } - WriterState::Done => self.state.resume(), } + IOState::Context => { + match self.poll_context() { + Ok(Async::Ready(None)) => { + self.iostate = IOState::Done; + self.running = TaskRunningState::Done; + io.write_eof()?; + break + }, + Ok(Async::Ready(Some(chunk))) => { + self.iostate = IOState::Context; + io.write(chunk.as_ref())? + } + Ok(Async::NotReady) => { + self.iostate = IOState::Context; + break + } + Err(err) => return Err(err), + } + } + IOState::Done => break, + }; + + match result { + WriterState::Pause => { + self.running.pause(); + break + } + WriterState::Done => + self.running.resume(), } } } // flush io match io.poll_complete() { - Ok(Async::Ready(_)) => self.state.resume(), + Ok(Async::Ready(_)) => self.running.resume(), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); @@ -296,65 +350,154 @@ impl Task { // response is completed if self.iostate.is_done() { - if let Some(ref mut resp) = self.prepared { - resp.set_response_size(io.written()); + if let ResponseState::Prepared(Some(ref mut resp)) = self.response { + resp.set_response_size(io.written()) } - Ok(Async::Ready(self.state.is_done())) + Ok(Async::Ready(self.running.is_done())) } else { Ok(Async::NotReady) } } - fn poll_stream(&mut self, stream: &mut S) -> Poll<(), Error> - where S: Stream - { + pub(crate) fn poll_response(&mut self, req: &mut HttpRequest) -> Poll { loop { - match stream.poll() { - Ok(Async::Ready(Some(frame))) => { - match frame { - Frame::Message(ref msg) => { - if self.iostate != TaskIOState::ReadingMessage { - error!("Unexpected frame {:?}", frame); - return Err(UnexpectedTaskFrame.into()) + let state = mem::replace(&mut self.response, ResponseState::Prepared(None)); + match state { + ResponseState::Ready(response) => { + // run middlewares + if let Some(mut middlewares) = self.middlewares.take() { + match middlewares.response(req, response) { + Ok(Some(response)) => + return Ok(Async::Ready(response)), + Ok(None) => { + // middlewares need to run some futures + self.response = ResponseState::Middlewares(middlewares); + continue } - let upgrade = msg.upgrade(); - if upgrade || msg.body().is_streaming() { - self.iostate = TaskIOState::ReadingPayload; - } else { - self.iostate = TaskIOState::Done; - } - }, - Frame::Payload(ref chunk) => { - if chunk.is_none() { - self.iostate = TaskIOState::Done; - } else if self.iostate != TaskIOState::ReadingPayload { - error!("Unexpected frame {:?}", self.iostate); - return Err(UnexpectedTaskFrame.into()) - } - }, - _ => (), + Err(err) => return Err(err), + } + } else { + return Ok(Async::Ready(response)) } - self.frames.push_back(frame) - }, - Ok(Async::Ready(None)) => - return Ok(Async::Ready(())), - Ok(Async::NotReady) => + } + ResponseState::Middlewares(mut middlewares) => { + // process middlewares + match middlewares.poll(req) { + Ok(Async::NotReady) => { + self.response = ResponseState::Middlewares(middlewares); + return Ok(Async::NotReady) + }, + Ok(Async::Ready(response)) => + return Ok(Async::Ready(response)), + Err(err) => + return Err(err), + } + } + _ => (), + } + self.response = state; + + match mem::replace(&mut self.stream, TaskStream::None) { + TaskStream::None => return Ok(Async::NotReady), - Err(err) => - return Err(err), + TaskStream::Context(mut context) => { + loop { + match context.poll() { + Ok(Async::Ready(Some(frame))) => { + match frame { + Frame::Message(msg) => { + if !self.response.is_reading() { + error!("Unexpected message frame {:?}", msg); + return Err(UnexpectedTaskFrame.into()) + } + self.stream = TaskStream::Context(context); + self.response = ResponseState::Ready(msg); + break + }, + Frame::Payload(_) => (), + Frame::Drain(fut) => { + self.drain.push(fut); + self.stream = TaskStream::Context(context); + break + } + } + }, + Ok(Async::Ready(None)) => { + error!("Unexpected eof"); + return Err(UnexpectedTaskFrame.into()) + }, + Ok(Async::NotReady) => { + self.stream = TaskStream::Context(context); + return Ok(Async::NotReady) + }, + Err(err) => + return Err(err), + } + } + }, + TaskStream::Response(mut fut) => { + match fut.poll() { + Ok(Async::NotReady) => { + self.stream = TaskStream::Response(fut); + return Ok(Async::NotReady); + }, + Ok(Async::Ready(response)) => { + self.response = ResponseState::Ready(response); + } + Err(err) => + return Err(err) + } + } } } } pub(crate) fn poll(&mut self) -> Poll<(), Error> { - let mut s = mem::replace(&mut self.stream, TaskStream::None); + match self.stream { + TaskStream::None | TaskStream::Response(_) => + Ok(Async::Ready(())), + TaskStream::Context(ref mut context) => { + loop { + match context.poll() { + Ok(Async::Ready(Some(_))) => (), + Ok(Async::Ready(None)) => + return Ok(Async::Ready(())), + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(err) => + return Err(err), + } + } + }, + } + } - let result = match s { - TaskStream::None => Ok(Async::Ready(())), - TaskStream::Stream(ref mut stream) => self.poll_stream(stream), - TaskStream::Context(ref mut context) => self.poll_stream(context), - }; - self.stream = s; - result + fn poll_context(&mut self) -> Poll, Error> { + match self.stream { + TaskStream::None | TaskStream::Response(_) => + Err(UnexpectedTaskFrame.into()), + TaskStream::Context(ref mut context) => { + match context.poll() { + Ok(Async::Ready(Some(frame))) => { + match frame { + Frame::Message(msg) => { + error!("Unexpected message frame {:?}", msg); + Err(UnexpectedTaskFrame.into()) + }, + Frame::Payload(payload) => { + Ok(Async::Ready(payload)) + }, + Frame::Drain(fut) => { + self.drain.push(fut); + Ok(Async::NotReady) + } + } + }, + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(err) => Err(err), + } + }, + } } } diff --git a/src/ws.rs b/src/ws.rs index 32bf49fb2..3bda41d6a 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -169,7 +169,7 @@ pub fn handshake(req: &HttpRequest) -> Result