mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-27 17:22:57 +01:00
180 lines
5.6 KiB
Rust
180 lines
5.6 KiB
Rust
//! Rate limiter using a fixed window counter for arbitrary keys, backed by Redis for Actix Web.
|
|
//!
|
|
//! ```toml
|
|
//! [dependencies]
|
|
//! actix-web = "4"
|
|
#![doc = concat!("actix-limitation = \"", env!("CARGO_PKG_VERSION_MAJOR"), ".", env!("CARGO_PKG_VERSION_MINOR"),"\"")]
|
|
//! ```
|
|
//!
|
|
//! ```no_run
|
|
//! use std::{sync::Arc, time::Duration};
|
|
//! use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder};
|
|
//! use actix_session::SessionExt as _;
|
|
//! use actix_limitation::{Limiter, RateLimiter};
|
|
//!
|
|
//! #[get("/{id}/{name}")]
|
|
//! async fn index(info: web::Path<(u32, String)>) -> impl Responder {
|
|
//! format!("Hello {}! id:{}", info.1, info.0)
|
|
//! }
|
|
//!
|
|
//! #[actix_web::main]
|
|
//! async fn main() -> std::io::Result<()> {
|
|
//! let limiter = web::Data::new(
|
|
//! Limiter::builder("redis://127.0.0.1")
|
|
//! .key_by(|req: &ServiceRequest| {
|
|
//! req.get_session()
|
|
//! .get(&"session-id")
|
|
//! .unwrap_or_else(|_| req.cookie(&"rate-api-id").map(|c| c.to_string()))
|
|
//! })
|
|
//! .limit(5000)
|
|
//! .period(Duration::from_secs(3600)) // 60 minutes
|
|
//! .build()
|
|
//! .unwrap(),
|
|
//! );
|
|
//!
|
|
//! HttpServer::new(move || {
|
|
//! App::new()
|
|
//! .wrap(RateLimiter::default())
|
|
//! .app_data(limiter.clone())
|
|
//! .service(index)
|
|
//! })
|
|
//! .bind(("127.0.0.1", 8080))?
|
|
//! .run()
|
|
//! .await
|
|
//! }
|
|
//! ```
|
|
|
|
#![forbid(unsafe_code)]
|
|
#![warn(missing_docs, missing_debug_implementations)]
|
|
#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
|
|
#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
|
|
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
|
|
|
|
use std::{borrow::Cow, fmt, sync::Arc, time::Duration};
|
|
|
|
use actix_web::dev::ServiceRequest;
|
|
use redis::Client;
|
|
|
|
mod builder;
|
|
mod errors;
|
|
mod middleware;
|
|
mod status;
|
|
|
|
pub use self::{builder::Builder, errors::Error, middleware::RateLimiter, status::Status};
|
|
|
|
/// Default request limit.
|
|
pub const DEFAULT_REQUEST_LIMIT: usize = 5000;
|
|
|
|
/// Default period (in seconds).
|
|
pub const DEFAULT_PERIOD_SECS: u64 = 3600;
|
|
|
|
/// Default cookie name.
|
|
pub const DEFAULT_COOKIE_NAME: &str = "sid";
|
|
|
|
/// Default session key.
|
|
#[cfg(feature = "session")]
|
|
pub const DEFAULT_SESSION_KEY: &str = "rate-api-id";
|
|
|
|
/// Helper trait to impl Debug on GetKeyFn type
|
|
trait GetKeyFnT: Fn(&ServiceRequest) -> Option<String> {}
|
|
|
|
impl<T> GetKeyFnT for T where T: Fn(&ServiceRequest) -> Option<String> {}
|
|
|
|
/// Get key function type with auto traits
|
|
type GetKeyFn = dyn GetKeyFnT + Send + Sync;
|
|
|
|
/// Get key resolver function type
|
|
impl fmt::Debug for GetKeyFn {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(f, "GetKeyFn")
|
|
}
|
|
}
|
|
|
|
/// Wrapped Get key function Trait
|
|
type GetArcBoxKeyFn = Arc<GetKeyFn>;
|
|
|
|
/// Rate limiter.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Limiter {
|
|
client: Client,
|
|
limit: usize,
|
|
period: Duration,
|
|
get_key_fn: GetArcBoxKeyFn,
|
|
}
|
|
|
|
impl Limiter {
|
|
/// Construct rate limiter builder with defaults.
|
|
///
|
|
/// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection
|
|
/// parameters for how to set the Redis URL.
|
|
#[must_use]
|
|
pub fn builder(redis_url: impl Into<String>) -> Builder {
|
|
Builder {
|
|
redis_url: redis_url.into(),
|
|
limit: DEFAULT_REQUEST_LIMIT,
|
|
period: Duration::from_secs(DEFAULT_PERIOD_SECS),
|
|
get_key_fn: None,
|
|
cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
|
|
#[cfg(feature = "session")]
|
|
session_key: Cow::Borrowed(DEFAULT_SESSION_KEY),
|
|
}
|
|
}
|
|
|
|
/// Consumes one rate limit unit, returning the status.
|
|
pub async fn count(&self, key: impl Into<String>) -> Result<Status, Error> {
|
|
let (count, reset) = self.track(key).await?;
|
|
let status = Status::new(count, self.limit, reset);
|
|
|
|
if count > self.limit {
|
|
Err(Error::LimitExceeded(status))
|
|
} else {
|
|
Ok(status)
|
|
}
|
|
}
|
|
|
|
/// Tracks the given key in a period and returns the count and TTL for the key in seconds.
|
|
async fn track(&self, key: impl Into<String>) -> Result<(usize, usize), Error> {
|
|
let key = key.into();
|
|
let expires = self.period.as_secs();
|
|
|
|
let mut connection = self.client.get_multiplexed_tokio_connection().await?;
|
|
|
|
// The seed of this approach is outlined Atul R in a blog post about rate limiting using
|
|
// NodeJS and Redis. For more details, see https://blog.atulr.com/rate-limiter
|
|
let mut pipe = redis::pipe();
|
|
pipe.atomic()
|
|
.cmd("SET") // Set key and value
|
|
.arg(&key)
|
|
.arg(0)
|
|
.arg("EX") // Set the specified expire time, in seconds.
|
|
.arg(expires)
|
|
.arg("NX") // Only set the key if it does not already exist.
|
|
.ignore() // --- ignore returned value of SET command ---
|
|
.cmd("INCR") // Increment key
|
|
.arg(&key)
|
|
.cmd("TTL") // Return time-to-live of key
|
|
.arg(&key);
|
|
|
|
let (count, ttl) = pipe.query_async(&mut connection).await?;
|
|
let reset = Status::epoch_utc_plus(Duration::from_secs(ttl))?;
|
|
|
|
Ok((count, reset))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_create_limiter() {
|
|
let mut builder = Limiter::builder("redis://127.0.0.1:6379/1");
|
|
let limiter = builder.build();
|
|
assert!(limiter.is_ok());
|
|
|
|
let limiter = limiter.unwrap();
|
|
assert_eq!(limiter.limit, 5000);
|
|
assert_eq!(limiter.period, Duration::from_secs(3600));
|
|
}
|
|
}
|