1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 09:42:40 +01:00

simplify csrf middleware

This commit is contained in:
Nikolay Kim 2018-04-09 09:49:07 -07:00
parent b505e682d4
commit eb66685d1a

View File

@ -110,6 +110,28 @@ fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
} }
/// A middleware that filters cross-site requests. /// A middleware that filters cross-site requests.
///
/// To construct a CSRF filter:
///
/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to
/// start building.
/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed
/// origins.
/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve
/// the constructed filter.
///
/// # Example
///
/// ```
/// use actix_web::App;
/// use actix_web::middleware::csrf;
///
/// # fn main() {
/// let app = App::new().middleware(
/// csrf::CsrfFilter::new()
/// .allowed_origin("https://www.example.com"));
/// # }
/// ```
pub struct CsrfFilter { pub struct CsrfFilter {
origins: HashSet<String>, origins: HashSet<String>,
allow_xhr: bool, allow_xhr: bool,
@ -119,17 +141,55 @@ pub struct CsrfFilter {
impl CsrfFilter { impl CsrfFilter {
/// Start building a `CsrfFilter`. /// Start building a `CsrfFilter`.
pub fn build() -> CsrfFilterBuilder { pub fn new() -> CsrfFilter {
CsrfFilterBuilder { CsrfFilter {
csrf: CsrfFilter { origins: HashSet::new(),
origins: HashSet::new(), allow_xhr: false,
allow_xhr: false, allow_missing_origin: false,
allow_missing_origin: false, allow_upgrade: false,
allow_upgrade: false,
}
} }
} }
/// Add an origin that is allowed to make requests. Will be verified
/// against the `Origin` request header.
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilter {
self.origins.insert(origin.to_owned());
self
}
/// Allow all requests with an `X-Requested-With` header.
///
/// A cross-site attacker should not be able to send requests with custom
/// headers unless a CORS policy whitelists them. Therefore it should be
/// safe to allow requests with an `X-Requested-With` header (added
/// automatically by many JavaScript libraries).
///
/// This is disabled by default, because in Safari it is possible to
/// circumvent this using redirects and Flash.
///
/// Use this method to enable more lax filtering.
pub fn allow_xhr(mut self) -> CsrfFilter {
self.allow_xhr = true;
self
}
/// Allow requests if the expected `Origin` header is missing (and
/// there is no `Referer` to fall back on).
///
/// The filter is conservative by default, but it should be safe to allow
/// missing `Origin` headers because a cross-site attacker cannot prevent
/// the browser from sending `Origin` on unsafe requests.
pub fn allow_missing_origin(mut self) -> CsrfFilter {
self.allow_missing_origin = true;
self
}
/// Allow cross-site upgrade requests (for example to open a WebSocket).
pub fn allow_upgrade(mut self) -> CsrfFilter {
self.allow_upgrade = true;
self
}
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> { fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
let is_upgrade = req.headers().contains_key(header::UPGRADE); let is_upgrade = req.headers().contains_key(header::UPGRADE);
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade); let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
@ -157,77 +217,6 @@ impl<S> Middleware<S> for CsrfFilter {
} }
} }
/// Used to build a `CsrfFilter`.
///
/// To construct a CSRF filter:
///
/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to
/// start building.
/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed
/// origins.
/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve
/// the constructed filter.
///
/// # Example
///
/// ```
/// use actix_web::middleware::csrf;
///
/// let csrf = csrf::CsrfFilter::build()
/// .allowed_origin("https://www.example.com")
/// .finish();
/// ```
pub struct CsrfFilterBuilder {
csrf: CsrfFilter,
}
impl CsrfFilterBuilder {
/// Add an origin that is allowed to make requests. Will be verified
/// against the `Origin` request header.
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
self.csrf.origins.insert(origin.to_owned());
self
}
/// Allow all requests with an `X-Requested-With` header.
///
/// A cross-site attacker should not be able to send requests with custom
/// headers unless a CORS policy whitelists them. Therefore it should be
/// safe to allow requests with an `X-Requested-With` header (added
/// automatically by many JavaScript libraries).
///
/// This is disabled by default, because in Safari it is possible to
/// circumvent this using redirects and Flash.
///
/// Use this method to enable more lax filtering.
pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
self.csrf.allow_xhr = true;
self
}
/// Allow requests if the expected `Origin` header is missing (and
/// there is no `Referer` to fall back on).
///
/// The filter is conservative by default, but it should be safe to allow
/// missing `Origin` headers because a cross-site attacker cannot prevent
/// the browser from sending `Origin` on unsafe requests.
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
self.csrf.allow_missing_origin = true;
self
}
/// Allow cross-site upgrade requests (for example to open a WebSocket).
pub fn allow_upgrade(mut self) -> CsrfFilterBuilder {
self.csrf.allow_upgrade = true;
self
}
/// Finishes building the `CsrfFilter` instance.
pub fn finish(self) -> CsrfFilter {
self.csrf
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -236,9 +225,8 @@ mod tests {
#[test] #[test]
fn test_safe() { fn test_safe() {
let csrf = CsrfFilter::build() let csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com");
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.w3.org") let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::HEAD) .method(Method::HEAD)
@ -249,9 +237,8 @@ mod tests {
#[test] #[test]
fn test_csrf() { fn test_csrf() {
let csrf = CsrfFilter::build() let csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com");
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.w3.org") let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::POST) .method(Method::POST)
@ -262,9 +249,8 @@ mod tests {
#[test] #[test]
fn test_referer() { fn test_referer() {
let csrf = CsrfFilter::build() let csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com");
.finish();
let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param") let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param")
.method(Method::POST) .method(Method::POST)
@ -275,14 +261,12 @@ mod tests {
#[test] #[test]
fn test_upgrade() { fn test_upgrade() {
let strict_csrf = CsrfFilter::build() let strict_csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com");
.finish();
let lax_csrf = CsrfFilter::build() let lax_csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allow_upgrade() .allow_upgrade();
.finish();
let mut req = TestRequest::with_header("Origin", "https://cswsh.com") let mut req = TestRequest::with_header("Origin", "https://cswsh.com")
.header("Connection", "Upgrade") .header("Connection", "Upgrade")