1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-17 21:34:32 +01:00

Merge pull request #113 from niklasf/csrf-upgrade

Let CSRF filter catch cross-site upgrades
This commit is contained in:
Nikolay Kim 2018-03-07 09:58:30 -08:00 committed by GitHub
commit 24342fb745
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -116,22 +116,27 @@ pub struct CsrfFilter {
origins: HashSet<String>,
allow_xhr: bool,
allow_missing_origin: bool,
allow_upgrade: bool,
}
impl CsrfFilter {
/// Start building a `CsrfFilter`.
pub fn build() -> CsrfFilterBuilder {
CsrfFilterBuilder {
cors: CsrfFilter {
csrf: CsrfFilter {
origins: HashSet::new(),
allow_xhr: false,
allow_missing_origin: false,
allow_upgrade: false,
}
}
}
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
let is_upgrade = req.headers().contains_key(header::UPGRADE);
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
Ok(())
} else if let Some(header) = origin(req.headers()) {
match header {
@ -175,14 +180,14 @@ impl<S> Middleware<S> for CsrfFilter {
/// .finish();
/// ```
pub struct CsrfFilterBuilder {
cors: CsrfFilter,
csrf: CsrfFilter,
}
impl CsrfFilterBuilder {
/// Add an origin that is allowed to make requests. Will be verified
/// against the `Origin` request header.
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
self.cors.origins.insert(origin.to_owned());
self.csrf.origins.insert(origin.to_owned());
self
}
@ -198,7 +203,7 @@ impl CsrfFilterBuilder {
///
/// Use this method to enable more lax filtering.
pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
self.cors.allow_xhr = true;
self.csrf.allow_xhr = true;
self
}
@ -209,13 +214,19 @@ impl CsrfFilterBuilder {
/// missing `Origin` headers because a cross-site attacker cannot prevent
/// the browser from sending `Origin` on unsafe requests.
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
self.cors.allow_missing_origin = true;
self.csrf.allow_missing_origin = true;
self
}
/// Allow cross-site upgrade requests (for example to open a WebSocket).
pub fn allow_upgrade(mut self) -> CsrfFilterBuilder {
self.csrf.allow_upgrade = true;
self
}
/// Finishes building the `CsrfFilter` instance.
pub fn finish(self) -> CsrfFilter {
self.cors
self.csrf
}
}
@ -263,4 +274,25 @@ mod tests {
assert!(csrf.start(&mut req).is_ok());
}
#[test]
fn test_upgrade() {
let strict_csrf = CsrfFilter::build()
.allowed_origin("https://www.example.com")
.finish();
let lax_csrf = CsrfFilter::build()
.allowed_origin("https://www.example.com")
.allow_upgrade()
.finish();
let mut req = TestRequest::with_header("Origin", "https://cswsh.com")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.method(Method::GET)
.finish();
assert!(strict_csrf.start(&mut req).is_err());
assert!(lax_csrf.start(&mut req).is_ok());
}
}