From 984791187a1b8149124995a837f304a18de5b440 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 4 Jun 2018 13:42:47 -0700 Subject: [PATCH] Middleware::response is not invoked if error result was returned by another Middleware::start #255 --- src/middleware/errhandlers.rs | 27 +++++ src/pipeline.rs | 12 +- src/route.rs | 18 +-- src/scope.rs | 14 ++- tests/test_middleware.rs | 217 +++++++++++++++++++++++++++++++++- 5 files changed, 267 insertions(+), 21 deletions(-) diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index 7e56b368e..e1b484182 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -84,8 +84,12 @@ impl Middleware for ErrorHandlers { #[cfg(test)] mod tests { use super::*; + use error::{Error, ErrorInternalServerError}; use http::header::CONTENT_TYPE; use http::StatusCode; + use httpmessage::HttpMessage; + use middleware::Started; + use test; fn render_500(_: &mut HttpRequest, resp: HttpResponse) -> Result { let mut builder = resp.into_builder(); @@ -113,4 +117,27 @@ mod tests { }; assert!(!resp.headers().contains_key(CONTENT_TYPE)); } + + struct MiddlewareOne; + + impl Middleware for MiddlewareOne { + fn start(&mut self, _req: &mut HttpRequest) -> Result { + Err(ErrorInternalServerError("middleware error")) + } + } + + #[test] + fn test_middleware_start_error() { + let mut srv = test::TestServer::new(move |app| { + app.middleware( + ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), + ).middleware(MiddlewareOne) + .handler(|_| HttpResponse::Ok()) + }); + + let request = srv.get().finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } } diff --git a/src/pipeline.rs b/src/pipeline.rs index f9ec97137..8be7ee838 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -249,7 +249,8 @@ impl> StartMiddlewares { let reply = unsafe { &mut *hnd.get() }.handle(info.req().clone(), htype); return WaitingResponse::init(info, reply); } else { - let state = info.mws.borrow_mut()[info.count as usize].start(info.req_mut()); + let state = + info.mws.borrow_mut()[info.count as usize].start(info.req_mut()); match state { Ok(Started::Done) => info.count += 1, Ok(Started::Response(resp)) => { @@ -263,7 +264,7 @@ impl> StartMiddlewares { _s: PhantomData, }) } - Err(err) => return ProcessResponse::init(err.into()), + Err(err) => return RunMiddlewares::init(info, err.into()), } } } @@ -297,13 +298,13 @@ impl> StartMiddlewares { continue 'outer; } Err(err) => { - return Some(ProcessResponse::init(err.into())) + return Some(RunMiddlewares::init(info, err.into())) } } } } } - Err(err) => return Some(ProcessResponse::init(err.into())), + Err(err) => return Some(RunMiddlewares::init(info, err.into())), } } } @@ -403,7 +404,8 @@ impl RunMiddlewares { if self.curr == len { return Some(ProcessResponse::init(resp)); } else { - let state = info.mws.borrow_mut()[self.curr].response(info.req_mut(), resp); + let state = + info.mws.borrow_mut()[self.curr].response(info.req_mut(), resp); match state { Err(err) => return Some(ProcessResponse::init(err.into())), Ok(Response::Done(r)) => { diff --git a/src/route.rs b/src/route.rs index 27aae8df6..44ac82807 100644 --- a/src/route.rs +++ b/src/route.rs @@ -289,7 +289,8 @@ impl ComposeState { impl Compose { fn new( - req: HttpRequest, mws: Rc>>>>, handler: InnerHandler, + req: HttpRequest, mws: Rc>>>>, + handler: InnerHandler, ) -> Self { let mut info = ComposeInfo { count: 0, @@ -350,7 +351,7 @@ impl StartMiddlewares { _s: PhantomData, }) } - Err(err) => return FinishingMiddlewares::init(info, err.into()), + Err(err) => return RunMiddlewares::init(info, err.into()), } } } @@ -371,7 +372,8 @@ impl StartMiddlewares { let reply = info.handler.handle(info.req.clone()); return Some(WaitingResponse::init(info, reply)); } else { - let state = info.mws.borrow_mut()[info.count].start(&mut info.req); + let state = + info.mws.borrow_mut()[info.count].start(&mut info.req); match state { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { @@ -382,16 +384,13 @@ impl StartMiddlewares { continue 'outer; } Err(err) => { - return Some(FinishingMiddlewares::init( - info, - err.into(), - )) + return Some(RunMiddlewares::init(info, err.into())) } } } } } - Err(err) => return Some(FinishingMiddlewares::init(info, err.into())), + Err(err) => return Some(RunMiddlewares::init(info, err.into())), } } } @@ -483,7 +482,8 @@ impl RunMiddlewares { if self.curr == len { return Some(FinishingMiddlewares::init(info, resp)); } else { - let state = info.mws.borrow_mut()[self.curr].response(&mut info.req, resp); + let state = + info.mws.borrow_mut()[self.curr].response(&mut info.req, resp); match state { Err(err) => { return Some(FinishingMiddlewares::init(info, err.into())) diff --git a/src/scope.rs b/src/scope.rs index 69c46e484..a40113023 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -519,7 +519,7 @@ impl StartMiddlewares { _s: PhantomData, }) } - Err(err) => return Response::init(err.into()), + Err(err) => return RunMiddlewares::init(info, err.into()), } } } @@ -546,7 +546,8 @@ impl StartMiddlewares { }; return Some(WaitingResponse::init(info, reply)); } else { - let state = info.mws.borrow_mut()[info.count].start(&mut info.req); + let state = + info.mws.borrow_mut()[info.count].start(&mut info.req); match state { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { @@ -556,12 +557,14 @@ impl StartMiddlewares { self.fut = Some(fut); continue 'outer; } - Err(err) => return Some(Response::init(err.into())), + Err(err) => { + return Some(RunMiddlewares::init(info, err.into())) + } } } } } - Err(err) => return Some(Response::init(err.into())), + Err(err) => return Some(RunMiddlewares::init(info, err.into())), } } } @@ -653,7 +656,8 @@ impl RunMiddlewares { if self.curr == len { return Some(FinishingMiddlewares::init(info, resp)); } else { - let state = info.mws.borrow_mut()[self.curr].response(&mut info.req, resp); + let state = + info.mws.borrow_mut()[self.curr].response(&mut info.req, resp); match state { Err(err) => { return Some(FinishingMiddlewares::init(info, err.into())) diff --git a/tests/test_middleware.rs b/tests/test_middleware.rs index 3f802b3ae..d64e4feed 100644 --- a/tests/test_middleware.rs +++ b/tests/test_middleware.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; +use actix_web::error::{Error, ErrorInternalServerError}; use actix_web::*; use futures::{future, Future}; use tokio_timer::Delay; @@ -33,7 +34,9 @@ impl middleware::Middleware for MiddlewareTest { Ok(middleware::Response::Done(resp)) } - fn finish(&mut self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { + fn finish( + &mut self, _: &mut HttpRequest, _: &HttpResponse, + ) -> middleware::Finished { self.finish .store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed); middleware::Finished::Done @@ -457,7 +460,9 @@ impl middleware::Middleware for MiddlewareAsyncTest { ))) } - fn finish(&mut self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { + fn finish( + &mut self, _: &mut HttpRequest, _: &HttpResponse, + ) -> middleware::Finished { let to = Delay::new(Instant::now() + Duration::from_millis(10)); let finish = Arc::clone(&self.finish); @@ -788,3 +793,211 @@ fn test_async_sync_resource_middleware_multiple() { thread::sleep(Duration::from_millis(40)); assert_eq!(num3.load(Ordering::Relaxed), 2); } + +struct MiddlewareWithErr; + +impl middleware::Middleware for MiddlewareWithErr { + fn start( + &mut self, _req: &mut HttpRequest, + ) -> Result { + Err(ErrorInternalServerError("middleware error")) + } +} + +struct MiddlewareAsyncWithErr; + +impl middleware::Middleware for MiddlewareAsyncWithErr { + fn start( + &mut self, _req: &mut HttpRequest, + ) -> Result { + Ok(middleware::Started::Future(Box::new(future::err( + ErrorInternalServerError("middleware error"), + )))) + } +} + +#[test] +fn test_middleware_chain_with_error() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new() + .middleware(mw1) + .middleware(MiddlewareWithErr) + .resource("/test", |r| r.h(|_| HttpResponse::Ok())) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + + assert_eq!(num1.load(Ordering::Relaxed), 1); + assert_eq!(num2.load(Ordering::Relaxed), 1); + assert_eq!(num3.load(Ordering::Relaxed), 1); +} + +#[test] +fn test_middleware_async_chain_with_error() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new() + .middleware(mw1) + .middleware(MiddlewareAsyncWithErr) + .resource("/test", |r| r.h(|_| HttpResponse::Ok())) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + + assert_eq!(num1.load(Ordering::Relaxed), 1); + assert_eq!(num2.load(Ordering::Relaxed), 1); + assert_eq!(num3.load(Ordering::Relaxed), 1); +} + +#[test] +fn test_scope_middleware_chain_with_error() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new().scope("/scope", |scope| { + scope + .middleware(mw1) + .middleware(MiddlewareWithErr) + .resource("/test", |r| r.h(|_| HttpResponse::Ok())) + }) + }); + + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + + assert_eq!(num1.load(Ordering::Relaxed), 1); + assert_eq!(num2.load(Ordering::Relaxed), 1); + assert_eq!(num3.load(Ordering::Relaxed), 1); +} + +#[test] +fn test_scope_middleware_async_chain_with_error() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new().scope("/scope", |scope| { + scope + .middleware(mw1) + .middleware(MiddlewareAsyncWithErr) + .resource("/test", |r| r.h(|_| HttpResponse::Ok())) + }) + }); + + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + + assert_eq!(num1.load(Ordering::Relaxed), 1); + assert_eq!(num2.load(Ordering::Relaxed), 1); + assert_eq!(num3.load(Ordering::Relaxed), 1); +} + +#[test] +fn test_resource_middleware_chain_with_error() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new().resource("/test", move |r| { + r.middleware(mw1); + r.middleware(MiddlewareWithErr); + r.h(|_| HttpResponse::Ok()); + }) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + + assert_eq!(num1.load(Ordering::Relaxed), 1); + assert_eq!(num2.load(Ordering::Relaxed), 1); + assert_eq!(num3.load(Ordering::Relaxed), 1); +} + +#[test] +fn test_resource_middleware_async_chain_with_error() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new().resource("/test", move |r| { + r.middleware(mw1); + r.middleware(MiddlewareAsyncWithErr); + r.h(|_| HttpResponse::Ok()); + }) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + + assert_eq!(num1.load(Ordering::Relaxed), 1); + assert_eq!(num2.load(Ordering::Relaxed), 1); + assert_eq!(num3.load(Ordering::Relaxed), 1); +}