diff --git a/src/guard.rs b/src/guard.rs index 6522a9845..6fd6d1d2c 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -26,7 +26,7 @@ //! ``` #![allow(non_snake_case)] -use actix_http::http::{self, header, HttpTryFrom}; +use actix_http::http::{self, header, HttpTryFrom, uri::Uri}; use actix_http::RequestHead; /// Trait defines resource guards. Guards are used for routes selection. @@ -256,45 +256,68 @@ impl Guard for HeaderGuard { } } -// /// Return predicate that matches if request contains specified Host name. -// /// -// /// ```rust,ignore -// /// # extern crate actix_web; -// /// use actix_web::{pred, App, HttpResponse}; -// /// -// /// fn main() { -// /// App::new().resource("/index.html", |r| { -// /// r.route() -// /// .guard(pred::Host("www.rust-lang.org")) -// /// .f(|_| HttpResponse::MethodNotAllowed()) -// /// }); -// /// } -// /// ``` -// pub fn Host>(host: H) -> HostGuard { -// HostGuard(host.as_ref().to_string(), None) -// } +/// Return predicate that matches if request contains specified Host name. +/// +/// ```rust,ignore +/// # extern crate actix_web; +/// use actix_web::{guard::Host, App, HttpResponse}; +/// +/// fn main() { +/// App::new().resource("/index.html", |r| { +/// r.route() +/// .guard(Host("www.rust-lang.org")) +/// .f(|_| HttpResponse::MethodNotAllowed()) +/// }); +/// } +/// ``` +pub fn Host>(host: H) -> HostGuard { + HostGuard(host.as_ref().to_string(), None) +} -// #[doc(hidden)] -// pub struct HostGuard(String, Option); +fn get_host_uri(req: &RequestHead) -> Option { + use core::str::FromStr; + let host_value = req.headers.get(header::HOST)?; + let host = host_value.to_str().ok()?; + let uri = Uri::from_str(host).ok()?; + Some(uri) +} -// impl HostGuard { -// /// Set reuest scheme to match -// pub fn scheme>(&mut self, scheme: H) { -// self.1 = Some(scheme.as_ref().to_string()) -// } -// } +#[doc(hidden)] +pub struct HostGuard(String, Option); -// impl Guard for HostGuard { -// fn check(&self, _req: &RequestHead) -> bool { -// // let info = req.connection_info(); -// // if let Some(ref scheme) = self.1 { -// // self.0 == info.host() && scheme == info.scheme() -// // } else { -// // self.0 == info.host() -// // } -// false -// } -// } +impl HostGuard { + /// Set request scheme to match + pub fn scheme>(mut self, scheme: H) -> HostGuard { + self.1 = Some(scheme.as_ref().to_string()); + self + } +} + +impl Guard for HostGuard { + fn check(&self, req: &RequestHead) -> bool { + let req_host_uri = if let Some(uri) = get_host_uri(req) { + uri + } else { + return false; + }; + + if let Some(uri_host) = req_host_uri.host() { + if self.0 != uri_host { + return false; + } + } else { + return false; + } + + if let Some(ref scheme) = self.1 { + if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() { + return scheme == req_host_uri_scheme; + } + } + + true + } +} #[cfg(test)] mod tests { @@ -318,21 +341,64 @@ mod tests { assert!(!pred.check(req.head())); } - // #[test] - // fn test_host() { - // let req = TestServiceRequest::default() - // .header( - // header::HOST, - // header::HeaderValue::from_static("www.rust-lang.org"), - // ) - // .request(); + #[test] + fn test_host() { + let req = TestRequest::default() + .header( + header::HOST, + header::HeaderValue::from_static("www.rust-lang.org"), + ) + .to_http_request(); - // let pred = Host("www.rust-lang.org"); - // assert!(pred.check(&req)); + let pred = Host("www.rust-lang.org"); + assert!(pred.check(req.head())); - // let pred = Host("localhost"); - // assert!(!pred.check(&req)); - // } + let pred = Host("www.rust-lang.org").scheme("https"); + assert!(pred.check(req.head())); + + let pred = Host("blog.rust-lang.org"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("crates.io"); + assert!(!pred.check(req.head())); + + let pred = Host("localhost"); + assert!(!pred.check(req.head())); + } + + #[test] + fn test_host_scheme() { + let req = TestRequest::default() + .header( + header::HOST, + header::HeaderValue::from_static("https://www.rust-lang.org"), + ) + .to_http_request(); + + let pred = Host("www.rust-lang.org").scheme("https"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org").scheme("http"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("crates.io").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("localhost"); + assert!(!pred.check(req.head())); + } #[test] fn test_methods() {