1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 16:02:59 +01:00
actix-extras/src/pred.rs

330 lines
9.6 KiB
Rust
Raw Normal View History

2017-12-04 22:32:05 +01:00
//! Route match predicates
#![allow(non_snake_case)]
2018-06-25 06:58:04 +02:00
use std::marker::PhantomData;
2017-12-04 22:32:05 +01:00
use http;
use http::{header, HttpTryFrom};
2018-06-25 06:58:04 +02:00
use server::message::Request;
2017-12-04 22:32:05 +01:00
/// Trait defines resource route predicate.
/// Predicate can modify request object. It is also possible to
2018-03-31 18:18:25 +02:00
/// to store extra attributes on request by using `Extensions` container,
/// Extensions container available via `HttpRequest::extensions()` method.
2017-12-04 22:32:05 +01:00
pub trait Predicate<S> {
/// Check if request matches predicate
2018-06-25 06:58:04 +02:00
fn check(&self, &Request, &S) -> bool;
2017-12-04 22:32:05 +01:00
}
/// Return predicate that matches if any of supplied predicate matches.
2017-12-20 22:23:50 +01:00
///
/// ```rust
/// # extern crate actix_web;
2018-03-31 09:16:55 +02:00
/// use actix_web::{pred, App, HttpResponse};
2017-12-20 22:23:50 +01:00
///
/// fn main() {
2018-06-01 18:37:14 +02:00
/// App::new().resource("/index.html", |r| {
/// r.route()
2018-03-02 03:32:31 +01:00
/// .filter(pred::Any(pred::Get()).or(pred::Post()))
2018-06-01 18:37:14 +02:00
/// .f(|r| HttpResponse::MethodNotAllowed())
/// });
2017-12-20 22:23:50 +01:00
/// }
/// ```
2018-04-14 01:02:01 +02:00
pub fn Any<S: 'static, P: Predicate<S> + 'static>(pred: P) -> AnyPredicate<S> {
2017-12-20 22:23:50 +01:00
AnyPredicate(vec![Box::new(pred)])
2017-12-04 22:32:05 +01:00
}
2017-12-20 22:23:50 +01:00
/// Matches if any of supplied predicate matches.
pub struct AnyPredicate<S>(Vec<Box<Predicate<S>>>);
impl<S> AnyPredicate<S> {
/// Add new predicate to list of predicates to check
pub fn or<P: Predicate<S> + 'static>(mut self, pred: P) -> Self {
self.0.push(Box::new(pred));
self
}
}
2017-12-04 22:32:05 +01:00
impl<S: 'static> Predicate<S> for AnyPredicate<S> {
2018-06-25 06:58:04 +02:00
fn check(&self, req: &Request, state: &S) -> bool {
2017-12-04 22:32:05 +01:00
for p in &self.0 {
2018-06-25 06:58:04 +02:00
if p.check(req, state) {
2018-04-14 01:02:01 +02:00
return true;
2017-12-04 22:32:05 +01:00
}
}
false
}
}
/// Return predicate that matches if all of supplied predicate matches.
2017-12-20 22:23:50 +01:00
///
/// ```rust
/// # extern crate actix_web;
2018-04-09 19:40:12 +02:00
/// use actix_web::{pred, App, HttpResponse};
2017-12-20 22:23:50 +01:00
///
/// fn main() {
2018-06-01 18:37:14 +02:00
/// App::new().resource("/index.html", |r| {
/// r.route()
/// .filter(
/// pred::All(pred::Get())
/// .and(pred::Header("content-type", "text/plain")),
2018-06-01 18:37:14 +02:00
/// )
/// .f(|_| HttpResponse::MethodNotAllowed())
/// });
2017-12-20 22:23:50 +01:00
/// }
/// ```
pub fn All<S: 'static, P: Predicate<S> + 'static>(pred: P) -> AllPredicate<S> {
AllPredicate(vec![Box::new(pred)])
2017-12-04 22:32:05 +01:00
}
2017-12-20 22:23:50 +01:00
/// Matches if all of supplied predicate matches.
pub struct AllPredicate<S>(Vec<Box<Predicate<S>>>);
impl<S> AllPredicate<S> {
/// Add new predicate to list of predicates to check
pub fn and<P: Predicate<S> + 'static>(mut self, pred: P) -> Self {
self.0.push(Box::new(pred));
self
}
}
2017-12-04 22:32:05 +01:00
impl<S: 'static> Predicate<S> for AllPredicate<S> {
2018-06-25 06:58:04 +02:00
fn check(&self, req: &Request, state: &S) -> bool {
2017-12-04 22:32:05 +01:00
for p in &self.0 {
2018-06-25 06:58:04 +02:00
if !p.check(req, state) {
2018-04-14 01:02:01 +02:00
return false;
2017-12-04 22:32:05 +01:00
}
}
true
}
}
/// Return predicate that matches if supplied predicate does not match.
2018-04-14 01:02:01 +02:00
pub fn Not<S: 'static, P: Predicate<S> + 'static>(pred: P) -> NotPredicate<S> {
2017-12-20 22:23:50 +01:00
NotPredicate(Box::new(pred))
2017-12-04 22:32:05 +01:00
}
2017-12-20 22:23:50 +01:00
#[doc(hidden)]
pub struct NotPredicate<S>(Box<Predicate<S>>);
2017-12-04 22:32:05 +01:00
impl<S: 'static> Predicate<S> for NotPredicate<S> {
2018-06-25 06:58:04 +02:00
fn check(&self, req: &Request, state: &S) -> bool {
!self.0.check(req, state)
2017-12-04 22:32:05 +01:00
}
}
/// Http method predicate
2017-12-20 22:23:50 +01:00
#[doc(hidden)]
pub struct MethodPredicate<S>(http::Method, PhantomData<S>);
2017-12-04 22:32:05 +01:00
impl<S: 'static> Predicate<S> for MethodPredicate<S> {
2018-06-25 06:58:04 +02:00
fn check(&self, req: &Request, _: &S) -> bool {
2017-12-04 22:32:05 +01:00
*req.method() == self.0
}
}
/// Predicate to match *GET* http method
2017-12-20 22:23:50 +01:00
pub fn Get<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::GET, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *POST* http method
2017-12-20 22:23:50 +01:00
pub fn Post<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::POST, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *PUT* http method
2017-12-20 22:23:50 +01:00
pub fn Put<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::PUT, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *DELETE* http method
2017-12-20 22:23:50 +01:00
pub fn Delete<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::DELETE, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *HEAD* http method
2017-12-20 22:23:50 +01:00
pub fn Head<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::HEAD, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *OPTIONS* http method
2017-12-20 22:23:50 +01:00
pub fn Options<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::OPTIONS, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *CONNECT* http method
2017-12-20 22:23:50 +01:00
pub fn Connect<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::CONNECT, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *PATCH* http method
2017-12-20 22:23:50 +01:00
pub fn Patch<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::PATCH, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match *TRACE* http method
2017-12-20 22:23:50 +01:00
pub fn Trace<S: 'static>() -> MethodPredicate<S> {
MethodPredicate(http::Method::TRACE, PhantomData)
2017-12-04 22:32:05 +01:00
}
/// Predicate to match specified http method
2017-12-20 22:23:50 +01:00
pub fn Method<S: 'static>(method: http::Method) -> MethodPredicate<S> {
MethodPredicate(method, PhantomData)
2017-12-04 22:32:05 +01:00
}
2018-04-14 01:02:01 +02:00
/// Return predicate that matches if request contains specified header and
/// value.
pub fn Header<S: 'static>(
2018-07-16 07:17:45 +02:00
name: &'static str, value: &'static str,
2018-04-14 01:02:01 +02:00
) -> HeaderPredicate<S> {
HeaderPredicate(
header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value),
PhantomData,
)
2017-12-04 22:32:05 +01:00
}
2017-12-20 22:23:50 +01:00
#[doc(hidden)]
2018-05-17 21:20:20 +02:00
pub struct HeaderPredicate<S>(header::HeaderName, header::HeaderValue, PhantomData<S>);
2017-12-04 22:32:05 +01:00
impl<S: 'static> Predicate<S> for HeaderPredicate<S> {
2018-06-25 06:58:04 +02:00
fn check(&self, req: &Request, _: &S) -> bool {
2017-12-04 22:32:05 +01:00
if let Some(val) = req.headers().get(&self.0) {
2018-04-14 01:02:01 +02:00
return val == self.1;
2017-12-04 22:32:05 +01:00
}
false
}
}
2017-12-08 21:51:44 +01:00
2018-06-08 05:00:54 +02:00
/// Return predicate that matches if request contains specified Host name.
///
/// ```rust
/// # extern crate actix_web;
/// use actix_web::{pred, App, HttpResponse};
///
/// fn main() {
/// App::new().resource("/index.html", |r| {
/// r.route()
/// .filter(pred::Host("www.rust-lang.org"))
/// .f(|_| HttpResponse::MethodNotAllowed())
/// });
/// }
/// ```
pub fn Host<S: 'static, H: AsRef<str>>(host: H) -> HostPredicate<S> {
HostPredicate(host.as_ref().to_string(), None, PhantomData)
}
#[doc(hidden)]
pub struct HostPredicate<S>(String, Option<String>, PhantomData<S>);
impl<S> HostPredicate<S> {
/// Set reuest scheme to match
pub fn scheme<H: AsRef<str>>(&mut self, scheme: H) {
self.1 = Some(scheme.as_ref().to_string())
}
}
impl<S: 'static> Predicate<S> for HostPredicate<S> {
2018-06-25 06:58:04 +02:00
fn check(&self, req: &Request, _: &S) -> bool {
2018-06-08 05:00:54 +02:00
let info = req.connection_info();
if let Some(ref scheme) = self.1 {
self.0 == info.host() && scheme == info.scheme()
} else {
self.0 == info.host()
}
}
}
2017-12-08 21:51:44 +01:00
#[cfg(test)]
mod tests {
use super::*;
2018-07-04 17:01:27 +02:00
use http::{header, Method};
2018-06-25 06:58:04 +02:00
use test::TestRequest;
2017-12-08 21:51:44 +01:00
#[test]
fn test_header() {
2018-06-25 06:58:04 +02:00
let req = TestRequest::with_header(
2018-04-14 01:02:01 +02:00
header::TRANSFER_ENCODING,
header::HeaderValue::from_static("chunked"),
2018-06-25 06:58:04 +02:00
).finish();
2017-12-08 21:51:44 +01:00
let pred = Header("transfer-encoding", "chunked");
2018-06-25 06:58:04 +02:00
assert!(pred.check(&req, req.state()));
2017-12-08 21:51:44 +01:00
let pred = Header("transfer-encoding", "other");
2018-06-25 06:58:04 +02:00
assert!(!pred.check(&req, req.state()));
2017-12-09 00:25:37 +01:00
let pred = Header("content-type", "other");
2018-06-25 06:58:04 +02:00
assert!(!pred.check(&req, req.state()));
2017-12-08 21:51:44 +01:00
}
2018-06-08 05:00:54 +02:00
#[test]
fn test_host() {
2018-06-25 06:58:04 +02:00
let req = TestRequest::default()
.header(
header::HOST,
header::HeaderValue::from_static("www.rust-lang.org"),
)
.finish();
2018-06-08 05:00:54 +02:00
let pred = Host("www.rust-lang.org");
2018-06-25 06:58:04 +02:00
assert!(pred.check(&req, req.state()));
2018-06-08 05:00:54 +02:00
let pred = Host("localhost");
2018-06-25 06:58:04 +02:00
assert!(!pred.check(&req, req.state()));
2018-06-08 05:00:54 +02:00
}
2017-12-08 21:51:44 +01:00
#[test]
fn test_methods() {
2018-06-25 06:58:04 +02:00
let req = TestRequest::default().finish();
let req2 = TestRequest::default().method(Method::POST).finish();
assert!(Get().check(&req, req.state()));
assert!(!Get().check(&req2, req2.state()));
assert!(Post().check(&req2, req2.state()));
assert!(!Post().check(&req, req.state()));
let r = TestRequest::default().method(Method::PUT).finish();
assert!(Put().check(&r, r.state()));
assert!(!Put().check(&req, req.state()));
let r = TestRequest::default().method(Method::DELETE).finish();
assert!(Delete().check(&r, r.state()));
assert!(!Delete().check(&req, req.state()));
let r = TestRequest::default().method(Method::HEAD).finish();
assert!(Head().check(&r, r.state()));
assert!(!Head().check(&req, req.state()));
let r = TestRequest::default().method(Method::OPTIONS).finish();
assert!(Options().check(&r, r.state()));
assert!(!Options().check(&req, req.state()));
let r = TestRequest::default().method(Method::CONNECT).finish();
assert!(Connect().check(&r, r.state()));
assert!(!Connect().check(&req, req.state()));
let r = TestRequest::default().method(Method::PATCH).finish();
assert!(Patch().check(&r, r.state()));
assert!(!Patch().check(&req, req.state()));
let r = TestRequest::default().method(Method::TRACE).finish();
assert!(Trace().check(&r, r.state()));
assert!(!Trace().check(&req, req.state()));
2017-12-08 21:51:44 +01:00
}
#[test]
fn test_preds() {
2018-06-25 06:58:04 +02:00
let r = TestRequest::default().method(Method::TRACE).finish();
assert!(Not(Get()).check(&r, r.state()));
assert!(!Not(Trace()).check(&r, r.state()));
assert!(All(Trace()).and(Trace()).check(&r, r.state()));
assert!(!All(Get()).and(Trace()).check(&r, r.state()));
assert!(Any(Get()).or(Trace()).check(&r, r.state()));
assert!(!Any(Get()).or(Get()).check(&r, r.state()));
2017-12-08 21:51:44 +01:00
}
}