From 4616ca8ee65ca1e14e784b8b46dc1a1306b422a5 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Tue, 28 Dec 2021 02:37:13 +0000 Subject: [PATCH] rework `Guard` trait (#2552) --- CHANGES.md | 11 + actix-files/src/service.rs | 8 +- actix-http/src/message.rs | 6 +- src/app_service.rs | 13 +- src/dev.rs | 23 +- src/guard.rs | 742 +++++++++++++++++++++--------------- src/middleware/normalize.rs | 12 +- src/route.rs | 38 +- src/scope.rs | 7 +- src/service.rs | 34 +- src/test/test_request.rs | 2 +- 11 files changed, 545 insertions(+), 351 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index e1317a7a..f925f3b9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,17 @@ # Changes ## Unreleased - 2021-xx-xx +### Added +- `guard::GuardContext` for use with the `Guard` trait. [#2552] +- `ServiceRequest::guard_ctx` for obtaining a guard context. [#2552] + +### Changed +- `Guard` trait now receives a `&GuardContext`. [#2552] +- `guard::fn_guard` functions now receives a `&GuardContext`. [#2552] +- Some guards now return `impl Guard` and their concrete types are made private: `guard::{Header}` and all the method guards. [#2552] +- The `Not` guard is now generic over the type of guard it wraps. [#2552] + +[#2552]: https://github.com/actix/actix-web/pull/2552 ## 4.0.0-beta.16 - 2021-12-27 diff --git a/actix-files/src/service.rs b/actix-files/src/service.rs index f6e1c2e1..057dbe5a 100644 --- a/actix-files/src/service.rs +++ b/actix-files/src/service.rs @@ -1,8 +1,8 @@ use std::{fmt, io, ops::Deref, path::PathBuf, rc::Rc}; -use actix_service::Service; use actix_web::{ - dev::{ServiceRequest, ServiceResponse}, + body::BoxBody, + dev::{Service, ServiceRequest, ServiceResponse}, error::Error, guard::Guard, http::{header, Method}, @@ -94,7 +94,7 @@ impl fmt::Debug for FilesService { } impl Service for FilesService { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; @@ -103,7 +103,7 @@ impl Service for FilesService { fn call(&self, req: ServiceRequest) -> Self::Future { let is_method_valid = if let Some(guard) = &self.guards { // execute user defined guards - (**guard).check(req.head()) + (**guard).check(&req.guard_ctx()) } else { // default behavior matches!(*req.method(), Method::HEAD | Method::GET) diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index 34213f68..ecd08fbb 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, ops, rc::Rc}; use bitflags::bitflags; @@ -49,7 +49,7 @@ impl Message { } } -impl std::ops::Deref for Message { +impl ops::Deref for Message { type Target = T; fn deref(&self) -> &Self::Target { @@ -57,7 +57,7 @@ impl std::ops::Deref for Message { } } -impl std::ops::DerefMut for Message { +impl ops::DerefMut for Message { fn deref_mut(&mut self) -> &mut Self::Target { Rc::get_mut(&mut self.head).expect("Multiple copies exist") } diff --git a/src/app_service.rs b/src/app_service.rs index e0d42439..56b24f0d 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,14 +1,16 @@ use std::{cell::RefCell, mem, rc::Rc}; -use actix_http::{Extensions, Request}; +use actix_http::Request; use actix_router::{Path, ResourceDef, Router, Url}; use actix_service::{boxed, fn_service, Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; use crate::{ + body::BoxBody, config::{AppConfig, AppService}, data::FnDataFactory, + dev::Extensions, guard::Guard, request::{HttpRequest, HttpRequestPool}, rmap::ResourceMap, @@ -297,7 +299,7 @@ pub struct AppRouting { } impl Service for AppRouting { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; @@ -306,12 +308,15 @@ impl Service for AppRouting { fn call(&self, mut req: ServiceRequest) -> Self::Future { let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { - for f in guards { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in guards { + if !guard.check(&guard_ctx) { return false; } } } + true }); diff --git a/src/dev.rs b/src/dev.rs index 6e197046..bb1385bd 100644 --- a/src/dev.rs +++ b/src/dev.rs @@ -3,6 +3,16 @@ //! Most users will not have to interact with the types in this module, but it is useful for those //! writing extractors, middleware, libraries, or interacting with the service API directly. +pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead}; +pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; +pub use actix_server::{Server, ServerHandle}; +pub use actix_service::{ + always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform, +}; + +#[cfg(feature = "__compress")] +pub use actix_http::encoding::Decoder as Decompress; + pub use crate::config::{AppConfig, AppService}; #[doc(hidden)] pub use crate::handler::Handler; @@ -14,16 +24,6 @@ pub use crate::types::form::UrlEncoded; pub use crate::types::json::JsonBody; pub use crate::types::readlines::Readlines; -pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead}; -pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; -pub use actix_server::{Server, ServerHandle}; -pub use actix_service::{ - always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform, -}; - -#[cfg(feature = "__compress")] -pub use actix_http::encoding::Decoder as Decompress; - use crate::http::header::ContentEncoding; use actix_router::Patterns; @@ -46,7 +46,6 @@ pub(crate) fn ensure_leading_slash(mut patterns: Patterns) -> Patterns { patterns } -struct Enc(ContentEncoding); /// Helper trait that allows to set specific encoding for response. pub trait BodyEncoding { @@ -70,6 +69,8 @@ impl BodyEncoding for actix_http::ResponseBuilder { } } +struct Enc(ContentEncoding); + impl BodyEncoding for actix_http::Response { fn get_encoding(&self) -> Option { self.extensions().get::().map(|enc| enc.0) diff --git a/src/guard.rs b/src/guard.rs index db7f0698..fb3e4f24 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,160 +1,248 @@ -//! Route match guards. +//! Route guards. //! -//! Guards are one of the ways how actix-web router chooses a -//! handler service. In essence it is just a function that accepts a -//! reference to a `RequestHead` instance and returns a boolean. -//! It is possible to add guards to *scopes*, *resources* -//! and *routes*. Actix provide several guards by default, like various -//! http methods, header, etc. To become a guard, type must implement `Guard` -//! trait. Simple functions could be guards as well. +//! Guards are used during routing to help select a matching service or handler using some aspect of +//! the request; though guards should not be used for path matching since it is a built-in function +//! of the Actix Web router. //! -//! Guards can not modify the request object. But it is possible -//! to store extra attributes on a request by using the `Extensions` container. -//! Extensions containers are available via the `RequestHead::extensions()` method. +//! Guards can be used on [`Scope`]s, [`Resource`]s, [`Route`]s, and other custom services. //! +//! Fundamentally, a guard is a predicate function that receives a reference to a request context +//! object and returns a boolean; true if the request _should_ be handled by the guarded service +//! or handler. This interface is defined by the [`Guard`] trait. +//! +//! Commonly-used guards are provided in this module as well as way of creating a guard from a +//! closure ([`fn_guard`]). The [`Not`], [`Any`], and [`All`] guards are noteworthy, as they can be +//! used to compose other guards in a more flexible and semantic way than calling `.guard(...)` on +//! services multiple times (which might have different combining behavior than you want). +//! +//! Guards can not modify anything about the request. However, it is possible to store extra +//! attributes in the request-local data container obtained with [`GuardContext::req_data_mut`]. +//! +//! Guards can prevent resource definitions from overlapping (resulting in some inaccessible routes) +//! where they otherwise would when only considering paths. See the virtual hosting example below. +//! +//! # Examples +//! In the following code, the `/guarded` resource has one defined route whose handler will only be +//! called if the request method is `POST` and there is a request header with name and value equal +//! to `x-guarded` and `secret`, respectively. //! ``` -//! use actix_web::{web, http, dev, guard, App, HttpResponse}; +//! use actix_web::{web, http::Method, guard, HttpResponse}; //! -//! App::new().service(web::resource("/index.html").route( +//! web::resource("/guarded").route( //! web::route() -//! .guard(guard::Post()) -//! .guard(guard::fn_guard(|head| head.method == http::Method::GET)) -//! .to(|| HttpResponse::MethodNotAllowed())) +//! .guard(guard::Any(guard::Get()).or(guard::Post())) +//! .guard(guard::Header("x-guarded", "secret")) +//! .to(|| HttpResponse::Ok()) //! ); //! ``` +//! +//! Guards can be used to set up some form of [virtual hosting] within a single app. +//! Overlapping scope prefixes are usually discouraged, but when combined with non-overlapping guard +//! definitions they become safe to use in this way. Without these host guards, only routes under +//! the first-to-be-defined scope would be accessible. You can test this locally using `127.0.0.1` +//! and `localhost` as the `Host` guards. +//! ``` +//! use actix_web::{web, http::Method, guard, App, HttpResponse}; +//! +//! App::new() +//! .service( +//! web::scope("") +//! .guard(guard::Host("www.rust-lang.org")) +//! .default_service(web::to(|| HttpResponse::Ok().body("marketing site"))), +//! ) +//! .service( +//! web::scope("") +//! .guard(guard::Host("play.rust-lang.org")) +//! .default_service(web::to(|| HttpResponse::Ok().body("playground frontend"))), +//! ); +//! ``` +//! +//! [`Scope`]: crate::Scope::guard() +//! [`Resource`]: crate::Resource::guard() +//! [`Route`]: crate::Route::guard() +//! [virtual hosting]: https://en.wikipedia.org/wiki/Virtual_hosting -#![allow(non_snake_case)] +use std::{ + cell::{Ref, RefMut}, + convert::TryFrom, + rc::Rc, +}; -use std::rc::Rc; -use std::{convert::TryFrom, ops::Deref}; +use actix_http::{header, uri::Uri, Extensions, Method as HttpMethod, RequestHead}; -use actix_http::{header, uri::Uri, Method as HttpMethod, RequestHead}; +use crate::service::ServiceRequest; -/// Trait defines resource guards. Guards are used for route selection. -/// -/// Guards can not modify the request object. But it is possible -/// to store extra attributes on a request by using the `Extensions` container. -/// Extensions containers are available via the `RequestHead::extensions()` method. -pub trait Guard { - /// Check if request matches predicate - fn check(&self, request: &RequestHead) -> bool; +/// Provides access to request parts that are useful during routing. +#[derive(Debug)] +pub struct GuardContext<'a> { + pub(crate) req: &'a ServiceRequest, } -impl Guard for Rc { - fn check(&self, request: &RequestHead) -> bool { - self.deref().check(request) +impl<'a> GuardContext<'a> { + /// Returns reference to the request head. + #[inline] + pub fn head(&self) -> &RequestHead { + self.req.head() + } + + /// Returns reference to the request-local data container. + #[inline] + pub fn req_data(&self) -> Ref<'a, Extensions> { + self.req.req_data() + } + + /// Returns mutable reference to the request-local data container. + #[inline] + pub fn req_data_mut(&self) -> RefMut<'a, Extensions> { + self.req.req_data_mut() } } -/// Create guard object for supplied function. +/// Interface for routing guards. /// +/// See [module level documentation](self) for more. +pub trait Guard { + /// Returns true if predicate condition is met for a given request. + fn check(&self, ctx: &GuardContext<'_>) -> bool; +} + +impl Guard for Rc { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (**self).check(ctx) + } +} + +/// Creates a guard using the given function. +/// +/// # Examples /// ``` -/// use actix_web::{guard, web, App, HttpResponse}; +/// use actix_web::{guard, web, HttpResponse}; /// -/// App::new().service(web::resource("/index.html").route( -/// web::route() -/// .guard( -/// guard::fn_guard( -/// |req| req.headers() -/// .contains_key("content-type"))) -/// .to(|| HttpResponse::MethodNotAllowed())) -/// ); +/// web::route() +/// .guard(guard::fn_guard(|ctx| { +/// ctx.head().headers().contains_key("content-type") +/// })) +/// .to(|| HttpResponse::Ok()); /// ``` pub fn fn_guard(f: F) -> impl Guard where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { FnGuard(f) } -struct FnGuard bool>(F); +struct FnGuard) -> bool>(F); impl Guard for FnGuard where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { - fn check(&self, head: &RequestHead) -> bool { - (self.0)(head) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (self.0)(ctx) } } impl Guard for F where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { - fn check(&self, head: &RequestHead) -> bool { - (self)(head) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (self)(ctx) } } -/// Return guard that matches if any of supplied guards. +/// Creates a guard that matches if any added guards match. /// +/// # Examples +/// The handler below will be called for either request method `GET` or `POST`. /// ``` -/// use actix_web::{web, guard, App, HttpResponse}; +/// use actix_web::{web, guard, HttpResponse}; /// -/// App::new().service(web::resource("/index.html").route( -/// web::route() -/// .guard(guard::Any(guard::Get()).or(guard::Post())) -/// .to(|| HttpResponse::MethodNotAllowed())) -/// ); +/// web::route() +/// .guard( +/// guard::Any(guard::Get()) +/// .or(guard::Post())) +/// .to(|| HttpResponse::Ok()); /// ``` +#[allow(non_snake_case)] pub fn Any(guard: F) -> AnyGuard { - AnyGuard(vec![Box::new(guard)]) + AnyGuard { + guards: vec![Box::new(guard)], + } } -/// Matches any of supplied guards. -pub struct AnyGuard(Vec>); +/// A collection of guards that match if the disjunction of their `check` outcomes is true. +/// +/// That is, only one contained guard needs to match in order for the aggregate guard to match. +/// +/// Construct an `AnyGuard` using [`Any`]. +pub struct AnyGuard { + guards: Vec>, +} impl AnyGuard { - /// Add guard to a list of guards to check + /// Adds new guard to the collection of guards to check. pub fn or(mut self, guard: F) -> Self { - self.0.push(Box::new(guard)); + self.guards.push(Box::new(guard)); self } } impl Guard for AnyGuard { - fn check(&self, req: &RequestHead) -> bool { - for p in &self.0 { - if p.check(req) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + for guard in &self.guards { + if guard.check(ctx) { return true; } } + false } } -/// Return guard that matches if all of the supplied guards. +/// Creates a guard that matches if all added guards match. /// +/// # Examples +/// The handler below will only be called if the request method is `GET` **and** the specified +/// header name and value match exactly. /// ``` -/// use actix_web::{guard, web, App, HttpResponse}; +/// use actix_web::{guard, web, HttpResponse}; /// -/// App::new().service(web::resource("/index.html").route( -/// web::route() -/// .guard( -/// guard::All(guard::Get()).and(guard::Header("content-type", "text/plain"))) -/// .to(|| HttpResponse::MethodNotAllowed())) -/// ); +/// web::route() +/// .guard( +/// guard::All(guard::Get()) +/// .and(guard::Header("accept", "text/plain")) +/// ) +/// .to(|| HttpResponse::Ok()); /// ``` +#[allow(non_snake_case)] pub fn All(guard: F) -> AllGuard { - AllGuard(vec![Box::new(guard)]) + AllGuard { + guards: vec![Box::new(guard)], + } } -/// Matches if all of supplied guards. -pub struct AllGuard(Vec>); +/// A collection of guards that match if the conjunction of their `check` outcomes is true. +/// +/// That is, **all** contained guard needs to match in order for the aggregate guard to match. +/// +/// Construct an `AllGuard` using [`All`]. +pub struct AllGuard { + guards: Vec>, +} impl AllGuard { - /// Add new guard to the list of guards to check + /// Adds new guard to the collection of guards to check. pub fn and(mut self, guard: F) -> Self { - self.0.push(Box::new(guard)); + self.guards.push(Box::new(guard)); self } } impl Guard for AllGuard { - fn check(&self, request: &RequestHead) -> bool { - for p in &self.0 { - if !p.check(request) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + for guard in &self.guards { + if !guard.check(ctx) { return false; } } @@ -162,159 +250,189 @@ impl Guard for AllGuard { } } -/// Return guard that matches if supplied guard does not match. -pub fn Not(guard: F) -> NotGuard { - NotGuard(Box::new(guard)) -} +/// Wraps a guard and inverts the outcome of it's `Guard` implementation. +/// +/// # Examples +/// The handler below will be called for any request method apart from `GET`. +/// ``` +/// use actix_web::{guard, web, HttpResponse}; +/// +/// web::route() +/// .guard(guard::Not(guard::Get())) +/// .to(|| HttpResponse::Ok()); +/// ``` +pub struct Not(pub G); -#[doc(hidden)] -pub struct NotGuard(Box); - -impl Guard for NotGuard { - fn check(&self, request: &RequestHead) -> bool { - !self.0.check(request) +impl Guard for Not { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + !self.0.check(ctx) } } -/// HTTP method guard. -#[doc(hidden)] -pub struct MethodGuard(HttpMethod); - -impl Guard for MethodGuard { - fn check(&self, request: &RequestHead) -> bool { - request.method == self.0 - } -} - -/// Guard to match *GET* HTTP method. -pub fn Get() -> MethodGuard { - MethodGuard(HttpMethod::GET) -} - -/// Predicate to match *POST* HTTP method. -pub fn Post() -> MethodGuard { - MethodGuard(HttpMethod::POST) -} - -/// Predicate to match *PUT* HTTP method. -pub fn Put() -> MethodGuard { - MethodGuard(HttpMethod::PUT) -} - -/// Predicate to match *DELETE* HTTP method. -pub fn Delete() -> MethodGuard { - MethodGuard(HttpMethod::DELETE) -} - -/// Predicate to match *HEAD* HTTP method. -pub fn Head() -> MethodGuard { - MethodGuard(HttpMethod::HEAD) -} - -/// Predicate to match *OPTIONS* HTTP method. -pub fn Options() -> MethodGuard { - MethodGuard(HttpMethod::OPTIONS) -} - -/// Predicate to match *CONNECT* HTTP method. -pub fn Connect() -> MethodGuard { - MethodGuard(HttpMethod::CONNECT) -} - -/// Predicate to match *PATCH* HTTP method. -pub fn Patch() -> MethodGuard { - MethodGuard(HttpMethod::PATCH) -} - -/// Predicate to match *TRACE* HTTP method. -pub fn Trace() -> MethodGuard { - MethodGuard(HttpMethod::TRACE) -} - -/// Predicate to match specified HTTP method. -pub fn Method(method: HttpMethod) -> MethodGuard { +/// Creates a guard that matches a specified HTTP method. +#[allow(non_snake_case)] +pub fn Method(method: HttpMethod) -> impl Guard { MethodGuard(method) } -/// Return predicate that matches if request contains specified header and -/// value. -pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard { +/// HTTP method guard. +struct MethodGuard(HttpMethod); + +impl Guard for MethodGuard { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + ctx.head().method == self.0 + } +} + +macro_rules! method_guard { + ($method_fn:ident, $method_const:ident) => { + paste::paste! { + #[doc = " Creates a guard that matches the `" $method_const "` request method."] + /// + /// # Examples + #[doc = " The route in this example will only respond to `" $method_const "` requests."] + /// ``` + /// use actix_web::{guard, web, HttpResponse}; + /// + /// web::route() + #[doc = " .guard(guard::" $method_fn "())"] + /// .to(|| HttpResponse::Ok()); + /// ``` + #[allow(non_snake_case)] + pub fn $method_fn() -> impl Guard { + MethodGuard(HttpMethod::$method_const) + } + } + }; +} + +method_guard!(Get, GET); +method_guard!(Post, POST); +method_guard!(Put, PUT); +method_guard!(Delete, DELETE); +method_guard!(Head, HEAD); +method_guard!(Options, OPTIONS); +method_guard!(Connect, CONNECT); +method_guard!(Patch, PATCH); +method_guard!(Trace, TRACE); + +/// Creates a guard that matches if request contains given header name and value. +/// +/// # Examples +/// The handler below will be called when the request contains an `x-guarded` header with value +/// equal to `secret`. +/// ``` +/// use actix_web::{guard, web, HttpResponse}; +/// +/// web::route() +/// .guard(guard::Header("x-guarded", "secret")) +/// .to(|| HttpResponse::Ok()); +/// ``` +#[allow(non_snake_case)] +pub fn Header(name: &'static str, value: &'static str) -> impl Guard { HeaderGuard( header::HeaderName::try_from(name).unwrap(), header::HeaderValue::from_static(value), ) } -#[doc(hidden)] -pub struct HeaderGuard(header::HeaderName, header::HeaderValue); +struct HeaderGuard(header::HeaderName, header::HeaderValue); impl Guard for HeaderGuard { - fn check(&self, req: &RequestHead) -> bool { - if let Some(val) = req.headers.get(&self.0) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + if let Some(val) = ctx.head().headers.get(&self.0) { return val == self.1; } + false } } -/// Return predicate that matches if request contains specified Host name. +/// Creates a guard that matches requests targetting a specific host. /// -/// ``` -/// use actix_web::{web, guard::Host, App, HttpResponse}; +/// # Matching Host +/// This guard will: +/// - match against the `Host` header, if present; +/// - fall-back to matching against the request target's host, if present; +/// - return false if host cannot be determined; /// -/// App::new().service( -/// web::resource("/index.html") -/// .guard(Host("www.rust-lang.org")) -/// .to(|| HttpResponse::MethodNotAllowed()) -/// ); +/// # Matching Scheme +/// Optionally, this guard can match against the host's scheme. Set the scheme for matching using +/// `Host(host).scheme(protocol)`. If the request's scheme cannot be determined, it will not prevent +/// the guard from matching successfully. +/// +/// # Examples +/// The [module-level documentation](self) has an example of virtual hosting using `Host` guards. +/// +/// The example below additionally guards on the host URI's scheme. This could allow routing to +/// different handlers for `http:` vs `https:` visitors; to redirect, for example. /// ``` -pub fn Host>(host: H) -> HostGuard { - HostGuard(host.as_ref().to_string(), None) +/// use actix_web::{web, guard::Host, HttpResponse}; +/// +/// web::scope("/admin") +/// .guard(Host("admin.rust-lang.org").scheme("https")) +/// .default_service(web::to(|| HttpResponse::Ok().body("admin connection is secure"))); +/// ``` +#[allow(non_snake_case)] +pub fn Host(host: impl AsRef) -> HostGuard { + HostGuard { + host: host.as_ref().to_string(), + scheme: None, + } } fn get_host_uri(req: &RequestHead) -> Option { - use core::str::FromStr; req.headers .get(header::HOST) .and_then(|host_value| host_value.to_str().ok()) .or_else(|| req.uri.host()) - .map(|host: &str| Uri::from_str(host).ok()) - .and_then(|host_success| host_success) + .and_then(|host| host.parse().ok()) } #[doc(hidden)] -pub struct HostGuard(String, Option); +pub struct HostGuard { + host: String, + scheme: Option, +} impl HostGuard { /// Set request scheme to match pub fn scheme>(mut self, scheme: H) -> HostGuard { - self.1 = Some(scheme.as_ref().to_string()); + self.scheme = Some(scheme.as_ref().to_string()); self } } impl Guard for HostGuard { - fn check(&self, req: &RequestHead) -> bool { - let req_host_uri = if let Some(uri) = get_host_uri(req) { - uri - } else { - return false; + fn check(&self, ctx: &GuardContext<'_>) -> bool { + // parse host URI from header or request target + let req_host_uri = match get_host_uri(ctx.head()) { + Some(uri) => uri, + + // no match if host cannot be determined + None => return false, }; - if let Some(uri_host) = req_host_uri.host() { - if self.0 != uri_host { - return false; - } - } else { - return false; + match req_host_uri.host() { + // fall through to scheme checks + Some(uri_host) if self.host == uri_host => {} + + // Either: + // - request's host does not match guard's host; + // - It was possible that the parsed URI from request target did not contain a host. + _ => return false, } - if let Some(ref scheme) = self.1 { + if let Some(ref scheme) = self.scheme { if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() { return scheme == req_host_uri_scheme; } + + // TODO: is the the correct behavior? + // falls through if scheme cannot be determined } + // all conditions passed true } } @@ -327,171 +445,201 @@ mod tests { use crate::test::TestRequest; #[test] - fn test_header() { + fn header_match() { let req = TestRequest::default() .insert_header((header::TRANSFER_ENCODING, "chunked")) - .to_http_request(); + .to_srv_request(); - let pred = Header("transfer-encoding", "chunked"); - assert!(pred.check(req.head())); + let hdr = Header("transfer-encoding", "chunked"); + assert!(hdr.check(&req.guard_ctx())); - let pred = Header("transfer-encoding", "other"); - assert!(!pred.check(req.head())); + let hdr = Header("transfer-encoding", "other"); + assert!(!hdr.check(&req.guard_ctx())); - let pred = Header("content-type", "other"); - assert!(!pred.check(req.head())); + let hdr = Header("content-type", "chunked"); + assert!(!hdr.check(&req.guard_ctx())); + + let hdr = Header("content-type", "other"); + assert!(!hdr.check(&req.guard_ctx())); } #[test] - fn test_host() { + fn host_from_header() { let req = TestRequest::default() .insert_header(( header::HOST, header::HeaderValue::from_static("www.rust-lang.org"), )) - .to_http_request(); + .to_srv_request(); - let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org").scheme("https"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org").scheme("https"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("crates.io"); - assert!(!pred.check(req.head())); + let host = Host("crates.io"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("localhost"); - assert!(!pred.check(req.head())); + let host = Host("localhost"); + assert!(!host.check(&req.guard_ctx())); } #[test] - fn test_host_scheme() { + fn host_without_header() { + let req = TestRequest::default() + .uri("www.rust-lang.org") + .to_srv_request(); + + let host = Host("www.rust-lang.org"); + assert!(host.check(&req.guard_ctx())); + + let host = Host("www.rust-lang.org").scheme("https"); + assert!(host.check(&req.guard_ctx())); + + let host = Host("blog.rust-lang.org"); + assert!(!host.check(&req.guard_ctx())); + + let host = Host("blog.rust-lang.org").scheme("https"); + assert!(!host.check(&req.guard_ctx())); + + let host = Host("crates.io"); + assert!(!host.check(&req.guard_ctx())); + + let host = Host("localhost"); + assert!(!host.check(&req.guard_ctx())); + } + + #[test] + fn host_scheme() { let req = TestRequest::default() .insert_header(( header::HOST, header::HeaderValue::from_static("https://www.rust-lang.org"), )) - .to_http_request(); + .to_srv_request(); - let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org").scheme("https"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("www.rust-lang.org").scheme("http"); - assert!(!pred.check(req.head())); + let host = Host("www.rust-lang.org").scheme("http"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org").scheme("https"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("crates.io").scheme("https"); - assert!(!pred.check(req.head())); + let host = Host("crates.io").scheme("https"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("localhost"); - assert!(!pred.check(req.head())); + let host = Host("localhost"); + assert!(!host.check(&req.guard_ctx())); } #[test] - fn test_host_without_header() { + fn method_guards() { + let get_req = TestRequest::get().to_srv_request(); + let post_req = TestRequest::post().to_srv_request(); + + assert!(Get().check(&get_req.guard_ctx())); + assert!(!Get().check(&post_req.guard_ctx())); + + assert!(Post().check(&post_req.guard_ctx())); + assert!(!Post().check(&get_req.guard_ctx())); + + let req = TestRequest::put().to_srv_request(); + assert!(Put().check(&req.guard_ctx())); + assert!(!Put().check(&get_req.guard_ctx())); + + let req = TestRequest::patch().to_srv_request(); + assert!(Patch().check(&req.guard_ctx())); + assert!(!Patch().check(&get_req.guard_ctx())); + + let r = TestRequest::delete().to_srv_request(); + assert!(Delete().check(&r.guard_ctx())); + assert!(!Delete().check(&get_req.guard_ctx())); + + let req = TestRequest::default().method(Method::HEAD).to_srv_request(); + assert!(Head().check(&req.guard_ctx())); + assert!(!Head().check(&get_req.guard_ctx())); + let req = TestRequest::default() - .uri("www.rust-lang.org") - .to_http_request(); - - let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); - - let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); - - let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); - - let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); - - let pred = Host("crates.io"); - assert!(!pred.check(req.head())); - - let pred = Host("localhost"); - assert!(!pred.check(req.head())); - } - - #[test] - fn test_methods() { - let req = TestRequest::default().to_http_request(); - let req2 = TestRequest::default() - .method(Method::POST) - .to_http_request(); - - assert!(Get().check(req.head())); - assert!(!Get().check(req2.head())); - assert!(Post().check(req2.head())); - assert!(!Post().check(req.head())); - - let r = TestRequest::default().method(Method::PUT).to_http_request(); - assert!(Put().check(r.head())); - assert!(!Put().check(req.head())); - - let r = TestRequest::default() - .method(Method::DELETE) - .to_http_request(); - assert!(Delete().check(r.head())); - assert!(!Delete().check(req.head())); - - let r = TestRequest::default() - .method(Method::HEAD) - .to_http_request(); - assert!(Head().check(r.head())); - assert!(!Head().check(req.head())); - - let r = TestRequest::default() .method(Method::OPTIONS) - .to_http_request(); - assert!(Options().check(r.head())); - assert!(!Options().check(req.head())); + .to_srv_request(); + assert!(Options().check(&req.guard_ctx())); + assert!(!Options().check(&get_req.guard_ctx())); - let r = TestRequest::default() + let req = TestRequest::default() .method(Method::CONNECT) - .to_http_request(); - assert!(Connect().check(r.head())); - assert!(!Connect().check(req.head())); + .to_srv_request(); + assert!(Connect().check(&req.guard_ctx())); + assert!(!Connect().check(&get_req.guard_ctx())); - let r = TestRequest::default() - .method(Method::PATCH) - .to_http_request(); - assert!(Patch().check(r.head())); - assert!(!Patch().check(req.head())); - - let r = TestRequest::default() + let req = TestRequest::default() .method(Method::TRACE) - .to_http_request(); - assert!(Trace().check(r.head())); - assert!(!Trace().check(req.head())); + .to_srv_request(); + assert!(Trace().check(&req.guard_ctx())); + assert!(!Trace().check(&get_req.guard_ctx())); } #[test] - fn test_preds() { - let r = TestRequest::default() + fn aggregate_any() { + let req = TestRequest::default() .method(Method::TRACE) - .to_http_request(); + .to_srv_request(); - assert!(Not(Get()).check(r.head())); - assert!(!Not(Trace()).check(r.head())); + assert!(Any(Trace()).check(&req.guard_ctx())); + assert!(Any(Trace()).or(Get()).check(&req.guard_ctx())); + assert!(!Any(Get()).or(Get()).check(&req.guard_ctx())); + } - assert!(All(Trace()).and(Trace()).check(r.head())); - assert!(!All(Get()).and(Trace()).check(r.head())); + #[test] + fn aggregate_all() { + let req = TestRequest::default() + .method(Method::TRACE) + .to_srv_request(); - assert!(Any(Get()).or(Trace()).check(r.head())); - assert!(!Any(Get()).or(Get()).check(r.head())); + assert!(All(Trace()).check(&req.guard_ctx())); + assert!(All(Trace()).and(Trace()).check(&req.guard_ctx())); + assert!(!All(Trace()).and(Get()).check(&req.guard_ctx())); + } + + #[test] + fn nested_not() { + let req = TestRequest::default().to_srv_request(); + + let get = Get(); + assert!(get.check(&req.guard_ctx())); + + let not_get = Not(get); + assert!(!not_get.check(&req.guard_ctx())); + + let not_not_get = Not(not_get); + assert!(not_not_get.check(&req.guard_ctx())); + } + + #[test] + fn function_guard() { + let domain = "rust-lang.org".to_owned(); + let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain)); + + let req = TestRequest::default() + .uri("blog.rust-lang.org") + .to_srv_request(); + assert!(guard.check(&req.guard_ctx())); + + let req = TestRequest::default().uri("crates.io").to_srv_request(); + assert!(!guard.check(&req.guard_ctx())); } } diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index 18dcaeef..3ab90848 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -225,7 +225,7 @@ mod tests { .service(web::resource("/v1/something").to(HttpResponse::Ok)) .service( web::resource("/v2/something") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -261,7 +261,7 @@ mod tests { .service(web::resource("/v1/something").to(HttpResponse::Ok)) .service( web::resource("/v2/something") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -294,7 +294,7 @@ mod tests { let app = init_service( App::new().wrap(NormalizePath(TrailingSlash::Trim)).service( web::resource("/") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -318,7 +318,7 @@ mod tests { .service(web::resource("/v1/something/").to(HttpResponse::Ok)) .service( web::resource("/v2/something/") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -353,7 +353,7 @@ mod tests { .wrap(NormalizePath(TrailingSlash::Always)) .service( web::resource("/") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -378,7 +378,7 @@ mod tests { .service(web::resource("/v1/").to(HttpResponse::Ok)) .service( web::resource("/v2/something") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) diff --git a/src/route.rs b/src/route.rs index 6d6fca4b..0410b99d 100644 --- a/src/route.rs +++ b/src/route.rs @@ -65,9 +65,12 @@ pub struct RouteService { } impl RouteService { + // TODO: does this need to take &mut ? pub fn check(&self, req: &mut ServiceRequest) -> bool { - for f in self.guards.iter() { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in self.guards.iter() { + if !guard.check(&guard_ctx) { return false; } } @@ -90,6 +93,7 @@ impl Service for RouteService { impl Route { /// Add method guard to the route. /// + /// # Examples /// ``` /// # use actix_web::*; /// # fn main() { @@ -110,6 +114,7 @@ impl Route { /// Add guard to the route. /// + /// # Examples /// ``` /// # use actix_web::*; /// # fn main() { @@ -143,16 +148,13 @@ impl Route { /// format!("Welcome {}!", info.username) /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/{username}/index.html") // <- define path parameters - /// .route(web::get().to(index)) // <- register handler - /// ); - /// } + /// let app = App::new().service( + /// web::resource("/{username}/index.html") // <- define path parameters + /// .route(web::get().to(index)) // <- register handler + /// ); /// ``` /// /// It is possible to use multiple extractors for one handler function. - /// /// ``` /// # use std::collections::HashMap; /// # use serde::Deserialize; @@ -164,16 +166,18 @@ impl Route { /// } /// /// /// extract path info using serde - /// async fn index(path: web::Path, query: web::Query>, body: web::Json) -> String { + /// async fn index( + /// path: web::Path, + /// query: web::Query>, + /// body: web::Json + /// ) -> String { /// format!("Welcome {}!", path.username) /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/{username}/index.html") // <- define path parameters - /// .route(web::get().to(index)) - /// ); - /// } + /// let app = App::new().service( + /// web::resource("/{username}/index.html") // <- define path parameters + /// .route(web::get().to(index)) + /// ); /// ``` pub fn to(mut self, handler: F) -> Self where @@ -199,7 +203,7 @@ impl Route { /// type Error = Infallible; /// type Future = LocalBoxFuture<'static, Result>; /// - /// always_ready!(); + /// dev::always_ready!(); /// /// fn call(&self, req: ServiceRequest) -> Self::Future { /// let (req, _) = req.into_parts(); diff --git a/src/scope.rs b/src/scope.rs index 176e0d5a..b4618bb6 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -538,12 +538,15 @@ impl Service for ScopeService { fn call(&self, mut req: ServiceRequest) -> Self::Future { let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { - for f in guards { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in guards { + if !guard.check(&guard_ctx) { return false; } } } + true }); diff --git a/src/service.rs b/src/service.rs index d5c381fa..97555619 100644 --- a/src/service.rs +++ b/src/service.rs @@ -21,7 +21,7 @@ use cookie::{Cookie, ParseError as CookieParseError}; use crate::{ config::{AppConfig, AppService}, dev::ensure_leading_slash, - guard::Guard, + guard::{Guard, GuardContext}, info::ConnectionInfo, rmap::ResourceMap, Error, HttpRequest, HttpResponse, @@ -172,7 +172,7 @@ impl ServiceRequest { self.head().uri.path() } - /// Counterpart to [`HttpRequest::query_string`](super::HttpRequest::query_string()). + /// Counterpart to [`HttpRequest::query_string`]. #[inline] pub fn query_string(&self) -> &str { self.req.query_string() @@ -208,13 +208,13 @@ impl ServiceRequest { self.req.match_info() } - /// Counterpart to [`HttpRequest::match_name`](super::HttpRequest::match_name()). + /// Counterpart to [`HttpRequest::match_name`]. #[inline] pub fn match_name(&self) -> Option<&str> { self.req.match_name() } - /// Counterpart to [`HttpRequest::match_pattern`](super::HttpRequest::match_pattern()). + /// Counterpart to [`HttpRequest::match_pattern`]. #[inline] pub fn match_pattern(&self) -> Option { self.req.match_pattern() @@ -238,7 +238,7 @@ impl ServiceRequest { self.req.app_config() } - /// Counterpart to [`HttpRequest::app_data`](super::HttpRequest::app_data()). + /// Counterpart to [`HttpRequest::app_data`]. #[inline] pub fn app_data(&self) -> Option<&T> { for container in self.req.inner.app_data.iter().rev() { @@ -250,19 +250,33 @@ impl ServiceRequest { None } - /// Counterpart to [`HttpRequest::conn_data`](super::HttpRequest::conn_data()). + /// Counterpart to [`HttpRequest::conn_data`]. #[inline] pub fn conn_data(&self) -> Option<&T> { self.req.conn_data() } + /// Counterpart to [`HttpRequest::req_data`]. + #[inline] + pub fn req_data(&self) -> Ref<'_, Extensions> { + self.req.req_data() + } + + /// Counterpart to [`HttpRequest::req_data_mut`]. + #[inline] + pub fn req_data_mut(&self) -> RefMut<'_, Extensions> { + self.req.req_data_mut() + } + #[cfg(feature = "cookies")] + #[inline] pub fn cookies(&self) -> Result>>, CookieParseError> { self.req.cookies() } /// Return request cookie. #[cfg(feature = "cookies")] + #[inline] pub fn cookie(&self, name: &str) -> Option> { self.req.cookie(name) } @@ -283,6 +297,14 @@ impl ServiceRequest { .app_data .push(extensions); } + + /// Creates a context object for use with a [guard](crate::guard). + /// + /// Useful if you are implementing + #[inline] + pub fn guard_ctx(&self) -> GuardContext<'_> { + GuardContext { req: self } + } } impl Resource for ServiceRequest { diff --git a/src/test/test_request.rs b/src/test/test_request.rs index 5c4de908..fc42253d 100644 --- a/src/test/test_request.rs +++ b/src/test/test_request.rs @@ -124,7 +124,7 @@ impl TestRequest { self } - /// Set HTTP Uri of this request + /// Set HTTP URI of this request pub fn uri(mut self, path: &str) -> Self { self.req.uri(path); self