diff --git a/.github/workflows/ci-master.yml b/.github/workflows/ci-master.yml index 1f8e72285..be1ac398f 100644 --- a/.github/workflows/ci-master.yml +++ b/.github/workflows/ci-master.yml @@ -122,7 +122,11 @@ jobs: timeout-minutes: 40 with: command: ci-test - args: --exclude=actix-redis --exclude=actix-session -- --nocapture + args: >- + --exclude=actix-redis + --exclude=actix-session + --exclude=actix-limitation + -- --nocapture - name: Clear the cargo caches run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eef45b3f4..858d44b2b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -131,7 +131,10 @@ jobs: timeout-minutes: 40 with: command: ci-test - args: --exclude=actix-redis --exclude=actix-session -- --nocapture + args: >- + --exclude=actix-redis + --exclude=actix-session + --exclude=actix-limitation - name: Clear the cargo caches run: | diff --git a/actix-limitation/CHANGES.md b/actix-limitation/CHANGES.md index 8601719ae..6224ecbe0 100644 --- a/actix-limitation/CHANGES.md +++ b/actix-limitation/CHANGES.md @@ -3,8 +3,12 @@ ## Unreleased - 2022-xx-xx - Update Actix Web dependency to v4 ecosystem. [#229] - Update Tokio dependencies to v1 ecosystem. [#229] +- Rename `Limiter::{build => builder}()`. [#232] +- Rename `Builder::{finish => build}()`. [#232] +- Exceeding the rate limit now returns a 429 Too Many Requests response. [#232] [#229]: https://github.com/actix/actix-extras/pull/229 +[#232]: https://github.com/actix/actix-extras/pull/232 ## 0.1.4 - 2022-03-18 diff --git a/actix-limitation/Cargo.toml b/actix-limitation/Cargo.toml index 7a2469321..fd314a1c5 100644 --- a/actix-limitation/Cargo.toml +++ b/actix-limitation/Cargo.toml @@ -15,10 +15,12 @@ actix-utils = "3" actix-web = { version = "4", default-features = false } chrono = "0.4" +derive_more = "0.99.5" log = "0.4" -redis = { version = "0.21", default-features = false, features = ["aio", "tokio-comp"] } +redis = { version = "0.21", default-features = false, features = ["tokio-comp"] } time = "0.3" [dev-dependencies] actix-web = "4" uuid = { version = "0.8", features = ["v4"] } +static_assertions = "1" diff --git a/actix-limitation/LICENSE-APACHE b/actix-limitation/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-limitation/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-limitation/LICENSE-MIT b/actix-limitation/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-limitation/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-limitation/README.md b/actix-limitation/README.md index 3be4b0b2e..15168f91e 100644 --- a/actix-limitation/README.md +++ b/actix-limitation/README.md @@ -12,8 +12,8 @@ ```toml [dependencies] -actix-limitation = "0.1.4" actix-web = "4" +actix-limitation = "0.1.4" ``` ```rust diff --git a/actix-limitation/src/builder.rs b/actix-limitation/src/builder.rs new file mode 100644 index 000000000..99072fe2d --- /dev/null +++ b/actix-limitation/src/builder.rs @@ -0,0 +1,121 @@ +use std::{borrow::Cow, time::Duration}; + +use redis::Client; + +use crate::{errors::Error, Limiter}; + +/// Rate limit builder. +#[derive(Debug)] +pub struct Builder<'a> { + pub(crate) redis_url: &'a str, + pub(crate) limit: usize, + pub(crate) period: Duration, + pub(crate) cookie_name: Cow<'static, str>, + pub(crate) session_key: Cow<'static, str>, +} + +impl Builder<'_> { + /// Set upper limit. + pub fn limit(&mut self, limit: usize) -> &mut Self { + self.limit = limit; + self + } + + /// Set limit window/period. + pub fn period(&mut self, period: Duration) -> &mut Self { + self.period = period; + self + } + + /// Set name of cookie to be sent. + pub fn cookie_name(&mut self, cookie_name: impl Into>) -> &mut Self { + self.cookie_name = cookie_name.into(); + self + } + + /// Set session key to be used in backend. + pub fn session_key(&mut self, session_key: impl Into>) -> &mut Self { + self.session_key = session_key.into(); + self + } + + /// Finalizes and returns a `Limiter`. + /// + /// Note that this method will connect to the Redis server to test its connection which is a + /// **synchronous** operation. + pub fn build(&self) -> Result { + Ok(Limiter { + client: Client::open(self.redis_url)?, + limit: self.limit, + period: self.period, + cookie_name: self.cookie_name.clone(), + session_key: self.session_key.clone(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_builder() { + let redis_url = "redis://127.0.0.1"; + let period = Duration::from_secs(10); + let builder = Builder { + redis_url, + limit: 100, + period, + cookie_name: Cow::Owned("session".to_string()), + session_key: Cow::Owned("rate-api".to_string()), + }; + + assert_eq!(builder.redis_url, redis_url); + assert_eq!(builder.limit, 100); + assert_eq!(builder.period, period); + assert_eq!(builder.session_key, "rate-api"); + assert_eq!(builder.cookie_name, "session"); + } + + #[test] + fn test_create_limiter() { + let redis_url = "redis://127.0.0.1"; + let period = Duration::from_secs(20); + let mut builder = Builder { + redis_url, + limit: 100, + period: Duration::from_secs(10), + session_key: Cow::Borrowed("key"), + cookie_name: Cow::Borrowed("sid"), + }; + + let limiter = builder + .limit(200) + .period(period) + .cookie_name("session".to_string()) + .session_key("rate-api".to_string()) + .build() + .unwrap(); + + assert_eq!(limiter.limit, 200); + assert_eq!(limiter.period, period); + assert_eq!(limiter.session_key, "rate-api"); + assert_eq!(limiter.cookie_name, "session"); + } + + #[test] + #[should_panic = "Redis URL did not parse"] + fn test_create_limiter_error() { + let redis_url = "127.0.0.1"; + let period = Duration::from_secs(20); + let mut builder = Builder { + redis_url, + limit: 100, + period: Duration::from_secs(10), + session_key: Cow::Borrowed("key"), + cookie_name: Cow::Borrowed("sid"), + }; + + builder.limit(200).period(period).build().unwrap(); + } +} diff --git a/actix-limitation/src/core/builder/mod.rs b/actix-limitation/src/core/builder/mod.rs deleted file mode 100644 index 98303c4b5..000000000 --- a/actix-limitation/src/core/builder/mod.rs +++ /dev/null @@ -1,51 +0,0 @@ -use redis::Client; -use std::time::Duration; - -use crate::{core::errors::Error, Limiter}; - -pub struct Builder<'builder> { - pub(crate) redis_url: &'builder str, - pub(crate) limit: usize, - pub(crate) period: Duration, - pub(crate) cookie_name: String, - pub(crate) session_key: String, -} - -impl Builder<'_> { - pub fn limit(&mut self, limit: usize) -> &mut Self { - self.limit = limit; - self - } - - pub fn period(&mut self, period: Duration) -> &mut Self { - self.period = period; - self - } - - pub fn cookie_name(&mut self, cookie_name: String) -> &mut Self { - self.cookie_name = cookie_name; - self - } - - pub fn session_key(&mut self, session_key: String) -> &mut Self { - self.session_key = session_key; - self - } - - /// Finializes and returns a `Limiter`. - /// - /// Note that this method will connect to the Redis server to test its connection which is a - /// **synchronous** operation. - pub fn finish(&self) -> Result { - Ok(Limiter { - client: Client::open(self.redis_url)?, - limit: self.limit, - period: self.period, - cookie_name: self.cookie_name.to_string(), - session_key: self.session_key.to_string(), - }) - } -} - -#[cfg(test)] -mod test; diff --git a/actix-limitation/src/core/builder/test.rs b/actix-limitation/src/core/builder/test.rs deleted file mode 100644 index e26b13ab3..000000000 --- a/actix-limitation/src/core/builder/test.rs +++ /dev/null @@ -1,62 +0,0 @@ -use super::*; - -#[test] -fn test_create_builder() { - let redis_url = "redis://127.0.0.1"; - let period = Duration::from_secs(10); - let builder = Builder { - redis_url, - limit: 100, - period, - cookie_name: "session".to_string(), - session_key: "rate-api".to_string(), - }; - - assert_eq!(builder.redis_url, redis_url); - assert_eq!(builder.limit, 100); - assert_eq!(builder.period, period); - assert_eq!(builder.session_key, "rate-api"); - assert_eq!(builder.cookie_name, "session"); -} - -#[test] -fn test_create_limiter() { - let redis_url = "redis://127.0.0.1"; - let period = Duration::from_secs(20); - let mut builder = Builder { - redis_url, - limit: 100, - period: Duration::from_secs(10), - session_key: "key".to_string(), - cookie_name: "sid".to_string(), - }; - - let limiter = builder - .limit(200) - .period(period) - .cookie_name("session".to_string()) - .session_key("rate-api".to_string()) - .finish() - .unwrap(); - - assert_eq!(limiter.limit, 200); - assert_eq!(limiter.period, period); - assert_eq!(limiter.session_key, "rate-api"); - assert_eq!(limiter.cookie_name, "session"); -} - -#[test] -#[should_panic = "Redis URL did not parse"] -fn test_create_limiter_error() { - let redis_url = "127.0.0.1"; - let period = Duration::from_secs(20); - let mut builder = Builder { - redis_url, - limit: 100, - period: Duration::from_secs(10), - session_key: "key".to_string(), - cookie_name: "sid".to_string(), - }; - - builder.limit(200).period(period).finish().unwrap(); -} diff --git a/actix-limitation/src/core/errors/mod.rs b/actix-limitation/src/core/errors/mod.rs deleted file mode 100644 index 095145d69..000000000 --- a/actix-limitation/src/core/errors/mod.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::{error, fmt}; - -use crate::core::status::Status; - -#[derive(Debug)] -pub enum Error { - /// The Redis client failed to connect or run a query. - Client(redis::RedisError), - - /// The limit is exceeded for a key. - LimitExceeded(Status), - - /// A time conversion failed. - Time(time::error::ComponentRange), - - Other(String), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Error::Client(ref err) => write!(f, "Client error ({})", err), - Error::LimitExceeded(ref status) => write!(f, "Rate limit exceeded ({:?})", status), - Error::Time(ref err) => write!(f, "Time conversion error ({})", err), - Error::Other(err) => write!(f, "{}", err), - } - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match self { - Error::Client(ref err) => err.source(), - Error::LimitExceeded(_) => None, - Error::Time(ref err) => err.source(), - Error::Other(_) => None, - } - } -} - -impl From for Error { - fn from(err: redis::RedisError) -> Self { - Error::Client(err) - } -} - -impl From for Error { - fn from(err: time::error::ComponentRange) -> Self { - Error::Time(err) - } -} diff --git a/actix-limitation/src/core/mod.rs b/actix-limitation/src/core/mod.rs deleted file mode 100644 index 423eff697..000000000 --- a/actix-limitation/src/core/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod builder; -pub mod errors; -pub mod status; diff --git a/actix-limitation/src/core/status/mod.rs b/actix-limitation/src/core/status/mod.rs deleted file mode 100644 index 25110c552..000000000 --- a/actix-limitation/src/core/status/mod.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::Error as LimitationError; -use chrono::SubsecRound; -use std::{convert::TryInto, ops::Add, time::Duration}; - -/// A report for a given key containing the limit status. -/// -/// The status contains the following information: -/// -/// - [`limit`]: the maximum number of requests allowed in the current period -/// - [`remaining`]: how many requests are left in the current period -/// - [`reset_epoch_utc`]: a UNIX timestamp in UTC approximately when the next period will begin -#[derive(Clone, Debug)] -pub struct Status { - pub(crate) limit: usize, - pub(crate) remaining: usize, - pub(crate) reset_epoch_utc: usize, -} - -impl Status { - pub fn limit(&self) -> usize { - self.limit - } - - pub fn remaining(&self) -> usize { - self.remaining - } - - pub fn reset_epoch_utc(&self) -> usize { - self.reset_epoch_utc - } - - pub(crate) fn build_status(count: usize, limit: usize, reset_epoch_utc: usize) -> Self { - let remaining = if count >= limit { 0 } else { limit - count }; - - Status { - limit, - remaining, - reset_epoch_utc, - } - } - - pub(crate) fn epoch_utc_plus(duration: Duration) -> Result { - match chrono::Duration::from_std(duration) { - Ok(value) => Ok(chrono::Utc::now() - .add(value) - .round_subsecs(0) - .timestamp() - .try_into() - .unwrap_or(0)), - Err(_) => Err(LimitationError::Other( - "Source duration value is out of range for the target type".to_string(), - )), - } - } -} - -#[cfg(test)] -mod test; diff --git a/actix-limitation/src/core/status/test.rs b/actix-limitation/src/core/status/test.rs deleted file mode 100644 index beb125bc9..000000000 --- a/actix-limitation/src/core/status/test.rs +++ /dev/null @@ -1,54 +0,0 @@ -use super::*; - -#[test] -fn test_create_status() { - let status = Status { - limit: 100, - remaining: 0, - reset_epoch_utc: 1000, - }; - - assert_eq!(status.limit(), 100); - assert_eq!(status.remaining(), 0); - assert_eq!(status.reset_epoch_utc(), 1000); -} - -#[test] -fn test_build_status() { - let count = 200; - let limit = 100; - let status = Status::build_status(count, limit, 2000); - assert_eq!(status.limit(), limit); - assert_eq!(status.remaining(), 0); - assert_eq!(status.reset_epoch_utc(), 2000); -} - -#[test] -fn test_build_status_limit() { - let limit = 100; - let status = Status::build_status(0, limit, 2000); - assert_eq!(status.limit(), limit); - assert_eq!(status.remaining(), limit); - assert_eq!(status.reset_epoch_utc(), 2000); -} - -#[test] -fn test_epoch_utc_plus_zero() { - let duration = Duration::from_secs(0); - let seconds = Status::epoch_utc_plus(duration).unwrap(); - assert!(seconds as u64 >= duration.as_secs()); -} - -#[test] -fn test_epoch_utc_plus() { - let duration = Duration::from_secs(10); - let seconds = Status::epoch_utc_plus(duration).unwrap(); - assert!(seconds as u64 >= duration.as_secs() + 10); -} - -#[test] -#[should_panic = "Source duration value is out of range for the target type"] -fn test_epoch_utc_plus_overflow() { - let duration = Duration::from_secs(10000000000000000000); - Status::epoch_utc_plus(duration).unwrap(); -} diff --git a/actix-limitation/src/errors.rs b/actix-limitation/src/errors.rs new file mode 100644 index 000000000..9f535d64d --- /dev/null +++ b/actix-limitation/src/errors.rs @@ -0,0 +1,42 @@ +use derive_more::{Display, Error, From}; + +use crate::status::Status; + +/// Failure modes of the rate limiter. +#[derive(Debug, Display, Error, From)] +pub enum Error { + /// Redis client failed to connect or run a query. + #[display(fmt = "Redis client failed to connect or run a query")] + Client(redis::RedisError), + + /// Limit is exceeded for a key. + #[display(fmt = "Limit is exceeded for a key")] + #[from(ignore)] + LimitExceeded(#[error(not(source))] Status), + + /// Time conversion failed. + #[display(fmt = "Time conversion failed")] + Time(time::error::ComponentRange), + + /// Generic error. + #[display(fmt = "Generic error")] + #[from(ignore)] + Other(#[error(not(source))] String), +} + +#[cfg(test)] +mod tests { + use super::*; + + static_assertions::assert_impl_all! { + Error: + From, + From, + } + + static_assertions::assert_not_impl_any! { + Error: + From, + From, + } +} diff --git a/actix-limitation/src/lib.rs b/actix-limitation/src/lib.rs index 21d21cbf3..a3cb9363a 100644 --- a/actix-limitation/src/lib.rs +++ b/actix-limitation/src/lib.rs @@ -1,84 +1,107 @@ -/*! -Rate limiter using a fixed window counter for arbitrary keys, backed by Redis for Actix Web +//! Rate limiter using a fixed window counter for arbitrary keys, backed by Redis for Actix Web. +//! +//! ```toml +//! [dependencies] +//! actix-web = "4" +//! actix-limitation = "0.1.4" +//! ``` +//! +//! ```no_run +//! use std::time::Duration; +//! use actix_web::{get, web, App, HttpServer, Responder}; +//! use actix_limitation::{Limiter, RateLimiter}; +//! +//! #[get("/{id}/{name}")] +//! async fn index(info: web::Path<(u32, String)>) -> impl Responder { +//! format!("Hello {}! id:{}", info.1, info.0) +//! } +//! +//! #[actix_web::main] +//! async fn main() -> std::io::Result<()> { +//! let limiter = web::Data::new( +//! Limiter::builder("redis://127.0.0.1") +//! .cookie_name("session-id".to_owned()) +//! .session_key("rate-api-id".to_owned()) +//! .limit(5000) +//! .period(Duration::from_secs(3600)) // 60 minutes +//! .build() +//! .unwrap(), +//! ); +//! +//! HttpServer::new(move || { +//! App::new() +//! .wrap(RateLimiter) +//! .app_data(limiter.clone()) +//! .service(index) +//! }) +//! .bind("127.0.0.1:8080")? +//! .run() +//! .await +//! } +//! ``` -```toml -[dependencies] -actix-limitation = "0.1.4" -actix-web = "4" -``` +#![forbid(unsafe_code)] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible, missing_docs, missing_debug_implementations)] +#![doc(html_logo_url = "https://actix.rs/img/logo.png")] +#![doc(html_favicon_url = "https://actix.rs/favicon.ico")] -```no_run -use std::time::Duration; -use actix_web::{get, web, App, HttpServer, Responder}; -use actix_limitation::{Limiter, RateLimiter}; - -#[get("/{id}/{name}")] -async fn index(info: web::Path<(u32, String)>) -> impl Responder { - format!("Hello {}! id:{}", info.1, info.0) -} - -#[actix_web::main] -async fn main() -> std::io::Result<()> { - let limiter = web::Data::new( - Limiter::build("redis://127.0.0.1") - .cookie_name("session-id".to_owned()) - .session_key("rate-api-id".to_owned()) - .limit(5000) - .period(Duration::from_secs(3600)) // 60 minutes - .finish() - .expect("Can't build actix-limiter"), - ); - - HttpServer::new(move || { - App::new() - .wrap(RateLimiter) - .app_data(limiter.clone()) - .service(index) - }) - .bind("127.0.0.1:8080")? - .run() - .await -} -``` -*/ - -#[macro_use] -extern crate log; +use std::{borrow::Cow, time::Duration}; use redis::Client; -use std::time::Duration; -pub use crate::core::{builder::Builder, errors::Error, status::Status}; -pub use crate::middleware::RateLimiter; +mod builder; +mod errors; +mod middleware; +mod status; +pub use self::builder::Builder; +pub use self::errors::Error; +pub use self::middleware::RateLimiter; +pub use self::status::Status; + +/// Default request limit. pub const DEFAULT_REQUEST_LIMIT: usize = 5000; + +/// Default period (in seconds). pub const DEFAULT_PERIOD_SECS: u64 = 3600; + +/// Default cookie name. pub const DEFAULT_COOKIE_NAME: &str = "sid"; + +/// Default session key. pub const DEFAULT_SESSION_KEY: &str = "rate-api-id"; +/// Rate limiter. #[derive(Clone, Debug)] pub struct Limiter { client: Client, limit: usize, period: Duration, - cookie_name: String, - session_key: String, + cookie_name: Cow<'static, str>, + session_key: Cow<'static, str>, } impl Limiter { - pub fn build(redis_url: &str) -> Builder { + /// Construct rate limiter builder with defaults. + /// + /// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection + /// parameters for how to set the Redis URL. + #[must_use] + pub fn builder(redis_url: &str) -> Builder<'_> { Builder { redis_url, limit: DEFAULT_REQUEST_LIMIT, period: Duration::from_secs(DEFAULT_PERIOD_SECS), - cookie_name: DEFAULT_COOKIE_NAME.to_string(), - session_key: DEFAULT_SESSION_KEY.to_string(), + cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME), + session_key: Cow::Borrowed(DEFAULT_SESSION_KEY), } } - pub async fn count>(&self, key: K) -> Result { + /// Consumes one rate limit unit, returning the status. + pub async fn count(&self, key: impl Into) -> Result { let (count, reset) = self.track(key).await?; - let status = Status::build_status(count, self.limit, reset); + let status = Status::new(count, self.limit, reset); if count > self.limit { Err(Error::LimitExceeded(status)) @@ -88,36 +111,49 @@ impl Limiter { } /// Tracks the given key in a period and returns the count and TTL for the key in seconds. - async fn track>(&self, key: K) -> Result<(usize, usize), Error> { + async fn track(&self, key: impl Into) -> Result<(usize, usize), Error> { let key = key.into(); - let exipres = self.period.as_secs(); + let expires = self.period.as_secs(); let mut connection = self.client.get_tokio_connection().await?; - // The seed of this approach is outlined Atul R in a blog post about rate limiting - // using NodeJS and Redis. For more details, see - // https://blog.atulr.com/rate-limiter/ + // The seed of this approach is outlined Atul R in a blog post about rate limiting using + // NodeJS and Redis. For more details, see https://blog.atulr.com/rate-limiter let mut pipe = redis::pipe(); pipe.atomic() - .cmd("SET") + .cmd("SET") // Set key and value .arg(&key) .arg(0) - .arg("EX") - .arg(exipres) - .arg("NX") - .ignore() - .cmd("INCR") + .arg("EX") // Set the specified expire time, in seconds. + .arg(expires) + .arg("NX") // Only set the key if it does not already exist. + .ignore() // --- ignore returned value of SET command --- + .cmd("INCR") // Increment key .arg(&key) - .cmd("TTL") + .cmd("TTL") // Return time-to-live of key .arg(&key); - let (count, ttl): (usize, u64) = pipe.query_async(&mut connection).await?; + let (count, ttl) = pipe.query_async(&mut connection).await?; let reset = Status::epoch_utc_plus(Duration::from_secs(ttl))?; + Ok((count, reset)) } } -mod core; -mod middleware; #[cfg(test)] -mod test; +mod tests { + use super::*; + + #[test] + fn test_create_limiter() { + let builder = Limiter::builder("redis://127.0.0.1:6379/1"); + let limiter = builder.build(); + assert!(limiter.is_ok()); + + let limiter = limiter.unwrap(); + assert_eq!(limiter.limit, 5000); + assert_eq!(limiter.period, Duration::from_secs(3600)); + assert_eq!(limiter.cookie_name, DEFAULT_COOKIE_NAME); + assert_eq!(limiter.session_key, DEFAULT_SESSION_KEY); + } +} diff --git a/actix-limitation/src/middleware/mod.rs b/actix-limitation/src/middleware.rs similarity index 88% rename from actix-limitation/src/middleware/mod.rs rename to actix-limitation/src/middleware.rs index 52ada6460..cf46cddca 100644 --- a/actix-limitation/src/middleware/mod.rs +++ b/actix-limitation/src/middleware.rs @@ -6,12 +6,14 @@ use actix_web::{ body::EitherBody, cookie::Cookie, dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, - http::header::COOKIE, + http::{header::COOKIE, StatusCode}, web, Error, HttpResponse, }; use crate::Limiter; +/// Rate limit middleware. +#[derive(Debug)] pub struct RateLimiter; impl Transform for RateLimiter @@ -22,8 +24,8 @@ where { type Response = ServiceResponse>; type Error = Error; - type InitError = (); type Transform = RateLimiterMiddleware; + type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { @@ -33,6 +35,8 @@ where } } +/// Rate limit middleware service. +#[derive(Debug)] pub struct RateLimiterMiddleware { service: Rc, } @@ -57,10 +61,10 @@ where .expect("web::Data should be set in app data for RateLimiter middleware") .clone(); - let forbidden = HttpResponse::Forbidden().finish().map_into_right_body(); let (key, fallback) = key(&req, limiter.clone()); let service = Rc::clone(&self.service); + let key = match key { Some(key) => key, None => match fallback { @@ -76,12 +80,15 @@ where }, }; - let service = Rc::clone(&self.service); Box::pin(async move { let status = limiter.count(key.to_string()).await; + if status.is_err() { - warn!("403. Rate limit exceed error for {}", key); - Ok(req.into_response(forbidden)) + log::warn!("Rate limit exceed error for {}", key); + + Ok(req.into_response( + HttpResponse::new(StatusCode::TOO_MANY_REQUESTS).map_into_right_body(), + )) } else { service .call(req) @@ -98,7 +105,7 @@ fn key(req: &ServiceRequest, limiter: web::Data) -> (Option, Op let cookies = req.headers().get_all(COOKIE); let cookie = cookies .filter_map(|i| i.to_str().ok()) - .find(|i| i.contains(&limiter.cookie_name)); + .find(|i| i.contains(limiter.cookie_name.as_ref())); let fallback = match cookie { Some(value) => Cookie::parse(value).ok().map(|i| i.to_string()), diff --git a/actix-limitation/src/status.rs b/actix-limitation/src/status.rs new file mode 100644 index 000000000..15fb2d0c9 --- /dev/null +++ b/actix-limitation/src/status.rs @@ -0,0 +1,118 @@ +use std::{convert::TryInto, ops::Add, time::Duration}; + +use chrono::SubsecRound as _; + +use crate::Error as LimitationError; + +/// A report for a given key containing the limit status. +#[derive(Debug, Clone)] +pub struct Status { + pub(crate) limit: usize, + pub(crate) remaining: usize, + pub(crate) reset_epoch_utc: usize, +} + +impl Status { + /// Constructs status limit status from parts. + #[must_use] + pub(crate) fn new(count: usize, limit: usize, reset_epoch_utc: usize) -> Self { + let remaining = if count >= limit { 0 } else { limit - count }; + + Status { + limit, + remaining, + reset_epoch_utc, + } + } + + /// Returns the maximum number of requests allowed in the current period. + #[must_use] + pub fn limit(&self) -> usize { + self.limit + } + + /// Returns how many requests are left in the current period. + #[must_use] + pub fn remaining(&self) -> usize { + self.remaining + } + + /// Returns a UNIX timestamp in UTC approximately when the next period will begin. + #[must_use] + pub fn reset_epoch_utc(&self) -> usize { + self.reset_epoch_utc + } + + pub(crate) fn epoch_utc_plus(duration: Duration) -> Result { + match chrono::Duration::from_std(duration) { + Ok(value) => Ok(chrono::Utc::now() + .add(value) + .round_subsecs(0) + .timestamp() + .try_into() + .unwrap_or(0)), + + Err(_) => Err(LimitationError::Other( + "Source duration value is out of range for the target type".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_status() { + let status = Status { + limit: 100, + remaining: 0, + reset_epoch_utc: 1000, + }; + + assert_eq!(status.limit(), 100); + assert_eq!(status.remaining(), 0); + assert_eq!(status.reset_epoch_utc(), 1000); + } + + #[test] + fn test_build_status() { + let count = 200; + let limit = 100; + let status = Status::new(count, limit, 2000); + assert_eq!(status.limit(), limit); + assert_eq!(status.remaining(), 0); + assert_eq!(status.reset_epoch_utc(), 2000); + } + + #[test] + fn test_build_status_limit() { + let limit = 100; + let status = Status::new(0, limit, 2000); + assert_eq!(status.limit(), limit); + assert_eq!(status.remaining(), limit); + assert_eq!(status.reset_epoch_utc(), 2000); + } + + #[test] + fn test_epoch_utc_plus_zero() { + let duration = Duration::from_secs(0); + let seconds = Status::epoch_utc_plus(duration).unwrap(); + assert!(seconds as u64 >= duration.as_secs()); + } + + #[test] + fn test_epoch_utc_plus() { + let duration = Duration::from_secs(10); + let seconds = Status::epoch_utc_plus(duration).unwrap(); + assert!(seconds as u64 >= duration.as_secs() + 10); + } + + #[test] + #[should_panic = "Source duration value is out of range for the target type"] + fn test_epoch_utc_plus_overflow() { + let duration = Duration::from_secs(10000000000000000000); + Status::epoch_utc_plus(duration).unwrap(); + } +} diff --git a/actix-limitation/src/test/mod.rs b/actix-limitation/src/test/mod.rs deleted file mode 100644 index 49a320a25..000000000 --- a/actix-limitation/src/test/mod.rs +++ /dev/null @@ -1,65 +0,0 @@ -use uuid::Uuid; - -use super::*; - -#[test] -fn test_create_limiter() { - let builder = Limiter::build("redis://127.0.0.1:6379/1"); - let limiter = builder.finish(); - assert!(limiter.is_ok()); - - let limiter = limiter.unwrap(); - assert_eq!(limiter.limit, 5000); - assert_eq!(limiter.period, Duration::from_secs(3600)); - assert_eq!(limiter.cookie_name, DEFAULT_COOKIE_NAME); - assert_eq!(limiter.session_key, DEFAULT_SESSION_KEY); -} - -#[test] -#[should_panic = "Redis URL did not parse"] -fn test_create_limiter_error() { - Limiter::build("127.0.0.1").finish().unwrap(); -} - -// TODO: figure out whats wrong with this test -#[ignore] -#[actix_web::test] -async fn test_limiter_count() -> Result<(), Error> { - let builder = Limiter::build("redis://127.0.0.1:6379/2"); - let limiter = builder.finish().unwrap(); - let id = Uuid::new_v4(); - - for i in 0..5000 { - let status = limiter.count(id.to_string()).await?; - assert_eq!(5000 - status.remaining(), i + 1); - } - - Ok(()) -} - -// TODO: figure out whats wrong with this test -#[ignore] -#[actix_web::test] -async fn test_limiter_count_error() -> Result<(), Error> { - let builder = Limiter::build("redis://127.0.0.1:6379/3"); - let limiter = builder.finish().unwrap(); - let id = Uuid::new_v4(); - - for i in 0..5000 { - let status = limiter.count(id.to_string()).await?; - assert_eq!(5000 - status.remaining(), i + 1); - } - - match limiter.count(id.to_string()).await.unwrap_err() { - Error::LimitExceeded(status) => assert_eq!(status.remaining(), 0), - _ => panic!("error should be LimitExceeded variant"), - }; - - let id = Uuid::new_v4(); - for i in 0..5000 { - let status = limiter.count(id.to_string()).await?; - assert_eq!(5000 - status.remaining(), i + 1); - } - - Ok(()) -} diff --git a/actix-limitation/tests/tests.rs b/actix-limitation/tests/tests.rs new file mode 100644 index 000000000..110a8060a --- /dev/null +++ b/actix-limitation/tests/tests.rs @@ -0,0 +1,53 @@ +use actix_limitation::{Error, Limiter}; +use uuid::Uuid; + +#[test] +#[should_panic = "Redis URL did not parse"] +fn test_create_limiter_error() { + Limiter::builder("127.0.0.1").build().unwrap(); +} + +#[actix_web::test] +async fn test_limiter_count() -> Result<(), Error> { + let limiter = Limiter::builder("redis://127.0.0.1:6379/2") + .limit(20) + .build() + .unwrap(); + + let id = Uuid::new_v4(); + + for i in 0..20 { + let status = limiter.count(id.to_string()).await?; + println!("status: {:?}", status); + assert_eq!(20 - status.remaining(), i + 1); + } + + Ok(()) +} + +#[actix_web::test] +async fn test_limiter_count_error() -> Result<(), Error> { + let limiter = Limiter::builder("redis://127.0.0.1:6379/3") + .limit(25) + .build() + .unwrap(); + + let id = Uuid::new_v4(); + for i in 0..25 { + let status = limiter.count(id.to_string()).await?; + assert_eq!(25 - status.remaining(), i + 1); + } + + match limiter.count(id.to_string()).await.unwrap_err() { + Error::LimitExceeded(status) => assert_eq!(status.remaining(), 0), + _ => panic!("error should be LimitExceeded variant"), + }; + + let id = Uuid::new_v4(); + for i in 0..25 { + let status = limiter.count(id.to_string()).await?; + assert_eq!(25 - status.remaining(), i + 1); + } + + Ok(()) +}