From 32313c0af620ee1639a5273fc83e2be18b58eefd Mon Sep 17 00:00:00 2001 From: Raphael C Date: Sun, 11 Sep 2022 01:02:54 +0200 Subject: [PATCH] Limitation: custom key from closure (#281) Co-authored-by: Rob Ede --- actix-limitation/CHANGES.md | 1 + actix-limitation/Cargo.toml | 10 +++- actix-limitation/README.md | 23 ++++---- actix-limitation/src/builder.rs | 86 +++++++++++++++++++++++------- actix-limitation/src/lib.rs | 43 +++++++++++---- actix-limitation/src/middleware.rs | 22 +++----- actix-limitation/tests/tests.rs | 43 ++++++++++++++- 7 files changed, 171 insertions(+), 57 deletions(-) diff --git a/actix-limitation/CHANGES.md b/actix-limitation/CHANGES.md index 73a7b4092..af34874de 100644 --- a/actix-limitation/CHANGES.md +++ b/actix-limitation/CHANGES.md @@ -1,6 +1,7 @@ # Changes ## Unreleased - 2022-xx-xx +- Add `Builder::key_by` for setting a custom rate limit key function. - Implement `Default` for `RateLimiter`. - `RateLimiter` is marked `#[non_exhaustive]`; use `RateLimiter::default()` instead. - In the middleware errors from the count function are matched and respond with `INTERNAL_SERVER_ERROR` if it's an unexpected error, instead of the default `TOO_MANY_REQUESTS`. diff --git a/actix-limitation/Cargo.toml b/actix-limitation/Cargo.toml index ce5da8197..6e7db3ebc 100644 --- a/actix-limitation/Cargo.toml +++ b/actix-limitation/Cargo.toml @@ -12,10 +12,13 @@ repository = "https://github.com/actix/actix-extras.git" license = "MIT OR Apache-2.0" edition = "2018" +[features] +default = ["session"] +session = ["actix-session"] + [dependencies] -actix-session = "0.7" actix-utils = "3" -actix-web = { version = "4", default-features = false } +actix-web = { version = "4", features = ["cookies"] } chrono = "0.4" derive_more = "0.99.5" @@ -23,6 +26,9 @@ log = "0.4" redis = { version = "0.21", default-features = false, features = ["tokio-comp"] } time = "0.3" +# session +actix-session = { version = "0.7", optional = true } + [dev-dependencies] actix-web = "4" static_assertions = "1" diff --git a/actix-limitation/README.md b/actix-limitation/README.md index 007ef4634..17ff67b93 100644 --- a/actix-limitation/README.md +++ b/actix-limitation/README.md @@ -17,9 +17,10 @@ actix-limitation = "0.3" ``` ```rust -use std::time::Duration; -use actix_web::{get, web, App, HttpServer, Responder}; use actix_limitation::{Limiter, RateLimiter}; +use actix_session::SessionExt as _; +use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder}; +use std::{sync::Arc, time::Duration}; #[get("/{id}/{name}")] async fn index(info: web::Path<(u32, String)>) -> impl Responder { @@ -29,22 +30,24 @@ async fn index(info: web::Path<(u32, String)>) -> impl Responder { #[actix_web::main] async fn main() -> std::io::Result<()> { let limiter = web::Data::new( - Limiter::build("redis://127.0.0.1") - .cookie_name("session-id".to_owned()) - .session_key("rate-api-id".to_owned()) + Limiter::builder("redis://127.0.0.1") + .key_by(|req: &ServiceRequest| { + req.get_session() + .get(&"session-id") + .unwrap_or_else(|_| req.cookie(&"rate-api-id").map(|c| c.to_string())) + }) .limit(5000) .period(Duration::from_secs(3600)) // 60 minutes - .finish() - .expect("Can't build actix-limiter"), + .build() + .unwrap(), ); - HttpServer::new(move || { App::new() - .wrap(RateLimiter) + .wrap(RateLimiter::default()) .app_data(limiter.clone()) .service(index) }) - .bind("127.0.0.1:8080")? + .bind(("127.0.0.1", 8080))? .run() .await } diff --git a/actix-limitation/src/builder.rs b/actix-limitation/src/builder.rs index 3bc316037..d6053f5ef 100644 --- a/actix-limitation/src/builder.rs +++ b/actix-limitation/src/builder.rs @@ -1,8 +1,11 @@ -use std::{borrow::Cow, time::Duration}; +use std::{borrow::Cow, sync::Arc, time::Duration}; +#[cfg(feature = "session")] +use actix_session::SessionExt as _; +use actix_web::dev::ServiceRequest; use redis::Client; -use crate::{errors::Error, Limiter}; +use crate::{errors::Error, GetArcBoxKeyFn, Limiter}; /// Rate limiter builder. #[derive(Debug)] @@ -10,7 +13,9 @@ pub struct Builder { pub(crate) redis_url: String, pub(crate) limit: usize, pub(crate) period: Duration, + pub(crate) get_key_fn: Option, pub(crate) cookie_name: Cow<'static, str>, + #[cfg(feature = "session")] pub(crate) session_key: Cow<'static, str>, } @@ -27,14 +32,38 @@ impl Builder { self } - /// Set name of cookie to be sent. + /// Sets rate limit key derivation function. + /// + /// Should not be used in combination with `cookie_name` or `session_key` as they conflict. + pub fn key_by(&mut self, resolver: F) -> &mut Self + where + F: Fn(&ServiceRequest) -> Option + Send + Sync + 'static, + { + self.get_key_fn = Some(Arc::new(resolver)); + self + } + + /// Sets name of cookie to be sent. + /// + /// This method should not be used in combination of `key_by` as they conflict. + #[deprecated = "Prefer `key_by`."] pub fn cookie_name(&mut self, cookie_name: impl Into>) -> &mut Self { + if self.get_key_fn.is_some() { + panic!("This method should not be used in combination of get_key as they overwrite each other") + } self.cookie_name = cookie_name.into(); self } - /// Set session key to be used in backend. + /// Sets session key to be used in backend. + /// + /// This method should not be used in combination of `key_by` as they conflict. + #[deprecated = "Prefer `key_by`."] + #[cfg(feature = "session")] pub fn session_key(&mut self, session_key: impl Into>) -> &mut Self { + if self.get_key_fn.is_some() { + panic!("This method should not be used in combination of get_key as they overwrite each other") + } self.session_key = session_key.into(); self } @@ -43,13 +72,35 @@ impl Builder { /// /// Note that this method will connect to the Redis server to test its connection which is a /// **synchronous** operation. - pub fn build(&self) -> Result { + pub fn build(&mut self) -> Result { + let get_key = if let Some(resolver) = self.get_key_fn.clone() { + resolver + } else { + let cookie_name = self.cookie_name.clone(); + + #[cfg(feature = "session")] + let session_key = self.session_key.clone(); + + let closure: GetArcBoxKeyFn = Arc::new(Box::new(move |req: &ServiceRequest| { + #[cfg(feature = "session")] + let res = req + .get_session() + .get(&session_key) + .unwrap_or_else(|_| req.cookie(&cookie_name).map(|c| c.to_string())); + + #[cfg(not(feature = "session"))] + let res = req.cookie(&cookie_name).map(|c| c.to_string()); + + res + })); + closure + }; + Ok(Limiter { client: Client::open(self.redis_url.as_str())?, limit: self.limit, period: self.period, - cookie_name: self.cookie_name.clone(), - session_key: self.session_key.clone(), + get_key_fn: get_key, }) } } @@ -66,13 +117,16 @@ mod tests { redis_url: redis_url.to_owned(), limit: 100, period, + get_key_fn: Some(Arc::new(|_| None)), cookie_name: Cow::Owned("session".to_string()), + #[cfg(feature = "session")] session_key: Cow::Owned("rate-api".to_string()), }; assert_eq!(builder.redis_url, redis_url); assert_eq!(builder.limit, 100); assert_eq!(builder.period, period); + #[cfg(feature = "session")] assert_eq!(builder.session_key, "rate-api"); assert_eq!(builder.cookie_name, "session"); } @@ -85,22 +139,16 @@ mod tests { redis_url: redis_url.to_owned(), limit: 100, period: Duration::from_secs(10), - session_key: Cow::Borrowed("key"), + get_key_fn: Some(Arc::new(|_| None)), cookie_name: Cow::Borrowed("sid"), + #[cfg(feature = "session")] + session_key: Cow::Borrowed("key"), }; - let limiter = builder - .limit(200) - .period(period) - .cookie_name("session".to_string()) - .session_key("rate-api".to_string()) - .build() - .unwrap(); + let limiter = builder.limit(200).period(period).build().unwrap(); assert_eq!(limiter.limit, 200); assert_eq!(limiter.period, period); - assert_eq!(limiter.session_key, "rate-api"); - assert_eq!(limiter.cookie_name, "session"); } #[test] @@ -112,8 +160,10 @@ mod tests { redis_url: redis_url.to_owned(), limit: 100, period: Duration::from_secs(10), - session_key: Cow::Borrowed("key"), + get_key_fn: Some(Arc::new(|_| None)), cookie_name: Cow::Borrowed("sid"), + #[cfg(feature = "session")] + session_key: Cow::Borrowed("key"), }; builder.limit(200).period(period).build().unwrap(); diff --git a/actix-limitation/src/lib.rs b/actix-limitation/src/lib.rs index 8124fe632..6d2ab63db 100644 --- a/actix-limitation/src/lib.rs +++ b/actix-limitation/src/lib.rs @@ -7,8 +7,9 @@ //! ``` //! //! ```no_run -//! use std::time::Duration; -//! use actix_web::{get, web, App, HttpServer, Responder}; +//! use std::{sync::Arc, time::Duration}; +//! use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder}; +//! use actix_session::SessionExt as _; //! use actix_limitation::{Limiter, RateLimiter}; //! //! #[get("/{id}/{name}")] @@ -20,8 +21,11 @@ //! async fn main() -> std::io::Result<()> { //! let limiter = web::Data::new( //! Limiter::builder("redis://127.0.0.1") -//! .cookie_name("session-id".to_owned()) -//! .session_key("rate-api-id".to_owned()) +//! .key_by(|req: &ServiceRequest| { +//! req.get_session() +//! .get(&"session-id") +//! .unwrap_or_else(|_| req.cookie(&"rate-api-id").map(|c| c.to_string())) +//! }) //! .limit(5000) //! .period(Duration::from_secs(3600)) // 60 minutes //! .build() @@ -46,8 +50,9 @@ #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] -use std::{borrow::Cow, time::Duration}; +use std::{borrow::Cow, fmt, sync::Arc, time::Duration}; +use actix_web::dev::ServiceRequest; use redis::Client; mod builder; @@ -70,16 +75,34 @@ pub const DEFAULT_PERIOD_SECS: u64 = 3600; pub const DEFAULT_COOKIE_NAME: &str = "sid"; /// Default session key. +#[cfg(feature = "session")] pub const DEFAULT_SESSION_KEY: &str = "rate-api-id"; +/// Helper trait to impl Debug on GetKeyFn type +trait GetKeyFnT: Fn(&ServiceRequest) -> Option {} + +impl GetKeyFnT for T where T: Fn(&ServiceRequest) -> Option {} + +/// Get key function type with auto traits +type GetKeyFn = dyn GetKeyFnT + Send + Sync; + +/// Get key resolver function type +impl fmt::Debug for GetKeyFn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "GetKeyFn") + } +} + +/// Wrapped Get key function Trait +type GetArcBoxKeyFn = Arc; + /// Rate limiter. #[derive(Debug, Clone)] pub struct Limiter { client: Client, limit: usize, period: Duration, - cookie_name: Cow<'static, str>, - session_key: Cow<'static, str>, + get_key_fn: GetArcBoxKeyFn, } impl Limiter { @@ -93,7 +116,9 @@ impl Limiter { redis_url: redis_url.into(), limit: DEFAULT_REQUEST_LIMIT, period: Duration::from_secs(DEFAULT_PERIOD_SECS), + get_key_fn: None, cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME), + #[cfg(feature = "session")] session_key: Cow::Borrowed(DEFAULT_SESSION_KEY), } } @@ -146,14 +171,12 @@ mod tests { #[test] fn test_create_limiter() { - let builder = Limiter::builder("redis://127.0.0.1:6379/1"); + let mut builder = Limiter::builder("redis://127.0.0.1:6379/1"); let limiter = builder.build(); assert!(limiter.is_ok()); let limiter = limiter.unwrap(); assert_eq!(limiter.limit, 5000); assert_eq!(limiter.period, Duration::from_secs(3600)); - assert_eq!(limiter.cookie_name, DEFAULT_COOKIE_NAME); - assert_eq!(limiter.session_key, DEFAULT_SESSION_KEY); } } diff --git a/actix-limitation/src/middleware.rs b/actix-limitation/src/middleware.rs index d387e6d62..290f075c1 100644 --- a/actix-limitation/src/middleware.rs +++ b/actix-limitation/src/middleware.rs @@ -1,6 +1,5 @@ use std::{future::Future, pin::Pin, rc::Rc}; -use actix_session::SessionExt as _; use actix_utils::future::{ok, Ready}; use actix_web::{ body::EitherBody, @@ -61,25 +60,18 @@ where .expect("web::Data should be set in app data for RateLimiter middleware") .clone(); - let key = req.get_session().get(&limiter.session_key).unwrap_or(None); + let key = (limiter.get_key_fn)(&req); let service = Rc::clone(&self.service); let key = match key { Some(key) => key, None => { - let fallback = req.cookie(&limiter.cookie_name).map(|c| c.to_string()); - - match fallback { - Some(key) => key, - None => { - return Box::pin(async move { - service - .call(req) - .await - .map(ServiceResponse::map_into_left_body) - }); - } - } + return Box::pin(async move { + service + .call(req) + .await + .map(ServiceResponse::map_into_left_body) + }); } }; diff --git a/actix-limitation/tests/tests.rs b/actix-limitation/tests/tests.rs index 110a8060a..07df6eb1a 100644 --- a/actix-limitation/tests/tests.rs +++ b/actix-limitation/tests/tests.rs @@ -1,9 +1,12 @@ -use actix_limitation::{Error, Limiter}; +use std::time::Duration; + +use actix_limitation::{Error, Limiter, RateLimiter}; +use actix_web::{dev::ServiceRequest, http::StatusCode, test, web, App, HttpRequest, HttpResponse}; use uuid::Uuid; #[test] #[should_panic = "Redis URL did not parse"] -fn test_create_limiter_error() { +async fn test_create_limiter_error() { Limiter::builder("127.0.0.1").build().unwrap(); } @@ -51,3 +54,39 @@ async fn test_limiter_count_error() -> Result<(), Error> { Ok(()) } + +#[actix_web::test] +async fn test_limiter_key_by() -> Result<(), Error> { + let cooldown_period = Duration::from_secs(1); + let limiter = Limiter::builder("redis://127.0.0.1:6379/3") + .limit(2) + .period(cooldown_period) + .key_by(|_: &ServiceRequest| Some("fix_key".to_string())) + .build() + .unwrap(); + + let app = test::init_service( + App::new() + .wrap(RateLimiter::default()) + .app_data(web::Data::new(limiter)) + .route( + "/", + web::get().to(|_: HttpRequest| async { HttpResponse::Ok().body("ok") }), + ), + ) + .await; + for _ in 1..2 { + for index in 1..4 { + let req = test::TestRequest::default().to_request(); + let resp = test::call_service(&app, req).await; + if index <= 2 { + assert!(resp.status().is_success()); + } else { + assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS); + } + } + std::thread::sleep(cooldown_period); + } + + Ok(()) +}