1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-23 16:21:06 +01:00

rework Guard trait (#2552)

This commit is contained in:
Rob Ede 2021-12-28 02:37:13 +00:00 committed by GitHub
parent 36193b0a50
commit 4616ca8ee6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 545 additions and 351 deletions

View File

@ -1,6 +1,17 @@
# Changes
## Unreleased - 2021-xx-xx
### Added
- `guard::GuardContext` for use with the `Guard` trait. [#2552]
- `ServiceRequest::guard_ctx` for obtaining a guard context. [#2552]
### Changed
- `Guard` trait now receives a `&GuardContext`. [#2552]
- `guard::fn_guard` functions now receives a `&GuardContext`. [#2552]
- Some guards now return `impl Guard` and their concrete types are made private: `guard::{Header}` and all the method guards. [#2552]
- The `Not` guard is now generic over the type of guard it wraps. [#2552]
[#2552]: https://github.com/actix/actix-web/pull/2552
## 4.0.0-beta.16 - 2021-12-27

View File

@ -1,8 +1,8 @@
use std::{fmt, io, ops::Deref, path::PathBuf, rc::Rc};
use actix_service::Service;
use actix_web::{
dev::{ServiceRequest, ServiceResponse},
body::BoxBody,
dev::{Service, ServiceRequest, ServiceResponse},
error::Error,
guard::Guard,
http::{header, Method},
@ -94,7 +94,7 @@ impl fmt::Debug for FilesService {
}
impl Service<ServiceRequest> for FilesService {
type Response = ServiceResponse;
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
@ -103,7 +103,7 @@ impl Service<ServiceRequest> for FilesService {
fn call(&self, req: ServiceRequest) -> Self::Future {
let is_method_valid = if let Some(guard) = &self.guards {
// execute user defined guards
(**guard).check(req.head())
(**guard).check(&req.guard_ctx())
} else {
// default behavior
matches!(*req.method(), Method::HEAD | Method::GET)

View File

@ -1,4 +1,4 @@
use std::{cell::RefCell, rc::Rc};
use std::{cell::RefCell, ops, rc::Rc};
use bitflags::bitflags;
@ -49,7 +49,7 @@ impl<T: Head> Message<T> {
}
}
impl<T: Head> std::ops::Deref for Message<T> {
impl<T: Head> ops::Deref for Message<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
@ -57,7 +57,7 @@ impl<T: Head> std::ops::Deref for Message<T> {
}
}
impl<T: Head> std::ops::DerefMut for Message<T> {
impl<T: Head> ops::DerefMut for Message<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
Rc::get_mut(&mut self.head).expect("Multiple copies exist")
}

View File

@ -1,14 +1,16 @@
use std::{cell::RefCell, mem, rc::Rc};
use actix_http::{Extensions, Request};
use actix_http::Request;
use actix_router::{Path, ResourceDef, Router, Url};
use actix_service::{boxed, fn_service, Service, ServiceFactory};
use futures_core::future::LocalBoxFuture;
use futures_util::future::join_all;
use crate::{
body::BoxBody,
config::{AppConfig, AppService},
data::FnDataFactory,
dev::Extensions,
guard::Guard,
request::{HttpRequest, HttpRequestPool},
rmap::ResourceMap,
@ -297,7 +299,7 @@ pub struct AppRouting {
}
impl Service<ServiceRequest> for AppRouting {
type Response = ServiceResponse;
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
@ -306,12 +308,15 @@ impl Service<ServiceRequest> for AppRouting {
fn call(&self, mut req: ServiceRequest) -> Self::Future {
let res = self.router.recognize_fn(&mut req, |req, guards| {
if let Some(ref guards) = guards {
for f in guards {
if !f.check(req.head()) {
let guard_ctx = req.guard_ctx();
for guard in guards {
if !guard.check(&guard_ctx) {
return false;
}
}
}
true
});

View File

@ -3,6 +3,16 @@
//! Most users will not have to interact with the types in this module, but it is useful for those
//! writing extractors, middleware, libraries, or interacting with the service API directly.
pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead};
pub use actix_router::{Path, ResourceDef, ResourcePath, Url};
pub use actix_server::{Server, ServerHandle};
pub use actix_service::{
always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform,
};
#[cfg(feature = "__compress")]
pub use actix_http::encoding::Decoder as Decompress;
pub use crate::config::{AppConfig, AppService};
#[doc(hidden)]
pub use crate::handler::Handler;
@ -14,16 +24,6 @@ pub use crate::types::form::UrlEncoded;
pub use crate::types::json::JsonBody;
pub use crate::types::readlines::Readlines;
pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead};
pub use actix_router::{Path, ResourceDef, ResourcePath, Url};
pub use actix_server::{Server, ServerHandle};
pub use actix_service::{
always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform,
};
#[cfg(feature = "__compress")]
pub use actix_http::encoding::Decoder as Decompress;
use crate::http::header::ContentEncoding;
use actix_router::Patterns;
@ -46,7 +46,6 @@ pub(crate) fn ensure_leading_slash(mut patterns: Patterns) -> Patterns {
patterns
}
struct Enc(ContentEncoding);
/// Helper trait that allows to set specific encoding for response.
pub trait BodyEncoding {
@ -70,6 +69,8 @@ impl BodyEncoding for actix_http::ResponseBuilder {
}
}
struct Enc(ContentEncoding);
impl<B> BodyEncoding for actix_http::Response<B> {
fn get_encoding(&self) -> Option<ContentEncoding> {
self.extensions().get::<Enc>().map(|enc| enc.0)

View File

@ -1,160 +1,248 @@
//! Route match guards.
//! Route guards.
//!
//! Guards are one of the ways how actix-web router chooses a
//! handler service. In essence it is just a function that accepts a
//! reference to a `RequestHead` instance and returns a boolean.
//! It is possible to add guards to *scopes*, *resources*
//! and *routes*. Actix provide several guards by default, like various
//! http methods, header, etc. To become a guard, type must implement `Guard`
//! trait. Simple functions could be guards as well.
//! Guards are used during routing to help select a matching service or handler using some aspect of
//! the request; though guards should not be used for path matching since it is a built-in function
//! of the Actix Web router.
//!
//! Guards can not modify the request object. But it is possible
//! to store extra attributes on a request by using the `Extensions` container.
//! Extensions containers are available via the `RequestHead::extensions()` method.
//! Guards can be used on [`Scope`]s, [`Resource`]s, [`Route`]s, and other custom services.
//!
//! Fundamentally, a guard is a predicate function that receives a reference to a request context
//! object and returns a boolean; true if the request _should_ be handled by the guarded service
//! or handler. This interface is defined by the [`Guard`] trait.
//!
//! Commonly-used guards are provided in this module as well as way of creating a guard from a
//! closure ([`fn_guard`]). The [`Not`], [`Any`], and [`All`] guards are noteworthy, as they can be
//! used to compose other guards in a more flexible and semantic way than calling `.guard(...)` on
//! services multiple times (which might have different combining behavior than you want).
//!
//! Guards can not modify anything about the request. However, it is possible to store extra
//! attributes in the request-local data container obtained with [`GuardContext::req_data_mut`].
//!
//! Guards can prevent resource definitions from overlapping (resulting in some inaccessible routes)
//! where they otherwise would when only considering paths. See the virtual hosting example below.
//!
//! # Examples
//! In the following code, the `/guarded` resource has one defined route whose handler will only be
//! called if the request method is `POST` and there is a request header with name and value equal
//! to `x-guarded` and `secret`, respectively.
//! ```
//! use actix_web::{web, http, dev, guard, App, HttpResponse};
//! use actix_web::{web, http::Method, guard, HttpResponse};
//!
//! App::new().service(web::resource("/index.html").route(
//! web::resource("/guarded").route(
//! web::route()
//! .guard(guard::Post())
//! .guard(guard::fn_guard(|head| head.method == http::Method::GET))
//! .to(|| HttpResponse::MethodNotAllowed()))
//! .guard(guard::Any(guard::Get()).or(guard::Post()))
//! .guard(guard::Header("x-guarded", "secret"))
//! .to(|| HttpResponse::Ok())
//! );
//! ```
//!
//! Guards can be used to set up some form of [virtual hosting] within a single app.
//! Overlapping scope prefixes are usually discouraged, but when combined with non-overlapping guard
//! definitions they become safe to use in this way. Without these host guards, only routes under
//! the first-to-be-defined scope would be accessible. You can test this locally using `127.0.0.1`
//! and `localhost` as the `Host` guards.
//! ```
//! use actix_web::{web, http::Method, guard, App, HttpResponse};
//!
//! App::new()
//! .service(
//! web::scope("")
//! .guard(guard::Host("www.rust-lang.org"))
//! .default_service(web::to(|| HttpResponse::Ok().body("marketing site"))),
//! )
//! .service(
//! web::scope("")
//! .guard(guard::Host("play.rust-lang.org"))
//! .default_service(web::to(|| HttpResponse::Ok().body("playground frontend"))),
//! );
//! ```
//!
//! [`Scope`]: crate::Scope::guard()
//! [`Resource`]: crate::Resource::guard()
//! [`Route`]: crate::Route::guard()
//! [virtual hosting]: https://en.wikipedia.org/wiki/Virtual_hosting
#![allow(non_snake_case)]
use std::{
cell::{Ref, RefMut},
convert::TryFrom,
rc::Rc,
};
use std::rc::Rc;
use std::{convert::TryFrom, ops::Deref};
use actix_http::{header, uri::Uri, Extensions, Method as HttpMethod, RequestHead};
use actix_http::{header, uri::Uri, Method as HttpMethod, RequestHead};
use crate::service::ServiceRequest;
/// Trait defines resource guards. Guards are used for route selection.
///
/// Guards can not modify the request object. But it is possible
/// to store extra attributes on a request by using the `Extensions` container.
/// Extensions containers are available via the `RequestHead::extensions()` method.
pub trait Guard {
/// Check if request matches predicate
fn check(&self, request: &RequestHead) -> bool;
/// Provides access to request parts that are useful during routing.
#[derive(Debug)]
pub struct GuardContext<'a> {
pub(crate) req: &'a ServiceRequest,
}
impl Guard for Rc<dyn Guard> {
fn check(&self, request: &RequestHead) -> bool {
self.deref().check(request)
impl<'a> GuardContext<'a> {
/// Returns reference to the request head.
#[inline]
pub fn head(&self) -> &RequestHead {
self.req.head()
}
/// Returns reference to the request-local data container.
#[inline]
pub fn req_data(&self) -> Ref<'a, Extensions> {
self.req.req_data()
}
/// Returns mutable reference to the request-local data container.
#[inline]
pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
self.req.req_data_mut()
}
}
/// Create guard object for supplied function.
/// Interface for routing guards.
///
/// See [module level documentation](self) for more.
pub trait Guard {
/// Returns true if predicate condition is met for a given request.
fn check(&self, ctx: &GuardContext<'_>) -> bool;
}
impl Guard for Rc<dyn Guard> {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
(**self).check(ctx)
}
}
/// Creates a guard using the given function.
///
/// # Examples
/// ```
/// use actix_web::{guard, web, App, HttpResponse};
/// use actix_web::{guard, web, HttpResponse};
///
/// App::new().service(web::resource("/index.html").route(
/// web::route()
/// .guard(
/// guard::fn_guard(
/// |req| req.headers()
/// .contains_key("content-type")))
/// .to(|| HttpResponse::MethodNotAllowed()))
/// );
/// .guard(guard::fn_guard(|ctx| {
/// ctx.head().headers().contains_key("content-type")
/// }))
/// .to(|| HttpResponse::Ok());
/// ```
pub fn fn_guard<F>(f: F) -> impl Guard
where
F: Fn(&RequestHead) -> bool,
F: Fn(&GuardContext<'_>) -> bool,
{
FnGuard(f)
}
struct FnGuard<F: Fn(&RequestHead) -> bool>(F);
struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
impl<F> Guard for FnGuard<F>
where
F: Fn(&RequestHead) -> bool,
F: Fn(&GuardContext<'_>) -> bool,
{
fn check(&self, head: &RequestHead) -> bool {
(self.0)(head)
fn check(&self, ctx: &GuardContext<'_>) -> bool {
(self.0)(ctx)
}
}
impl<F> Guard for F
where
F: Fn(&RequestHead) -> bool,
F: Fn(&GuardContext<'_>) -> bool,
{
fn check(&self, head: &RequestHead) -> bool {
(self)(head)
fn check(&self, ctx: &GuardContext<'_>) -> bool {
(self)(ctx)
}
}
/// Return guard that matches if any of supplied guards.
/// Creates a guard that matches if any added guards match.
///
/// # Examples
/// The handler below will be called for either request method `GET` or `POST`.
/// ```
/// use actix_web::{web, guard, App, HttpResponse};
/// use actix_web::{web, guard, HttpResponse};
///
/// App::new().service(web::resource("/index.html").route(
/// web::route()
/// .guard(guard::Any(guard::Get()).or(guard::Post()))
/// .to(|| HttpResponse::MethodNotAllowed()))
/// );
/// .guard(
/// guard::Any(guard::Get())
/// .or(guard::Post()))
/// .to(|| HttpResponse::Ok());
/// ```
#[allow(non_snake_case)]
pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
AnyGuard(vec![Box::new(guard)])
AnyGuard {
guards: vec![Box::new(guard)],
}
}
/// Matches any of supplied guards.
pub struct AnyGuard(Vec<Box<dyn Guard>>);
/// A collection of guards that match if the disjunction of their `check` outcomes is true.
///
/// That is, only one contained guard needs to match in order for the aggregate guard to match.
///
/// Construct an `AnyGuard` using [`Any`].
pub struct AnyGuard {
guards: Vec<Box<dyn Guard>>,
}
impl AnyGuard {
/// Add guard to a list of guards to check
/// Adds new guard to the collection of guards to check.
pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
self.0.push(Box::new(guard));
self.guards.push(Box::new(guard));
self
}
}
impl Guard for AnyGuard {
fn check(&self, req: &RequestHead) -> bool {
for p in &self.0 {
if p.check(req) {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
for guard in &self.guards {
if guard.check(ctx) {
return true;
}
}
false
}
}
/// Return guard that matches if all of the supplied guards.
/// Creates a guard that matches if all added guards match.
///
/// # Examples
/// The handler below will only be called if the request method is `GET` **and** the specified
/// header name and value match exactly.
/// ```
/// use actix_web::{guard, web, App, HttpResponse};
/// use actix_web::{guard, web, HttpResponse};
///
/// App::new().service(web::resource("/index.html").route(
/// web::route()
/// .guard(
/// guard::All(guard::Get()).and(guard::Header("content-type", "text/plain")))
/// .to(|| HttpResponse::MethodNotAllowed()))
/// );
/// guard::All(guard::Get())
/// .and(guard::Header("accept", "text/plain"))
/// )
/// .to(|| HttpResponse::Ok());
/// ```
#[allow(non_snake_case)]
pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
AllGuard(vec![Box::new(guard)])
AllGuard {
guards: vec![Box::new(guard)],
}
}
/// Matches if all of supplied guards.
pub struct AllGuard(Vec<Box<dyn Guard>>);
/// A collection of guards that match if the conjunction of their `check` outcomes is true.
///
/// That is, **all** contained guard needs to match in order for the aggregate guard to match.
///
/// Construct an `AllGuard` using [`All`].
pub struct AllGuard {
guards: Vec<Box<dyn Guard>>,
}
impl AllGuard {
/// Add new guard to the list of guards to check
/// Adds new guard to the collection of guards to check.
pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
self.0.push(Box::new(guard));
self.guards.push(Box::new(guard));
self
}
}
impl Guard for AllGuard {
fn check(&self, request: &RequestHead) -> bool {
for p in &self.0 {
if !p.check(request) {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
for guard in &self.guards {
if !guard.check(ctx) {
return false;
}
}
@ -162,159 +250,189 @@ impl Guard for AllGuard {
}
}
/// Return guard that matches if supplied guard does not match.
pub fn Not<F: Guard + 'static>(guard: F) -> NotGuard {
NotGuard(Box::new(guard))
}
/// Wraps a guard and inverts the outcome of it's `Guard` implementation.
///
/// # Examples
/// The handler below will be called for any request method apart from `GET`.
/// ```
/// use actix_web::{guard, web, HttpResponse};
///
/// web::route()
/// .guard(guard::Not(guard::Get()))
/// .to(|| HttpResponse::Ok());
/// ```
pub struct Not<G>(pub G);
#[doc(hidden)]
pub struct NotGuard(Box<dyn Guard>);
impl Guard for NotGuard {
fn check(&self, request: &RequestHead) -> bool {
!self.0.check(request)
impl<G: Guard> Guard for Not<G> {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
!self.0.check(ctx)
}
}
/// HTTP method guard.
#[doc(hidden)]
pub struct MethodGuard(HttpMethod);
impl Guard for MethodGuard {
fn check(&self, request: &RequestHead) -> bool {
request.method == self.0
}
}
/// Guard to match *GET* HTTP method.
pub fn Get() -> MethodGuard {
MethodGuard(HttpMethod::GET)
}
/// Predicate to match *POST* HTTP method.
pub fn Post() -> MethodGuard {
MethodGuard(HttpMethod::POST)
}
/// Predicate to match *PUT* HTTP method.
pub fn Put() -> MethodGuard {
MethodGuard(HttpMethod::PUT)
}
/// Predicate to match *DELETE* HTTP method.
pub fn Delete() -> MethodGuard {
MethodGuard(HttpMethod::DELETE)
}
/// Predicate to match *HEAD* HTTP method.
pub fn Head() -> MethodGuard {
MethodGuard(HttpMethod::HEAD)
}
/// Predicate to match *OPTIONS* HTTP method.
pub fn Options() -> MethodGuard {
MethodGuard(HttpMethod::OPTIONS)
}
/// Predicate to match *CONNECT* HTTP method.
pub fn Connect() -> MethodGuard {
MethodGuard(HttpMethod::CONNECT)
}
/// Predicate to match *PATCH* HTTP method.
pub fn Patch() -> MethodGuard {
MethodGuard(HttpMethod::PATCH)
}
/// Predicate to match *TRACE* HTTP method.
pub fn Trace() -> MethodGuard {
MethodGuard(HttpMethod::TRACE)
}
/// Predicate to match specified HTTP method.
pub fn Method(method: HttpMethod) -> MethodGuard {
/// Creates a guard that matches a specified HTTP method.
#[allow(non_snake_case)]
pub fn Method(method: HttpMethod) -> impl Guard {
MethodGuard(method)
}
/// Return predicate that matches if request contains specified header and
/// value.
pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard {
/// HTTP method guard.
struct MethodGuard(HttpMethod);
impl Guard for MethodGuard {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
ctx.head().method == self.0
}
}
macro_rules! method_guard {
($method_fn:ident, $method_const:ident) => {
paste::paste! {
#[doc = " Creates a guard that matches the `" $method_const "` request method."]
///
/// # Examples
#[doc = " The route in this example will only respond to `" $method_const "` requests."]
/// ```
/// use actix_web::{guard, web, HttpResponse};
///
/// web::route()
#[doc = " .guard(guard::" $method_fn "())"]
/// .to(|| HttpResponse::Ok());
/// ```
#[allow(non_snake_case)]
pub fn $method_fn() -> impl Guard {
MethodGuard(HttpMethod::$method_const)
}
}
};
}
method_guard!(Get, GET);
method_guard!(Post, POST);
method_guard!(Put, PUT);
method_guard!(Delete, DELETE);
method_guard!(Head, HEAD);
method_guard!(Options, OPTIONS);
method_guard!(Connect, CONNECT);
method_guard!(Patch, PATCH);
method_guard!(Trace, TRACE);
/// Creates a guard that matches if request contains given header name and value.
///
/// # Examples
/// The handler below will be called when the request contains an `x-guarded` header with value
/// equal to `secret`.
/// ```
/// use actix_web::{guard, web, HttpResponse};
///
/// web::route()
/// .guard(guard::Header("x-guarded", "secret"))
/// .to(|| HttpResponse::Ok());
/// ```
#[allow(non_snake_case)]
pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
HeaderGuard(
header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value),
)
}
#[doc(hidden)]
pub struct HeaderGuard(header::HeaderName, header::HeaderValue);
struct HeaderGuard(header::HeaderName, header::HeaderValue);
impl Guard for HeaderGuard {
fn check(&self, req: &RequestHead) -> bool {
if let Some(val) = req.headers.get(&self.0) {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
if let Some(val) = ctx.head().headers.get(&self.0) {
return val == self.1;
}
false
}
}
/// Return predicate that matches if request contains specified Host name.
/// Creates a guard that matches requests targetting a specific host.
///
/// ```
/// use actix_web::{web, guard::Host, App, HttpResponse};
/// # Matching Host
/// This guard will:
/// - match against the `Host` header, if present;
/// - fall-back to matching against the request target's host, if present;
/// - return false if host cannot be determined;
///
/// App::new().service(
/// web::resource("/index.html")
/// .guard(Host("www.rust-lang.org"))
/// .to(|| HttpResponse::MethodNotAllowed())
/// );
/// # Matching Scheme
/// Optionally, this guard can match against the host's scheme. Set the scheme for matching using
/// `Host(host).scheme(protocol)`. If the request's scheme cannot be determined, it will not prevent
/// the guard from matching successfully.
///
/// # Examples
/// The [module-level documentation](self) has an example of virtual hosting using `Host` guards.
///
/// The example below additionally guards on the host URI's scheme. This could allow routing to
/// different handlers for `http:` vs `https:` visitors; to redirect, for example.
/// ```
pub fn Host<H: AsRef<str>>(host: H) -> HostGuard {
HostGuard(host.as_ref().to_string(), None)
/// use actix_web::{web, guard::Host, HttpResponse};
///
/// web::scope("/admin")
/// .guard(Host("admin.rust-lang.org").scheme("https"))
/// .default_service(web::to(|| HttpResponse::Ok().body("admin connection is secure")));
/// ```
#[allow(non_snake_case)]
pub fn Host(host: impl AsRef<str>) -> HostGuard {
HostGuard {
host: host.as_ref().to_string(),
scheme: None,
}
}
fn get_host_uri(req: &RequestHead) -> Option<Uri> {
use core::str::FromStr;
req.headers
.get(header::HOST)
.and_then(|host_value| host_value.to_str().ok())
.or_else(|| req.uri.host())
.map(|host: &str| Uri::from_str(host).ok())
.and_then(|host_success| host_success)
.and_then(|host| host.parse().ok())
}
#[doc(hidden)]
pub struct HostGuard(String, Option<String>);
pub struct HostGuard {
host: String,
scheme: Option<String>,
}
impl HostGuard {
/// Set request scheme to match
pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard {
self.1 = Some(scheme.as_ref().to_string());
self.scheme = Some(scheme.as_ref().to_string());
self
}
}
impl Guard for HostGuard {
fn check(&self, req: &RequestHead) -> bool {
let req_host_uri = if let Some(uri) = get_host_uri(req) {
uri
} else {
return false;
fn check(&self, ctx: &GuardContext<'_>) -> bool {
// parse host URI from header or request target
let req_host_uri = match get_host_uri(ctx.head()) {
Some(uri) => uri,
// no match if host cannot be determined
None => return false,
};
if let Some(uri_host) = req_host_uri.host() {
if self.0 != uri_host {
return false;
}
} else {
return false;
match req_host_uri.host() {
// fall through to scheme checks
Some(uri_host) if self.host == uri_host => {}
// Either:
// - request's host does not match guard's host;
// - It was possible that the parsed URI from request target did not contain a host.
_ => return false,
}
if let Some(ref scheme) = self.1 {
if let Some(ref scheme) = self.scheme {
if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() {
return scheme == req_host_uri_scheme;
}
// TODO: is the the correct behavior?
// falls through if scheme cannot be determined
}
// all conditions passed
true
}
}
@ -327,171 +445,201 @@ mod tests {
use crate::test::TestRequest;
#[test]
fn test_header() {
fn header_match() {
let req = TestRequest::default()
.insert_header((header::TRANSFER_ENCODING, "chunked"))
.to_http_request();
.to_srv_request();
let pred = Header("transfer-encoding", "chunked");
assert!(pred.check(req.head()));
let hdr = Header("transfer-encoding", "chunked");
assert!(hdr.check(&req.guard_ctx()));
let pred = Header("transfer-encoding", "other");
assert!(!pred.check(req.head()));
let hdr = Header("transfer-encoding", "other");
assert!(!hdr.check(&req.guard_ctx()));
let pred = Header("content-type", "other");
assert!(!pred.check(req.head()));
let hdr = Header("content-type", "chunked");
assert!(!hdr.check(&req.guard_ctx()));
let hdr = Header("content-type", "other");
assert!(!hdr.check(&req.guard_ctx()));
}
#[test]
fn test_host() {
fn host_from_header() {
let req = TestRequest::default()
.insert_header((
header::HOST,
header::HeaderValue::from_static("www.rust-lang.org"),
))
.to_http_request();
.to_srv_request();
let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head()));
let host = Host("www.rust-lang.org");
assert!(host.check(&req.guard_ctx()));
let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head()));
let host = Host("www.rust-lang.org").scheme("https");
assert!(host.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head()));
let host = Host("blog.rust-lang.org");
assert!(!host.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head()));
let host = Host("blog.rust-lang.org").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let pred = Host("crates.io");
assert!(!pred.check(req.head()));
let host = Host("crates.io");
assert!(!host.check(&req.guard_ctx()));
let pred = Host("localhost");
assert!(!pred.check(req.head()));
let host = Host("localhost");
assert!(!host.check(&req.guard_ctx()));
}
#[test]
fn test_host_scheme() {
fn host_without_header() {
let req = TestRequest::default()
.uri("www.rust-lang.org")
.to_srv_request();
let host = Host("www.rust-lang.org");
assert!(host.check(&req.guard_ctx()));
let host = Host("www.rust-lang.org").scheme("https");
assert!(host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org");
assert!(!host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let host = Host("crates.io");
assert!(!host.check(&req.guard_ctx()));
let host = Host("localhost");
assert!(!host.check(&req.guard_ctx()));
}
#[test]
fn host_scheme() {
let req = TestRequest::default()
.insert_header((
header::HOST,
header::HeaderValue::from_static("https://www.rust-lang.org"),
))
.to_http_request();
.to_srv_request();
let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head()));
let host = Host("www.rust-lang.org").scheme("https");
assert!(host.check(&req.guard_ctx()));
let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head()));
let host = Host("www.rust-lang.org");
assert!(host.check(&req.guard_ctx()));
let pred = Host("www.rust-lang.org").scheme("http");
assert!(!pred.check(req.head()));
let host = Host("www.rust-lang.org").scheme("http");
assert!(!host.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head()));
let host = Host("blog.rust-lang.org");
assert!(!host.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head()));
let host = Host("blog.rust-lang.org").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let pred = Host("crates.io").scheme("https");
assert!(!pred.check(req.head()));
let host = Host("crates.io").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let pred = Host("localhost");
assert!(!pred.check(req.head()));
let host = Host("localhost");
assert!(!host.check(&req.guard_ctx()));
}
#[test]
fn test_host_without_header() {
fn method_guards() {
let get_req = TestRequest::get().to_srv_request();
let post_req = TestRequest::post().to_srv_request();
assert!(Get().check(&get_req.guard_ctx()));
assert!(!Get().check(&post_req.guard_ctx()));
assert!(Post().check(&post_req.guard_ctx()));
assert!(!Post().check(&get_req.guard_ctx()));
let req = TestRequest::put().to_srv_request();
assert!(Put().check(&req.guard_ctx()));
assert!(!Put().check(&get_req.guard_ctx()));
let req = TestRequest::patch().to_srv_request();
assert!(Patch().check(&req.guard_ctx()));
assert!(!Patch().check(&get_req.guard_ctx()));
let r = TestRequest::delete().to_srv_request();
assert!(Delete().check(&r.guard_ctx()));
assert!(!Delete().check(&get_req.guard_ctx()));
let req = TestRequest::default().method(Method::HEAD).to_srv_request();
assert!(Head().check(&req.guard_ctx()));
assert!(!Head().check(&get_req.guard_ctx()));
let req = TestRequest::default()
.uri("www.rust-lang.org")
.to_http_request();
let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head()));
let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head()));
let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head()));
let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head()));
let pred = Host("crates.io");
assert!(!pred.check(req.head()));
let pred = Host("localhost");
assert!(!pred.check(req.head()));
}
#[test]
fn test_methods() {
let req = TestRequest::default().to_http_request();
let req2 = TestRequest::default()
.method(Method::POST)
.to_http_request();
assert!(Get().check(req.head()));
assert!(!Get().check(req2.head()));
assert!(Post().check(req2.head()));
assert!(!Post().check(req.head()));
let r = TestRequest::default().method(Method::PUT).to_http_request();
assert!(Put().check(r.head()));
assert!(!Put().check(req.head()));
let r = TestRequest::default()
.method(Method::DELETE)
.to_http_request();
assert!(Delete().check(r.head()));
assert!(!Delete().check(req.head()));
let r = TestRequest::default()
.method(Method::HEAD)
.to_http_request();
assert!(Head().check(r.head()));
assert!(!Head().check(req.head()));
let r = TestRequest::default()
.method(Method::OPTIONS)
.to_http_request();
assert!(Options().check(r.head()));
assert!(!Options().check(req.head()));
.to_srv_request();
assert!(Options().check(&req.guard_ctx()));
assert!(!Options().check(&get_req.guard_ctx()));
let r = TestRequest::default()
let req = TestRequest::default()
.method(Method::CONNECT)
.to_http_request();
assert!(Connect().check(r.head()));
assert!(!Connect().check(req.head()));
.to_srv_request();
assert!(Connect().check(&req.guard_ctx()));
assert!(!Connect().check(&get_req.guard_ctx()));
let r = TestRequest::default()
.method(Method::PATCH)
.to_http_request();
assert!(Patch().check(r.head()));
assert!(!Patch().check(req.head()));
let r = TestRequest::default()
let req = TestRequest::default()
.method(Method::TRACE)
.to_http_request();
assert!(Trace().check(r.head()));
assert!(!Trace().check(req.head()));
.to_srv_request();
assert!(Trace().check(&req.guard_ctx()));
assert!(!Trace().check(&get_req.guard_ctx()));
}
#[test]
fn test_preds() {
let r = TestRequest::default()
fn aggregate_any() {
let req = TestRequest::default()
.method(Method::TRACE)
.to_http_request();
.to_srv_request();
assert!(Not(Get()).check(r.head()));
assert!(!Not(Trace()).check(r.head()));
assert!(Any(Trace()).check(&req.guard_ctx()));
assert!(Any(Trace()).or(Get()).check(&req.guard_ctx()));
assert!(!Any(Get()).or(Get()).check(&req.guard_ctx()));
}
assert!(All(Trace()).and(Trace()).check(r.head()));
assert!(!All(Get()).and(Trace()).check(r.head()));
#[test]
fn aggregate_all() {
let req = TestRequest::default()
.method(Method::TRACE)
.to_srv_request();
assert!(Any(Get()).or(Trace()).check(r.head()));
assert!(!Any(Get()).or(Get()).check(r.head()));
assert!(All(Trace()).check(&req.guard_ctx()));
assert!(All(Trace()).and(Trace()).check(&req.guard_ctx()));
assert!(!All(Trace()).and(Get()).check(&req.guard_ctx()));
}
#[test]
fn nested_not() {
let req = TestRequest::default().to_srv_request();
let get = Get();
assert!(get.check(&req.guard_ctx()));
let not_get = Not(get);
assert!(!not_get.check(&req.guard_ctx()));
let not_not_get = Not(not_get);
assert!(not_not_get.check(&req.guard_ctx()));
}
#[test]
fn function_guard() {
let domain = "rust-lang.org".to_owned();
let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain));
let req = TestRequest::default()
.uri("blog.rust-lang.org")
.to_srv_request();
assert!(guard.check(&req.guard_ctx()));
let req = TestRequest::default().uri("crates.io").to_srv_request();
assert!(!guard.check(&req.guard_ctx()));
}
}

View File

@ -225,7 +225,7 @@ mod tests {
.service(web::resource("/v1/something").to(HttpResponse::Ok))
.service(
web::resource("/v2/something")
.guard(fn_guard(|req| req.uri.query() == Some("query=test")))
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
)
@ -261,7 +261,7 @@ mod tests {
.service(web::resource("/v1/something").to(HttpResponse::Ok))
.service(
web::resource("/v2/something")
.guard(fn_guard(|req| req.uri.query() == Some("query=test")))
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
)
@ -294,7 +294,7 @@ mod tests {
let app = init_service(
App::new().wrap(NormalizePath(TrailingSlash::Trim)).service(
web::resource("/")
.guard(fn_guard(|req| req.uri.query() == Some("query=test")))
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
)
@ -318,7 +318,7 @@ mod tests {
.service(web::resource("/v1/something/").to(HttpResponse::Ok))
.service(
web::resource("/v2/something/")
.guard(fn_guard(|req| req.uri.query() == Some("query=test")))
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
)
@ -353,7 +353,7 @@ mod tests {
.wrap(NormalizePath(TrailingSlash::Always))
.service(
web::resource("/")
.guard(fn_guard(|req| req.uri.query() == Some("query=test")))
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
)
@ -378,7 +378,7 @@ mod tests {
.service(web::resource("/v1/").to(HttpResponse::Ok))
.service(
web::resource("/v2/something")
.guard(fn_guard(|req| req.uri.query() == Some("query=test")))
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
)

View File

@ -65,9 +65,12 @@ pub struct RouteService {
}
impl RouteService {
// TODO: does this need to take &mut ?
pub fn check(&self, req: &mut ServiceRequest) -> bool {
for f in self.guards.iter() {
if !f.check(req.head()) {
let guard_ctx = req.guard_ctx();
for guard in self.guards.iter() {
if !guard.check(&guard_ctx) {
return false;
}
}
@ -90,6 +93,7 @@ impl Service<ServiceRequest> for RouteService {
impl Route {
/// Add method guard to the route.
///
/// # Examples
/// ```
/// # use actix_web::*;
/// # fn main() {
@ -110,6 +114,7 @@ impl Route {
/// Add guard to the route.
///
/// # Examples
/// ```
/// # use actix_web::*;
/// # fn main() {
@ -143,16 +148,13 @@ impl Route {
/// format!("Welcome {}!", info.username)
/// }
///
/// fn main() {
/// let app = App::new().service(
/// web::resource("/{username}/index.html") // <- define path parameters
/// .route(web::get().to(index)) // <- register handler
/// );
/// }
/// ```
///
/// It is possible to use multiple extractors for one handler function.
///
/// ```
/// # use std::collections::HashMap;
/// # use serde::Deserialize;
@ -164,16 +166,18 @@ impl Route {
/// }
///
/// /// extract path info using serde
/// async fn index(path: web::Path<Info>, query: web::Query<HashMap<String, String>>, body: web::Json<Info>) -> String {
/// async fn index(
/// path: web::Path<Info>,
/// query: web::Query<HashMap<String, String>>,
/// body: web::Json<Info>
/// ) -> String {
/// format!("Welcome {}!", path.username)
/// }
///
/// fn main() {
/// let app = App::new().service(
/// web::resource("/{username}/index.html") // <- define path parameters
/// .route(web::get().to(index))
/// );
/// }
/// ```
pub fn to<F, Args>(mut self, handler: F) -> Self
where
@ -199,7 +203,7 @@ impl Route {
/// type Error = Infallible;
/// type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
///
/// always_ready!();
/// dev::always_ready!();
///
/// fn call(&self, req: ServiceRequest) -> Self::Future {
/// let (req, _) = req.into_parts();

View File

@ -538,12 +538,15 @@ impl Service<ServiceRequest> for ScopeService {
fn call(&self, mut req: ServiceRequest) -> Self::Future {
let res = self.router.recognize_fn(&mut req, |req, guards| {
if let Some(ref guards) = guards {
for f in guards {
if !f.check(req.head()) {
let guard_ctx = req.guard_ctx();
for guard in guards {
if !guard.check(&guard_ctx) {
return false;
}
}
}
true
});

View File

@ -21,7 +21,7 @@ use cookie::{Cookie, ParseError as CookieParseError};
use crate::{
config::{AppConfig, AppService},
dev::ensure_leading_slash,
guard::Guard,
guard::{Guard, GuardContext},
info::ConnectionInfo,
rmap::ResourceMap,
Error, HttpRequest, HttpResponse,
@ -172,7 +172,7 @@ impl ServiceRequest {
self.head().uri.path()
}
/// Counterpart to [`HttpRequest::query_string`](super::HttpRequest::query_string()).
/// Counterpart to [`HttpRequest::query_string`].
#[inline]
pub fn query_string(&self) -> &str {
self.req.query_string()
@ -208,13 +208,13 @@ impl ServiceRequest {
self.req.match_info()
}
/// Counterpart to [`HttpRequest::match_name`](super::HttpRequest::match_name()).
/// Counterpart to [`HttpRequest::match_name`].
#[inline]
pub fn match_name(&self) -> Option<&str> {
self.req.match_name()
}
/// Counterpart to [`HttpRequest::match_pattern`](super::HttpRequest::match_pattern()).
/// Counterpart to [`HttpRequest::match_pattern`].
#[inline]
pub fn match_pattern(&self) -> Option<String> {
self.req.match_pattern()
@ -238,7 +238,7 @@ impl ServiceRequest {
self.req.app_config()
}
/// Counterpart to [`HttpRequest::app_data`](super::HttpRequest::app_data()).
/// Counterpart to [`HttpRequest::app_data`].
#[inline]
pub fn app_data<T: 'static>(&self) -> Option<&T> {
for container in self.req.inner.app_data.iter().rev() {
@ -250,19 +250,33 @@ impl ServiceRequest {
None
}
/// Counterpart to [`HttpRequest::conn_data`](super::HttpRequest::conn_data()).
/// Counterpart to [`HttpRequest::conn_data`].
#[inline]
pub fn conn_data<T: 'static>(&self) -> Option<&T> {
self.req.conn_data()
}
/// Counterpart to [`HttpRequest::req_data`].
#[inline]
pub fn req_data(&self) -> Ref<'_, Extensions> {
self.req.req_data()
}
/// Counterpart to [`HttpRequest::req_data_mut`].
#[inline]
pub fn req_data_mut(&self) -> RefMut<'_, Extensions> {
self.req.req_data_mut()
}
#[cfg(feature = "cookies")]
#[inline]
pub fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
self.req.cookies()
}
/// Return request cookie.
#[cfg(feature = "cookies")]
#[inline]
pub fn cookie(&self, name: &str) -> Option<Cookie<'static>> {
self.req.cookie(name)
}
@ -283,6 +297,14 @@ impl ServiceRequest {
.app_data
.push(extensions);
}
/// Creates a context object for use with a [guard](crate::guard).
///
/// Useful if you are implementing
#[inline]
pub fn guard_ctx(&self) -> GuardContext<'_> {
GuardContext { req: self }
}
}
impl Resource<Url> for ServiceRequest {

View File

@ -124,7 +124,7 @@ impl TestRequest {
self
}
/// Set HTTP Uri of this request
/// Set HTTP URI of this request
pub fn uri(mut self, path: &str) -> Self {
self.req.uri(path);
self