diff --git a/actix-limitation/src/builder.rs b/actix-limitation/src/builder.rs index d6053f5ef..8b265bd14 100644 --- a/actix-limitation/src/builder.rs +++ b/actix-limitation/src/builder.rs @@ -7,10 +7,19 @@ use redis::Client; use crate::{errors::Error, GetArcBoxKeyFn, Limiter}; +/// [RedisConnectionKind] is used to define which connection parameter for the Redis server will be passed +/// It can be an Url or a Client +/// This is done so we can use the same client for multiple Limiters +#[derive(Debug, Clone)] +pub enum RedisConnectionKind { + Url(String), + Client(Client), +} + /// Rate limiter builder. #[derive(Debug)] pub struct Builder { - pub(crate) redis_url: String, + pub(crate) redis_connection: RedisConnectionKind, pub(crate) limit: usize, pub(crate) period: Duration, pub(crate) get_key_fn: Option, @@ -96,8 +105,12 @@ impl Builder { closure }; + let client = match &self.redis_connection { + RedisConnectionKind::Url(url) => Client::open(url.as_str())?, + RedisConnectionKind::Client(client) => client.clone(), + }; Ok(Limiter { - client: Client::open(self.redis_url.as_str())?, + client, limit: self.limit, period: self.period, get_key_fn: get_key, @@ -109,12 +122,25 @@ impl Builder { mod tests { use super::*; + /// Implementing partial Eq to check if builder assigned the redis connection url correctly + /// We can't / shouldn't compare Redis clients, thus the method panic if we try to compare them + impl PartialEq for RedisConnectionKind { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Url(l_url), Self::Url(r_url)) => l_url == r_url, + _ => { + panic!("RedisConnectionKind PartialEq is only implemented for Url") + } + } + } + } + #[test] fn test_create_builder() { - let redis_url = "redis://127.0.0.1"; + let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string()); let period = Duration::from_secs(10); let builder = Builder { - redis_url: redis_url.to_owned(), + redis_connection: redis_connection.clone(), limit: 100, period, get_key_fn: Some(Arc::new(|_| None)), @@ -123,7 +149,7 @@ mod tests { session_key: Cow::Owned("rate-api".to_string()), }; - assert_eq!(builder.redis_url, redis_url); + assert_eq!(builder.redis_connection, redis_connection); assert_eq!(builder.limit, 100); assert_eq!(builder.period, period); #[cfg(feature = "session")] @@ -133,10 +159,10 @@ mod tests { #[test] fn test_create_limiter() { - let redis_url = "redis://127.0.0.1"; + let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string()); let period = Duration::from_secs(20); let mut builder = Builder { - redis_url: redis_url.to_owned(), + redis_connection, limit: 100, period: Duration::from_secs(10), get_key_fn: Some(Arc::new(|_| None)), @@ -154,10 +180,10 @@ mod tests { #[test] #[should_panic = "Redis URL did not parse"] fn test_create_limiter_error() { - let redis_url = "127.0.0.1"; + let redis_connection = RedisConnectionKind::Url("127.0.0.1".to_string()); let period = Duration::from_secs(20); let mut builder = Builder { - redis_url: redis_url.to_owned(), + redis_connection, limit: 100, period: Duration::from_secs(10), get_key_fn: Some(Arc::new(|_| None)), diff --git a/actix-limitation/src/lib.rs b/actix-limitation/src/lib.rs index 8cb02e6b5..58a3fa5b3 100644 --- a/actix-limitation/src/lib.rs +++ b/actix-limitation/src/lib.rs @@ -53,6 +53,7 @@ use std::{borrow::Cow, fmt, sync::Arc, time::Duration}; use actix_web::dev::ServiceRequest; +use builder::RedisConnectionKind; use redis::Client; mod builder; @@ -106,11 +107,27 @@ impl Limiter { /// Construct rate limiter builder with defaults. /// /// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection - /// parameters for how to set the Redis URL. + /// redis_url parameter is used for connecting to Redis. #[must_use] pub fn builder(redis_url: impl Into) -> Builder { Builder { - redis_url: redis_url.into(), + redis_connection: RedisConnectionKind::Url(redis_url.into()), + limit: DEFAULT_REQUEST_LIMIT, + period: Duration::from_secs(DEFAULT_PERIOD_SECS), + get_key_fn: None, + cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME), + #[cfg(feature = "session")] + session_key: Cow::Borrowed(DEFAULT_SESSION_KEY), + } + } + + /// Construct rate limiter builder with defaults. + /// + /// parameters for how to set the Redis URL. + #[must_use] + pub fn builder_with_redis_client(redis_client: Client) -> Builder { + Builder { + redis_connection: RedisConnectionKind::Client(redis_client), limit: DEFAULT_REQUEST_LIMIT, period: Duration::from_secs(DEFAULT_PERIOD_SECS), get_key_fn: None, @@ -177,40 +194,78 @@ mod tests { assert_eq!(limiter.period, Duration::from_secs(3600)); } + use std::{collections::HashMap, time::Duration}; + + use actix_web::{ + dev::ServiceRequest, get, http::StatusCode, test as actix_test, web, Responder, + }; + #[actix_web::test] async fn test_create_scoped_limiter() { - todo!("finish tests") - // use actix_session::SessionExt as _; - // use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder}; - // use std::time::Duration; + #[get("/")] + async fn index() -> impl Responder { + "index" + } - // #[get("/{id}/{name}")] - // async fn index(info: web::Path<(u32, String)>) -> impl Responder { - // format!("Hello {}! id:{}", info.1, info.0) - // } + let mut limiters = HashMap::new(); + let redis_client = + Client::open("redis://127.0.0.1/2").expect("unable to create redis client"); + limiters.insert( + "default", + Limiter::builder_with_redis_client(redis_client.clone()) + .key_by(|_req: &ServiceRequest| Some("something_default".to_string())) + .limit(5000) + .period(Duration::from_secs(60)) + .build() + .unwrap(), + ); + limiters.insert( + "scoped", + Limiter::builder_with_redis_client(redis_client) + .key_by(|_req: &ServiceRequest| Some("something_scoped".to_string())) + .limit(1) + .period(Duration::from_secs(60)) + .build() + .unwrap(), + ); + let limiters = web::Data::new(limiters); - // let limiter = web::Data::new( - // Limiter::builder("redis://127.0.0.1") - // .key_by(|req: &ServiceRequest| { - // req.get_session() - // .get("session-id") - // .unwrap_or_else(|_| req.cookie("rate-api-id").map(|c| c.to_string())) - // }) - // .limit(5000) - // .period(Duration::from_secs(3600)) // 60 minutes - // .build() - // .unwrap(), - // ); + let app = actix_web::test::init_service( + actix_web::App::new() + .wrap(RateLimiter { + scope: Some("default"), + }) + .app_data(limiters.clone()) + .service( + web::scope("/scoped") + .wrap(RateLimiter { + scope: Some("scoped"), + }) + .service(index), + ) + .service(index), + ) + .await; - // HttpServer::new(move || { - // App::new() - // .wrap(RateLimiter::default()) - // .app_data(limiter.clone()) - // .service(index) - // }) - // .bind(("127.0.0.1", 8080)) - // .expect("test") - // .run() - // .await; + for _ in 0..3 { + let req = actix_test::TestRequest::get().uri("/").to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK, "{:#?}", resp); + } + for request_count in 0..3 { + let req = actix_test::TestRequest::get().uri("/scoped/").to_request(); + let resp = actix_test::call_service(&app, req).await; + + assert_eq!( + resp.status(), + if request_count > 0 { + StatusCode::TOO_MANY_REQUESTS + } else { + StatusCode::OK + }, + "{:#?}", + resp + ); + } } } diff --git a/actix-limitation/src/middleware.rs b/actix-limitation/src/middleware.rs index b9669ec4a..c2fdc71c3 100644 --- a/actix-limitation/src/middleware.rs +++ b/actix-limitation/src/middleware.rs @@ -73,11 +73,11 @@ where .unwrap_or_else(|| panic!("Unable to find defined limiter with scope: {}", scope)) .clone() } else { - let a = req + let limiter = req .app_data::>() .expect("web::Data should be set in app data for RateLimiter middleware"); // Deref to get the Limiter - (***a).clone() + (***limiter).clone() }; let key = (limiter.get_key_fn)(&req);