From eb66685d1ab88b6258ea22970ab5f91dc8cfc5e7 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 9 Apr 2018 09:49:07 -0700 Subject: [PATCH] simplify csrf middleware --- src/middleware/csrf.rs | 172 +++++++++++++++++++---------------------- 1 file changed, 78 insertions(+), 94 deletions(-) diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index c2003ae35..a80b17cb6 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -110,6 +110,28 @@ fn origin(headers: &HeaderMap) -> Option, CsrfError>> { } /// A middleware that filters cross-site requests. +/// +/// To construct a CSRF filter: +/// +/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to +/// start building. +/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed +/// origins. +/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve +/// the constructed filter. +/// +/// # Example +/// +/// ``` +/// use actix_web::App; +/// use actix_web::middleware::csrf; +/// +/// # fn main() { +/// let app = App::new().middleware( +/// csrf::CsrfFilter::new() +/// .allowed_origin("https://www.example.com")); +/// # } +/// ``` pub struct CsrfFilter { origins: HashSet, allow_xhr: bool, @@ -119,17 +141,55 @@ pub struct CsrfFilter { impl CsrfFilter { /// Start building a `CsrfFilter`. - pub fn build() -> CsrfFilterBuilder { - CsrfFilterBuilder { - csrf: CsrfFilter { - origins: HashSet::new(), - allow_xhr: false, - allow_missing_origin: false, - allow_upgrade: false, - } + pub fn new() -> CsrfFilter { + CsrfFilter { + origins: HashSet::new(), + allow_xhr: false, + allow_missing_origin: false, + allow_upgrade: false, } } + /// Add an origin that is allowed to make requests. Will be verified + /// against the `Origin` request header. + pub fn allowed_origin(mut self, origin: &str) -> CsrfFilter { + self.origins.insert(origin.to_owned()); + self + } + + /// Allow all requests with an `X-Requested-With` header. + /// + /// A cross-site attacker should not be able to send requests with custom + /// headers unless a CORS policy whitelists them. Therefore it should be + /// safe to allow requests with an `X-Requested-With` header (added + /// automatically by many JavaScript libraries). + /// + /// This is disabled by default, because in Safari it is possible to + /// circumvent this using redirects and Flash. + /// + /// Use this method to enable more lax filtering. + pub fn allow_xhr(mut self) -> CsrfFilter { + self.allow_xhr = true; + self + } + + /// Allow requests if the expected `Origin` header is missing (and + /// there is no `Referer` to fall back on). + /// + /// The filter is conservative by default, but it should be safe to allow + /// missing `Origin` headers because a cross-site attacker cannot prevent + /// the browser from sending `Origin` on unsafe requests. + pub fn allow_missing_origin(mut self) -> CsrfFilter { + self.allow_missing_origin = true; + self + } + + /// Allow cross-site upgrade requests (for example to open a WebSocket). + pub fn allow_upgrade(mut self) -> CsrfFilter { + self.allow_upgrade = true; + self + } + fn validate(&self, req: &mut HttpRequest) -> Result<(), CsrfError> { let is_upgrade = req.headers().contains_key(header::UPGRADE); let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade); @@ -157,77 +217,6 @@ impl Middleware for CsrfFilter { } } -/// Used to build a `CsrfFilter`. -/// -/// To construct a CSRF filter: -/// -/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to -/// start building. -/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed -/// origins. -/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve -/// the constructed filter. -/// -/// # Example -/// -/// ``` -/// use actix_web::middleware::csrf; -/// -/// let csrf = csrf::CsrfFilter::build() -/// .allowed_origin("https://www.example.com") -/// .finish(); -/// ``` -pub struct CsrfFilterBuilder { - csrf: CsrfFilter, -} - -impl CsrfFilterBuilder { - /// Add an origin that is allowed to make requests. Will be verified - /// against the `Origin` request header. - pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder { - self.csrf.origins.insert(origin.to_owned()); - self - } - - /// Allow all requests with an `X-Requested-With` header. - /// - /// A cross-site attacker should not be able to send requests with custom - /// headers unless a CORS policy whitelists them. Therefore it should be - /// safe to allow requests with an `X-Requested-With` header (added - /// automatically by many JavaScript libraries). - /// - /// This is disabled by default, because in Safari it is possible to - /// circumvent this using redirects and Flash. - /// - /// Use this method to enable more lax filtering. - pub fn allow_xhr(mut self) -> CsrfFilterBuilder { - self.csrf.allow_xhr = true; - self - } - - /// Allow requests if the expected `Origin` header is missing (and - /// there is no `Referer` to fall back on). - /// - /// The filter is conservative by default, but it should be safe to allow - /// missing `Origin` headers because a cross-site attacker cannot prevent - /// the browser from sending `Origin` on unsafe requests. - pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder { - self.csrf.allow_missing_origin = true; - self - } - - /// Allow cross-site upgrade requests (for example to open a WebSocket). - pub fn allow_upgrade(mut self) -> CsrfFilterBuilder { - self.csrf.allow_upgrade = true; - self - } - - /// Finishes building the `CsrfFilter` instance. - pub fn finish(self) -> CsrfFilter { - self.csrf - } -} - #[cfg(test)] mod tests { use super::*; @@ -236,9 +225,8 @@ mod tests { #[test] fn test_safe() { - let csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); + let csrf = CsrfFilter::new() + .allowed_origin("https://www.example.com"); let mut req = TestRequest::with_header("Origin", "https://www.w3.org") .method(Method::HEAD) @@ -249,9 +237,8 @@ mod tests { #[test] fn test_csrf() { - let csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); + let csrf = CsrfFilter::new() + .allowed_origin("https://www.example.com"); let mut req = TestRequest::with_header("Origin", "https://www.w3.org") .method(Method::POST) @@ -262,9 +249,8 @@ mod tests { #[test] fn test_referer() { - let csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); + let csrf = CsrfFilter::new() + .allowed_origin("https://www.example.com"); let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param") .method(Method::POST) @@ -275,14 +261,12 @@ mod tests { #[test] fn test_upgrade() { - let strict_csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); + let strict_csrf = CsrfFilter::new() + .allowed_origin("https://www.example.com"); - let lax_csrf = CsrfFilter::build() + let lax_csrf = CsrfFilter::new() .allowed_origin("https://www.example.com") - .allow_upgrade() - .finish(); + .allow_upgrade(); let mut req = TestRequest::with_header("Origin", "https://cswsh.com") .header("Connection", "Upgrade")