diff --git a/src/application.rs b/src/application.rs index 20dfa7901..4b5747a4d 100644 --- a/src/application.rs +++ b/src/application.rs @@ -9,6 +9,7 @@ use httprequest::HttpRequest; use httpresponse::HttpResponse; use middleware::Middleware; use pipeline::{HandlerType, Pipeline, PipelineHandler}; +use pred::Predicate; use resource::ResourceHandler; use router::{Resource, Router}; use scope::Scope; @@ -34,7 +35,7 @@ pub(crate) struct Inner { enum PrefixHandlerType { Handler(String, Box>), - Scope(Resource, Box>), + Scope(Resource, Box>, Vec>>), } impl PipelineHandler for Inner { @@ -51,7 +52,7 @@ impl PipelineHandler for Inner { } HandlerType::Handler(idx) => match self.handlers[idx] { PrefixHandlerType::Handler(_, ref mut hnd) => hnd.handle(req), - PrefixHandlerType::Scope(_, ref mut hnd) => hnd.handle(req), + PrefixHandlerType::Scope(_, ref mut hnd, _) => hnd.handle(req), }, HandlerType::Default => self.default.handle(req, None), } @@ -73,8 +74,8 @@ impl HttpApplication { let path: &'static str = unsafe { &*(&req.path()[inner.prefix..] as *const _) }; let path_len = path.len(); - for idx in 0..inner.handlers.len() { - match &inner.handlers[idx] { + 'outer: for idx in 0..inner.handlers.len() { + match inner.handlers[idx] { PrefixHandlerType::Handler(ref prefix, _) => { let m = { path.starts_with(prefix) @@ -96,10 +97,16 @@ impl HttpApplication { return HandlerType::Handler(idx); } } - PrefixHandlerType::Scope(ref pattern, _) => { + PrefixHandlerType::Scope(ref pattern, _, ref filters) => { if let Some(prefix_len) = pattern.match_prefix_with_params(path, req.match_info_mut()) { + for filter in filters { + if !filter.check(req) { + continue 'outer; + } + } + let prefix_len = inner.prefix + prefix_len - 1; let path: &'static str = unsafe { &*(&req.path()[prefix_len..] as *const _) }; @@ -361,7 +368,7 @@ where F: FnOnce(Scope) -> Scope, { { - let scope = Box::new(f(Scope::new())); + let mut scope = Box::new(f(Scope::new())); let mut path = path.trim().trim_right_matches('/').to_owned(); if !path.is_empty() && !path.starts_with('/') { @@ -372,9 +379,11 @@ where } let parts = self.parts.as_mut().expect("Use after finish"); + let filters = scope.take_filters(); parts.handlers.push(PrefixHandlerType::Scope( Resource::prefix("", &path), scope, + filters, )); } self diff --git a/src/scope.rs b/src/scope.rs index b9ee0e8e9..74009841f 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,5 +1,6 @@ use std::cell::UnsafeCell; use std::marker::PhantomData; +use std::mem; use std::rc::Rc; use futures::{Async, Future, Poll}; @@ -11,11 +12,13 @@ use httprequest::HttpRequest; use httpresponse::HttpResponse; use middleware::{Finished as MiddlewareFinished, Middleware, Response as MiddlewareResponse, Started as MiddlewareStarted}; +use pred::Predicate; use resource::ResourceHandler; use router::Resource; type Route = UnsafeCell>>; type ScopeResources = Rc>>)>>; +type NestedInfo = (Resource, Route, Vec>>); /// Resources scope /// @@ -25,8 +28,8 @@ type ScopeResources = Rc>>)>> /// Scope prefix is always complete path segment, i.e `/app` would /// be converted to a `/app/` and it would not match `/app` path. /// -/// You can use variable path segments with `Path` extractor, also variable -/// segments are available in `HttpRequest::match_info()`. +/// You can get variable path segments from HttpRequest::match_info()`. +/// `Path` extractor also is able to extract scope level variable segments. /// /// ```rust /// # extern crate actix_web; @@ -48,7 +51,8 @@ type ScopeResources = Rc>>)>> /// * /{project_id}/path3 - `HEAD` requests /// pub struct Scope { - nested: Vec<(Resource, Route)>, + filters: Vec>>, + nested: Vec>, middlewares: Rc>>>, default: Rc>>, resources: ScopeResources, @@ -63,6 +67,7 @@ impl Default for Scope { impl Scope { pub fn new() -> Scope { Scope { + filters: Vec::new(), nested: Vec::new(), resources: Rc::new(Vec::new()), middlewares: Rc::new(Vec::new()), @@ -70,6 +75,36 @@ impl Scope { } } + #[inline] + pub(crate) fn take_filters(&mut self) -> Vec>> { + mem::replace(&mut self.filters, Vec::new()) + } + + /// Add match predicate to scoupe. + /// + /// ```rust + /// # extern crate actix_web; + /// use actix_web::{http, pred, App, HttpRequest, HttpResponse, Path}; + /// + /// fn index(data: Path<(String, String)>) -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new() + /// .scope("/app", |scope| { + /// scope.filter(pred::Header("content-type", "text/plain")) + /// .route("/test1", http::Method::GET, index) + /// .route("/test2", http::Method::POST, + /// |_: HttpRequest| HttpResponse::MethodNotAllowed()) + /// }); + /// } + /// ``` + pub fn filter + 'static>(mut self, p: T) -> Self { + self.filters.push(Box::new(p)); + self + } + /// Create nested scope with new state. /// /// ```rust @@ -96,12 +131,13 @@ impl Scope { F: FnOnce(Scope) -> Scope, { let scope = Scope { + filters: Vec::new(), nested: Vec::new(), resources: Rc::new(Vec::new()), middlewares: Rc::new(Vec::new()), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), }; - let scope = f(scope); + let mut scope = f(scope); let mut path = path.trim().trim_right_matches('/').to_owned(); if !path.is_empty() && !path.starts_with('/') { @@ -111,12 +147,14 @@ impl Scope { path.push('/'); } - let handler = UnsafeCell::new(Box::new(Wrapper { - scope, - state: Rc::new(state), - })); + let state = Rc::new(state); + let filters: Vec>> = vec![Box::new(FiltersWrapper { + state: Rc::clone(&state), + filters: scope.take_filters(), + })]; + let handler = UnsafeCell::new(Box::new(Wrapper { scope, state })); self.nested - .push((Resource::prefix("", &path), handler)); + .push((Resource::prefix("", &path), handler, filters)); self } @@ -147,12 +185,13 @@ impl Scope { F: FnOnce(Scope) -> Scope, { let scope = Scope { + filters: Vec::new(), nested: Vec::new(), resources: Rc::new(Vec::new()), middlewares: Rc::new(Vec::new()), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), }; - let scope = f(scope); + let mut scope = f(scope); let mut path = path.trim().trim_right_matches('/').to_owned(); if !path.is_empty() && !path.starts_with('/') { @@ -162,9 +201,11 @@ impl Scope { path.push('/'); } + let filters = scope.take_filters(); self.nested.push(( Resource::prefix("", &path), UnsafeCell::new(Box::new(scope)), + filters, )); self @@ -315,10 +356,15 @@ impl RouteHandler for Scope { let len = req.prefix_len() as usize; let path: &'static str = unsafe { &*(&req.path()[len..] as *const _) }; - for &(ref prefix, ref handler) in &self.nested { + 'outer: for &(ref prefix, ref handler, ref filters) in &self.nested { if let Some(prefix_len) = prefix.match_prefix_with_params(path, req.match_info_mut()) { + for filter in filters { + if !filter.check(&mut req) { + continue 'outer; + } + } let prefix_len = len + prefix_len - 1; let path: &'static str = unsafe { &*(&req.path()[prefix_len..] as *const _) }; @@ -363,6 +409,23 @@ impl RouteHandler for Wrapper { } } +struct FiltersWrapper { + state: Rc, + filters: Vec>>, +} + +impl Predicate for FiltersWrapper { + fn check(&self, req: &mut HttpRequest) -> bool { + let mut req = req.change_state(Rc::clone(&self.state)); + for filter in &self.filters { + if !filter.check(&mut req) { + return false; + } + } + true + } +} + /// Compose resource level middlewares with route handler. struct Compose { info: ComposeInfo, @@ -713,8 +776,9 @@ mod tests { use application::App; use body::Body; - use http::StatusCode; + use http::{Method, StatusCode}; use httpresponse::HttpResponse; + use pred; use test::TestRequest; #[test] @@ -730,6 +794,29 @@ mod tests { assert_eq!(resp.as_msg().status(), StatusCode::OK); } + #[test] + fn test_scope_filter() { + let mut app = App::new() + .scope("/app", |scope| { + scope + .filter(pred::Get()) + .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) + }) + .finish(); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .finish(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::GET) + .finish(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + } + #[test] fn test_scope_variable_segment() { let mut app = App::new() @@ -748,7 +835,7 @@ mod tests { assert_eq!(resp.as_msg().status(), StatusCode::OK); match resp.as_msg().body() { - Body::Binary(ref b) => { + &Body::Binary(ref b) => { let bytes: Bytes = b.clone().into(); assert_eq!(bytes, Bytes::from_static(b"project: project1")); } @@ -777,6 +864,33 @@ mod tests { assert_eq!(resp.as_msg().status(), StatusCode::CREATED); } + #[test] + fn test_scope_with_state_filter() { + struct State; + + let mut app = App::new() + .scope("/app", |scope| { + scope.with_state("/t1", State, |scope| { + scope + .filter(pred::Get()) + .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) + }) + }) + .finish(); + + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::POST) + .finish(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::GET) + .finish(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + } + #[test] fn test_nested_scope() { let mut app = App::new() @@ -792,6 +906,31 @@ mod tests { assert_eq!(resp.as_msg().status(), StatusCode::CREATED); } + #[test] + fn test_nested_scope_filter() { + let mut app = App::new() + .scope("/app", |scope| { + scope.nested("/t1", |scope| { + scope + .filter(pred::Get()) + .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) + }) + }) + .finish(); + + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::POST) + .finish(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::GET) + .finish(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + } + #[test] fn test_nested_scope_with_variable_segment() { let mut app = App::new() @@ -814,7 +953,7 @@ mod tests { assert_eq!(resp.as_msg().status(), StatusCode::CREATED); match resp.as_msg().body() { - Body::Binary(ref b) => { + &Body::Binary(ref b) => { let bytes: Bytes = b.clone().into(); assert_eq!(bytes, Bytes::from_static(b"project: project_1")); } @@ -847,7 +986,7 @@ mod tests { assert_eq!(resp.as_msg().status(), StatusCode::CREATED); match resp.as_msg().body() { - Body::Binary(ref b) => { + &Body::Binary(ref b) => { let bytes: Bytes = b.clone().into(); assert_eq!(bytes, Bytes::from_static(b"project: test - 1")); }