1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-30 18:34:36 +01:00

filter cross-site upgrades in csrf middleware

This commit is contained in:
Niklas Fiekas 2018-03-07 17:49:30 +01:00
parent 5816ecd1bc
commit b9d6bbd357

View File

@ -116,6 +116,7 @@ pub struct CsrfFilter {
origins: HashSet<String>, origins: HashSet<String>,
allow_xhr: bool, allow_xhr: bool,
allow_missing_origin: bool, allow_missing_origin: bool,
allow_upgrade: bool,
} }
impl CsrfFilter { impl CsrfFilter {
@ -126,12 +127,16 @@ impl CsrfFilter {
origins: HashSet::new(), origins: HashSet::new(),
allow_xhr: false, allow_xhr: false,
allow_missing_origin: false, allow_missing_origin: false,
allow_upgrade: false,
} }
} }
} }
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> { fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { let is_upgrade = req.headers().contains_key(header::UPGRADE);
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
Ok(()) Ok(())
} else if let Some(header) = origin(req.headers()) { } else if let Some(header) = origin(req.headers()) {
match header { match header {
@ -213,6 +218,12 @@ impl CsrfFilterBuilder {
self self
} }
/// Allow cross-site upgrade requests (for example to open a WebSocket).
pub fn allow_upgrade(mut self) -> CsrfFilterBuilder {
self.csrf.allow_upgrade = true;
self
}
/// Finishes building the `CsrfFilter` instance. /// Finishes building the `CsrfFilter` instance.
pub fn finish(self) -> CsrfFilter { pub fn finish(self) -> CsrfFilter {
self.csrf self.csrf