From b393ddf879b0d955b3bbea74291534680519850b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 16 May 2018 11:00:29 -0700 Subject: [PATCH] fix panic during middleware execution #226 --- CHANGES.md | 4 +- src/pipeline.rs | 12 +-- src/route.rs | 47 ++++++----- src/scope.rs | 35 ++++---- tests/test_middleware.rs | 169 ++++++++++++++++++++++++++++++--------- 5 files changed, 185 insertions(+), 82 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index f450c70d6..aa4d7c455 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,8 +1,8 @@ # Changes -## 0.6.6 (2018-05-xx) +## 0.6.6 (2018-05-16) -.. +* Panic during middleware execution #226 ## 0.6.5 (2018-05-15) diff --git a/src/pipeline.rs b/src/pipeline.rs index 4d5d405c0..82ec45a74 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -284,12 +284,12 @@ impl> StartMiddlewares { if let Some(resp) = resp { return Some(RunMiddlewares::init(info, resp)); } - if info.count == len { - let reply = unsafe { &mut *self.hnd.get() } - .handle(info.req().clone(), self.htype); - return Some(WaitingResponse::init(info, reply)); - } else { - loop { + loop { + if info.count == len { + let reply = unsafe { &mut *self.hnd.get() } + .handle(info.req().clone(), self.htype); + return Some(WaitingResponse::init(info, reply)); + } else { match info.mws[info.count as usize].start(info.req_mut()) { Ok(Started::Done) => info.count += 1, Ok(Started::Response(resp)) => { diff --git a/src/route.rs b/src/route.rs index 1322d1087..b109fd609 100644 --- a/src/route.rs +++ b/src/route.rs @@ -5,13 +5,17 @@ use std::rc::Rc; use futures::{Async, Future, Poll}; use error::Error; -use handler::{AsyncHandler, AsyncResult, AsyncResultItem, FromRequest, Handler, - Responder, RouteHandler, WrapHandler}; +use handler::{ + AsyncHandler, AsyncResult, AsyncResultItem, FromRequest, Handler, Responder, + RouteHandler, WrapHandler, +}; use http::StatusCode; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use middleware::{Finished as MiddlewareFinished, Middleware, - Response as MiddlewareResponse, Started as MiddlewareStarted}; +use middleware::{ + Finished as MiddlewareFinished, Middleware, Response as MiddlewareResponse, + Started as MiddlewareStarted, +}; use pred::Predicate; use with::{ExtractorConfig, With, With2, With3, WithAsync}; @@ -51,7 +55,9 @@ impl Route { #[inline] pub(crate) fn compose( - &mut self, req: HttpRequest, mws: Rc>>>, + &mut self, + req: HttpRequest, + mws: Rc>>>, ) -> AsyncResult { AsyncResult::async(Box::new(Compose::new(req, mws, self.handler.clone()))) } @@ -242,7 +248,8 @@ impl Route { /// } /// ``` pub fn with2( - &mut self, handler: F, + &mut self, + handler: F, ) -> (ExtractorConfig, ExtractorConfig) where F: Fn(T1, T2) -> R + 'static, @@ -263,7 +270,8 @@ impl Route { #[doc(hidden)] /// Set handler function, use request extractor for all parameters. pub fn with3( - &mut self, handler: F, + &mut self, + handler: F, ) -> ( ExtractorConfig, ExtractorConfig, @@ -296,9 +304,7 @@ struct InnerHandler(Rc>>>); impl InnerHandler { #[inline] fn new>(h: H) -> Self { - InnerHandler(Rc::new(UnsafeCell::new(Box::new(WrapHandler::new( - h, - ))))) + InnerHandler(Rc::new(UnsafeCell::new(Box::new(WrapHandler::new(h))))) } #[inline] @@ -309,9 +315,7 @@ impl InnerHandler { R: Responder + 'static, E: Into + 'static, { - InnerHandler(Rc::new(UnsafeCell::new(Box::new(AsyncHandler::new( - h, - ))))) + InnerHandler(Rc::new(UnsafeCell::new(Box::new(AsyncHandler::new(h))))) } #[inline] @@ -364,7 +368,9 @@ impl ComposeState { impl Compose { fn new( - req: HttpRequest, mws: Rc>>>, handler: InnerHandler, + req: HttpRequest, + mws: Rc>>>, + handler: InnerHandler, ) -> Self { let mut info = ComposeInfo { count: 0, @@ -440,11 +446,11 @@ impl StartMiddlewares { if let Some(resp) = resp { return Some(RunMiddlewares::init(info, resp)); } - if info.count == len { - let reply = info.handler.handle(info.req.clone()); - return Some(WaitingResponse::init(info, reply)); - } else { - loop { + loop { + if info.count == len { + let reply = info.handler.handle(info.req.clone()); + return Some(WaitingResponse::init(info, reply)); + } else { match info.mws[info.count].start(&mut info.req) { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { @@ -479,7 +485,8 @@ struct WaitingResponse { impl WaitingResponse { #[inline] fn init( - info: &mut ComposeInfo, reply: AsyncResult, + info: &mut ComposeInfo, + reply: AsyncResult, ) -> ComposeState { match reply.into() { AsyncResultItem::Err(err) => RunMiddlewares::init(info, err.into()), diff --git a/src/scope.rs b/src/scope.rs index aecfd6bfa..00bcadad9 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -10,8 +10,10 @@ use handler::{AsyncResult, AsyncResultItem, FromRequest, Responder, RouteHandler use http::Method; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use middleware::{Finished as MiddlewareFinished, Middleware, - Response as MiddlewareResponse, Started as MiddlewareStarted}; +use middleware::{ + Finished as MiddlewareFinished, Middleware, Response as MiddlewareResponse, + Started as MiddlewareStarted, +}; use pred::Predicate; use resource::ResourceHandler; use router::Resource; @@ -400,8 +402,7 @@ struct Wrapper { impl RouteHandler for Wrapper { fn handle(&mut self, req: HttpRequest) -> AsyncResult { - self.scope - .handle(req.change_state(Rc::clone(&self.state))) + self.scope.handle(req.change_state(Rc::clone(&self.state))) } } @@ -458,7 +459,8 @@ impl ComposeState { impl Compose { fn new( - req: HttpRequest, mws: Rc>>>, + req: HttpRequest, + mws: Rc>>>, resource: Rc>>, default: Option>>>, ) -> Self { @@ -543,17 +545,17 @@ impl StartMiddlewares { if let Some(resp) = resp { return Some(RunMiddlewares::init(info, resp)); } - if info.count == len { - let resource = unsafe { &mut *info.resource.get() }; - let reply = if let Some(ref default) = info.default { - let d = unsafe { &mut *default.as_ref().get() }; - resource.handle(info.req.clone(), Some(d)) + loop { + if info.count == len { + let resource = unsafe { &mut *info.resource.get() }; + let reply = if let Some(ref default) = info.default { + let d = unsafe { &mut *default.as_ref().get() }; + resource.handle(info.req.clone(), Some(d)) + } else { + resource.handle(info.req.clone(), None) + }; + return Some(WaitingResponse::init(info, reply)); } else { - resource.handle(info.req.clone(), None) - }; - return Some(WaitingResponse::init(info, reply)); - } else { - loop { match info.mws[info.count].start(&mut info.req) { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { @@ -583,7 +585,8 @@ struct WaitingResponse { impl WaitingResponse { #[inline] fn init( - info: &mut ComposeInfo, reply: AsyncResult, + info: &mut ComposeInfo, + reply: AsyncResult, ) -> ComposeState { match reply.into() { AsyncResultItem::Ok(resp) => RunMiddlewares::init(info, resp), diff --git a/tests/test_middleware.rs b/tests/test_middleware.rs index 99151afd7..2c9160b61 100644 --- a/tests/test_middleware.rs +++ b/tests/test_middleware.rs @@ -21,28 +21,24 @@ struct MiddlewareTest { impl middleware::Middleware for MiddlewareTest { fn start(&self, _: &mut HttpRequest) -> Result { - self.start.store( - self.start.load(Ordering::Relaxed) + 1, - Ordering::Relaxed, - ); + self.start + .store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed); Ok(middleware::Started::Done) } fn response( - &self, _: &mut HttpRequest, resp: HttpResponse, + &self, + _: &mut HttpRequest, + resp: HttpResponse, ) -> Result { - self.response.store( - self.response.load(Ordering::Relaxed) + 1, - Ordering::Relaxed, - ); + self.response + .store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed); Ok(middleware::Response::Done(resp)) } fn finish(&self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { - self.finish.store( - self.finish.load(Ordering::Relaxed) + 1, - Ordering::Relaxed, - ); + self.finish + .store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed); middleware::Finished::Done } } @@ -187,10 +183,7 @@ fn test_scope_middleware() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -226,10 +219,7 @@ fn test_scope_middleware_multiple() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -337,10 +327,7 @@ fn test_scope_middleware_async_handler() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -402,10 +389,7 @@ fn test_scope_middleware_async_error() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); @@ -466,7 +450,9 @@ impl middleware::Middleware for MiddlewareAsyncTest { } fn response( - &self, _: &mut HttpRequest, resp: HttpResponse, + &self, + _: &mut HttpRequest, + resp: HttpResponse, ) -> Result { let to = Timeout::new(Duration::from_millis(10), &Arbiter::handle()).unwrap(); @@ -555,6 +541,42 @@ fn test_async_middleware_multiple() { assert_eq!(num3.load(Ordering::Relaxed), 2); } +#[test] +fn test_async_sync_middleware_multiple() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + App::new() + .middleware(MiddlewareAsyncTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .middleware(MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + assert_eq!(num1.load(Ordering::Relaxed), 2); + assert_eq!(num2.load(Ordering::Relaxed), 2); + + thread::sleep(Duration::from_millis(50)); + assert_eq!(num3.load(Ordering::Relaxed), 2); +} + #[test] fn test_async_scope_middleware() { let num1 = Arc::new(AtomicUsize::new(0)); @@ -577,10 +599,7 @@ fn test_async_scope_middleware() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -618,10 +637,45 @@ fn test_async_scope_middleware_multiple() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + assert_eq!(num1.load(Ordering::Relaxed), 2); + assert_eq!(num2.load(Ordering::Relaxed), 2); + + thread::sleep(Duration::from_millis(20)); + assert_eq!(num3.load(Ordering::Relaxed), 2); +} + +#[test] +fn test_async_async_scope_middleware_multiple() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + App::new().scope("/scope", |scope| { + scope + .middleware(MiddlewareAsyncTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .middleware(MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }) + }); + + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -703,3 +757,42 @@ fn test_async_resource_middleware_multiple() { thread::sleep(Duration::from_millis(40)); assert_eq!(num3.load(Ordering::Relaxed), 2); } + +#[test] +fn test_async_sync_resource_middleware_multiple() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareAsyncTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + let mw2 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new().resource("/test", move |r| { + r.middleware(mw1); + r.middleware(mw2); + r.h(|_| HttpResponse::Ok()); + }) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + assert_eq!(num1.load(Ordering::Relaxed), 2); + assert_eq!(num2.load(Ordering::Relaxed), 2); + + thread::sleep(Duration::from_millis(40)); + assert_eq!(num3.load(Ordering::Relaxed), 2); +}