1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
//! Rate limiter using a fixed window counter for arbitrary keys, backed by Redis for Actix Web.
//!
//! ```toml
//! [dependencies]
//! actix-web = "4"
#![doc = concat!("actix-limitation = \"", env!("CARGO_PKG_VERSION_MAJOR"), ".", env!("CARGO_PKG_VERSION_MINOR"),"\"")]
//! ```
//!
//! ```no_run
//! use std::{sync::Arc, time::Duration};
//! use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder};
//! use actix_session::SessionExt as _;
//! use actix_limitation::{Limiter, RateLimiter};
//!
//! #[get("/{id}/{name}")]
//! async fn index(info: web::Path<(u32, String)>) -> impl Responder {
//! format!("Hello {}! id:{}", info.1, info.0)
//! }
//!
//! #[actix_web::main]
//! async fn main() -> std::io::Result<()> {
//! let limiter = web::Data::new(
//! Limiter::builder("redis://127.0.0.1")
//! .key_by(|req: &ServiceRequest| {
//! req.get_session()
//! .get(&"session-id")
//! .unwrap_or_else(|_| req.cookie(&"rate-api-id").map(|c| c.to_string()))
//! })
//! .limit(5000)
//! .period(Duration::from_secs(3600)) // 60 minutes
//! .build()
//! .unwrap(),
//! );
//!
//! HttpServer::new(move || {
//! App::new()
//! .wrap(RateLimiter::default())
//! .app_data(limiter.clone())
//! .service(index)
//! })
//! .bind(("127.0.0.1", 8080))?
//! .run()
//! .await
//! }
//! ```
#![forbid(unsafe_code)]
#![deny(rust_2018_idioms, nonstandard_style)]
#![warn(future_incompatible, missing_docs, missing_debug_implementations)]
#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
use std::{borrow::Cow, fmt, sync::Arc, time::Duration};
use actix_web::dev::ServiceRequest;
use redis::Client;
mod builder;
mod errors;
mod middleware;
mod status;
pub use self::{builder::Builder, errors::Error, middleware::RateLimiter, status::Status};
/// Default request limit.
pub const DEFAULT_REQUEST_LIMIT: usize = 5000;
/// Default period (in seconds).
pub const DEFAULT_PERIOD_SECS: u64 = 3600;
/// Default cookie name.
pub const DEFAULT_COOKIE_NAME: &str = "sid";
/// Default session key.
#[cfg(feature = "session")]
pub const DEFAULT_SESSION_KEY: &str = "rate-api-id";
/// Helper trait to impl Debug on GetKeyFn type
trait GetKeyFnT: Fn(&ServiceRequest) -> Option<String> {}
impl<T> GetKeyFnT for T where T: Fn(&ServiceRequest) -> Option<String> {}
/// Get key function type with auto traits
type GetKeyFn = dyn GetKeyFnT + Send + Sync;
/// Get key resolver function type
impl fmt::Debug for GetKeyFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "GetKeyFn")
}
}
/// Wrapped Get key function Trait
type GetArcBoxKeyFn = Arc<GetKeyFn>;
/// Rate limiter.
#[derive(Debug, Clone)]
pub struct Limiter {
client: Client,
limit: usize,
period: Duration,
get_key_fn: GetArcBoxKeyFn,
}
impl Limiter {
/// Construct rate limiter builder with defaults.
///
/// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection
/// parameters for how to set the Redis URL.
#[must_use]
pub fn builder(redis_url: impl Into<String>) -> Builder {
Builder {
redis_url: redis_url.into(),
limit: DEFAULT_REQUEST_LIMIT,
period: Duration::from_secs(DEFAULT_PERIOD_SECS),
get_key_fn: None,
cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
#[cfg(feature = "session")]
session_key: Cow::Borrowed(DEFAULT_SESSION_KEY),
}
}
/// Consumes one rate limit unit, returning the status.
pub async fn count(&self, key: impl Into<String>) -> Result<Status, Error> {
let (count, reset) = self.track(key).await?;
let status = Status::new(count, self.limit, reset);
if count > self.limit {
Err(Error::LimitExceeded(status))
} else {
Ok(status)
}
}
/// Tracks the given key in a period and returns the count and TTL for the key in seconds.
async fn track(&self, key: impl Into<String>) -> Result<(usize, usize), Error> {
let key = key.into();
let expires = self.period.as_secs();
let mut connection = self.client.get_tokio_connection().await?;
// The seed of this approach is outlined Atul R in a blog post about rate limiting using
// NodeJS and Redis. For more details, see https://blog.atulr.com/rate-limiter
let mut pipe = redis::pipe();
pipe.atomic()
.cmd("SET") // Set key and value
.arg(&key)
.arg(0)
.arg("EX") // Set the specified expire time, in seconds.
.arg(expires)
.arg("NX") // Only set the key if it does not already exist.
.ignore() // --- ignore returned value of SET command ---
.cmd("INCR") // Increment key
.arg(&key)
.cmd("TTL") // Return time-to-live of key
.arg(&key);
let (count, ttl) = pipe.query_async(&mut connection).await?;
let reset = Status::epoch_utc_plus(Duration::from_secs(ttl))?;
Ok((count, reset))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_limiter() {
let mut builder = Limiter::builder("redis://127.0.0.1:6379/1");
let limiter = builder.build();
assert!(limiter.is_ok());
let limiter = limiter.unwrap();
assert_eq!(limiter.limit, 5000);
assert_eq!(limiter.period, Duration::from_secs(3600));
}
}