use std::{borrow::Cow, sync::Arc, time::Duration};
#[cfg(feature = "session")]
use actix_session::SessionExt as _;
use actix_web::dev::ServiceRequest;
use redis::Client;
use crate::{errors::Error, GetArcBoxKeyFn, Limiter};
#[derive(Debug)]
pub struct Builder {
pub(crate) redis_url: String,
pub(crate) limit: usize,
pub(crate) period: Duration,
pub(crate) get_key_fn: Option<GetArcBoxKeyFn>,
pub(crate) cookie_name: Cow<'static, str>,
#[cfg(feature = "session")]
pub(crate) session_key: Cow<'static, str>,
}
impl Builder {
pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit;
self
}
pub fn period(&mut self, period: Duration) -> &mut Self {
self.period = period;
self
}
pub fn key_by<F>(&mut self, resolver: F) -> &mut Self
where
F: Fn(&ServiceRequest) -> Option<String> + Send + Sync + 'static,
{
self.get_key_fn = Some(Arc::new(resolver));
self
}
#[deprecated = "Prefer `key_by`."]
pub fn cookie_name(&mut self, cookie_name: impl Into<Cow<'static, str>>) -> &mut Self {
if self.get_key_fn.is_some() {
panic!("This method should not be used in combination of get_key as they overwrite each other")
}
self.cookie_name = cookie_name.into();
self
}
#[deprecated = "Prefer `key_by`."]
#[cfg(feature = "session")]
pub fn session_key(&mut self, session_key: impl Into<Cow<'static, str>>) -> &mut Self {
if self.get_key_fn.is_some() {
panic!("This method should not be used in combination of get_key as they overwrite each other")
}
self.session_key = session_key.into();
self
}
pub fn build(&mut self) -> Result<Limiter, Error> {
let get_key = if let Some(resolver) = self.get_key_fn.clone() {
resolver
} else {
let cookie_name = self.cookie_name.clone();
#[cfg(feature = "session")]
let session_key = self.session_key.clone();
let closure: GetArcBoxKeyFn = Arc::new(Box::new(move |req: &ServiceRequest| {
#[cfg(feature = "session")]
let res = req
.get_session()
.get(&session_key)
.unwrap_or_else(|_| req.cookie(&cookie_name).map(|c| c.to_string()));
#[cfg(not(feature = "session"))]
let res = req.cookie(&cookie_name).map(|c| c.to_string());
res
}));
closure
};
Ok(Limiter {
client: Client::open(self.redis_url.as_str())?,
limit: self.limit,
period: self.period,
get_key_fn: get_key,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_builder() {
let redis_url = "redis://127.0.0.1";
let period = Duration::from_secs(10);
let builder = Builder {
redis_url: redis_url.to_owned(),
limit: 100,
period,
get_key_fn: Some(Arc::new(|_| None)),
cookie_name: Cow::Owned("session".to_string()),
#[cfg(feature = "session")]
session_key: Cow::Owned("rate-api".to_string()),
};
assert_eq!(builder.redis_url, redis_url);
assert_eq!(builder.limit, 100);
assert_eq!(builder.period, period);
#[cfg(feature = "session")]
assert_eq!(builder.session_key, "rate-api");
assert_eq!(builder.cookie_name, "session");
}
#[test]
fn test_create_limiter() {
let redis_url = "redis://127.0.0.1";
let period = Duration::from_secs(20);
let mut builder = Builder {
redis_url: redis_url.to_owned(),
limit: 100,
period: Duration::from_secs(10),
get_key_fn: Some(Arc::new(|_| None)),
cookie_name: Cow::Borrowed("sid"),
#[cfg(feature = "session")]
session_key: Cow::Borrowed("key"),
};
let limiter = builder.limit(200).period(period).build().unwrap();
assert_eq!(limiter.limit, 200);
assert_eq!(limiter.period, period);
}
#[test]
#[should_panic = "Redis URL did not parse"]
fn test_create_limiter_error() {
let redis_url = "127.0.0.1";
let period = Duration::from_secs(20);
let mut builder = Builder {
redis_url: redis_url.to_owned(),
limit: 100,
period: Duration::from_secs(10),
get_key_fn: Some(Arc::new(|_| None)),
cookie_name: Cow::Borrowed("sid"),
#[cfg(feature = "session")]
session_key: Cow::Borrowed("key"),
};
builder.limit(200).period(period).build().unwrap();
}
}