diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index 5385f5d4d..dfdb538d9 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -116,22 +116,27 @@ pub struct CsrfFilter { origins: HashSet, allow_xhr: bool, allow_missing_origin: bool, + allow_upgrade: bool, } impl CsrfFilter { /// Start building a `CsrfFilter`. pub fn build() -> CsrfFilterBuilder { CsrfFilterBuilder { - cors: CsrfFilter { + csrf: CsrfFilter { origins: HashSet::new(), allow_xhr: false, allow_missing_origin: false, + allow_upgrade: false, } } } fn validate(&self, req: &mut HttpRequest) -> Result<(), CsrfError> { - if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { + let is_upgrade = req.headers().contains_key(header::UPGRADE); + let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade); + + if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { Ok(()) } else if let Some(header) = origin(req.headers()) { match header { @@ -175,14 +180,14 @@ impl Middleware for CsrfFilter { /// .finish(); /// ``` pub struct CsrfFilterBuilder { - cors: CsrfFilter, + csrf: CsrfFilter, } impl CsrfFilterBuilder { /// Add an origin that is allowed to make requests. Will be verified /// against the `Origin` request header. pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder { - self.cors.origins.insert(origin.to_owned()); + self.csrf.origins.insert(origin.to_owned()); self } @@ -198,7 +203,7 @@ impl CsrfFilterBuilder { /// /// Use this method to enable more lax filtering. pub fn allow_xhr(mut self) -> CsrfFilterBuilder { - self.cors.allow_xhr = true; + self.csrf.allow_xhr = true; self } @@ -209,13 +214,19 @@ impl CsrfFilterBuilder { /// missing `Origin` headers because a cross-site attacker cannot prevent /// the browser from sending `Origin` on unsafe requests. pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder { - self.cors.allow_missing_origin = true; + self.csrf.allow_missing_origin = true; + self + } + + /// Allow cross-site upgrade requests (for example to open a WebSocket). + pub fn allow_upgrade(mut self) -> CsrfFilterBuilder { + self.csrf.allow_upgrade = true; self } /// Finishes building the `CsrfFilter` instance. pub fn finish(self) -> CsrfFilter { - self.cors + self.csrf } } @@ -263,4 +274,25 @@ mod tests { assert!(csrf.start(&mut req).is_ok()); } + + #[test] + fn test_upgrade() { + let strict_csrf = CsrfFilter::build() + .allowed_origin("https://www.example.com") + .finish(); + + let lax_csrf = CsrfFilter::build() + .allowed_origin("https://www.example.com") + .allow_upgrade() + .finish(); + + let mut req = TestRequest::with_header("Origin", "https://cswsh.com") + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .method(Method::GET) + .finish(); + + assert!(strict_csrf.start(&mut req).is_err()); + assert!(lax_csrf.start(&mut req).is_ok()); + } }