diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index 3177cfde..d9d4692b 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -116,6 +116,7 @@ pub struct CsrfFilter { origins: HashSet, allow_xhr: bool, allow_missing_origin: bool, + allow_upgrade: bool, } impl CsrfFilter { @@ -126,12 +127,16 @@ impl CsrfFilter { origins: HashSet::new(), allow_xhr: false, allow_missing_origin: false, + allow_upgrade: false, } } } fn validate(&self, req: &mut HttpRequest) -> Result<(), CsrfError> { - if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { + let is_upgrade = req.headers().contains_key(header::UPGRADE); + let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade); + + if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { Ok(()) } else if let Some(header) = origin(req.headers()) { match header { @@ -213,6 +218,12 @@ impl CsrfFilterBuilder { self } + /// Allow cross-site upgrade requests (for example to open a WebSocket). + pub fn allow_upgrade(mut self) -> CsrfFilterBuilder { + self.csrf.allow_upgrade = true; + self + } + /// Finishes building the `CsrfFilter` instance. pub fn finish(self) -> CsrfFilter { self.csrf