mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-28 09:42:40 +01:00
Merge pull request #113 from niklasf/csrf-upgrade
Let CSRF filter catch cross-site upgrades
This commit is contained in:
commit
24342fb745
@ -116,22 +116,27 @@ pub struct CsrfFilter {
|
|||||||
origins: HashSet<String>,
|
origins: HashSet<String>,
|
||||||
allow_xhr: bool,
|
allow_xhr: bool,
|
||||||
allow_missing_origin: bool,
|
allow_missing_origin: bool,
|
||||||
|
allow_upgrade: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CsrfFilter {
|
impl CsrfFilter {
|
||||||
/// Start building a `CsrfFilter`.
|
/// Start building a `CsrfFilter`.
|
||||||
pub fn build() -> CsrfFilterBuilder {
|
pub fn build() -> CsrfFilterBuilder {
|
||||||
CsrfFilterBuilder {
|
CsrfFilterBuilder {
|
||||||
cors: CsrfFilter {
|
csrf: CsrfFilter {
|
||||||
origins: HashSet::new(),
|
origins: HashSet::new(),
|
||||||
allow_xhr: false,
|
allow_xhr: false,
|
||||||
allow_missing_origin: false,
|
allow_missing_origin: false,
|
||||||
|
allow_upgrade: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
|
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
|
||||||
if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
|
let is_upgrade = req.headers().contains_key(header::UPGRADE);
|
||||||
|
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
|
||||||
|
|
||||||
|
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else if let Some(header) = origin(req.headers()) {
|
} else if let Some(header) = origin(req.headers()) {
|
||||||
match header {
|
match header {
|
||||||
@ -175,14 +180,14 @@ impl<S> Middleware<S> for CsrfFilter {
|
|||||||
/// .finish();
|
/// .finish();
|
||||||
/// ```
|
/// ```
|
||||||
pub struct CsrfFilterBuilder {
|
pub struct CsrfFilterBuilder {
|
||||||
cors: CsrfFilter,
|
csrf: CsrfFilter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CsrfFilterBuilder {
|
impl CsrfFilterBuilder {
|
||||||
/// Add an origin that is allowed to make requests. Will be verified
|
/// Add an origin that is allowed to make requests. Will be verified
|
||||||
/// against the `Origin` request header.
|
/// against the `Origin` request header.
|
||||||
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
|
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
|
||||||
self.cors.origins.insert(origin.to_owned());
|
self.csrf.origins.insert(origin.to_owned());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,7 +203,7 @@ impl CsrfFilterBuilder {
|
|||||||
///
|
///
|
||||||
/// Use this method to enable more lax filtering.
|
/// Use this method to enable more lax filtering.
|
||||||
pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
|
pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
|
||||||
self.cors.allow_xhr = true;
|
self.csrf.allow_xhr = true;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,13 +214,19 @@ impl CsrfFilterBuilder {
|
|||||||
/// missing `Origin` headers because a cross-site attacker cannot prevent
|
/// missing `Origin` headers because a cross-site attacker cannot prevent
|
||||||
/// the browser from sending `Origin` on unsafe requests.
|
/// the browser from sending `Origin` on unsafe requests.
|
||||||
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
|
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
|
||||||
self.cors.allow_missing_origin = true;
|
self.csrf.allow_missing_origin = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allow cross-site upgrade requests (for example to open a WebSocket).
|
||||||
|
pub fn allow_upgrade(mut self) -> CsrfFilterBuilder {
|
||||||
|
self.csrf.allow_upgrade = true;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Finishes building the `CsrfFilter` instance.
|
/// Finishes building the `CsrfFilter` instance.
|
||||||
pub fn finish(self) -> CsrfFilter {
|
pub fn finish(self) -> CsrfFilter {
|
||||||
self.cors
|
self.csrf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,4 +274,25 @@ mod tests {
|
|||||||
|
|
||||||
assert!(csrf.start(&mut req).is_ok());
|
assert!(csrf.start(&mut req).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_upgrade() {
|
||||||
|
let strict_csrf = CsrfFilter::build()
|
||||||
|
.allowed_origin("https://www.example.com")
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
let lax_csrf = CsrfFilter::build()
|
||||||
|
.allowed_origin("https://www.example.com")
|
||||||
|
.allow_upgrade()
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
let mut req = TestRequest::with_header("Origin", "https://cswsh.com")
|
||||||
|
.header("Connection", "Upgrade")
|
||||||
|
.header("Upgrade", "websocket")
|
||||||
|
.method(Method::GET)
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
assert!(strict_csrf.start(&mut req).is_err());
|
||||||
|
assert!(lax_csrf.start(&mut req).is_ok());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user