diff --git a/src/application.rs b/src/application.rs index c710142d4..b25ede793 100644 --- a/src/application.rs +++ b/src/application.rs @@ -1,3 +1,5 @@ +#![allow(unused_imports, dead_code)] + use std::rc::Rc; use std::string::ToString; use std::collections::HashMap; @@ -10,6 +12,7 @@ use recognizer::{RouteRecognizer, check_pattern}; use httprequest::HttpRequest; use httpresponse::HttpResponse; use channel::HttpHandler; +use pipeline::Pipeline; use middlewares::Middleware; @@ -48,14 +51,9 @@ impl HttpHandler for Application { &self.prefix } - fn handle(&self, req: &mut HttpRequest, payload: Payload) -> Task { - let mut task = self.run(req, payload); - - // init middlewares - if !self.middlewares.is_empty() { - task.set_middlewares(Rc::clone(&self.middlewares)); - } - task + fn handle(&self, req: HttpRequest, payload: Payload) -> Pipeline { + Pipeline::new(req, payload, Rc::clone(&self.middlewares), + &|req: &mut HttpRequest, payload: Payload| {self.run(req, payload)}) } } diff --git a/src/channel.rs b/src/channel.rs index 2403b46c2..9d595ec10 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -8,8 +8,8 @@ use tokio_io::{AsyncRead, AsyncWrite}; use h1; use h2; -use task::Task; use payload::Payload; +use pipeline::Pipeline; use httprequest::HttpRequest; /// Low level http request handler @@ -17,7 +17,7 @@ pub trait HttpHandler: 'static { /// Http handler prefix fn prefix(&self) -> &str; /// Handle request - fn handle(&self, req: &mut HttpRequest, payload: Payload) -> Task; + fn handle(&self, req: HttpRequest, payload: Payload) -> Pipeline; } enum HttpProtocol diff --git a/src/dev.rs b/src/dev.rs index 7e269e979..dd23a6066 100644 --- a/src/dev.rs +++ b/src/dev.rs @@ -12,4 +12,5 @@ pub use super::*; // dev specific pub use task::Task; +pub use pipeline::Pipeline; pub use recognizer::RouteRecognizer; diff --git a/src/error.rs b/src/error.rs index 4fca84181..fa06e5d40 100644 --- a/src/error.rs +++ b/src/error.rs @@ -86,6 +86,13 @@ default impl ErrorResponse for T { /// `InternalServerError` for `JsonError` impl ErrorResponse for JsonError {} +/// Internal error +#[derive(Fail, Debug)] +#[fail(display="Unexpected task frame")] +pub struct UnexpectedTaskFrame; + +impl ErrorResponse for UnexpectedTaskFrame {} + /// A set of errors that can occur during parsing HTTP streams #[derive(Fail, Debug)] pub enum ParseError { diff --git a/src/h1.rs b/src/h1.rs index ea6fda6ab..2d75c95a5 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -1,6 +1,5 @@ use std::{self, io, ptr}; use std::rc::Rc; -use std::cell::UnsafeCell; use std::net::SocketAddr; use std::time::Duration; use std::collections::VecDeque; @@ -15,13 +14,13 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tokio_core::reactor::Timeout; use percent_encoding; -use task::Task; +use pipeline::Pipeline; +use encoding::PayloadType; use channel::HttpHandler; -use error::{ParseError, PayloadError, ErrorResponse}; use h1writer::H1Writer; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; -use encoding::PayloadType; +use error::{ParseError, PayloadError, ErrorResponse}; use payload::{Payload, PayloadWriter, DEFAULT_BUFFER_SIZE}; const KEEPALIVE_PERIOD: u64 = 15; // seconds @@ -51,8 +50,8 @@ pub(crate) struct Http1 { } struct Entry { - task: Task, - req: UnsafeCell, + task: Pipeline, + //req: UnsafeCell, eof: bool, error: bool, finished: bool, @@ -107,8 +106,7 @@ impl Http1 } // this is anoying - let req = unsafe {item.req.get().as_mut().unwrap()}; - match item.task.poll_io(&mut self.stream, req) + match item.task.poll_io(&mut self.stream) { Ok(Async::Ready(ready)) => { not_ready = false; @@ -182,14 +180,14 @@ impl Http1 let mut task = None; for h in self.router.iter() { if req.path().starts_with(h.prefix()) { - task = Some(h.handle(&mut req, payload)); + task = Some(h.handle(req, payload)); break } } self.tasks.push_back( - Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)), - req: UnsafeCell::new(req), + Entry {task: task.unwrap_or_else(|| Pipeline::error(HTTPNotFound)), + //req: UnsafeCell::new(req), eof: false, error: false, finished: false}); @@ -217,15 +215,13 @@ impl Http1 self.keepalive = false; self.keepalive_timer.take(); - // on parse error, stop reading stream but - // tasks need to be completed + // on parse error, stop reading stream but tasks need to be completed self.error = true; if self.tasks.is_empty() { if let ReaderError::Error(err) = err { self.tasks.push_back( - Entry {task: Task::reply(err.error_response()), - req: UnsafeCell::new(HttpRequest::for_error()), + Entry {task: Pipeline::error(err.error_response()), eof: false, error: false, finished: false}); diff --git a/src/h2.rs b/src/h2.rs index baa3f452e..cad667641 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -1,7 +1,6 @@ use std::{io, cmp, mem}; use std::rc::Rc; use std::io::{Read, Write}; -use std::cell::UnsafeCell; use std::time::Duration; use std::net::SocketAddr; use std::collections::VecDeque; @@ -15,13 +14,13 @@ use futures::{Async, Poll, Future, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_core::reactor::Timeout; -use task::Task; +use pipeline::Pipeline; use h2writer::H2Writer; use channel::HttpHandler; -use httpcodes::HTTPNotFound; -use httprequest::HttpRequest; use error::PayloadError; use encoding::PayloadType; +use httpcodes::HTTPNotFound; +use httprequest::HttpRequest; use payload::{Payload, PayloadWriter}; const KEEPALIVE_PERIOD: u64 = 15; // seconds @@ -83,8 +82,8 @@ impl Http2 item.poll_payload(); if !item.eof { - let req = unsafe {item.req.get().as_mut().unwrap()}; - match item.task.poll_io(&mut item.stream, req) { + //let req = unsafe {item.req.get().as_mut().unwrap()}; + match item.task.poll_io(&mut item.stream) { Ok(Async::Ready(ready)) => { item.eof = true; if ready { @@ -198,8 +197,7 @@ impl Http2 } struct Entry { - task: Task, - req: UnsafeCell, + task: Pipeline, payload: PayloadType, recv: RecvStream, stream: H2Writer, @@ -230,18 +228,19 @@ impl Entry { // Payload and Content-Encoding let (psender, payload) = Payload::new(false); + // Payload sender + let psender = PayloadType::new(req.headers(), psender); + // start request processing let mut task = None; for h in router.iter() { if req.path().starts_with(h.prefix()) { - task = Some(h.handle(&mut req, payload)); + task = Some(h.handle(req, payload)); break } } - let psender = PayloadType::new(req.headers(), psender); - Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)), - req: UnsafeCell::new(req), + Entry {task: task.unwrap_or_else(|| Pipeline::error(HTTPNotFound)), payload: psender, recv: recv, stream: H2Writer::new(resp), diff --git a/src/lib.rs b/src/lib.rs index a063bc5aa..9543ad21a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ mod resource; mod recognizer; mod route; mod task; +mod pipeline; mod staticfiles; mod server; mod channel; diff --git a/src/middlewares/logger.rs b/src/middlewares/logger.rs index 0b62b093b..b70b34fe4 100644 --- a/src/middlewares/logger.rs +++ b/src/middlewares/logger.rs @@ -101,9 +101,9 @@ impl Logger { impl Middleware for Logger { - fn start(&self, req: &mut HttpRequest) -> Started { + fn start(&self, mut req: HttpRequest) -> Started { req.extensions().insert(StartTime(time::now())); - Started::Done + Started::Done(req) } fn finish(&self, req: &mut HttpRequest, resp: &HttpResponse) -> Finished { @@ -298,16 +298,16 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB")); - let mut req = HttpRequest::new( + let req = HttpRequest::new( Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new()); let resp = HttpResponse::builder(StatusCode::OK) .header("X-Test", "ttt") .force_close().body(Body::Empty).unwrap(); - match logger.start(&mut req) { - Started::Done => (), + let mut req = match logger.start(req) { + Started::Done(req) => req, _ => panic!(), - } + }; match logger.finish(&mut req, &resp) { Finished::Done => (), _ => panic!(), diff --git a/src/middlewares/mod.rs b/src/middlewares/mod.rs index 04790bc85..742cef1ca 100644 --- a/src/middlewares/mod.rs +++ b/src/middlewares/mod.rs @@ -1,7 +1,10 @@ //! Middlewares +#![allow(unused_imports, dead_code)] + use std::rc::Rc; -use std::error::Error; use futures::{Async, Future, Poll}; + +use error::Error; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -11,12 +14,12 @@ pub use self::logger::Logger; /// Middleware start result pub enum Started { /// Execution completed - Done, + Done(HttpRequest), /// New http response got generated. If middleware generates response /// handler execution halts. - Response(HttpResponse), - /// Execution completed, but run future to completion. - Future(Box>), + Response(HttpRequest, HttpResponse), + /// Execution completed, runs future to completion. + Future(Box), Error=(HttpRequest, HttpResponse)>>), } /// Middleware execution result @@ -32,7 +35,7 @@ pub enum Finished { /// Execution completed Done, /// Execution completed, but run future to completion - Future(Box>>), + Future(Box>), } /// Middleware definition @@ -41,8 +44,8 @@ pub trait Middleware { /// Method is called when request is ready. It may return /// future, which should resolve before next middleware get called. - fn start(&self, req: &mut HttpRequest) -> Started { - Started::Done + fn start(&self, req: HttpRequest) -> Started { + Started::Done(req) } /// Method is called when handler returns response, @@ -56,192 +59,3 @@ pub trait Middleware { Finished::Done } } - -/// Middlewares executor -pub(crate) struct MiddlewaresExecutor { - state: ExecutorState, - fut: Option>>, - started: Option>>, - finished: Option>>>, - middlewares: Option>>>, -} - -enum ExecutorState { - None, - Starting(usize), - Started(usize), - Processing(usize, usize), - Finishing(usize), -} - -impl Default for MiddlewaresExecutor { - - fn default() -> MiddlewaresExecutor { - MiddlewaresExecutor { - fut: None, - started: None, - finished: None, - state: ExecutorState::None, - middlewares: None, - } - } -} - -impl MiddlewaresExecutor { - - pub fn start(&mut self, mw: Rc>>) { - self.state = ExecutorState::Starting(0); - self.middlewares = Some(mw); - } - - pub fn starting(&mut self, req: &mut HttpRequest) -> Poll, ()> { - if let Some(ref middlewares) = self.middlewares { - let state = &mut self.state; - if let ExecutorState::Starting(mut idx) = *state { - loop { - // poll latest fut - if let Some(ref mut fut) = self.started { - match fut.poll() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(())) => idx += 1, - Err(response) => { - *state = ExecutorState::Started(idx); - return Ok(Async::Ready(Some(response))) - } - } - } - self.started = None; - - if idx >= middlewares.len() { - *state = ExecutorState::Started(idx-1); - return Ok(Async::Ready(None)) - } else { - match middlewares[idx].start(req) { - Started::Done => idx += 1, - Started::Response(resp) => { - *state = ExecutorState::Started(idx); - return Ok(Async::Ready(Some(resp))) - }, - Started::Future(fut) => { - self.started = Some(fut); - }, - } - } - } - } - } - Ok(Async::Ready(None)) - } - - pub fn processing(&mut self, req: &mut HttpRequest) -> Poll, ()> { - if let Some(ref middlewares) = self.middlewares { - let state = &mut self.state; - match *state { - ExecutorState::Processing(mut idx, total) => { - loop { - // poll latest fut - let mut resp = match self.fut.as_mut().unwrap().poll() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(response)) | Err(response) => { - idx += 1; - response - } - }; - self.fut = None; - - loop { - if idx == 0 { - *state = ExecutorState::Finishing(total); - return Ok(Async::Ready(Some(resp))) - } else { - match middlewares[idx].response(req, resp) { - Response::Response(r) => { - idx -= 1; - resp = r - }, - Response::Future(fut) => { - self.fut = Some(fut); - break - }, - } - } - } - } - } - _ => Ok(Async::Ready(None)) - } - } else { - Ok(Async::Ready(None)) - } - } - - pub fn finishing(&mut self, req: &mut HttpRequest, resp: &HttpResponse) -> Poll<(), ()> { - if let Some(ref middlewares) = self.middlewares { - let state = &mut self.state; - if let ExecutorState::Finishing(mut idx) = *state { - loop { - // poll latest fut - if let Some(ref mut fut) = self.finished { - match fut.poll() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(())) => idx -= 1, - Err(err) => { - error!("Middleware finish error: {}", err); - } - } - } - self.finished = None; - - match middlewares[idx].finish(req, resp) { - Finished::Done => { - if idx == 0 { - return Ok(Async::Ready(())) - } else { - idx -= 1 - } - } - Finished::Future(fut) => { - self.finished = Some(fut); - }, - } - } - } - } - Ok(Async::Ready(())) - } - - pub fn response(&mut self, req: &mut HttpRequest, resp: HttpResponse) - -> Option - { - if let Some(ref middlewares) = self.middlewares { - let mut resp = resp; - let state = &mut self.state; - match *state { - ExecutorState::Started(mut idx) => { - let total = idx; - loop { - resp = match middlewares[idx].response(req, resp) { - Response::Response(r) => { - if idx == 0 { - *state = ExecutorState::Finishing(total); - return Some(r) - } else { - idx -= 1; - r - } - }, - Response::Future(fut) => { - *state = ExecutorState::Processing(idx, total); - self.fut = Some(fut); - return None - }, - }; - } - } - _ => Some(resp) - } - } else { - Some(resp) - } - } -} diff --git a/src/pipeline.rs b/src/pipeline.rs new file mode 100644 index 000000000..93c082a05 --- /dev/null +++ b/src/pipeline.rs @@ -0,0 +1,437 @@ +use std::mem; +use std::rc::Rc; + +use futures::{Async, Poll, Future}; + +use task::Task; +use error::Error; +use payload::Payload; +use middlewares::{Middleware, Finished, Started, Response}; +use h1writer::Writer; +use httprequest::HttpRequest; +use httpresponse::HttpResponse; + +type Handler = Fn(&mut HttpRequest, Payload) -> Task; +pub(crate) type PipelineHandler<'a> = &'a Fn(&mut HttpRequest, Payload) -> Task; + +pub struct Pipeline(PipelineState); + +enum PipelineState { + None, + Starting(Start), + Handle(Box), + Finishing(Box), + Error(Box<(Task, HttpRequest)>), + Task(Box<(Task, HttpRequest)>), +} + +impl Pipeline { + + pub fn new(mut req: HttpRequest, payload: Payload, + mw: Rc>>, handler: PipelineHandler) -> Pipeline { + if mw.is_empty() { + let task = (handler)(&mut req, payload); + Pipeline(PipelineState::Task(Box::new((task, req)))) + } else { + match Start::init(mw, req, handler, payload) { + StartResult::Ready(res) => { + Pipeline(PipelineState::Handle(res)) + }, + StartResult::NotReady(res) => { + Pipeline(PipelineState::Starting(res)) + }, + } + } + } + + pub fn error>(resp: R) -> Self { + Pipeline(PipelineState::Error(Box::new((Task::reply(resp), HttpRequest::for_error())))) + } + + pub(crate) fn disconnected(&mut self) { + match self.0 { + PipelineState::Starting(ref mut st) => + st.disconnected(), + PipelineState::Handle(ref mut st) => + st.task.disconnected(), + _ =>(), + } + } + + pub(crate) fn poll_io(&mut self, io: &mut T) -> Poll { + loop { + let state = mem::replace(&mut self.0, PipelineState::None); + match state { + PipelineState::Task(mut st) => { + let req:&mut HttpRequest = unsafe{mem::transmute(&mut st.1)}; + let res = st.0.poll_io(io, req); + self.0 = PipelineState::Task(st); + return res + } + PipelineState::Starting(mut st) => { + match st.poll() { + Async::NotReady => { + self.0 = PipelineState::Starting(st); + return Ok(Async::NotReady) + } + Async::Ready(h) => + self.0 = PipelineState::Handle(h), + } + } + PipelineState::Handle(mut st) => { + let res = st.poll_io(io); + if let Ok(Async::Ready(r)) = res { + if r { + self.0 = PipelineState::Finishing(st.finish()); + return Ok(Async::Ready(false)) + } else { + self.0 = PipelineState::Handle(st); + return res + } + } else { + self.0 = PipelineState::Handle(st); + return res + } + } + PipelineState::Error(mut st) => { + let req:&mut HttpRequest = unsafe{mem::transmute(&mut st.1)}; + let res = st.0.poll_io(io, req); + self.0 = PipelineState::Error(st); + return res + } + PipelineState::Finishing(_) | PipelineState::None => unreachable!(), + } + } + } + + pub(crate) fn poll(&mut self) -> Poll<(), Error> { + loop { + let state = mem::replace(&mut self.0, PipelineState::None); + match state { + PipelineState::Handle(mut st) => { + let res = st.poll(); + match res { + Ok(Async::NotReady) => { + self.0 = PipelineState::Handle(st); + return Ok(Async::NotReady) + } + Ok(Async::Ready(())) | Err(_) => { + self.0 = PipelineState::Finishing(st.finish()); + } + } + } + PipelineState::Finishing(mut st) => { + let res = st.poll(); + self.0 = PipelineState::Finishing(st); + return Ok(res) + } + PipelineState::Error(mut st) => { + let res = st.0.poll(); + self.0 = PipelineState::Error(st); + return res + } + PipelineState::Task(mut st) => { + let res = st.0.poll(); + self.0 = PipelineState::Task(st); + return res + } + _ => { + self.0 = state; + return Ok(Async::Ready(())) + } + } + } + } +} + +struct Handle { + idx: usize, + req: HttpRequest, + task: Task, + middlewares: Rc>>, +} + +impl Handle { + fn new(idx: usize, + req: HttpRequest, + task: Task, + mw: Rc>>) -> Handle + { + Handle { + idx: idx, req: req, task:task, middlewares: mw } + } + + fn poll_io(&mut self, io: &mut T) -> Poll { + self.task.poll_io(io, &mut self.req) + } + + fn poll(&mut self) -> Poll<(), Error> { + self.task.poll() + } + + fn finish(mut self) -> Box { + Box::new(Finish { + idx: self.idx, + req: self.req, + fut: None, + resp: self.task.response(), + middlewares: self.middlewares + }) + } +} + +/// Middlewares start executor +struct Finish { + idx: usize, + req: HttpRequest, + resp: HttpResponse, + fut: Option>>, + middlewares: Rc>>, +} + +impl Finish { + + pub fn poll(&mut self) -> Async<()> { + loop { + // poll latest fut + if let Some(ref mut fut) = self.fut { + match fut.poll() { + Ok(Async::NotReady) => return Async::NotReady, + Ok(Async::Ready(())) => self.idx -= 1, + Err(err) => { + error!("Middleware finish error: {}", err); + self.idx -= 1; + } + } + } + self.fut = None; + + match self.middlewares[self.idx].finish(&mut self.req, &self.resp) { + Finished::Done => { + if self.idx == 0 { + return Async::Ready(()) + } else { + self.idx -= 1 + } + } + Finished::Future(fut) => { + self.fut = Some(fut); + }, + } + } + } +} + +type Fut = Box), Error=(HttpRequest, HttpResponse)>>; + +/// Middlewares start executor +struct Start { + idx: usize, + hnd: *mut Handler, + disconnected: bool, + payload: Option, + fut: Option, + middlewares: Rc>>, +} + +enum StartResult { + Ready(Box), + NotReady(Start), +} + +impl Start { + + fn init(mw: Rc>>, + req: HttpRequest, handler: PipelineHandler, payload: Payload) -> StartResult + { + Start { + idx: 0, + fut: None, + disconnected: false, + hnd: handler as *const _ as *mut _, + payload: Some(payload), + middlewares: mw, + }.start(req) + } + + fn disconnected(&mut self) { + self.disconnected = true; + } + + fn prepare(&self, mut task: Task) -> Task { + if self.disconnected { + task.disconnected() + } + task.set_middlewares( + MiddlewaresResponse::new(self.idx, Rc::clone(&self.middlewares))); + task + } + + fn start(mut self, mut req: HttpRequest) -> StartResult { + loop { + if self.idx >= self.middlewares.len() { + let task = (unsafe{&*self.hnd})( + &mut req, self.payload.take().expect("Something is completlywrong")); + return StartResult::Ready( + Box::new(Handle::new(self.idx-1, req, self.prepare(task), self.middlewares))) + } else { + req = match self.middlewares[self.idx].start(req) { + Started::Done(req) => { + self.idx += 1; + req + } + Started::Response(req, resp) => { + return StartResult::Ready( + Box::new(Handle::new( + self.idx, req, self.prepare(Task::reply(resp)), self.middlewares))) + }, + Started::Future(mut fut) => { + match fut.poll() { + Ok(Async::NotReady) => { + self.fut = Some(fut); + return StartResult::NotReady(self) + } + Ok(Async::Ready((req, resp))) => { + self.idx += 1; + if let Some(resp) = resp { + return StartResult::Ready( + Box::new(Handle::new( + self.idx, req, + self.prepare(Task::reply(resp)), self.middlewares))) + } + req + } + Err((req, resp)) => { + return StartResult::Ready(Box::new(Handle::new( + self.idx, req, + self.prepare(Task::reply(resp)), self.middlewares))) + } + } + }, + } + } + } + } + + fn poll(&mut self) -> Async> { + 'outer: loop { + match self.fut.as_mut().unwrap().poll() { + Ok(Async::NotReady) => return Async::NotReady, + Ok(Async::Ready((mut req, resp))) => { + self.idx += 1; + if let Some(resp) = resp { + return Async::Ready(Box::new(Handle::new( + self.idx, req, + self.prepare(Task::reply(resp)), Rc::clone(&self.middlewares)))) + } + if self.idx >= self.middlewares.len() { + let task = (unsafe{&*self.hnd})( + &mut req, self.payload.take().expect("Something is completlywrong")); + return Async::Ready(Box::new(Handle::new( + self.idx-1, req, + self.prepare(task), Rc::clone(&self.middlewares)))) + } else { + loop { + req = match self.middlewares[self.idx].start(req) { + Started::Done(req) => { + self.idx += 1; + req + } + Started::Response(req, resp) => { + return Async::Ready(Box::new(Handle::new( + self.idx, req, + self.prepare(Task::reply(resp)), + Rc::clone(&self.middlewares)))) + }, + Started::Future(mut fut) => { + self.fut = Some(fut); + continue 'outer + }, + } + } + } + } + Err((req, resp)) => { + return Async::Ready(Box::new(Handle::new( + self.idx, req, + self.prepare(Task::reply(resp)), + Rc::clone(&self.middlewares)))) + } + } + } + } +} + + +/// Middlewares response executor +pub(crate) struct MiddlewaresResponse { + idx: usize, + fut: Option>>, + middlewares: Rc>>, +} + +impl MiddlewaresResponse { + + fn new(idx: usize, mw: Rc>>) -> MiddlewaresResponse { + let idx = if idx == 0 { 0 } else { idx - 1 }; + MiddlewaresResponse { + idx: idx, + fut: None, + middlewares: mw } + } + + pub fn response(&mut self, req: &mut HttpRequest, mut resp: HttpResponse) + -> Option + { + loop { + resp = match self.middlewares[self.idx].response(req, resp) { + Response::Response(r) => { + if self.idx == 0 { + return Some(r) + } else { + self.idx -= 1; + r + } + }, + Response::Future(fut) => { + self.fut = Some(fut); + return None + }, + }; + } + } + + pub fn poll(&mut self, req: &mut HttpRequest) -> Poll, ()> { + if self.fut.is_none() { + return Ok(Async::Ready(None)) + } + + loop { + // poll latest fut + let mut resp = match self.fut.as_mut().unwrap().poll() { + Ok(Async::NotReady) => return Ok(Async::NotReady), + Ok(Async::Ready(resp)) | Err(resp) => { + self.idx += 1; + resp + } + }; + + loop { + if self.idx == 0 { + return Ok(Async::Ready(Some(resp))) + } else { + match self.middlewares[self.idx].response(req, resp) { + Response::Response(r) => { + self.idx -= 1; + resp = r + }, + Response::Future(fut) => { + self.fut = Some(fut); + break + }, + } + } + } + } + } +} diff --git a/src/task.rs b/src/task.rs index 42387d580..40b2442dd 100644 --- a/src/task.rs +++ b/src/task.rs @@ -7,9 +7,9 @@ use futures::{Async, Future, Poll, Stream}; use futures::task::{Task as FutureTask, current as current_task}; use h1writer::{Writer, WriterState}; -use error::Error; +use error::{Error, UnexpectedTaskFrame}; use route::Frame; -use middlewares::{Middleware, MiddlewaresExecutor}; +use pipeline::MiddlewaresResponse; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -111,7 +111,7 @@ pub struct Task { drain: Vec>>, prepared: Option, disconnected: bool, - middlewares: MiddlewaresExecutor, + middlewares: Option, } impl Task { @@ -128,7 +128,7 @@ impl Task { stream: TaskStream::None, prepared: None, disconnected: false, - middlewares: MiddlewaresExecutor::default() } + middlewares: None } } pub(crate) fn with_context(ctx: C) -> Self { @@ -139,7 +139,7 @@ impl Task { drain: Vec::new(), prepared: None, disconnected: false, - middlewares: MiddlewaresExecutor::default() } + middlewares: None } } pub(crate) fn with_stream(stream: S) -> Self @@ -152,11 +152,15 @@ impl Task { drain: Vec::new(), prepared: None, disconnected: false, - middlewares: MiddlewaresExecutor::default() } + middlewares: None } } - pub(crate) fn set_middlewares(&mut self, middlewares: Rc>>) { - self.middlewares.start(middlewares) + pub(crate) fn response(&mut self) -> HttpResponse { + self.prepared.take().unwrap() + } + + pub(crate) fn set_middlewares(&mut self, middlewares: MiddlewaresResponse) { + self.middlewares = Some(middlewares) } pub(crate) fn disconnected(&mut self) { @@ -171,16 +175,6 @@ impl Task { { trace!("POLL-IO frames:{:?}", self.frames.len()); - // start middlewares - match self.middlewares.starting(req) { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(None)) | Err(_) => (), - Ok(Async::Ready(Some(response))) => { - self.frames.clear(); - self.frames.push_front(Frame::Message(response)); - }, - } - // response is completed if self.frames.is_empty() && self.iostate.is_done() { return Ok(Async::Ready(self.state.is_done())); @@ -197,21 +191,28 @@ impl Task { } // process middlewares response - match self.middlewares.processing(req) { - Err(_) => return Err(()), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(None)) => (), - Ok(Async::Ready(Some(mut response))) => { - let result = io.start(req, &mut response); - self.prepared = Some(response); - match result { - Ok(WriterState::Pause) => { - self.state.pause(); - } - Ok(WriterState::Done) => self.state.resume(), - Err(_) => return Err(()) + if let Some(mut middlewares) = self.middlewares.take() { + match middlewares.poll(req) { + Err(_) => return Err(()), + Ok(Async::NotReady) => { + self.middlewares = Some(middlewares); + return Ok(Async::NotReady); } - }, + Ok(Async::Ready(None)) => { + self.middlewares = Some(middlewares); + } + Ok(Async::Ready(Some(mut response))) => { + let result = io.start(req, &mut response); + self.prepared = Some(response); + match result { + Ok(WriterState::Pause) => { + self.state.pause(); + } + Ok(WriterState::Done) => self.state.resume(), + Err(_) => return Err(()) + } + }, + } } // if task is paused, write buffer probably is full @@ -220,15 +221,22 @@ impl Task { while let Some(frame) = self.frames.pop_front() { trace!("IO Frame: {:?}", frame); let res = match frame { - Frame::Message(resp) => { + Frame::Message(mut resp) => { // run middlewares - if let Some(mut resp) = self.middlewares.response(req, resp) { + if let Some(mut middlewares) = self.middlewares.take() { + if let Some(mut resp) = middlewares.response(req, resp) { + let result = io.start(req, &mut resp); + self.prepared = Some(resp); + result + } else { + // middlewares need to run some futures + self.middlewares = Some(middlewares); + return self.poll_io(io, req) + } + } else { let result = io.start(req, &mut resp); self.prepared = Some(resp); result - } else { - // middlewares need to run some futures - return self.poll_io(io, req) } } Frame::Payload(Some(chunk)) => { @@ -278,12 +286,8 @@ impl Task { // response is completed if self.iostate.is_done() { - // finish middlewares if let Some(ref mut resp) = self.prepared { resp.set_response_size(io.written()); - if let Ok(Async::NotReady) = self.middlewares.finishing(req, resp) { - return Ok(Async::NotReady) - } } Ok(Async::Ready(self.state.is_done())) } else { @@ -291,8 +295,9 @@ impl Task { } } - fn poll_stream(&mut self, stream: &mut S) -> Poll<(), ()> - where S: Stream { + fn poll_stream(&mut self, stream: &mut S) -> Poll<(), Error> + where S: Stream + { loop { match stream.poll() { Ok(Async::Ready(Some(frame))) => { @@ -300,7 +305,7 @@ impl Task { Frame::Message(ref msg) => { if self.iostate != TaskIOState::ReadingMessage { error!("Unexpected frame {:?}", frame); - return Err(()) + return Err(UnexpectedTaskFrame.into()) } let upgrade = msg.upgrade(); if upgrade || msg.body().is_streaming() { @@ -314,7 +319,7 @@ impl Task { self.iostate = TaskIOState::Done; } else if self.iostate != TaskIOState::ReadingPayload { error!("Unexpected frame {:?}", self.iostate); - return Err(()) + return Err(UnexpectedTaskFrame.into()) } }, _ => (), @@ -325,8 +330,8 @@ impl Task { return Ok(Async::Ready(())), Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(_) => - return Err(()), + Err(err) => + return Err(err), } } } @@ -334,7 +339,7 @@ impl Task { impl Future for Task { type Item = (); - type Error = (); + type Error = Error; fn poll(&mut self) -> Poll { let mut s = mem::replace(&mut self.stream, TaskStream::None); diff --git a/tests/test_server.rs b/tests/test_server.rs index eec293033..d3e9d035a 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -61,9 +61,9 @@ struct MiddlewareTest { } impl middlewares::Middleware for MiddlewareTest { - fn start(&self, _: &mut HttpRequest) -> middlewares::Started { + fn start(&self, req: HttpRequest) -> middlewares::Started { self.start.store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed); - middlewares::Started::Done + middlewares::Started::Done(req) } fn response(&self, _: &mut HttpRequest, resp: HttpResponse) -> middlewares::Response {