1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-02-22 18:33:18 +01:00

feat: builder with redis client & tests for scoped

This commit is contained in:
GreeFine 2022-10-04 16:38:47 +02:00
parent 368ada2b75
commit f6802bc960
3 changed files with 124 additions and 43 deletions

View File

@ -7,10 +7,19 @@ use redis::Client;
use crate::{errors::Error, GetArcBoxKeyFn, Limiter}; use crate::{errors::Error, GetArcBoxKeyFn, Limiter};
/// [RedisConnectionKind] is used to define which connection parameter for the Redis server will be passed
/// It can be an Url or a Client
/// This is done so we can use the same client for multiple Limiters
#[derive(Debug, Clone)]
pub enum RedisConnectionKind {
Url(String),
Client(Client),
}
/// Rate limiter builder. /// Rate limiter builder.
#[derive(Debug)] #[derive(Debug)]
pub struct Builder { pub struct Builder {
pub(crate) redis_url: String, pub(crate) redis_connection: RedisConnectionKind,
pub(crate) limit: usize, pub(crate) limit: usize,
pub(crate) period: Duration, pub(crate) period: Duration,
pub(crate) get_key_fn: Option<GetArcBoxKeyFn>, pub(crate) get_key_fn: Option<GetArcBoxKeyFn>,
@ -96,8 +105,12 @@ impl Builder {
closure closure
}; };
let client = match &self.redis_connection {
RedisConnectionKind::Url(url) => Client::open(url.as_str())?,
RedisConnectionKind::Client(client) => client.clone(),
};
Ok(Limiter { Ok(Limiter {
client: Client::open(self.redis_url.as_str())?, client,
limit: self.limit, limit: self.limit,
period: self.period, period: self.period,
get_key_fn: get_key, get_key_fn: get_key,
@ -109,12 +122,25 @@ impl Builder {
mod tests { mod tests {
use super::*; use super::*;
/// Implementing partial Eq to check if builder assigned the redis connection url correctly
/// We can't / shouldn't compare Redis clients, thus the method panic if we try to compare them
impl PartialEq for RedisConnectionKind {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Url(l_url), Self::Url(r_url)) => l_url == r_url,
_ => {
panic!("RedisConnectionKind PartialEq is only implemented for Url")
}
}
}
}
#[test] #[test]
fn test_create_builder() { fn test_create_builder() {
let redis_url = "redis://127.0.0.1"; let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string());
let period = Duration::from_secs(10); let period = Duration::from_secs(10);
let builder = Builder { let builder = Builder {
redis_url: redis_url.to_owned(), redis_connection: redis_connection.clone(),
limit: 100, limit: 100,
period, period,
get_key_fn: Some(Arc::new(|_| None)), get_key_fn: Some(Arc::new(|_| None)),
@ -123,7 +149,7 @@ mod tests {
session_key: Cow::Owned("rate-api".to_string()), session_key: Cow::Owned("rate-api".to_string()),
}; };
assert_eq!(builder.redis_url, redis_url); assert_eq!(builder.redis_connection, redis_connection);
assert_eq!(builder.limit, 100); assert_eq!(builder.limit, 100);
assert_eq!(builder.period, period); assert_eq!(builder.period, period);
#[cfg(feature = "session")] #[cfg(feature = "session")]
@ -133,10 +159,10 @@ mod tests {
#[test] #[test]
fn test_create_limiter() { fn test_create_limiter() {
let redis_url = "redis://127.0.0.1"; let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string());
let period = Duration::from_secs(20); let period = Duration::from_secs(20);
let mut builder = Builder { let mut builder = Builder {
redis_url: redis_url.to_owned(), redis_connection,
limit: 100, limit: 100,
period: Duration::from_secs(10), period: Duration::from_secs(10),
get_key_fn: Some(Arc::new(|_| None)), get_key_fn: Some(Arc::new(|_| None)),
@ -154,10 +180,10 @@ mod tests {
#[test] #[test]
#[should_panic = "Redis URL did not parse"] #[should_panic = "Redis URL did not parse"]
fn test_create_limiter_error() { fn test_create_limiter_error() {
let redis_url = "127.0.0.1"; let redis_connection = RedisConnectionKind::Url("127.0.0.1".to_string());
let period = Duration::from_secs(20); let period = Duration::from_secs(20);
let mut builder = Builder { let mut builder = Builder {
redis_url: redis_url.to_owned(), redis_connection,
limit: 100, limit: 100,
period: Duration::from_secs(10), period: Duration::from_secs(10),
get_key_fn: Some(Arc::new(|_| None)), get_key_fn: Some(Arc::new(|_| None)),

View File

@ -53,6 +53,7 @@
use std::{borrow::Cow, fmt, sync::Arc, time::Duration}; use std::{borrow::Cow, fmt, sync::Arc, time::Duration};
use actix_web::dev::ServiceRequest; use actix_web::dev::ServiceRequest;
use builder::RedisConnectionKind;
use redis::Client; use redis::Client;
mod builder; mod builder;
@ -106,11 +107,27 @@ impl Limiter {
/// Construct rate limiter builder with defaults. /// Construct rate limiter builder with defaults.
/// ///
/// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection /// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection
/// parameters for how to set the Redis URL. /// redis_url parameter is used for connecting to Redis.
#[must_use] #[must_use]
pub fn builder(redis_url: impl Into<String>) -> Builder { pub fn builder(redis_url: impl Into<String>) -> Builder {
Builder { Builder {
redis_url: redis_url.into(), redis_connection: RedisConnectionKind::Url(redis_url.into()),
limit: DEFAULT_REQUEST_LIMIT,
period: Duration::from_secs(DEFAULT_PERIOD_SECS),
get_key_fn: None,
cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
#[cfg(feature = "session")]
session_key: Cow::Borrowed(DEFAULT_SESSION_KEY),
}
}
/// Construct rate limiter builder with defaults.
///
/// parameters for how to set the Redis URL.
#[must_use]
pub fn builder_with_redis_client(redis_client: Client) -> Builder {
Builder {
redis_connection: RedisConnectionKind::Client(redis_client),
limit: DEFAULT_REQUEST_LIMIT, limit: DEFAULT_REQUEST_LIMIT,
period: Duration::from_secs(DEFAULT_PERIOD_SECS), period: Duration::from_secs(DEFAULT_PERIOD_SECS),
get_key_fn: None, get_key_fn: None,
@ -177,40 +194,78 @@ mod tests {
assert_eq!(limiter.period, Duration::from_secs(3600)); assert_eq!(limiter.period, Duration::from_secs(3600));
} }
use std::{collections::HashMap, time::Duration};
use actix_web::{
dev::ServiceRequest, get, http::StatusCode, test as actix_test, web, Responder,
};
#[actix_web::test] #[actix_web::test]
async fn test_create_scoped_limiter() { async fn test_create_scoped_limiter() {
todo!("finish tests") #[get("/")]
// use actix_session::SessionExt as _; async fn index() -> impl Responder {
// use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder}; "index"
// use std::time::Duration; }
// #[get("/{id}/{name}")] let mut limiters = HashMap::new();
// async fn index(info: web::Path<(u32, String)>) -> impl Responder { let redis_client =
// format!("Hello {}! id:{}", info.1, info.0) Client::open("redis://127.0.0.1/2").expect("unable to create redis client");
// } limiters.insert(
"default",
Limiter::builder_with_redis_client(redis_client.clone())
.key_by(|_req: &ServiceRequest| Some("something_default".to_string()))
.limit(5000)
.period(Duration::from_secs(60))
.build()
.unwrap(),
);
limiters.insert(
"scoped",
Limiter::builder_with_redis_client(redis_client)
.key_by(|_req: &ServiceRequest| Some("something_scoped".to_string()))
.limit(1)
.period(Duration::from_secs(60))
.build()
.unwrap(),
);
let limiters = web::Data::new(limiters);
// let limiter = web::Data::new( let app = actix_web::test::init_service(
// Limiter::builder("redis://127.0.0.1") actix_web::App::new()
// .key_by(|req: &ServiceRequest| { .wrap(RateLimiter {
// req.get_session() scope: Some("default"),
// .get("session-id") })
// .unwrap_or_else(|_| req.cookie("rate-api-id").map(|c| c.to_string())) .app_data(limiters.clone())
// }) .service(
// .limit(5000) web::scope("/scoped")
// .period(Duration::from_secs(3600)) // 60 minutes .wrap(RateLimiter {
// .build() scope: Some("scoped"),
// .unwrap(), })
// ); .service(index),
)
.service(index),
)
.await;
// HttpServer::new(move || { for _ in 0..3 {
// App::new() let req = actix_test::TestRequest::get().uri("/").to_request();
// .wrap(RateLimiter::default()) let resp = actix_test::call_service(&app, req).await;
// .app_data(limiter.clone()) assert_eq!(resp.status(), StatusCode::OK, "{:#?}", resp);
// .service(index) }
// }) for request_count in 0..3 {
// .bind(("127.0.0.1", 8080)) let req = actix_test::TestRequest::get().uri("/scoped/").to_request();
// .expect("test") let resp = actix_test::call_service(&app, req).await;
// .run()
// .await; assert_eq!(
resp.status(),
if request_count > 0 {
StatusCode::TOO_MANY_REQUESTS
} else {
StatusCode::OK
},
"{:#?}",
resp
);
}
} }
} }

View File

@ -73,11 +73,11 @@ where
.unwrap_or_else(|| panic!("Unable to find defined limiter with scope: {}", scope)) .unwrap_or_else(|| panic!("Unable to find defined limiter with scope: {}", scope))
.clone() .clone()
} else { } else {
let a = req let limiter = req
.app_data::<web::Data<Limiter>>() .app_data::<web::Data<Limiter>>()
.expect("web::Data<Limiter> should be set in app data for RateLimiter middleware"); .expect("web::Data<Limiter> should be set in app data for RateLimiter middleware");
// Deref to get the Limiter // Deref to get the Limiter
(***a).clone() (***limiter).clone()
}; };
let key = (limiter.get_key_fn)(&req); let key = (limiter.get_key_fn)(&req);