diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md new file mode 100644 index 000000000..8022ea4e8 --- /dev/null +++ b/actix-cors/CHANGES.md @@ -0,0 +1,15 @@ +# Changes + +## [0.2.0] - 2019-12-20 + +* Release + +## [0.2.0-alpha.3] - 2019-12-07 + +* Migrate to actix-web 2.0.0 + +* Bump `derive_more` crate version to 0.99.0 + +## [0.1.0] - 2019-06-15 + +* Move cors middleware to separate crate diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml new file mode 100644 index 000000000..3fcd92f4f --- /dev/null +++ b/actix-cors/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "actix-cors" +version = "0.2.0" +authors = ["Nikolay Kim "] +description = "Cross-origin resource sharing (CORS) for Actix applications." +readme = "README.md" +keywords = ["web", "framework"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-cors/" +license = "MIT/Apache-2.0" +edition = "2018" +workspace = ".." + +[lib] +name = "actix_cors" +path = "src/lib.rs" + +[dependencies] +actix-web = "2.0.0-rc" +actix-service = "1.0.1" +derive_more = "0.99.2" +futures = "0.3.1" + +[dev-dependencies] +actix-rt = "1.0.0" diff --git a/actix-cors/LICENSE-APACHE b/actix-cors/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-cors/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-cors/LICENSE-MIT b/actix-cors/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-cors/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-cors/README.md b/actix-cors/README.md new file mode 100644 index 000000000..a77f6c6d3 --- /dev/null +++ b/actix-cors/README.md @@ -0,0 +1,9 @@ +# Cors Middleware for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-cors)](https://crates.io/crates/actix-cors) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-cors/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-cors](https://crates.io/crates/actix-cors) +* Minimum supported Rust version: 1.34 or later diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs new file mode 100644 index 000000000..429fe9eab --- /dev/null +++ b/actix-cors/src/lib.rs @@ -0,0 +1,1204 @@ +#![allow(clippy::borrow_interior_mutable_const, clippy::type_complexity)] +//! Cross-origin resource sharing (CORS) for Actix applications +//! +//! CORS middleware could be used with application and with resource. +//! Cors middleware could be used as parameter for `App::wrap()`, +//! `Resource::wrap()` or `Scope::wrap()` methods. +//! +//! # Example +//! +//! ```rust +//! use actix_cors::Cors; +//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; +//! +//! async fn index(req: HttpRequest) -> &'static str { +//! "Hello world" +//! } +//! +//! fn main() -> std::io::Result<()> { +//! HttpServer::new(|| App::new() +//! .wrap( +//! Cors::new() // <- Construct CORS middleware builder +//! .allowed_origin("https://www.rust-lang.org/") +//! .allowed_methods(vec!["GET", "POST"]) +//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) +//! .allowed_header(http::header::CONTENT_TYPE) +//! .max_age(3600) +//! .finish()) +//! .service( +//! web::resource("/index.html") +//! .route(web::get().to(index)) +//! .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +//! )) +//! .bind("127.0.0.1:8080")?; +//! +//! Ok(()) +//! } +//! ``` +//! In this example custom *CORS* middleware get registered for "/index.html" +//! endpoint. +//! +//! Cors middleware automatically handle *OPTIONS* preflight request. +use std::collections::HashSet; +use std::convert::TryFrom; +use std::iter::FromIterator; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_service::{Service, Transform}; +use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; +use actix_web::error::{Error, ResponseError, Result}; +use actix_web::http::header::{self, HeaderName, HeaderValue}; +use actix_web::http::{self, Error as HttpError, Method, StatusCode, Uri}; +use actix_web::HttpResponse; +use derive_more::Display; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; + +/// A set of errors that can occur during processing CORS +#[derive(Debug, Display)] +pub enum CorsError { + /// The HTTP request header `Origin` is required but was not provided + #[display( + fmt = "The HTTP request header `Origin` is required but was not provided" + )] + MissingOrigin, + /// The HTTP request header `Origin` could not be parsed correctly. + #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")] + BadOrigin, + /// The request header `Access-Control-Request-Method` is required but is + /// missing + #[display( + fmt = "The request header `Access-Control-Request-Method` is required but is missing" + )] + MissingRequestMethod, + /// The request header `Access-Control-Request-Method` has an invalid value + #[display( + fmt = "The request header `Access-Control-Request-Method` has an invalid value" + )] + BadRequestMethod, + /// The request header `Access-Control-Request-Headers` has an invalid + /// value + #[display( + fmt = "The request header `Access-Control-Request-Headers` has an invalid value" + )] + BadRequestHeaders, + /// Origin is not allowed to make this request + #[display(fmt = "Origin is not allowed to make this request")] + OriginNotAllowed, + /// Requested method is not allowed + #[display(fmt = "Requested method is not allowed")] + MethodNotAllowed, + /// One or more headers requested are not allowed + #[display(fmt = "One or more headers requested are not allowed")] + HeadersNotAllowed, +} + +impl ResponseError for CorsError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into()) + } +} + +/// An enum signifying that some of type T is allowed, or `All` (everything is +/// allowed). +/// +/// `Default` is implemented for this enum and is `All`. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AllOrSome { + /// Everything is allowed. Usually equivalent to the "*" value. + All, + /// Only some of `T` is allowed + Some(T), +} + +impl Default for AllOrSome { + fn default() -> Self { + AllOrSome::All + } +} + +impl AllOrSome { + /// Returns whether this is an `All` variant + pub fn is_all(&self) -> bool { + match *self { + AllOrSome::All => true, + AllOrSome::Some(_) => false, + } + } + + /// Returns whether this is a `Some` variant + pub fn is_some(&self) -> bool { + !self.is_all() + } + + /// Returns &T + pub fn as_ref(&self) -> Option<&T> { + match *self { + AllOrSome::All => None, + AllOrSome::Some(ref t) => Some(t), + } + } +} + +/// Structure that follows the builder pattern for building `Cors` middleware +/// structs. +/// +/// To construct a cors: +/// +/// 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building. +/// 2. Use any of the builder methods to set fields in the backend. +/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the +/// constructed backend. +/// +/// # Example +/// +/// ```rust +/// use actix_cors::Cors; +/// use actix_web::http::header; +/// +/// # fn main() { +/// let cors = Cors::new() +/// .allowed_origin("https://www.rust-lang.org/") +/// .allowed_methods(vec!["GET", "POST"]) +/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) +/// .allowed_header(header::CONTENT_TYPE) +/// .max_age(3600); +/// # } +/// ``` +#[derive(Default)] +pub struct Cors { + cors: Option, + methods: bool, + error: Option, + expose_hdrs: HashSet, +} + +impl Cors { + /// Build a new CORS middleware instance + pub fn new() -> Cors { + Cors { + cors: Some(Inner { + origins: AllOrSome::All, + origins_str: None, + methods: HashSet::new(), + headers: AllOrSome::All, + expose_hdrs: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }), + methods: false, + error: None, + expose_hdrs: HashSet::new(), + } + } + + /// Build a new CORS default middleware + pub fn default() -> CorsFactory { + let inner = Inner { + origins: AllOrSome::default(), + origins_str: None, + methods: HashSet::from_iter( + vec![ + Method::GET, + Method::HEAD, + Method::POST, + Method::OPTIONS, + Method::PUT, + Method::PATCH, + Method::DELETE, + ] + .into_iter(), + ), + headers: AllOrSome::All, + expose_hdrs: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }; + CorsFactory { + inner: Rc::new(inner), + } + } + + /// Add an origin that are allowed to make requests. + /// Will be verified against the `Origin` request header. + /// + /// When `All` is set, and `send_wildcard` is set, "*" will be sent in + /// the `Access-Control-Allow-Origin` response header. Otherwise, the + /// client's `Origin` request header will be echoed back in the + /// `Access-Control-Allow-Origin` response header. + /// + /// When `Some` is set, the client's `Origin` request header will be + /// checked in a case-sensitive manner. + /// + /// This is the `list of origins` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `All`. + /// + /// Builder panics if supplied origin is not valid uri. + pub fn allowed_origin(mut self, origin: &str) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + match Uri::try_from(origin) { + Ok(_) => { + if cors.origins.is_all() { + cors.origins = AllOrSome::Some(HashSet::new()); + } + if let AllOrSome::Some(ref mut origins) = cors.origins { + origins.insert(origin.to_owned()); + } + } + Err(e) => { + self.error = Some(e.into()); + } + } + } + self + } + + /// Set a list of methods which the allowed origins are allowed to access + /// for requests. + /// + /// This is the `list of methods` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` + pub fn allowed_methods(mut self, methods: U) -> Cors + where + U: IntoIterator, + Method: TryFrom, + >::Error: Into, + { + self.methods = true; + if let Some(cors) = cors(&mut self.cors, &self.error) { + for m in methods { + match Method::try_from(m) { + Ok(method) => { + cors.methods.insert(method); + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + } + self + } + + /// Set an allowed header + pub fn allowed_header(mut self, header: H) -> Cors + where + HeaderName: TryFrom, + >::Error: Into, + { + if let Some(cors) = cors(&mut self.cors, &self.error) { + match HeaderName::try_from(header) { + Ok(method) => { + if cors.headers.is_all() { + cors.headers = AllOrSome::Some(HashSet::new()); + } + if let AllOrSome::Some(ref mut headers) = cors.headers { + headers.insert(method); + } + } + Err(e) => self.error = Some(e.into()), + } + } + self + } + + /// Set a list of header field names which can be used when + /// this resource is accessed by allowed origins. + /// + /// If `All` is set, whatever is requested by the client in + /// `Access-Control-Request-Headers` will be echoed back in the + /// `Access-Control-Allow-Headers` header. + /// + /// This is the `list of headers` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `All`. + pub fn allowed_headers(mut self, headers: U) -> Cors + where + U: IntoIterator, + HeaderName: TryFrom, + >::Error: Into, + { + if let Some(cors) = cors(&mut self.cors, &self.error) { + for h in headers { + match HeaderName::try_from(h) { + Ok(method) => { + if cors.headers.is_all() { + cors.headers = AllOrSome::Some(HashSet::new()); + } + if let AllOrSome::Some(ref mut headers) = cors.headers { + headers.insert(method); + } + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + } + self + } + + /// Set a list of headers which are safe to expose to the API of a CORS API + /// specification. This corresponds to the + /// `Access-Control-Expose-Headers` response header. + /// + /// This is the `list of exposed headers` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// This defaults to an empty set. + pub fn expose_headers(mut self, headers: U) -> Cors + where + U: IntoIterator, + HeaderName: TryFrom, + >::Error: Into, + { + for h in headers { + match HeaderName::try_from(h) { + Ok(method) => { + self.expose_hdrs.insert(method); + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + self + } + + /// Set a maximum time for which this CORS request maybe cached. + /// This value is set as the `Access-Control-Max-Age` header. + /// + /// This defaults to `None` (unset). + pub fn max_age(mut self, max_age: usize) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.max_age = Some(max_age) + } + self + } + + /// Set a wildcard origins + /// + /// If send wildcard is set and the `allowed_origins` parameter is `All`, a + /// wildcard `Access-Control-Allow-Origin` response header is sent, + /// rather than the request’s `Origin` header. + /// + /// This is the `supports credentials flag` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// This **CANNOT** be used in conjunction with `allowed_origins` set to + /// `All` and `allow_credentials` set to `true`. Depending on the mode + /// of usage, this will either result in an `Error:: + /// CredentialsWithWildcardOrigin` error during actix launch or runtime. + /// + /// Defaults to `false`. + pub fn send_wildcard(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.send_wildcard = true + } + self + } + + /// Allows users to make authenticated requests + /// + /// If true, injects the `Access-Control-Allow-Credentials` header in + /// responses. This allows cookies and credentials to be submitted + /// across domains. + /// + /// This option cannot be used in conjunction with an `allowed_origin` set + /// to `All` and `send_wildcards` set to `true`. + /// + /// Defaults to `false`. + /// + /// Builder panics if credentials are allowed, but the Origin is set to "*". + /// This is not allowed by W3C + pub fn supports_credentials(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.supports_credentials = true + } + self + } + + /// Disable `Vary` header support. + /// + /// When enabled the header `Vary: Origin` will be returned as per the W3 + /// implementation guidelines. + /// + /// Setting this header when the `Access-Control-Allow-Origin` is + /// dynamically generated (e.g. when there is more than one allowed + /// origin, and an Origin than '*' is returned) informs CDNs and other + /// caches that the CORS headers are dynamic, and cannot be cached. + /// + /// By default `vary` header support is enabled. + pub fn disable_vary_header(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.vary_header = false + } + self + } + + /// Disable *preflight* request support. + /// + /// When enabled cors middleware automatically handles *OPTIONS* request. + /// This is useful application level middleware. + /// + /// By default *preflight* support is enabled. + pub fn disable_preflight(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.preflight = false + } + self + } + + /// Construct cors middleware + pub fn finish(self) -> CorsFactory { + let mut slf = if !self.methods { + self.allowed_methods(vec![ + Method::GET, + Method::HEAD, + Method::POST, + Method::OPTIONS, + Method::PUT, + Method::PATCH, + Method::DELETE, + ]) + } else { + self + }; + + if let Some(e) = slf.error.take() { + panic!("{}", e); + } + + let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder"); + + if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { + panic!("Credentials are allowed, but the Origin is set to \"*\""); + } + + if let AllOrSome::Some(ref origins) = cors.origins { + let s = origins + .iter() + .fold(String::new(), |s, v| format!("{}, {}", s, v)); + cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap()); + } + + if !slf.expose_hdrs.is_empty() { + cors.expose_hdrs = Some( + slf.expose_hdrs + .iter() + .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..] + .to_owned(), + ); + } + + CorsFactory { + inner: Rc::new(cors), + } + } +} + +fn cors<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + parts.as_mut() +} + +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +pub struct CorsFactory { + inner: Rc, +} + +impl Transform for CorsFactory +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = CorsMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CorsMiddleware { + service, + inner: self.inner.clone(), + }) + } +} + +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +#[derive(Clone)] +pub struct CorsMiddleware { + service: S, + inner: Rc, +} + +struct Inner { + methods: HashSet, + origins: AllOrSome>, + origins_str: Option, + headers: AllOrSome>, + expose_hdrs: Option, + max_age: Option, + preflight: bool, + send_wildcard: bool, + supports_credentials: bool, + vary_header: bool, +} + +impl Inner { + fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(&header::ORIGIN) { + if let Ok(origin) = hdr.to_str() { + return match self.origins { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_origins) => allowed_origins + .get(origin) + .and_then(|_| Some(())) + .ok_or_else(|| CorsError::OriginNotAllowed), + }; + } + Err(CorsError::BadOrigin) + } else { + match self.origins { + AllOrSome::All => Ok(()), + _ => Err(CorsError::MissingOrigin), + } + } + } + + fn access_control_allow_origin(&self, req: &RequestHead) -> Option { + match self.origins { + AllOrSome::All => { + if self.send_wildcard { + Some(HeaderValue::from_static("*")) + } else if let Some(origin) = req.headers().get(&header::ORIGIN) { + Some(origin.clone()) + } else { + None + } + } + AllOrSome::Some(ref origins) => { + if let Some(origin) = + req.headers() + .get(&header::ORIGIN) + .filter(|o| match o.to_str() { + Ok(os) => origins.contains(os), + _ => false, + }) + { + Some(origin.clone()) + } else { + Some(self.origins_str.as_ref().unwrap().clone()) + } + } + } + } + + fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) { + if let Ok(meth) = hdr.to_str() { + if let Ok(method) = Method::try_from(meth) { + return self + .methods + .get(&method) + .and_then(|_| Some(())) + .ok_or_else(|| CorsError::MethodNotAllowed); + } + } + Err(CorsError::BadRequestMethod) + } else { + Err(CorsError::MissingRequestMethod) + } + } + + fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> { + match self.headers { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_headers) => { + if let Some(hdr) = + req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) + { + if let Ok(headers) = hdr.to_str() { + let mut hdrs = HashSet::new(); + for hdr in headers.split(',') { + match HeaderName::try_from(hdr.trim()) { + Ok(hdr) => hdrs.insert(hdr), + Err(_) => return Err(CorsError::BadRequestHeaders), + }; + } + // `Access-Control-Request-Headers` must contain 1 or more + // `field-name`. + if !hdrs.is_empty() { + if !hdrs.is_subset(allowed_headers) { + return Err(CorsError::HeadersNotAllowed); + } + return Ok(()); + } + } + Err(CorsError::BadRequestHeaders) + } else { + Ok(()) + } + } + } + } +} + +impl Service for CorsMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = Either< + Ready>, + LocalBoxFuture<'static, Result>, + >; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + if self.inner.preflight && Method::OPTIONS == *req.method() { + if let Err(e) = self + .inner + .validate_origin(req.head()) + .and_then(|_| self.inner.validate_allowed_method(req.head())) + .and_then(|_| self.inner.validate_allowed_headers(req.head())) + { + return Either::Left(ok(req.error_response(e))); + } + + // allowed headers + let headers = if let Some(headers) = self.inner.headers.as_ref() { + Some( + HeaderValue::try_from( + &headers + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .unwrap(), + ) + } else if let Some(hdr) = + req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) + { + Some(hdr.clone()) + } else { + None + }; + + let res = HttpResponse::Ok() + .if_some(self.inner.max_age.as_ref(), |max_age, resp| { + let _ = resp.header( + header::ACCESS_CONTROL_MAX_AGE, + format!("{}", max_age).as_str(), + ); + }) + .if_some(headers, |headers, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); + }) + .if_some( + self.inner.access_control_allow_origin(req.head()), + |origin, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }, + ) + .if_true(self.inner.supports_credentials, |resp| { + resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + }) + .header( + header::ACCESS_CONTROL_ALLOW_METHODS, + &self + .inner + .methods + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .finish() + .into_body(); + + Either::Left(ok(req.into_response(res))) + } else { + if req.headers().contains_key(&header::ORIGIN) { + // Only check requests with a origin header. + if let Err(e) = self.inner.validate_origin(req.head()) { + return Either::Left(ok(req.error_response(e))); + } + } + + let inner = self.inner.clone(); + let has_origin = req.headers().contains_key(&header::ORIGIN); + let fut = self.service.call(req); + + Either::Right( + async move { + let res = fut.await; + + if has_origin { + let mut res = res?; + if let Some(origin) = + inner.access_control_allow_origin(res.request().head()) + { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.clone(), + ); + }; + + if let Some(ref expose) = inner.expose_hdrs { + res.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap(), + ); + } + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { + let value = if let Some(hdr) = + res.headers_mut().get(&header::VARY) + { + let mut val: Vec = + Vec::with_capacity(hdr.as_bytes().len() + 8); + val.extend(hdr.as_bytes()); + val.extend(b", Origin"); + HeaderValue::try_from(&val[..]).unwrap() + } else { + HeaderValue::from_static("Origin") + }; + res.headers_mut().insert(header::VARY, value); + } + Ok(res) + } else { + res + } + } + .boxed_local(), + ) + } + } +} + +#[cfg(test)] +mod tests { + use actix_service::{fn_service, Transform}; + use actix_web::test::{self, TestRequest}; + + use super::*; + + #[actix_rt::test] + #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] + async fn cors_validates_illegal_allow_credentials() { + let _cors = Cors::new().supports_credentials().send_wildcard().finish(); + } + + #[actix_rt::test] + async fn validate_origin_allows_all_origins() { + let mut cors = Cors::new() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn default() { + let mut cors = Cors::default() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_preflight() { + let mut cors = Cors::new() + .send_wildcard() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") + .to_srv_request(); + + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") + .method(Method::OPTIONS) + .to_srv_request(); + + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"3600"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_MAX_AGE) + .unwrap() + .as_bytes() + ); + let hdr = resp + .headers() + .get(&header::ACCESS_CONTROL_ALLOW_HEADERS) + .unwrap() + .to_str() + .unwrap(); + assert!(hdr.contains("authorization")); + assert!(hdr.contains("accept")); + assert!(hdr.contains("content-type")); + + let methods = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_METHODS) + .unwrap() + .to_str() + .unwrap(); + assert!(methods.contains("POST")); + assert!(methods.contains("GET")); + assert!(methods.contains("OPTIONS")); + + Rc::get_mut(&mut cors.inner).unwrap().preflight = false; + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + // #[actix_rt::test] + // #[should_panic(expected = "MissingOrigin")] + // async fn test_validate_missing_origin() { + // let cors = Cors::build() + // .allowed_origin("https://www.example.com") + // .finish(); + // let mut req = HttpRequest::default(); + // cors.start(&req).unwrap(); + // } + + #[actix_rt::test] + #[should_panic(expected = "OriginNotAllowed")] + async fn test_validate_not_allowed_origin() { + let cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + cors.inner.validate_origin(req.head()).unwrap(); + cors.inner.validate_allowed_method(req.head()).unwrap(); + cors.inner.validate_allowed_headers(req.head()).unwrap(); + } + + #[actix_rt::test] + async fn test_validate_origin() { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_no_origin_response() { + let mut cors = Cors::new() + .disable_preflight() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::default().method(Method::GET).to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert!(resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://www.example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } + + #[actix_rt::test] + async fn test_response() { + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + { + let headers = resp + .headers() + .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) + .unwrap() + .to_str() + .unwrap() + .split(',') + .map(|s| s.trim()) + .collect::>(); + + for h in exposed_headers { + assert!(headers.contains(&h.as_str())); + } + } + + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(fn_service(|req: ServiceRequest| { + ok(req.into_response( + HttpResponse::Ok().header(header::VARY, "Accept").finish(), + )) + })) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"Accept, Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + let mut cors = Cors::new() + .disable_vary_header() + .allowed_origin("https://www.example.com") + .allowed_origin("https://www.google.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + + let origins_str = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap(); + + assert_eq!("https://www.example.com", origins_str); + } + + #[actix_rt::test] + async fn test_multiple_origins() { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } + + #[actix_rt::test] + async fn test_multiple_origins_preflight() { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } +} diff --git a/actix-identity/CHANGES.md b/actix-identity/CHANGES.md new file mode 100644 index 000000000..0c9809ea1 --- /dev/null +++ b/actix-identity/CHANGES.md @@ -0,0 +1,17 @@ +# Changes + +## [Unreleased] - 2020-xx-xx + +* Update the `time` dependency to 0.2.5 + +## [0.2.1] - 2020-01-10 + +* Fix panic with already borrowed: BorrowMutError #1263 + +## [0.2.0] - 2019-12-20 + +* Use actix-web 2.0 + +## [0.1.0] - 2019-06-xx + +* Move identity middleware to separate crate diff --git a/actix-identity/Cargo.toml b/actix-identity/Cargo.toml new file mode 100644 index 000000000..efeb24bda --- /dev/null +++ b/actix-identity/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "actix-identity" +version = "0.2.1" +authors = ["Nikolay Kim "] +description = "Identity service for actix web framework." +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-identity/" +license = "MIT/Apache-2.0" +edition = "2018" + +[lib] +name = "actix_identity" +path = "src/lib.rs" + +[dependencies] +actix-web = { version = "2.0.0", default-features = false, features = ["secure-cookies"] } +actix-service = "1.0.2" +futures = "0.3.1" +serde = "1.0" +serde_json = "1.0" +time = { version = "0.2.5", default-features = false, features = ["std"] } + +[dev-dependencies] +actix-rt = "1.0.0" +actix-http = "1.0.1" +bytes = "0.5.3" diff --git a/actix-identity/LICENSE-APACHE b/actix-identity/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-identity/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-identity/LICENSE-MIT b/actix-identity/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-identity/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-identity/README.md b/actix-identity/README.md new file mode 100644 index 000000000..60b615c76 --- /dev/null +++ b/actix-identity/README.md @@ -0,0 +1,9 @@ +# Identity service for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-identity)](https://crates.io/crates/actix-identity) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-identity/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-session](https://crates.io/crates/actix-identity) +* Minimum supported Rust version: 1.34 or later diff --git a/actix-identity/src/lib.rs b/actix-identity/src/lib.rs new file mode 100644 index 000000000..b584b1af7 --- /dev/null +++ b/actix-identity/src/lib.rs @@ -0,0 +1,1128 @@ +//! Request identity service for Actix applications. +//! +//! [**IdentityService**](struct.IdentityService.html) middleware can be +//! used with different policies types to store identity information. +//! +//! By default, only cookie identity policy is implemented. Other backend +//! implementations can be added separately. +//! +//! [**CookieIdentityPolicy**](struct.CookieIdentityPolicy.html) +//! uses cookies as identity storage. +//! +//! To access current request identity +//! [**Identity**](struct.Identity.html) extractor should be used. +//! +//! ```rust +//! use actix_web::*; +//! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; +//! +//! async fn index(id: Identity) -> String { +//! // access request identity +//! if let Some(id) = id.identity() { +//! format!("Welcome! {}", id) +//! } else { +//! "Welcome Anonymous!".to_owned() +//! } +//! } +//! +//! async fn login(id: Identity) -> HttpResponse { +//! id.remember("User1".to_owned()); // <- remember identity +//! HttpResponse::Ok().finish() +//! } +//! +//! async fn logout(id: Identity) -> HttpResponse { +//! id.forget(); // <- remove identity +//! HttpResponse::Ok().finish() +//! } +//! +//! fn main() { +//! let app = App::new().wrap(IdentityService::new( +//! // <- create identity middleware +//! CookieIdentityPolicy::new(&[0; 32]) // <- create cookie identity policy +//! .name("auth-cookie") +//! .secure(false))) +//! .service(web::resource("/index.html").to(index)) +//! .service(web::resource("/login.html").to(login)) +//! .service(web::resource("/logout.html").to(logout)); +//! } +//! ``` +use std::cell::RefCell; +use std::future::Future; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +use actix_service::{Service, Transform}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use serde::{Deserialize, Serialize}; +use time::Duration; + +use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; +use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse}; +use actix_web::error::{Error, Result}; +use actix_web::http::header::{self, HeaderValue}; +use actix_web::{FromRequest, HttpMessage, HttpRequest}; + +/// The extractor type to obtain your identity from a request. +/// +/// ```rust +/// use actix_web::*; +/// use actix_identity::Identity; +/// +/// fn index(id: Identity) -> Result { +/// // access request identity +/// if let Some(id) = id.identity() { +/// Ok(format!("Welcome! {}", id)) +/// } else { +/// Ok("Welcome Anonymous!".to_owned()) +/// } +/// } +/// +/// fn login(id: Identity) -> HttpResponse { +/// id.remember("User1".to_owned()); // <- remember identity +/// HttpResponse::Ok().finish() +/// } +/// +/// fn logout(id: Identity) -> HttpResponse { +/// id.forget(); // <- remove identity +/// HttpResponse::Ok().finish() +/// } +/// # fn main() {} +/// ``` +#[derive(Clone)] +pub struct Identity(HttpRequest); + +impl Identity { + /// Return the claimed identity of the user associated request or + /// ``None`` if no identity can be found associated with the request. + pub fn identity(&self) -> Option { + Identity::get_identity(&self.0.extensions()) + } + + /// Remember identity. + pub fn remember(&self, identity: String) { + if let Some(id) = self.0.extensions_mut().get_mut::() { + id.id = Some(identity); + id.changed = true; + } + } + + /// This method is used to 'forget' the current identity on subsequent + /// requests. + pub fn forget(&self) { + if let Some(id) = self.0.extensions_mut().get_mut::() { + id.id = None; + id.changed = true; + } + } + + fn get_identity(extensions: &Extensions) -> Option { + if let Some(id) = extensions.get::() { + id.id.clone() + } else { + None + } + } +} + +struct IdentityItem { + id: Option, + changed: bool, +} + +/// Helper trait that allows to get Identity. +/// +/// It could be used in middleware but identity policy must be set before any other middleware that needs identity +/// RequestIdentity is implemented both for `ServiceRequest` and `HttpRequest`. +pub trait RequestIdentity { + fn get_identity(&self) -> Option; +} + +impl RequestIdentity for T +where + T: HttpMessage, +{ + fn get_identity(&self) -> Option { + Identity::get_identity(&self.extensions()) + } +} + +/// Extractor implementation for Identity type. +/// +/// ```rust +/// # use actix_web::*; +/// use actix_identity::Identity; +/// +/// fn index(id: Identity) -> String { +/// // access request identity +/// if let Some(id) = id.identity() { +/// format!("Welcome! {}", id) +/// } else { +/// "Welcome Anonymous!".to_owned() +/// } +/// } +/// # fn main() {} +/// ``` +impl FromRequest for Identity { + type Config = (); + type Error = Error; + type Future = Ready>; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(Identity(req.clone())) + } +} + +/// Identity policy definition. +pub trait IdentityPolicy: Sized + 'static { + /// The return type of the middleware + type Future: Future, Error>>; + + /// The return type of the middleware + type ResponseFuture: Future>; + + /// Parse the session from request and load data from a service identity. + fn from_request(&self, request: &mut ServiceRequest) -> Self::Future; + + /// Write changes to response + fn to_response( + &self, + identity: Option, + changed: bool, + response: &mut ServiceResponse, + ) -> Self::ResponseFuture; +} + +/// Request identity middleware +/// +/// ```rust +/// use actix_web::App; +/// use actix_identity::{CookieIdentityPolicy, IdentityService}; +/// +/// fn main() { +/// let app = App::new().wrap(IdentityService::new( +/// // <- create identity middleware +/// CookieIdentityPolicy::new(&[0; 32]) // <- create cookie session backend +/// .name("auth-cookie") +/// .secure(false), +/// )); +/// } +/// ``` +pub struct IdentityService { + backend: Rc, +} + +impl IdentityService { + /// Create new identity service with specified backend. + pub fn new(backend: T) -> Self { + IdentityService { + backend: Rc::new(backend), + } + } +} + +impl Transform for IdentityService +where + S: Service, Error = Error> + + 'static, + S::Future: 'static, + T: IdentityPolicy, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = IdentityServiceMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(IdentityServiceMiddleware { + backend: self.backend.clone(), + service: Rc::new(RefCell::new(service)), + }) + } +} + +#[doc(hidden)] +pub struct IdentityServiceMiddleware { + backend: Rc, + service: Rc>, +} + +impl Clone for IdentityServiceMiddleware { + fn clone(&self) -> Self { + Self { + backend: self.backend.clone(), + service: self.service.clone(), + } + } +} + +impl Service for IdentityServiceMiddleware +where + B: 'static, + S: Service, Error = Error> + + 'static, + S::Future: 'static, + T: IdentityPolicy, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.borrow_mut().poll_ready(cx) + } + + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + let srv = self.service.clone(); + let backend = self.backend.clone(); + let fut = self.backend.from_request(&mut req); + + async move { + match fut.await { + Ok(id) => { + req.extensions_mut() + .insert(IdentityItem { id, changed: false }); + + // https://github.com/actix/actix-web/issues/1263 + let fut = { srv.borrow_mut().call(req) }; + let mut res = fut.await?; + let id = res.request().extensions_mut().remove::(); + + if let Some(id) = id { + match backend.to_response(id.id, id.changed, &mut res).await { + Ok(_) => Ok(res), + Err(e) => Ok(res.error_response(e)), + } + } else { + Ok(res) + } + } + Err(err) => Ok(req.error_response(err)), + } + } + .boxed_local() + } +} + +struct CookieIdentityInner { + key: Key, + key_v2: Key, + name: String, + path: String, + domain: Option, + secure: bool, + max_age: Option, + same_site: Option, + visit_deadline: Option, + login_deadline: Option, +} + +#[derive(Deserialize, Serialize, Debug)] +struct CookieValue { + identity: String, + #[serde(skip_serializing_if = "Option::is_none")] + login_timestamp: Option, + #[serde(skip_serializing_if = "Option::is_none")] + visit_timestamp: Option, +} + +#[derive(Debug)] +struct CookieIdentityExtention { + login_timestamp: Option, +} + +impl CookieIdentityInner { + fn new(key: &[u8]) -> CookieIdentityInner { + let key_v2: Vec = key.iter().chain([1, 0, 0, 0].iter()).cloned().collect(); + CookieIdentityInner { + key: Key::from_master(key), + key_v2: Key::from_master(&key_v2), + name: "actix-identity".to_owned(), + path: "/".to_owned(), + domain: None, + secure: true, + max_age: None, + same_site: None, + visit_deadline: None, + login_deadline: None, + } + } + + fn set_cookie( + &self, + resp: &mut ServiceResponse, + value: Option, + ) -> Result<()> { + let add_cookie = value.is_some(); + let val = value.map(|val| { + if !self.legacy_supported() { + serde_json::to_string(&val) + } else { + Ok(val.identity) + } + }); + let mut cookie = + Cookie::new(self.name.clone(), val.unwrap_or_else(|| Ok(String::new()))?); + cookie.set_path(self.path.clone()); + cookie.set_secure(self.secure); + cookie.set_http_only(true); + + if let Some(ref domain) = self.domain { + cookie.set_domain(domain.clone()); + } + + if let Some(max_age) = self.max_age { + cookie.set_max_age(max_age); + } + + if let Some(same_site) = self.same_site { + cookie.set_same_site(same_site); + } + + let mut jar = CookieJar::new(); + let key = if self.legacy_supported() { + &self.key + } else { + &self.key_v2 + }; + if add_cookie { + jar.private(&key).add(cookie); + } else { + jar.add_original(cookie.clone()); + jar.private(&key).remove(cookie); + } + for cookie in jar.delta() { + let val = HeaderValue::from_str(&cookie.to_string())?; + resp.headers_mut().append(header::SET_COOKIE, val); + } + Ok(()) + } + + fn load(&self, req: &ServiceRequest) -> Option { + let cookie = req.cookie(&self.name)?; + let mut jar = CookieJar::new(); + jar.add_original(cookie.clone()); + let res = if self.legacy_supported() { + jar.private(&self.key).get(&self.name).map(|n| CookieValue { + identity: n.value().to_string(), + login_timestamp: None, + visit_timestamp: None, + }) + } else { + None + }; + res.or_else(|| { + jar.private(&self.key_v2) + .get(&self.name) + .and_then(|c| self.parse(c)) + }) + } + + fn parse(&self, cookie: Cookie) -> Option { + let value: CookieValue = serde_json::from_str(cookie.value()).ok()?; + let now = SystemTime::now(); + if let Some(visit_deadline) = self.visit_deadline { + if now.duration_since(value.visit_timestamp?).ok()? + > visit_deadline + { + return None; + } + } + if let Some(login_deadline) = self.login_deadline { + if now.duration_since(value.login_timestamp?).ok()? + > login_deadline + { + return None; + } + } + Some(value) + } + + fn legacy_supported(&self) -> bool { + self.visit_deadline.is_none() && self.login_deadline.is_none() + } + + fn always_update_cookie(&self) -> bool { + self.visit_deadline.is_some() + } + + fn requires_oob_data(&self) -> bool { + self.login_deadline.is_some() + } +} + +/// Use cookies for request identity storage. +/// +/// The constructors take a key as an argument. +/// This is the private key for cookie - when this value is changed, +/// all identities are lost. The constructors will panic if the key is less +/// than 32 bytes in length. +/// +/// # Example +/// +/// ```rust +/// use actix_web::App; +/// use actix_identity::{CookieIdentityPolicy, IdentityService}; +/// +/// fn main() { +/// let app = App::new().wrap(IdentityService::new( +/// // <- create identity middleware +/// CookieIdentityPolicy::new(&[0; 32]) // <- construct cookie policy +/// .domain("www.rust-lang.org") +/// .name("actix_auth") +/// .path("/") +/// .secure(true), +/// )); +/// } +/// ``` +pub struct CookieIdentityPolicy(Rc); + +impl CookieIdentityPolicy { + /// Construct new `CookieIdentityPolicy` instance. + /// + /// Panics if key length is less than 32 bytes. + pub fn new(key: &[u8]) -> CookieIdentityPolicy { + CookieIdentityPolicy(Rc::new(CookieIdentityInner::new(key))) + } + + /// Sets the `path` field in the session cookie being built. + pub fn path>(mut self, value: S) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().path = value.into(); + self + } + + /// Sets the `name` field in the session cookie being built. + pub fn name>(mut self, value: S) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().name = value.into(); + self + } + + /// Sets the `domain` field in the session cookie being built. + pub fn domain>(mut self, value: S) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().domain = Some(value.into()); + self + } + + /// Sets the `secure` field in the session cookie being built. + /// + /// If the `secure` field is set, a cookie will only be transmitted when the + /// connection is secure - i.e. `https` + pub fn secure(mut self, value: bool) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().secure = value; + self + } + + /// Sets the `max-age` field in the session cookie being built with given number of seconds. + pub fn max_age(self, seconds: i64) -> CookieIdentityPolicy { + self.max_age_time(Duration::seconds(seconds)) + } + + /// Sets the `max-age` field in the session cookie being built with `chrono::Duration`. + pub fn max_age_time(mut self, value: Duration) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().max_age = Some(value); + self + } + + /// Sets the `same_site` field in the session cookie being built. + pub fn same_site(mut self, same_site: SameSite) -> Self { + Rc::get_mut(&mut self.0).unwrap().same_site = Some(same_site); + self + } + + /// Accepts only users whose cookie has been seen before the given deadline + /// + /// By default visit deadline is disabled. + pub fn visit_deadline(mut self, value: Duration) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().visit_deadline = Some(value); + self + } + + /// Accepts only users which has been authenticated before the given deadline + /// + /// By default login deadline is disabled. + pub fn login_deadline(mut self, value: Duration) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().login_deadline = Some(value); + self + } +} + +impl IdentityPolicy for CookieIdentityPolicy { + type Future = Ready, Error>>; + type ResponseFuture = Ready>; + + fn from_request(&self, req: &mut ServiceRequest) -> Self::Future { + ok(self.0.load(req).map( + |CookieValue { + identity, + login_timestamp, + .. + }| { + if self.0.requires_oob_data() { + req.extensions_mut() + .insert(CookieIdentityExtention { login_timestamp }); + } + identity + }, + )) + } + + fn to_response( + &self, + id: Option, + changed: bool, + res: &mut ServiceResponse, + ) -> Self::ResponseFuture { + let _ = if changed { + let login_timestamp = SystemTime::now(); + self.0.set_cookie( + res, + id.map(|identity| CookieValue { + identity, + login_timestamp: self.0.login_deadline.map(|_| login_timestamp), + visit_timestamp: self.0.visit_deadline.map(|_| login_timestamp), + }), + ) + } else if self.0.always_update_cookie() && id.is_some() { + let visit_timestamp = SystemTime::now(); + let login_timestamp = if self.0.requires_oob_data() { + let CookieIdentityExtention { + login_timestamp: lt, + } = res.request().extensions_mut().remove().unwrap(); + lt + } else { + None + }; + self.0.set_cookie( + res, + Some(CookieValue { + identity: id.unwrap(), + login_timestamp, + visit_timestamp: self.0.visit_deadline.map(|_| visit_timestamp), + }), + ) + } else { + Ok(()) + }; + ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Borrow; + + use super::*; + use actix_service::into_service; + use actix_web::http::StatusCode; + use actix_web::test::{self, TestRequest}; + use actix_web::{error, web, App, Error, HttpResponse}; + + const COOKIE_KEY_MASTER: [u8; 32] = [0; 32]; + const COOKIE_NAME: &'static str = "actix_auth"; + const COOKIE_LOGIN: &'static str = "test"; + + #[actix_rt::test] + async fn test_identity() { + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .secure(true), + )) + .service(web::resource("/index").to(|id: Identity| { + if id.identity().is_some() { + HttpResponse::Created() + } else { + HttpResponse::Ok() + } + })) + .service(web::resource("/login").to(|id: Identity| { + id.remember(COOKIE_LOGIN.to_string()); + HttpResponse::Ok() + })) + .service(web::resource("/logout").to(|id: Identity| { + if id.identity().is_some() { + id.forget(); + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + })), + ) + .await; + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + let c = resp.response().cookies().next().unwrap().to_owned(); + + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/index") + .cookie(c.clone()) + .to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/logout") + .cookie(c.clone()) + .to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)) + } + + #[actix_rt::test] + async fn test_identity_max_age_time() { + let duration = Duration::days(1); + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .max_age_time(duration) + .secure(true), + )) + .service(web::resource("/login").to(|id: Identity| { + id.remember("test".to_string()); + HttpResponse::Ok() + })), + ) + .await; + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)); + let c = resp.response().cookies().next().unwrap().to_owned(); + assert_eq!(duration, c.max_age().unwrap()); + } + + #[actix_rt::test] + async fn test_identity_max_age() { + let seconds = 60; + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .max_age(seconds) + .secure(true), + )) + .service(web::resource("/login").to(|id: Identity| { + id.remember("test".to_string()); + HttpResponse::Ok() + })), + ) + .await; + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)); + let c = resp.response().cookies().next().unwrap().to_owned(); + assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap()); + } + + async fn create_identity_server< + F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static, + >( + f: F, + ) -> impl actix_service::Service< + Request = actix_http::Request, + Response = ServiceResponse, + Error = Error, + > { + test::init_service( + App::new() + .wrap(IdentityService::new(f(CookieIdentityPolicy::new( + &COOKIE_KEY_MASTER, + ) + .secure(false) + .name(COOKIE_NAME)))) + .service(web::resource("/").to(|id: Identity| { + async move { + let identity = id.identity(); + if identity.is_none() { + id.remember(COOKIE_LOGIN.to_string()) + } + web::Json(identity) + } + })), + ) + .await + } + + fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { + let mut jar = CookieJar::new(); + jar.private(&Key::from_master(&COOKIE_KEY_MASTER)) + .add(Cookie::new(COOKIE_NAME, identity)); + jar.get(COOKIE_NAME).unwrap().clone() + } + + fn login_cookie( + identity: &'static str, + login_timestamp: Option, + visit_timestamp: Option, + ) -> Cookie<'static> { + let mut jar = CookieJar::new(); + let key: Vec = COOKIE_KEY_MASTER + .iter() + .chain([1, 0, 0, 0].iter()) + .map(|e| *e) + .collect(); + jar.private(&Key::from_master(&key)).add(Cookie::new( + COOKIE_NAME, + serde_json::to_string(&CookieValue { + identity: identity.to_string(), + login_timestamp, + visit_timestamp, + }) + .unwrap(), + )); + jar.get(COOKIE_NAME).unwrap().clone() + } + + async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) { + let bytes = test::read_body(response).await; + let resp: Option = serde_json::from_slice(&bytes[..]).unwrap(); + assert_eq!(resp.as_ref().map(|s| s.borrow()), identity); + } + + fn assert_legacy_login_cookie(response: &mut ServiceResponse, identity: &str) { + let mut cookies = CookieJar::new(); + for cookie in response.headers().get_all(header::SET_COOKIE) { + cookies.add(Cookie::parse(cookie.to_str().unwrap().to_string()).unwrap()); + } + let cookie = cookies + .private(&Key::from_master(&COOKIE_KEY_MASTER)) + .get(COOKIE_NAME) + .unwrap(); + assert_eq!(cookie.value(), identity); + } + + enum LoginTimestampCheck { + NoTimestamp, + NewTimestamp, + OldTimestamp(SystemTime), + } + + enum VisitTimeStampCheck { + NoTimestamp, + NewTimestamp, + } + + fn assert_login_cookie( + response: &mut ServiceResponse, + identity: &str, + login_timestamp: LoginTimestampCheck, + visit_timestamp: VisitTimeStampCheck, + ) { + let mut cookies = CookieJar::new(); + for cookie in response.headers().get_all(header::SET_COOKIE) { + cookies.add(Cookie::parse(cookie.to_str().unwrap().to_string()).unwrap()); + } + let key: Vec = COOKIE_KEY_MASTER + .iter() + .chain([1, 0, 0, 0].iter()) + .map(|e| *e) + .collect(); + let cookie = cookies + .private(&Key::from_master(&key)) + .get(COOKIE_NAME) + .unwrap(); + let cv: CookieValue = serde_json::from_str(cookie.value()).unwrap(); + assert_eq!(cv.identity, identity); + let now = SystemTime::now(); + let t30sec_ago = now - Duration::seconds(30); + match login_timestamp { + LoginTimestampCheck::NoTimestamp => assert_eq!(cv.login_timestamp, None), + LoginTimestampCheck::NewTimestamp => assert!( + t30sec_ago <= cv.login_timestamp.unwrap() + && cv.login_timestamp.unwrap() <= now + ), + LoginTimestampCheck::OldTimestamp(old_timestamp) => { + assert_eq!(cv.login_timestamp, Some(old_timestamp)) + } + } + match visit_timestamp { + VisitTimeStampCheck::NoTimestamp => assert_eq!(cv.visit_timestamp, None), + VisitTimeStampCheck::NewTimestamp => assert!( + t30sec_ago <= cv.visit_timestamp.unwrap() + && cv.visit_timestamp.unwrap() <= now + ), + } + } + + fn assert_no_login_cookie(response: &mut ServiceResponse) { + let mut cookies = CookieJar::new(); + for cookie in response.headers().get_all(header::SET_COOKIE) { + cookies.add(Cookie::parse(cookie.to_str().unwrap().to_string()).unwrap()); + } + assert!(cookies.get(COOKIE_NAME).is_none()); + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_is_set() { + let mut srv = create_identity_server(|c| c).await; + let mut resp = + test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await; + assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_works() { + let mut srv = create_identity_server(|c| c).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_login_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_visit_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_login_timestamp_too_old() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie( + COOKIE_LOGIN, + Some(SystemTime::now() - Duration::days(180)), + None, + ); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = login_cookie( + COOKIE_LOGIN, + None, + Some(SystemTime::now() - Duration::days(180)), + ); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_not_updated_on_login_deadline() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + } + + // https://github.com/actix/actix-web/issues/1263 + #[actix_rt::test] + async fn test_identity_cookie_updated_on_visit_deadline() { + let mut srv = create_identity_server(|c| { + c.visit_deadline(Duration::days(90)) + .login_deadline(Duration::days(90)) + }) + .await; + let timestamp = SystemTime::now() - Duration::days(1); + let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::OldTimestamp(timestamp), + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + } + + #[actix_rt::test] + async fn test_borrowed_mut_error() { + use futures::future::{lazy, ok, Ready}; + + struct Ident; + impl IdentityPolicy for Ident { + type Future = Ready, Error>>; + type ResponseFuture = Ready>; + + fn from_request(&self, _: &mut ServiceRequest) -> Self::Future { + ok(Some("test".to_string())) + } + + fn to_response( + &self, + _: Option, + _: bool, + _: &mut ServiceResponse, + ) -> Self::ResponseFuture { + ok(()) + } + } + + let mut srv = IdentityServiceMiddleware { + backend: Rc::new(Ident), + service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| { + async move { + actix_rt::time::delay_for(std::time::Duration::from_secs(100)).await; + Err::(error::ErrorBadRequest("error")) + } + }))), + }; + + let mut srv2 = srv.clone(); + let req = TestRequest::default().to_srv_request(); + actix_rt::spawn(async move { + let _ = srv2.call(req).await; + }); + actix_rt::time::delay_for(std::time::Duration::from_millis(50)).await; + + let _ = lazy(|cx| srv.poll_ready(cx)).await; + } +} diff --git a/actix-session/CHANGES.md b/actix-session/CHANGES.md new file mode 100644 index 000000000..f6753ae58 --- /dev/null +++ b/actix-session/CHANGES.md @@ -0,0 +1,73 @@ +# Changes + +## [Unreleased] - 2020-01-xx + +* Update the `time` dependency to 0.2.5 +* [#1292](https://github.com/actix/actix-web/pull/1292) Long lasting auto-prolonged session + +## [0.3.0] - 2019-12-20 + +* Release + +## [0.3.0-alpha.4] - 2019-12-xx + +* Allow access to sessions also from not mutable references to the request + +## [0.3.0-alpha.3] - 2019-12-xx + +* Add access to the session from RequestHead for use of session from guard methods + +* Migrate to `std::future` + +* Migrate to `actix-web` 2.0 + +## [0.2.0] - 2019-07-08 + +* Enhanced ``actix-session`` to facilitate state changes. Use ``Session.renew()`` + at successful login to cycle a session (new key/cookie but keeps state). + Use ``Session.purge()`` at logout to invalid a session cookie (and remove + from redis cache, if applicable). + +## [0.1.1] - 2019-06-03 + +* Fix optional cookie session support + +## [0.1.0] - 2019-05-18 + +* Use actix-web 1.0.0-rc + +## [0.1.0-beta.4] - 2019-05-12 + +* Use actix-web 1.0.0-beta.4 + +## [0.1.0-beta.2] - 2019-04-28 + +* Add helper trait `UserSession` which allows to get session for ServiceRequest and HttpRequest + +## [0.1.0-beta.1] - 2019-04-20 + +* Update actix-web to beta.1 + +* `CookieSession::max_age()` accepts value in seconds + +## [0.1.0-alpha.6] - 2019-04-14 + +* Update actix-web alpha.6 + +## [0.1.0-alpha.4] - 2019-04-08 + +* Update actix-web + +## [0.1.0-alpha.3] - 2019-04-02 + +* Update actix-web + +## [0.1.0-alpha.2] - 2019-03-29 + +* Update actix-web + +* Use new feature name for secure cookies + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/actix-session/Cargo.toml b/actix-session/Cargo.toml new file mode 100644 index 000000000..b279c9d89 --- /dev/null +++ b/actix-session/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "actix-session" +version = "0.3.0" +authors = ["Nikolay Kim "] +description = "Session for actix web framework." +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-session/" +license = "MIT/Apache-2.0" +edition = "2018" + +[lib] +name = "actix_session" +path = "src/lib.rs" + +[features] +default = ["cookie-session"] + +# sessions feature, session require "ring" crate and c compiler +cookie-session = ["actix-web/secure-cookies"] + +[dependencies] +actix-web = "2.0.0-rc" +actix-service = "1.0.1" +bytes = "0.5.3" +derive_more = "0.99.2" +futures = "0.3.1" +serde = "1.0" +serde_json = "1.0" +time = { version = "0.2.5", default-features = false, features = ["std"] } + +[dev-dependencies] +actix-rt = "1.0.0" diff --git a/actix-session/LICENSE-APACHE b/actix-session/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-session/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-session/LICENSE-MIT b/actix-session/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-session/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-session/README.md b/actix-session/README.md new file mode 100644 index 000000000..0aee756fd --- /dev/null +++ b/actix-session/README.md @@ -0,0 +1,9 @@ +# Session for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-session)](https://crates.io/crates/actix-session) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-session/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-session](https://crates.io/crates/actix-session) +* Minimum supported Rust version: 1.34 or later diff --git a/actix-session/src/cookie.rs b/actix-session/src/cookie.rs new file mode 100644 index 000000000..b5297f561 --- /dev/null +++ b/actix-session/src/cookie.rs @@ -0,0 +1,545 @@ +//! Cookie session. +//! +//! [**CookieSession**](struct.CookieSession.html) +//! uses cookies as session storage. `CookieSession` creates sessions +//! which are limited to storing fewer than 4000 bytes of data, as the payload +//! must fit into a single cookie. An internal server error is generated if a +//! session contains more than 4000 bytes. +//! +//! A cookie may have a security policy of *signed* or *private*. Each has +//! a respective `CookieSession` constructor. +//! +//! A *signed* cookie may be viewed but not modified by the client. A *private* +//! cookie may neither be viewed nor modified by the client. +//! +//! The constructors take a key as an argument. This is the private key +//! for cookie session - when this value is changed, all session data is lost. + +use std::collections::HashMap; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_service::{Service, Transform}; +use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; +use actix_web::dev::{ServiceRequest, ServiceResponse}; +use actix_web::http::{header::SET_COOKIE, HeaderValue}; +use actix_web::{Error, HttpMessage, ResponseError}; +use derive_more::{Display, From}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use serde_json::error::Error as JsonError; +use time::{Duration, OffsetDateTime}; + +use crate::{Session, SessionStatus}; + +/// Errors that can occur during handling cookie session +#[derive(Debug, From, Display)] +pub enum CookieSessionError { + /// Size of the serialized session is greater than 4000 bytes. + #[display(fmt = "Size of the serialized session is greater than 4000 bytes.")] + Overflow, + /// Fail to serialize session. + #[display(fmt = "Fail to serialize session")] + Serialize(JsonError), +} + +impl ResponseError for CookieSessionError {} + +enum CookieSecurity { + Signed, + Private, +} + +struct CookieSessionInner { + key: Key, + security: CookieSecurity, + name: String, + path: String, + domain: Option, + secure: bool, + http_only: bool, + max_age: Option, + expires_in: Option, + same_site: Option, +} + +impl CookieSessionInner { + fn new(key: &[u8], security: CookieSecurity) -> CookieSessionInner { + CookieSessionInner { + security, + key: Key::from_master(key), + name: "actix-session".to_owned(), + path: "/".to_owned(), + domain: None, + secure: true, + http_only: true, + max_age: None, + expires_in: None, + same_site: None, + } + } + + fn set_cookie( + &self, + res: &mut ServiceResponse, + state: impl Iterator, + ) -> Result<(), Error> { + let state: HashMap = state.collect(); + let value = + serde_json::to_string(&state).map_err(CookieSessionError::Serialize)?; + if value.len() > 4064 { + return Err(CookieSessionError::Overflow.into()); + } + + let mut cookie = Cookie::new(self.name.clone(), value); + cookie.set_path(self.path.clone()); + cookie.set_secure(self.secure); + cookie.set_http_only(self.http_only); + + if let Some(ref domain) = self.domain { + cookie.set_domain(domain.clone()); + } + + if let Some(expires_in) = self.expires_in { + cookie.set_expires(OffsetDateTime::now() + expires_in); + } + + if let Some(max_age) = self.max_age { + cookie.set_max_age(max_age); + } + + if let Some(same_site) = self.same_site { + cookie.set_same_site(same_site); + } + + let mut jar = CookieJar::new(); + + match self.security { + CookieSecurity::Signed => jar.signed(&self.key).add(cookie), + CookieSecurity::Private => jar.private(&self.key).add(cookie), + } + + for cookie in jar.delta() { + let val = HeaderValue::from_str(&cookie.encoded().to_string())?; + res.headers_mut().append(SET_COOKIE, val); + } + + Ok(()) + } + + /// invalidates session cookie + fn remove_cookie(&self, res: &mut ServiceResponse) -> Result<(), Error> { + let mut cookie = Cookie::named(self.name.clone()); + cookie.set_value(""); + cookie.set_max_age(Duration::zero()); + cookie.set_expires(OffsetDateTime::now() - Duration::days(365)); + + let val = HeaderValue::from_str(&cookie.to_string())?; + res.headers_mut().append(SET_COOKIE, val); + + Ok(()) + } + + fn load(&self, req: &ServiceRequest) -> (bool, HashMap) { + if let Ok(cookies) = req.cookies() { + for cookie in cookies.iter() { + if cookie.name() == self.name { + let mut jar = CookieJar::new(); + jar.add_original(cookie.clone()); + + let cookie_opt = match self.security { + CookieSecurity::Signed => jar.signed(&self.key).get(&self.name), + CookieSecurity::Private => { + jar.private(&self.key).get(&self.name) + } + }; + if let Some(cookie) = cookie_opt { + if let Ok(val) = serde_json::from_str(cookie.value()) { + return (false, val); + } + } + } + } + } + (true, HashMap::new()) + } +} + +/// Use cookies for session storage. +/// +/// `CookieSession` creates sessions which are limited to storing +/// fewer than 4000 bytes of data (as the payload must fit into a single +/// cookie). An Internal Server Error is generated if the session contains more +/// than 4000 bytes. +/// +/// A cookie may have a security policy of *signed* or *private*. Each has a +/// respective `CookieSessionBackend` constructor. +/// +/// A *signed* cookie is stored on the client as plaintext alongside +/// a signature such that the cookie may be viewed but not modified by the +/// client. +/// +/// A *private* cookie is stored on the client as encrypted text +/// such that it may neither be viewed nor modified by the client. +/// +/// The constructors take a key as an argument. +/// This is the private key for cookie session - when this value is changed, +/// all session data is lost. The constructors will panic if the key is less +/// than 32 bytes in length. +/// +/// The backend relies on `cookie` crate to create and read cookies. +/// By default all cookies are percent encoded, but certain symbols may +/// cause troubles when reading cookie, if they are not properly percent encoded. +/// +/// # Example +/// +/// ```rust +/// use actix_session::CookieSession; +/// use actix_web::{web, App, HttpResponse, HttpServer}; +/// +/// fn main() { +/// let app = App::new().wrap( +/// CookieSession::signed(&[0; 32]) +/// .domain("www.rust-lang.org") +/// .name("actix_session") +/// .path("/") +/// .secure(true)) +/// .service(web::resource("/").to(|| HttpResponse::Ok())); +/// } +/// ``` +pub struct CookieSession(Rc); + +impl CookieSession { + /// Construct new *signed* `CookieSessionBackend` instance. + /// + /// Panics if key length is less than 32 bytes. + pub fn signed(key: &[u8]) -> CookieSession { + CookieSession(Rc::new(CookieSessionInner::new( + key, + CookieSecurity::Signed, + ))) + } + + /// Construct new *private* `CookieSessionBackend` instance. + /// + /// Panics if key length is less than 32 bytes. + pub fn private(key: &[u8]) -> CookieSession { + CookieSession(Rc::new(CookieSessionInner::new( + key, + CookieSecurity::Private, + ))) + } + + /// Sets the `path` field in the session cookie being built. + pub fn path>(mut self, value: S) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().path = value.into(); + self + } + + /// Sets the `name` field in the session cookie being built. + pub fn name>(mut self, value: S) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().name = value.into(); + self + } + + /// Sets the `domain` field in the session cookie being built. + pub fn domain>(mut self, value: S) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().domain = Some(value.into()); + self + } + + /// Sets the `secure` field in the session cookie being built. + /// + /// If the `secure` field is set, a cookie will only be transmitted when the + /// connection is secure - i.e. `https` + pub fn secure(mut self, value: bool) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().secure = value; + self + } + + /// Sets the `http_only` field in the session cookie being built. + pub fn http_only(mut self, value: bool) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().http_only = value; + self + } + + /// Sets the `same_site` field in the session cookie being built. + pub fn same_site(mut self, value: SameSite) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().same_site = Some(value); + self + } + + /// Sets the `max-age` field in the session cookie being built. + pub fn max_age(self, seconds: i64) -> CookieSession { + self.max_age_time(Duration::seconds(seconds)) + } + + /// Sets the `max-age` field in the session cookie being built. + pub fn max_age_time(mut self, value: time::Duration) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().max_age = Some(value); + self + } + + /// Sets the `expires` field in the session cookie being built. + pub fn expires_in(self, seconds: i64) -> CookieSession { + self.expires_in_time(Duration::seconds(seconds)) + } + + /// Sets the `expires` field in the session cookie being built. + pub fn expires_in_time(mut self, value: Duration) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().expires_in = Some(value); + self + } +} + +impl Transform for CookieSession +where + S: Service>, + S::Future: 'static, + S::Error: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = S::Error; + type InitError = (); + type Transform = CookieSessionMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CookieSessionMiddleware { + service, + inner: self.0.clone(), + }) + } +} + +/// Cookie session middleware +pub struct CookieSessionMiddleware { + service: S, + inner: Rc, +} + +impl Service for CookieSessionMiddleware +where + S: Service>, + S::Future: 'static, + S::Error: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = S::Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + /// On first request, a new session cookie is returned in response, regardless + /// of whether any session state is set. With subsequent requests, if the + /// session state changes, then set-cookie is returned in response. As + /// a user logs out, call session.purge() to set SessionStatus accordingly + /// and this will trigger removal of the session cookie in the response. + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + let inner = self.inner.clone(); + let (is_new, state) = self.inner.load(&req); + let prolong_expiration = self.inner.expires_in.is_some(); + Session::set_session(state.into_iter(), &mut req); + + let fut = self.service.call(req); + + async move { + fut.await.map(|mut res| { + match Session::get_changes(&mut res) { + (SessionStatus::Changed, Some(state)) + | (SessionStatus::Renewed, Some(state)) => { + res.checked_expr(|res| inner.set_cookie(res, state)) + } + (SessionStatus::Unchanged, Some(state)) if prolong_expiration => { + res.checked_expr(|res| inner.set_cookie(res, state)) + } + (SessionStatus::Unchanged, _) => + // set a new session cookie upon first request (new client) + { + if is_new { + let state: HashMap = HashMap::new(); + res.checked_expr(|res| { + inner.set_cookie(res, state.into_iter()) + }) + } else { + res + } + } + (SessionStatus::Purged, _) => { + let _ = inner.remove_cookie(&mut res); + res + } + _ => res, + } + }) + } + .boxed_local() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::{test, web, App}; + use bytes::Bytes; + + #[actix_rt::test] + async fn cookie_session() { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::signed(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + } + + #[actix_rt::test] + async fn private_cookie() { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::private(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + } + + #[actix_rt::test] + async fn cookie_session_extractor() { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::signed(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + } + + #[actix_rt::test] + async fn basics() { + let mut app = test::init_service( + App::new() + .wrap( + CookieSession::signed(&[0; 32]) + .path("/test/") + .name("actix-test") + .domain("localhost") + .http_only(true) + .same_site(SameSite::Lax) + .max_age(100), + ) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })) + .service(web::resource("/test/").to(|ses: Session| { + async move { + let val: usize = ses.get("counter").unwrap().unwrap(); + format!("counter: {}", val) + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + let cookie = response + .response() + .cookies() + .find(|c| c.name() == "actix-test") + .unwrap() + .clone(); + assert_eq!(cookie.path().unwrap(), "/test/"); + + let request = test::TestRequest::with_uri("/test/") + .cookie(cookie) + .to_request(); + let body = test::read_response(&mut app, request).await; + assert_eq!(body, Bytes::from_static(b"counter: 100")); + } + + #[actix_rt::test] + async fn prolong_expiration() { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::signed(&[0; 32]).secure(false).expires_in(60)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })) + .service( + web::resource("/test/") + .to(|| async move { "no-changes-in-session" }), + ), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + let expires_1 = response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .expect("Cookie is set") + .expires() + .expect("Expiration is set"); + + actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await; + + let request = test::TestRequest::with_uri("/test/").to_request(); + let response = app.call(request).await.unwrap(); + let expires_2 = response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .expect("Cookie is set") + .expires() + .expect("Expiration is set"); + + assert!(expires_2 - expires_1 >= Duration::seconds(1)); + } +} diff --git a/actix-session/src/lib.rs b/actix-session/src/lib.rs new file mode 100644 index 000000000..b6e5dd331 --- /dev/null +++ b/actix-session/src/lib.rs @@ -0,0 +1,322 @@ +//! User sessions. +//! +//! Actix provides a general solution for session management. Session +//! middlewares could provide different implementations which could +//! be accessed via general session api. +//! +//! By default, only cookie session backend is implemented. Other +//! backend implementations can be added. +//! +//! In general, you insert a *session* middleware and initialize it +//! , such as a `CookieSessionBackend`. To access session data, +//! [*Session*](struct.Session.html) extractor must be used. Session +//! extractor allows us to get or set session data. +//! +//! ```rust,no_run +//! use actix_web::{web, App, HttpServer, HttpResponse, Error}; +//! use actix_session::{Session, CookieSession}; +//! +//! fn index(session: Session) -> Result<&'static str, Error> { +//! // access session data +//! if let Some(count) = session.get::("counter")? { +//! println!("SESSION value: {}", count); +//! session.set("counter", count+1)?; +//! } else { +//! session.set("counter", 1)?; +//! } +//! +//! Ok("Welcome!") +//! } +//! +//! #[actix_rt::main] +//! async fn main() -> std::io::Result<()> { +//! HttpServer::new( +//! || App::new().wrap( +//! CookieSession::signed(&[0; 32]) // <- create cookie based session middleware +//! .secure(false) +//! ) +//! .service(web::resource("/").to(|| HttpResponse::Ok()))) +//! .bind("127.0.0.1:59880")? +//! .run() +//! .await +//! } +//! ``` +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; + +use actix_web::dev::{ + Extensions, Payload, RequestHead, ServiceRequest, ServiceResponse, +}; +use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; +use futures::future::{ok, Ready}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_json; + +#[cfg(feature = "cookie-session")] +mod cookie; +#[cfg(feature = "cookie-session")] +pub use crate::cookie::CookieSession; + +/// The high-level interface you use to modify session data. +/// +/// Session object could be obtained with +/// [`RequestSession::session`](trait.RequestSession.html#tymethod.session) +/// method. `RequestSession` trait is implemented for `HttpRequest`. +/// +/// ```rust +/// use actix_session::Session; +/// use actix_web::*; +/// +/// fn index(session: Session) -> Result<&'static str> { +/// // access session data +/// if let Some(count) = session.get::("counter")? { +/// session.set("counter", count + 1)?; +/// } else { +/// session.set("counter", 1)?; +/// } +/// +/// Ok("Welcome!") +/// } +/// # fn main() {} +/// ``` +pub struct Session(Rc>); + +/// Helper trait that allows to get session +pub trait UserSession { + fn get_session(&self) -> Session; +} + +impl UserSession for HttpRequest { + fn get_session(&self) -> Session { + Session::get_session(&mut *self.extensions_mut()) + } +} + +impl UserSession for ServiceRequest { + fn get_session(&self) -> Session { + Session::get_session(&mut *self.extensions_mut()) + } +} + +impl UserSession for RequestHead { + fn get_session(&self) -> Session { + Session::get_session(&mut *self.extensions_mut()) + } +} + +#[derive(PartialEq, Clone, Debug)] +pub enum SessionStatus { + Changed, + Purged, + Renewed, + Unchanged, +} +impl Default for SessionStatus { + fn default() -> SessionStatus { + SessionStatus::Unchanged + } +} + +#[derive(Default)] +struct SessionInner { + state: HashMap, + pub status: SessionStatus, +} + +impl Session { + /// Get a `value` from the session. + pub fn get(&self, key: &str) -> Result, Error> { + if let Some(s) = self.0.borrow().state.get(key) { + Ok(Some(serde_json::from_str(s)?)) + } else { + Ok(None) + } + } + + /// Set a `value` from the session. + pub fn set(&self, key: &str, value: T) -> Result<(), Error> { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Changed; + inner + .state + .insert(key.to_owned(), serde_json::to_string(&value)?); + } + Ok(()) + } + + /// Remove value from the session. + pub fn remove(&self, key: &str) { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Changed; + inner.state.remove(key); + } + } + + /// Clear the session. + pub fn clear(&self) { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Changed; + inner.state.clear() + } + } + + /// Removes session, both client and server side. + pub fn purge(&self) { + let mut inner = self.0.borrow_mut(); + inner.status = SessionStatus::Purged; + inner.state.clear(); + } + + /// Renews the session key, assigning existing session state to new key. + pub fn renew(&self) { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Renewed; + } + } + + pub fn set_session( + data: impl Iterator, + req: &mut ServiceRequest, + ) { + let session = Session::get_session(&mut *req.extensions_mut()); + let mut inner = session.0.borrow_mut(); + inner.state.extend(data); + } + + pub fn get_changes( + res: &mut ServiceResponse, + ) -> ( + SessionStatus, + Option>, + ) { + if let Some(s_impl) = res + .request() + .extensions() + .get::>>() + { + let state = + std::mem::replace(&mut s_impl.borrow_mut().state, HashMap::new()); + (s_impl.borrow().status.clone(), Some(state.into_iter())) + } else { + (SessionStatus::Unchanged, None) + } + } + + fn get_session(extensions: &mut Extensions) -> Session { + if let Some(s_impl) = extensions.get::>>() { + return Session(Rc::clone(&s_impl)); + } + let inner = Rc::new(RefCell::new(SessionInner::default())); + extensions.insert(inner.clone()); + Session(inner) + } +} + +/// Extractor implementation for Session type. +/// +/// ```rust +/// # use actix_web::*; +/// use actix_session::Session; +/// +/// fn index(session: Session) -> Result<&'static str> { +/// // access session data +/// if let Some(count) = session.get::("counter")? { +/// session.set("counter", count + 1)?; +/// } else { +/// session.set("counter", 1)?; +/// } +/// +/// Ok("Welcome!") +/// } +/// # fn main() {} +/// ``` +impl FromRequest for Session { + type Error = Error; + type Future = Ready>; + type Config = (); + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(Session::get_session(&mut *req.extensions_mut())) + } +} + +#[cfg(test)] +mod tests { + use actix_web::{test, HttpResponse}; + + use super::*; + + #[test] + fn session() { + let mut req = test::TestRequest::default().to_srv_request(); + + Session::set_session( + vec![("key".to_string(), "\"value\"".to_string())].into_iter(), + &mut req, + ); + let session = Session::get_session(&mut *req.extensions_mut()); + let res = session.get::("key").unwrap(); + assert_eq!(res, Some("value".to_string())); + + session.set("key2", "value2".to_string()).unwrap(); + session.remove("key"); + + let mut res = req.into_response(HttpResponse::Ok().finish()); + let (_status, state) = Session::get_changes(&mut res); + let changes: Vec<_> = state.unwrap().collect(); + assert_eq!(changes, [("key2".to_string(), "\"value2\"".to_string())]); + } + + #[test] + fn get_session() { + let mut req = test::TestRequest::default().to_srv_request(); + + Session::set_session( + vec![("key".to_string(), "\"value\"".to_string())].into_iter(), + &mut req, + ); + + let session = req.get_session(); + let res = session.get::("key").unwrap(); + assert_eq!(res, Some("value".to_string())); + } + + #[test] + fn get_session_from_request_head() { + let mut req = test::TestRequest::default().to_srv_request(); + + Session::set_session( + vec![("key".to_string(), "\"value\"".to_string())].into_iter(), + &mut req, + ); + + let session = req.head_mut().get_session(); + let res = session.get::("key").unwrap(); + assert_eq!(res, Some("value".to_string())); + } + + #[test] + fn purge_session() { + let req = test::TestRequest::default().to_srv_request(); + let session = Session::get_session(&mut *req.extensions_mut()); + assert_eq!(session.0.borrow().status, SessionStatus::Unchanged); + session.purge(); + assert_eq!(session.0.borrow().status, SessionStatus::Purged); + } + + #[test] + fn renew_session() { + let req = test::TestRequest::default().to_srv_request(); + let session = Session::get_session(&mut *req.extensions_mut()); + assert_eq!(session.0.borrow().status, SessionStatus::Unchanged); + session.renew(); + assert_eq!(session.0.borrow().status, SessionStatus::Renewed); + } +}