mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-24 07:53:00 +01:00
Merge pull request #113 from niklasf/csrf-upgrade
Let CSRF filter catch cross-site upgrades
This commit is contained in:
commit
24342fb745
@ -116,22 +116,27 @@ pub struct CsrfFilter {
|
||||
origins: HashSet<String>,
|
||||
allow_xhr: bool,
|
||||
allow_missing_origin: bool,
|
||||
allow_upgrade: bool,
|
||||
}
|
||||
|
||||
impl CsrfFilter {
|
||||
/// Start building a `CsrfFilter`.
|
||||
pub fn build() -> CsrfFilterBuilder {
|
||||
CsrfFilterBuilder {
|
||||
cors: CsrfFilter {
|
||||
csrf: CsrfFilter {
|
||||
origins: HashSet::new(),
|
||||
allow_xhr: false,
|
||||
allow_missing_origin: false,
|
||||
allow_upgrade: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
|
||||
if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
|
||||
let is_upgrade = req.headers().contains_key(header::UPGRADE);
|
||||
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
|
||||
|
||||
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
|
||||
Ok(())
|
||||
} else if let Some(header) = origin(req.headers()) {
|
||||
match header {
|
||||
@ -175,14 +180,14 @@ impl<S> Middleware<S> for CsrfFilter {
|
||||
/// .finish();
|
||||
/// ```
|
||||
pub struct CsrfFilterBuilder {
|
||||
cors: CsrfFilter,
|
||||
csrf: CsrfFilter,
|
||||
}
|
||||
|
||||
impl CsrfFilterBuilder {
|
||||
/// Add an origin that is allowed to make requests. Will be verified
|
||||
/// against the `Origin` request header.
|
||||
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
|
||||
self.cors.origins.insert(origin.to_owned());
|
||||
self.csrf.origins.insert(origin.to_owned());
|
||||
self
|
||||
}
|
||||
|
||||
@ -198,7 +203,7 @@ impl CsrfFilterBuilder {
|
||||
///
|
||||
/// Use this method to enable more lax filtering.
|
||||
pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
|
||||
self.cors.allow_xhr = true;
|
||||
self.csrf.allow_xhr = true;
|
||||
self
|
||||
}
|
||||
|
||||
@ -209,13 +214,19 @@ impl CsrfFilterBuilder {
|
||||
/// missing `Origin` headers because a cross-site attacker cannot prevent
|
||||
/// the browser from sending `Origin` on unsafe requests.
|
||||
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
|
||||
self.cors.allow_missing_origin = true;
|
||||
self.csrf.allow_missing_origin = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Allow cross-site upgrade requests (for example to open a WebSocket).
|
||||
pub fn allow_upgrade(mut self) -> CsrfFilterBuilder {
|
||||
self.csrf.allow_upgrade = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Finishes building the `CsrfFilter` instance.
|
||||
pub fn finish(self) -> CsrfFilter {
|
||||
self.cors
|
||||
self.csrf
|
||||
}
|
||||
}
|
||||
|
||||
@ -263,4 +274,25 @@ mod tests {
|
||||
|
||||
assert!(csrf.start(&mut req).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upgrade() {
|
||||
let strict_csrf = CsrfFilter::build()
|
||||
.allowed_origin("https://www.example.com")
|
||||
.finish();
|
||||
|
||||
let lax_csrf = CsrfFilter::build()
|
||||
.allowed_origin("https://www.example.com")
|
||||
.allow_upgrade()
|
||||
.finish();
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://cswsh.com")
|
||||
.header("Connection", "Upgrade")
|
||||
.header("Upgrade", "websocket")
|
||||
.method(Method::GET)
|
||||
.finish();
|
||||
|
||||
assert!(strict_csrf.start(&mut req).is_err());
|
||||
assert!(lax_csrf.start(&mut req).is_ok());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user