1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-02-22 10:23:18 +01:00

feat: builder with redis client & tests for scoped

This commit is contained in:
GreeFine 2022-10-04 16:38:47 +02:00
parent 368ada2b75
commit f6802bc960
3 changed files with 124 additions and 43 deletions

View File

@ -7,10 +7,19 @@ use redis::Client;
use crate::{errors::Error, GetArcBoxKeyFn, Limiter};
/// [RedisConnectionKind] is used to define which connection parameter for the Redis server will be passed
/// It can be an Url or a Client
/// This is done so we can use the same client for multiple Limiters
#[derive(Debug, Clone)]
pub enum RedisConnectionKind {
Url(String),
Client(Client),
}
/// Rate limiter builder.
#[derive(Debug)]
pub struct Builder {
pub(crate) redis_url: String,
pub(crate) redis_connection: RedisConnectionKind,
pub(crate) limit: usize,
pub(crate) period: Duration,
pub(crate) get_key_fn: Option<GetArcBoxKeyFn>,
@ -96,8 +105,12 @@ impl Builder {
closure
};
let client = match &self.redis_connection {
RedisConnectionKind::Url(url) => Client::open(url.as_str())?,
RedisConnectionKind::Client(client) => client.clone(),
};
Ok(Limiter {
client: Client::open(self.redis_url.as_str())?,
client,
limit: self.limit,
period: self.period,
get_key_fn: get_key,
@ -109,12 +122,25 @@ impl Builder {
mod tests {
use super::*;
/// Implementing partial Eq to check if builder assigned the redis connection url correctly
/// We can't / shouldn't compare Redis clients, thus the method panic if we try to compare them
impl PartialEq for RedisConnectionKind {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Url(l_url), Self::Url(r_url)) => l_url == r_url,
_ => {
panic!("RedisConnectionKind PartialEq is only implemented for Url")
}
}
}
}
#[test]
fn test_create_builder() {
let redis_url = "redis://127.0.0.1";
let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string());
let period = Duration::from_secs(10);
let builder = Builder {
redis_url: redis_url.to_owned(),
redis_connection: redis_connection.clone(),
limit: 100,
period,
get_key_fn: Some(Arc::new(|_| None)),
@ -123,7 +149,7 @@ mod tests {
session_key: Cow::Owned("rate-api".to_string()),
};
assert_eq!(builder.redis_url, redis_url);
assert_eq!(builder.redis_connection, redis_connection);
assert_eq!(builder.limit, 100);
assert_eq!(builder.period, period);
#[cfg(feature = "session")]
@ -133,10 +159,10 @@ mod tests {
#[test]
fn test_create_limiter() {
let redis_url = "redis://127.0.0.1";
let redis_connection = RedisConnectionKind::Url("redis://127.0.0.1".to_string());
let period = Duration::from_secs(20);
let mut builder = Builder {
redis_url: redis_url.to_owned(),
redis_connection,
limit: 100,
period: Duration::from_secs(10),
get_key_fn: Some(Arc::new(|_| None)),
@ -154,10 +180,10 @@ mod tests {
#[test]
#[should_panic = "Redis URL did not parse"]
fn test_create_limiter_error() {
let redis_url = "127.0.0.1";
let redis_connection = RedisConnectionKind::Url("127.0.0.1".to_string());
let period = Duration::from_secs(20);
let mut builder = Builder {
redis_url: redis_url.to_owned(),
redis_connection,
limit: 100,
period: Duration::from_secs(10),
get_key_fn: Some(Arc::new(|_| None)),

View File

@ -53,6 +53,7 @@
use std::{borrow::Cow, fmt, sync::Arc, time::Duration};
use actix_web::dev::ServiceRequest;
use builder::RedisConnectionKind;
use redis::Client;
mod builder;
@ -106,11 +107,27 @@ impl Limiter {
/// Construct rate limiter builder with defaults.
///
/// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection
/// parameters for how to set the Redis URL.
/// redis_url parameter is used for connecting to Redis.
#[must_use]
pub fn builder(redis_url: impl Into<String>) -> Builder {
Builder {
redis_url: redis_url.into(),
redis_connection: RedisConnectionKind::Url(redis_url.into()),
limit: DEFAULT_REQUEST_LIMIT,
period: Duration::from_secs(DEFAULT_PERIOD_SECS),
get_key_fn: None,
cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
#[cfg(feature = "session")]
session_key: Cow::Borrowed(DEFAULT_SESSION_KEY),
}
}
/// Construct rate limiter builder with defaults.
///
/// parameters for how to set the Redis URL.
#[must_use]
pub fn builder_with_redis_client(redis_client: Client) -> Builder {
Builder {
redis_connection: RedisConnectionKind::Client(redis_client),
limit: DEFAULT_REQUEST_LIMIT,
period: Duration::from_secs(DEFAULT_PERIOD_SECS),
get_key_fn: None,
@ -177,40 +194,78 @@ mod tests {
assert_eq!(limiter.period, Duration::from_secs(3600));
}
use std::{collections::HashMap, time::Duration};
use actix_web::{
dev::ServiceRequest, get, http::StatusCode, test as actix_test, web, Responder,
};
#[actix_web::test]
async fn test_create_scoped_limiter() {
todo!("finish tests")
// use actix_session::SessionExt as _;
// use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder};
// use std::time::Duration;
#[get("/")]
async fn index() -> impl Responder {
"index"
}
// #[get("/{id}/{name}")]
// async fn index(info: web::Path<(u32, String)>) -> impl Responder {
// format!("Hello {}! id:{}", info.1, info.0)
// }
let mut limiters = HashMap::new();
let redis_client =
Client::open("redis://127.0.0.1/2").expect("unable to create redis client");
limiters.insert(
"default",
Limiter::builder_with_redis_client(redis_client.clone())
.key_by(|_req: &ServiceRequest| Some("something_default".to_string()))
.limit(5000)
.period(Duration::from_secs(60))
.build()
.unwrap(),
);
limiters.insert(
"scoped",
Limiter::builder_with_redis_client(redis_client)
.key_by(|_req: &ServiceRequest| Some("something_scoped".to_string()))
.limit(1)
.period(Duration::from_secs(60))
.build()
.unwrap(),
);
let limiters = web::Data::new(limiters);
// let limiter = web::Data::new(
// Limiter::builder("redis://127.0.0.1")
// .key_by(|req: &ServiceRequest| {
// req.get_session()
// .get("session-id")
// .unwrap_or_else(|_| req.cookie("rate-api-id").map(|c| c.to_string()))
// })
// .limit(5000)
// .period(Duration::from_secs(3600)) // 60 minutes
// .build()
// .unwrap(),
// );
let app = actix_web::test::init_service(
actix_web::App::new()
.wrap(RateLimiter {
scope: Some("default"),
})
.app_data(limiters.clone())
.service(
web::scope("/scoped")
.wrap(RateLimiter {
scope: Some("scoped"),
})
.service(index),
)
.service(index),
)
.await;
// HttpServer::new(move || {
// App::new()
// .wrap(RateLimiter::default())
// .app_data(limiter.clone())
// .service(index)
// })
// .bind(("127.0.0.1", 8080))
// .expect("test")
// .run()
// .await;
for _ in 0..3 {
let req = actix_test::TestRequest::get().uri("/").to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK, "{:#?}", resp);
}
for request_count in 0..3 {
let req = actix_test::TestRequest::get().uri("/scoped/").to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(
resp.status(),
if request_count > 0 {
StatusCode::TOO_MANY_REQUESTS
} else {
StatusCode::OK
},
"{:#?}",
resp
);
}
}
}

View File

@ -73,11 +73,11 @@ where
.unwrap_or_else(|| panic!("Unable to find defined limiter with scope: {}", scope))
.clone()
} else {
let a = req
let limiter = req
.app_data::<web::Data<Limiter>>()
.expect("web::Data<Limiter> should be set in app data for RateLimiter middleware");
// Deref to get the Limiter
(***a).clone()
(***limiter).clone()
};
let key = (limiter.get_key_fn)(&req);