From 9c9eb6203155f62b181ff8ad1d1e3c4a529ddec4 Mon Sep 17 00:00:00 2001 From: Josh Leeb-du Toit Date: Sat, 2 Jun 2018 13:37:29 +1000 Subject: [PATCH] Update Middleware trait to use `&mut self` --- CHANGES.md | 2 ++ src/application.rs | 6 ++--- src/middleware/cors.rs | 18 ++++++------- src/middleware/csrf.rs | 12 ++++----- src/middleware/defaultheaders.rs | 4 +-- src/middleware/errhandlers.rs | 4 +-- src/middleware/identity.rs | 4 +-- src/middleware/logger.rs | 6 ++--- src/middleware/mod.rs | 6 ++--- src/middleware/session.rs | 4 +-- src/pipeline.rs | 33 ++++++++++++++--------- src/resource.rs | 10 ++++--- src/route.rs | 35 ++++++++++++++----------- src/scope.rs | 45 ++++++++++++++++++-------------- tests/test_middleware.rs | 12 ++++----- 15 files changed, 111 insertions(+), 90 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 78db1afec..5b9e16bd7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,8 @@ * Min rustc version is 1.26 +* Use `&mut self` instead of `&self` for Middleware trait + ### Removed * Remove `Route::with2()` and `Route::with3()` use tuple of extractors instead. diff --git a/src/application.rs b/src/application.rs index fffa5d839..90c70bd18 100644 --- a/src/application.rs +++ b/src/application.rs @@ -1,4 +1,4 @@ -use std::cell::UnsafeCell; +use std::cell::{RefCell, UnsafeCell}; use std::collections::HashMap; use std::rc::Rc; @@ -22,7 +22,7 @@ pub struct HttpApplication { prefix_len: usize, router: Router, inner: Rc>>, - middlewares: Rc>>>, + middlewares: Rc>>>>, } pub(crate) struct Inner { @@ -612,7 +612,7 @@ where HttpApplication { state: Rc::new(parts.state), router: router.clone(), - middlewares: Rc::new(parts.middlewares), + middlewares: Rc::new(RefCell::new(parts.middlewares)), prefix, prefix_len, inner, diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 93c8aeb56..cc705c554 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -355,7 +355,7 @@ impl Cors { } impl Middleware for Cors { - fn start(&self, req: &mut HttpRequest) -> Result { + fn start(&mut self, req: &mut HttpRequest) -> Result { if self.inner.preflight && Method::OPTIONS == *req.method() { self.validate_origin(req)?; self.validate_allowed_method(req)?; @@ -430,7 +430,7 @@ impl Middleware for Cors { } fn response( - &self, req: &mut HttpRequest, mut resp: HttpResponse, + &mut self, req: &mut HttpRequest, mut resp: HttpResponse, ) -> Result { match self.inner.origins { AllOrSome::All => { @@ -940,7 +940,7 @@ mod tests { #[test] fn validate_origin_allows_all_origins() { - let cors = Cors::default(); + let mut cors = Cors::default(); let mut req = TestRequest::with_header("Origin", "https://www.example.com").finish(); @@ -1009,7 +1009,7 @@ mod tests { #[test] #[should_panic(expected = "MissingOrigin")] fn test_validate_missing_origin() { - let cors = Cors::build() + let mut cors = Cors::build() .allowed_origin("https://www.example.com") .finish(); @@ -1020,7 +1020,7 @@ mod tests { #[test] #[should_panic(expected = "OriginNotAllowed")] fn test_validate_not_allowed_origin() { - let cors = Cors::build() + let mut cors = Cors::build() .allowed_origin("https://www.example.com") .finish(); @@ -1032,7 +1032,7 @@ mod tests { #[test] fn test_validate_origin() { - let cors = Cors::build() + let mut cors = Cors::build() .allowed_origin("https://www.example.com") .finish(); @@ -1045,7 +1045,7 @@ mod tests { #[test] fn test_no_origin_response() { - let cors = Cors::build().finish(); + let mut cors = Cors::build().finish(); let mut req = TestRequest::default().method(Method::GET).finish(); let resp: HttpResponse = HttpResponse::Ok().into(); @@ -1071,7 +1071,7 @@ mod tests { #[test] fn test_response() { - let cors = Cors::build() + let mut cors = Cors::build() .send_wildcard() .disable_preflight() .max_age(3600) @@ -1106,7 +1106,7 @@ mod tests { resp.headers().get(header::VARY).unwrap().as_bytes() ); - let cors = Cors::build() + let mut cors = Cors::build() .disable_vary_header() .allowed_origin("https://www.example.com") .finish(); diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index 670ec1c1f..8c2b06e72 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -209,7 +209,7 @@ impl CsrfFilter { } impl Middleware for CsrfFilter { - fn start(&self, req: &mut HttpRequest) -> Result { + fn start(&mut self, req: &mut HttpRequest) -> Result { self.validate(req)?; Ok(Started::Done) } @@ -223,7 +223,7 @@ mod tests { #[test] fn test_safe() { - let csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); + let mut csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); let mut req = TestRequest::with_header("Origin", "https://www.w3.org") .method(Method::HEAD) @@ -234,7 +234,7 @@ mod tests { #[test] fn test_csrf() { - let csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); + let mut csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); let mut req = TestRequest::with_header("Origin", "https://www.w3.org") .method(Method::POST) @@ -245,7 +245,7 @@ mod tests { #[test] fn test_referer() { - let csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); + let mut csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); let mut req = TestRequest::with_header( "Referer", @@ -258,9 +258,9 @@ mod tests { #[test] fn test_upgrade() { - let strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); + let mut strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); - let lax_csrf = CsrfFilter::new() + let mut lax_csrf = CsrfFilter::new() .allowed_origin("https://www.example.com") .allow_upgrade(); diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index dca8dfbe1..acccc552f 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -75,7 +75,7 @@ impl DefaultHeaders { impl Middleware for DefaultHeaders { fn response( - &self, _: &mut HttpRequest, mut resp: HttpResponse, + &mut self, _: &mut HttpRequest, mut resp: HttpResponse, ) -> Result { for (key, value) in self.headers.iter() { if !resp.headers().contains_key(key) { @@ -100,7 +100,7 @@ mod tests { #[test] fn test_default_headers() { - let mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001"); + let mut mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001"); let mut req = HttpRequest::default(); diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index 42f75a3de..7e56b368e 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -71,7 +71,7 @@ impl ErrorHandlers { impl Middleware for ErrorHandlers { fn response( - &self, req: &mut HttpRequest, resp: HttpResponse, + &mut self, req: &mut HttpRequest, resp: HttpResponse, ) -> Result { if let Some(handler) = self.handlers.get(&resp.status()) { handler(req, resp) @@ -95,7 +95,7 @@ mod tests { #[test] fn test_handler() { - let mw = + let mut mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mut req = HttpRequest::default(); diff --git a/src/middleware/identity.rs b/src/middleware/identity.rs index c8505d686..76d1894e4 100644 --- a/src/middleware/identity.rs +++ b/src/middleware/identity.rs @@ -175,7 +175,7 @@ unsafe impl Send for IdentityBox {} unsafe impl Sync for IdentityBox {} impl> Middleware for IdentityService { - fn start(&self, req: &mut HttpRequest) -> Result { + fn start(&mut self, req: &mut HttpRequest) -> Result { let mut req = req.clone(); let fut = self @@ -192,7 +192,7 @@ impl> Middleware for IdentityService { } fn response( - &self, req: &mut HttpRequest, resp: HttpResponse, + &mut self, req: &mut HttpRequest, resp: HttpResponse, ) -> Result { if let Some(mut id) = req.extensions_mut().remove::() { id.0.write(resp) diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index a731d6955..ab9ae4a02 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -124,14 +124,14 @@ impl Logger { } impl Middleware for Logger { - fn start(&self, req: &mut HttpRequest) -> Result { + fn start(&mut self, req: &mut HttpRequest) -> Result { if !self.exclude.contains(req.path()) { req.extensions_mut().insert(StartTime(time::now())); } Ok(Started::Done) } - fn finish(&self, req: &mut HttpRequest, resp: &HttpResponse) -> Finished { + fn finish(&mut self, req: &mut HttpRequest, resp: &HttpResponse) -> Finished { self.log(req, resp); Finished::Done } @@ -322,7 +322,7 @@ mod tests { #[test] fn test_logger() { - let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); + let mut logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); let mut headers = HeaderMap::new(); headers.insert( diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 2551ded15..7fd339327 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -51,20 +51,20 @@ pub enum Finished { pub trait Middleware: 'static { /// Method is called when request is ready. It may return /// future, which should resolve before next middleware get called. - fn start(&self, req: &mut HttpRequest) -> Result { + fn start(&mut self, req: &mut HttpRequest) -> Result { Ok(Started::Done) } /// Method is called when handler returns response, /// but before sending http message to peer. fn response( - &self, req: &mut HttpRequest, resp: HttpResponse, + &mut self, req: &mut HttpRequest, resp: HttpResponse, ) -> Result { Ok(Response::Done(resp)) } /// Method is called after body stream get sent to peer. - fn finish(&self, req: &mut HttpRequest, resp: &HttpResponse) -> Finished { + fn finish(&mut self, req: &mut HttpRequest, resp: &HttpResponse) -> Finished { Finished::Done } } diff --git a/src/middleware/session.rs b/src/middleware/session.rs index bfc40dd5d..5c9867677 100644 --- a/src/middleware/session.rs +++ b/src/middleware/session.rs @@ -250,7 +250,7 @@ impl> SessionStorage { } impl> Middleware for SessionStorage { - fn start(&self, req: &mut HttpRequest) -> Result { + fn start(&mut self, req: &mut HttpRequest) -> Result { let mut req = req.clone(); let fut = self.0.from_request(&mut req).then(move |res| match res { @@ -265,7 +265,7 @@ impl> Middleware for SessionStorage { } fn response( - &self, req: &mut HttpRequest, resp: HttpResponse, + &mut self, req: &mut HttpRequest, resp: HttpResponse, ) -> Result { if let Some(s_box) = req.extensions_mut().remove::>() { s_box.0.borrow_mut().write(resp) diff --git a/src/pipeline.rs b/src/pipeline.rs index 289c5fcbb..f9ec97137 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -1,4 +1,4 @@ -use std::cell::UnsafeCell; +use std::cell::{RefCell, UnsafeCell}; use std::marker::PhantomData; use std::rc::Rc; use std::{io, mem}; @@ -71,7 +71,7 @@ impl> PipelineState { struct PipelineInfo { req: UnsafeCell>, count: u16, - mws: Rc>>>, + mws: Rc>>>>, context: Option>, error: Option, disconnected: Option, @@ -83,7 +83,7 @@ impl PipelineInfo { PipelineInfo { req: UnsafeCell::new(req), count: 0, - mws: Rc::new(Vec::new()), + mws: Rc::new(RefCell::new(Vec::new())), error: None, context: None, disconnected: None, @@ -120,7 +120,7 @@ impl PipelineInfo { impl> Pipeline { pub fn new( - req: HttpRequest, mws: Rc>>>, + req: HttpRequest, mws: Rc>>>>, handler: Rc>, htype: HandlerType, ) -> Pipeline { let mut info = PipelineInfo { @@ -243,13 +243,14 @@ impl> StartMiddlewares { ) -> PipelineState { // execute middlewares, we need this stage because middlewares could be // non-async and we can move to next state immediately - let len = info.mws.len() as u16; + let len = info.mws.borrow().len() as u16; loop { if info.count == len { let reply = unsafe { &mut *hnd.get() }.handle(info.req().clone(), htype); return WaitingResponse::init(info, reply); } else { - match info.mws[info.count as usize].start(info.req_mut()) { + let state = info.mws.borrow_mut()[info.count as usize].start(info.req_mut()); + match state { Ok(Started::Done) => info.count += 1, Ok(Started::Response(resp)) => { return RunMiddlewares::init(info, resp) @@ -269,7 +270,7 @@ impl> StartMiddlewares { } fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - let len = info.mws.len() as u16; + let len = info.mws.borrow().len() as u16; 'outer: loop { match self.fut.as_mut().unwrap().poll() { Ok(Async::NotReady) => return None, @@ -284,7 +285,9 @@ impl> StartMiddlewares { .handle(info.req().clone(), self.htype); return Some(WaitingResponse::init(info, reply)); } else { - match info.mws[info.count as usize].start(info.req_mut()) { + let state = info.mws.borrow_mut()[info.count as usize] + .start(info.req_mut()); + match state { Ok(Started::Done) => info.count += 1, Ok(Started::Response(resp)) => { return Some(RunMiddlewares::init(info, resp)); @@ -353,10 +356,11 @@ impl RunMiddlewares { return ProcessResponse::init(resp); } let mut curr = 0; - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { - resp = match info.mws[curr].response(info.req_mut(), resp) { + let state = info.mws.borrow_mut()[curr].response(info.req_mut(), resp); + resp = match state { Err(err) => { info.count = (curr + 1) as u16; return ProcessResponse::init(err.into()); @@ -382,7 +386,7 @@ impl RunMiddlewares { } fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { // poll latest fut @@ -399,7 +403,8 @@ impl RunMiddlewares { if self.curr == len { return Some(ProcessResponse::init(resp)); } else { - match info.mws[self.curr].response(info.req_mut(), resp) { + let state = info.mws.borrow_mut()[self.curr].response(info.req_mut(), resp); + match state { Err(err) => return Some(ProcessResponse::init(err.into())), Ok(Response::Done(r)) => { self.curr += 1; @@ -723,7 +728,9 @@ impl FinishingMiddlewares { } info.count -= 1; - match info.mws[info.count as usize].finish(info.req_mut(), &self.resp) { + let state = info.mws.borrow_mut()[info.count as usize] + .finish(info.req_mut(), &self.resp); + match state { Finished::Done => { if info.count == 0 { return Some(Completed::init(info)); diff --git a/src/resource.rs b/src/resource.rs index af3a9e673..2b9c8538b 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; use std::rc::Rc; +use std::cell::RefCell; use futures::Future; use http::{Method, StatusCode}; @@ -37,7 +38,7 @@ pub struct ResourceHandler { name: String, state: PhantomData, routes: SmallVec<[Route; 3]>, - middlewares: Rc>>>, + middlewares: Rc>>>>, } impl Default for ResourceHandler { @@ -46,7 +47,7 @@ impl Default for ResourceHandler { name: String::new(), state: PhantomData, routes: SmallVec::new(), - middlewares: Rc::new(Vec::new()), + middlewares: Rc::new(RefCell::new(Vec::new())), } } } @@ -57,7 +58,7 @@ impl ResourceHandler { name: String::new(), state: PhantomData, routes: SmallVec::new(), - middlewares: Rc::new(Vec::new()), + middlewares: Rc::new(RefCell::new(Vec::new())), } } @@ -276,6 +277,7 @@ impl ResourceHandler { pub fn middleware>(&mut self, mw: M) { Rc::get_mut(&mut self.middlewares) .unwrap() + .borrow_mut() .push(Box::new(mw)); } @@ -284,7 +286,7 @@ impl ResourceHandler { ) -> AsyncResult { for route in &mut self.routes { if route.check(&mut req) { - return if self.middlewares.is_empty() { + return if self.middlewares.borrow().is_empty() { route.handle(req) } else { route.compose(req, Rc::clone(&self.middlewares)) diff --git a/src/route.rs b/src/route.rs index 3039d0896..27aae8df6 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,4 +1,4 @@ -use std::cell::UnsafeCell; +use std::cell::{RefCell, UnsafeCell}; use std::marker::PhantomData; use std::rc::Rc; @@ -55,7 +55,7 @@ impl Route { #[inline] pub(crate) fn compose( - &mut self, req: HttpRequest, mws: Rc>>>, + &mut self, req: HttpRequest, mws: Rc>>>>, ) -> AsyncResult { AsyncResult::async(Box::new(Compose::new(req, mws, self.handler.clone()))) } @@ -263,7 +263,7 @@ struct Compose { struct ComposeInfo { count: usize, req: HttpRequest, - mws: Rc>>>, + mws: Rc>>>>, handler: InnerHandler, } @@ -289,7 +289,7 @@ impl ComposeState { impl Compose { fn new( - req: HttpRequest, mws: Rc>>>, handler: InnerHandler, + req: HttpRequest, mws: Rc>>>>, handler: InnerHandler, ) -> Self { let mut info = ComposeInfo { count: 0, @@ -332,13 +332,14 @@ type Fut = Box, Error = Error>>; impl StartMiddlewares { fn init(info: &mut ComposeInfo) -> ComposeState { - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { if info.count == len { let reply = info.handler.handle(info.req.clone()); return WaitingResponse::init(info, reply); } else { - match info.mws[info.count].start(&mut info.req) { + let state = info.mws.borrow_mut()[info.count].start(&mut info.req); + match state { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { return RunMiddlewares::init(info, resp) @@ -356,7 +357,7 @@ impl StartMiddlewares { } fn poll(&mut self, info: &mut ComposeInfo) -> Option> { - let len = info.mws.len(); + let len = info.mws.borrow().len(); 'outer: loop { match self.fut.as_mut().unwrap().poll() { Ok(Async::NotReady) => return None, @@ -370,7 +371,8 @@ impl StartMiddlewares { let reply = info.handler.handle(info.req.clone()); return Some(WaitingResponse::init(info, reply)); } else { - match info.mws[info.count].start(&mut info.req) { + let state = info.mws.borrow_mut()[info.count].start(&mut info.req); + match state { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { return Some(RunMiddlewares::init(info, resp)); @@ -435,10 +437,11 @@ struct RunMiddlewares { impl RunMiddlewares { fn init(info: &mut ComposeInfo, mut resp: HttpResponse) -> ComposeState { let mut curr = 0; - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { - resp = match info.mws[curr].response(&mut info.req, resp) { + let state = info.mws.borrow_mut()[curr].response(&mut info.req, resp); + resp = match state { Err(err) => { info.count = curr + 1; return FinishingMiddlewares::init(info, err.into()); @@ -463,7 +466,7 @@ impl RunMiddlewares { } fn poll(&mut self, info: &mut ComposeInfo) -> Option> { - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { // poll latest fut @@ -480,7 +483,8 @@ impl RunMiddlewares { if self.curr == len { return Some(FinishingMiddlewares::init(info, resp)); } else { - match info.mws[self.curr].response(&mut info.req, resp) { + let state = info.mws.borrow_mut()[self.curr].response(&mut info.req, resp); + match state { Err(err) => { return Some(FinishingMiddlewares::init(info, err.into())) } @@ -548,9 +552,10 @@ impl FinishingMiddlewares { } info.count -= 1; - match info.mws[info.count as usize] - .finish(&mut info.req, self.resp.as_ref().unwrap()) - { + + let state = info.mws.borrow_mut()[info.count as usize] + .finish(&mut info.req, self.resp.as_ref().unwrap()); + match state { MiddlewareFinished::Done => { if info.count == 0 { return Some(Response::init(self.resp.take().unwrap())); diff --git a/src/scope.rs b/src/scope.rs index 9452191f5..8f8d69952 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,4 +1,4 @@ -use std::cell::UnsafeCell; +use std::cell::{RefCell, UnsafeCell}; use std::marker::PhantomData; use std::mem; use std::rc::Rc; @@ -56,7 +56,7 @@ type NestedInfo = (Resource, Route, Vec>>); pub struct Scope { filters: Vec>>, nested: Vec>, - middlewares: Rc>>>, + middlewares: Rc>>>>, default: Rc>>, resources: ScopeResources, } @@ -68,7 +68,7 @@ impl Scope { filters: Vec::new(), nested: Vec::new(), resources: Rc::new(Vec::new()), - middlewares: Rc::new(Vec::new()), + middlewares: Rc::new(RefCell::new(Vec::new())), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), } } @@ -132,7 +132,7 @@ impl Scope { filters: Vec::new(), nested: Vec::new(), resources: Rc::new(Vec::new()), - middlewares: Rc::new(Vec::new()), + middlewares: Rc::new(RefCell::new(Vec::new())), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), }; let mut scope = f(scope); @@ -175,7 +175,7 @@ impl Scope { filters: Vec::new(), nested: Vec::new(), resources: Rc::new(Vec::new()), - middlewares: Rc::new(Vec::new()), + middlewares: Rc::new(RefCell::new(Vec::new())), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), }; let mut scope = f(scope); @@ -312,6 +312,7 @@ impl Scope { pub fn middleware>(mut self, mw: M) -> Scope { Rc::get_mut(&mut self.middlewares) .expect("Can not use after configuration") + .borrow_mut() .push(Box::new(mw)); self } @@ -327,7 +328,7 @@ impl RouteHandler for Scope { let default = unsafe { &mut *self.default.as_ref().get() }; req.match_info_mut().remove("tail"); - if self.middlewares.is_empty() { + if self.middlewares.borrow().is_empty() { let resource = unsafe { &mut *resource.get() }; return resource.handle(req, Some(default)); } else { @@ -369,7 +370,7 @@ impl RouteHandler for Scope { // default handler let default = unsafe { &mut *self.default.as_ref().get() }; - if self.middlewares.is_empty() { + if self.middlewares.borrow().is_empty() { default.handle(req, None) } else { AsyncResult::async(Box::new(Compose::new( @@ -419,7 +420,7 @@ struct Compose { struct ComposeInfo { count: usize, req: HttpRequest, - mws: Rc>>>, + mws: Rc>>>>, default: Option>>>, resource: Rc>>, } @@ -446,7 +447,7 @@ impl ComposeState { impl Compose { fn new( - req: HttpRequest, mws: Rc>>>, + req: HttpRequest, mws: Rc>>>>, resource: Rc>>, default: Option>>>, ) -> Self { @@ -492,7 +493,7 @@ type Fut = Box, Error = Error>>; impl StartMiddlewares { fn init(info: &mut ComposeInfo) -> ComposeState { - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { if info.count == len { let resource = unsafe { &mut *info.resource.get() }; @@ -504,7 +505,8 @@ impl StartMiddlewares { }; return WaitingResponse::init(info, reply); } else { - match info.mws[info.count].start(&mut info.req) { + let state = info.mws.borrow_mut()[info.count].start(&mut info.req); + match state { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { return RunMiddlewares::init(info, resp) @@ -522,7 +524,7 @@ impl StartMiddlewares { } fn poll(&mut self, info: &mut ComposeInfo) -> Option> { - let len = info.mws.len(); + let len = info.mws.borrow().len(); 'outer: loop { match self.fut.as_mut().unwrap().poll() { Ok(Async::NotReady) => return None, @@ -542,7 +544,8 @@ impl StartMiddlewares { }; return Some(WaitingResponse::init(info, reply)); } else { - match info.mws[info.count].start(&mut info.req) { + let state = info.mws.borrow_mut()[info.count].start(&mut info.req); + match state { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { return Some(RunMiddlewares::init(info, resp)); @@ -602,10 +605,11 @@ struct RunMiddlewares { impl RunMiddlewares { fn init(info: &mut ComposeInfo, mut resp: HttpResponse) -> ComposeState { let mut curr = 0; - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { - resp = match info.mws[curr].response(&mut info.req, resp) { + let state = info.mws.borrow_mut()[curr].response(&mut info.req, resp); + resp = match state { Err(err) => { info.count = curr + 1; return FinishingMiddlewares::init(info, err.into()); @@ -630,7 +634,7 @@ impl RunMiddlewares { } fn poll(&mut self, info: &mut ComposeInfo) -> Option> { - let len = info.mws.len(); + let len = info.mws.borrow().len(); loop { // poll latest fut @@ -647,7 +651,8 @@ impl RunMiddlewares { if self.curr == len { return Some(FinishingMiddlewares::init(info, resp)); } else { - match info.mws[self.curr].response(&mut info.req, resp) { + let state = info.mws.borrow_mut()[self.curr].response(&mut info.req, resp); + match state { Err(err) => { return Some(FinishingMiddlewares::init(info, err.into())) } @@ -715,9 +720,9 @@ impl FinishingMiddlewares { } info.count -= 1; - match info.mws[info.count as usize] - .finish(&mut info.req, self.resp.as_ref().unwrap()) - { + let state = info.mws.borrow_mut()[info.count as usize] + .finish(&mut info.req, self.resp.as_ref().unwrap()); + match state { MiddlewareFinished::Done => { if info.count == 0 { return Some(Response::init(self.resp.take().unwrap())); diff --git a/tests/test_middleware.rs b/tests/test_middleware.rs index 4996542e2..3f802b3ae 100644 --- a/tests/test_middleware.rs +++ b/tests/test_middleware.rs @@ -19,21 +19,21 @@ struct MiddlewareTest { } impl middleware::Middleware for MiddlewareTest { - fn start(&self, _: &mut HttpRequest) -> Result { + fn start(&mut self, _: &mut HttpRequest) -> Result { self.start .store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed); Ok(middleware::Started::Done) } fn response( - &self, _: &mut HttpRequest, resp: HttpResponse, + &mut self, _: &mut HttpRequest, resp: HttpResponse, ) -> Result { self.response .store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed); Ok(middleware::Response::Done(resp)) } - fn finish(&self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { + fn finish(&mut self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { self.finish .store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed); middleware::Finished::Done @@ -431,7 +431,7 @@ struct MiddlewareAsyncTest { } impl middleware::Middleware for MiddlewareAsyncTest { - fn start(&self, _: &mut HttpRequest) -> Result { + fn start(&mut self, _: &mut HttpRequest) -> Result { let to = Delay::new(Instant::now() + Duration::from_millis(10)); let start = Arc::clone(&self.start); @@ -444,7 +444,7 @@ impl middleware::Middleware for MiddlewareAsyncTest { } fn response( - &self, _: &mut HttpRequest, resp: HttpResponse, + &mut self, _: &mut HttpRequest, resp: HttpResponse, ) -> Result { let to = Delay::new(Instant::now() + Duration::from_millis(10)); @@ -457,7 +457,7 @@ impl middleware::Middleware for MiddlewareAsyncTest { ))) } - fn finish(&self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { + fn finish(&mut self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { let to = Delay::new(Instant::now() + Duration::from_millis(10)); let finish = Arc::clone(&self.finish);