From 5816ecd1bc7f256107a34e2c7b2e6edc595e500e Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 7 Mar 2018 17:44:19 +0100 Subject: [PATCH 1/3] fix variable name: cors -> csrf --- src/middleware/csrf.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index 5385f5d4d..3177cfded 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -122,7 +122,7 @@ impl CsrfFilter { /// Start building a `CsrfFilter`. pub fn build() -> CsrfFilterBuilder { CsrfFilterBuilder { - cors: CsrfFilter { + csrf: CsrfFilter { origins: HashSet::new(), allow_xhr: false, allow_missing_origin: false, @@ -175,14 +175,14 @@ impl Middleware for CsrfFilter { /// .finish(); /// ``` pub struct CsrfFilterBuilder { - cors: CsrfFilter, + csrf: CsrfFilter, } impl CsrfFilterBuilder { /// Add an origin that is allowed to make requests. Will be verified /// against the `Origin` request header. pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder { - self.cors.origins.insert(origin.to_owned()); + self.csrf.origins.insert(origin.to_owned()); self } @@ -198,7 +198,7 @@ impl CsrfFilterBuilder { /// /// Use this method to enable more lax filtering. pub fn allow_xhr(mut self) -> CsrfFilterBuilder { - self.cors.allow_xhr = true; + self.csrf.allow_xhr = true; self } @@ -209,13 +209,13 @@ impl CsrfFilterBuilder { /// missing `Origin` headers because a cross-site attacker cannot prevent /// the browser from sending `Origin` on unsafe requests. pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder { - self.cors.allow_missing_origin = true; + self.csrf.allow_missing_origin = true; self } /// Finishes building the `CsrfFilter` instance. pub fn finish(self) -> CsrfFilter { - self.cors + self.csrf } } From b9d6bbd35752cd301726650c4cb6166d130bff2f Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 7 Mar 2018 17:49:30 +0100 Subject: [PATCH 2/3] filter cross-site upgrades in csrf middleware --- src/middleware/csrf.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index 3177cfded..d9d4692b4 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -116,6 +116,7 @@ pub struct CsrfFilter { origins: HashSet, allow_xhr: bool, allow_missing_origin: bool, + allow_upgrade: bool, } impl CsrfFilter { @@ -126,12 +127,16 @@ impl CsrfFilter { origins: HashSet::new(), allow_xhr: false, allow_missing_origin: false, + allow_upgrade: false, } } } fn validate(&self, req: &mut HttpRequest) -> Result<(), CsrfError> { - if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { + let is_upgrade = req.headers().contains_key(header::UPGRADE); + let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade); + + if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { Ok(()) } else if let Some(header) = origin(req.headers()) { match header { @@ -213,6 +218,12 @@ impl CsrfFilterBuilder { self } + /// Allow cross-site upgrade requests (for example to open a WebSocket). + pub fn allow_upgrade(mut self) -> CsrfFilterBuilder { + self.csrf.allow_upgrade = true; + self + } + /// Finishes building the `CsrfFilter` instance. pub fn finish(self) -> CsrfFilter { self.csrf From 0278e364ece20aca847c4084c24c7253e695ebb0 Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 7 Mar 2018 18:42:21 +0100 Subject: [PATCH 3/3] add tests for csrf upgrade filter --- src/middleware/csrf.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index d9d4692b4..dfdb538d9 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -274,4 +274,25 @@ mod tests { assert!(csrf.start(&mut req).is_ok()); } + + #[test] + fn test_upgrade() { + let strict_csrf = CsrfFilter::build() + .allowed_origin("https://www.example.com") + .finish(); + + let lax_csrf = CsrfFilter::build() + .allowed_origin("https://www.example.com") + .allow_upgrade() + .finish(); + + let mut req = TestRequest::with_header("Origin", "https://cswsh.com") + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .method(Method::GET) + .finish(); + + assert!(strict_csrf.start(&mut req).is_err()); + assert!(lax_csrf.start(&mut req).is_ok()); + } }