1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 09:42:40 +01:00

Merge pull request #113 from niklasf/csrf-upgrade

Let CSRF filter catch cross-site upgrades
This commit is contained in:
Nikolay Kim 2018-03-07 09:58:30 -08:00 committed by GitHub
commit 24342fb745
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -116,22 +116,27 @@ pub struct CsrfFilter {
origins: HashSet<String>, origins: HashSet<String>,
allow_xhr: bool, allow_xhr: bool,
allow_missing_origin: bool, allow_missing_origin: bool,
allow_upgrade: bool,
} }
impl CsrfFilter { impl CsrfFilter {
/// Start building a `CsrfFilter`. /// Start building a `CsrfFilter`.
pub fn build() -> CsrfFilterBuilder { pub fn build() -> CsrfFilterBuilder {
CsrfFilterBuilder { CsrfFilterBuilder {
cors: CsrfFilter { csrf: CsrfFilter {
origins: HashSet::new(), origins: HashSet::new(),
allow_xhr: false, allow_xhr: false,
allow_missing_origin: false, allow_missing_origin: false,
allow_upgrade: false,
} }
} }
} }
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> { fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { let is_upgrade = req.headers().contains_key(header::UPGRADE);
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
Ok(()) Ok(())
} else if let Some(header) = origin(req.headers()) { } else if let Some(header) = origin(req.headers()) {
match header { match header {
@ -175,14 +180,14 @@ impl<S> Middleware<S> for CsrfFilter {
/// .finish(); /// .finish();
/// ``` /// ```
pub struct CsrfFilterBuilder { pub struct CsrfFilterBuilder {
cors: CsrfFilter, csrf: CsrfFilter,
} }
impl CsrfFilterBuilder { impl CsrfFilterBuilder {
/// Add an origin that is allowed to make requests. Will be verified /// Add an origin that is allowed to make requests. Will be verified
/// against the `Origin` request header. /// against the `Origin` request header.
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder { pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
self.cors.origins.insert(origin.to_owned()); self.csrf.origins.insert(origin.to_owned());
self self
} }
@ -198,7 +203,7 @@ impl CsrfFilterBuilder {
/// ///
/// Use this method to enable more lax filtering. /// Use this method to enable more lax filtering.
pub fn allow_xhr(mut self) -> CsrfFilterBuilder { pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
self.cors.allow_xhr = true; self.csrf.allow_xhr = true;
self self
} }
@ -209,13 +214,19 @@ impl CsrfFilterBuilder {
/// missing `Origin` headers because a cross-site attacker cannot prevent /// missing `Origin` headers because a cross-site attacker cannot prevent
/// the browser from sending `Origin` on unsafe requests. /// the browser from sending `Origin` on unsafe requests.
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder { pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
self.cors.allow_missing_origin = true; self.csrf.allow_missing_origin = true;
self
}
/// Allow cross-site upgrade requests (for example to open a WebSocket).
pub fn allow_upgrade(mut self) -> CsrfFilterBuilder {
self.csrf.allow_upgrade = true;
self self
} }
/// Finishes building the `CsrfFilter` instance. /// Finishes building the `CsrfFilter` instance.
pub fn finish(self) -> CsrfFilter { pub fn finish(self) -> CsrfFilter {
self.cors self.csrf
} }
} }
@ -263,4 +274,25 @@ mod tests {
assert!(csrf.start(&mut req).is_ok()); assert!(csrf.start(&mut req).is_ok());
} }
#[test]
fn test_upgrade() {
let strict_csrf = CsrfFilter::build()
.allowed_origin("https://www.example.com")
.finish();
let lax_csrf = CsrfFilter::build()
.allowed_origin("https://www.example.com")
.allow_upgrade()
.finish();
let mut req = TestRequest::with_header("Origin", "https://cswsh.com")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.method(Method::GET)
.finish();
assert!(strict_csrf.start(&mut req).is_err());
assert!(lax_csrf.start(&mut req).is_ok());
}
} }