From 368ada2b758061bfe2e1d91f6246ab6f0accd530 Mon Sep 17 00:00:00 2001 From: GreeFine Date: Mon, 3 Oct 2022 18:26:08 +0200 Subject: [PATCH] feat: optional scopes for middleware --- actix-limitation/src/lib.rs | 43 +++++++++++++++++++++++++++--- actix-limitation/src/middleware.rs | 32 +++++++++++++++++----- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/actix-limitation/src/lib.rs b/actix-limitation/src/lib.rs index 384f51dad..8cb02e6b5 100644 --- a/actix-limitation/src/lib.rs +++ b/actix-limitation/src/lib.rs @@ -7,7 +7,7 @@ //! ``` //! //! ```no_run -//! use std::{sync::Arc, time::Duration}; +//! use std::{time::Duration}; //! use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder}; //! use actix_session::SessionExt as _; //! use actix_limitation::{Limiter, RateLimiter}; @@ -23,8 +23,8 @@ //! Limiter::builder("redis://127.0.0.1") //! .key_by(|req: &ServiceRequest| { //! req.get_session() -//! .get(&"session-id") -//! .unwrap_or_else(|_| req.cookie(&"rate-api-id").map(|c| c.to_string())) +//! .get("session-id") +//! .unwrap_or_else(|_| req.cookie("rate-api-id").map(|c| c.to_string())) //! }) //! .limit(5000) //! .period(Duration::from_secs(3600)) // 60 minutes @@ -176,4 +176,41 @@ mod tests { assert_eq!(limiter.limit, 5000); assert_eq!(limiter.period, Duration::from_secs(3600)); } + + #[actix_web::test] + async fn test_create_scoped_limiter() { + todo!("finish tests") + // use actix_session::SessionExt as _; + // use actix_web::{dev::ServiceRequest, get, web, App, HttpServer, Responder}; + // use std::time::Duration; + + // #[get("/{id}/{name}")] + // async fn index(info: web::Path<(u32, String)>) -> impl Responder { + // format!("Hello {}! id:{}", info.1, info.0) + // } + + // let limiter = web::Data::new( + // Limiter::builder("redis://127.0.0.1") + // .key_by(|req: &ServiceRequest| { + // req.get_session() + // .get("session-id") + // .unwrap_or_else(|_| req.cookie("rate-api-id").map(|c| c.to_string())) + // }) + // .limit(5000) + // .period(Duration::from_secs(3600)) // 60 minutes + // .build() + // .unwrap(), + // ); + + // HttpServer::new(move || { + // App::new() + // .wrap(RateLimiter::default()) + // .app_data(limiter.clone()) + // .service(index) + // }) + // .bind(("127.0.0.1", 8080)) + // .expect("test") + // .run() + // .await; + } } diff --git a/actix-limitation/src/middleware.rs b/actix-limitation/src/middleware.rs index 290f075c1..b9669ec4a 100644 --- a/actix-limitation/src/middleware.rs +++ b/actix-limitation/src/middleware.rs @@ -1,4 +1,4 @@ -use std::{future::Future, pin::Pin, rc::Rc}; +use std::{collections::HashMap, future::Future, pin::Pin, rc::Rc}; use actix_utils::future::{ok, Ready}; use actix_web::{ @@ -11,9 +11,16 @@ use actix_web::{ use crate::{Error as LimitationError, Limiter}; /// Rate limit middleware. +/// +/// Use the `scope` variable to define multiple limiter #[derive(Debug, Default)] #[non_exhaustive] -pub struct RateLimiter; +pub struct RateLimiter { + /// Used to define multiple limiter, with different configurations + /// + /// WARNING: When used (not None) the middleware will expect a `HashMap` in the actix-web `app_data` + pub scope: Option<&'static str>, +} impl Transform for RateLimiter where @@ -30,6 +37,7 @@ where fn new_transform(&self, service: S) -> Self::Future { ok(RateLimiterMiddleware { service: Rc::new(service), + scope: self.scope, }) } } @@ -38,6 +46,7 @@ where #[derive(Debug)] pub struct RateLimiterMiddleware { service: Rc, + scope: Option<&'static str>, } impl Service for RateLimiterMiddleware @@ -55,10 +64,21 @@ where fn call(&self, req: ServiceRequest) -> Self::Future { // A mis-configuration of the Actix App will result in a **runtime** failure, so the expect // method description is important context for the developer. - let limiter = req - .app_data::>() - .expect("web::Data should be set in app data for RateLimiter middleware") - .clone(); + let limiter = if let Some(scope) = self.scope { + let limiters = req.app_data::>>().expect( + "web::Data> should be set in app data for RateLimiter middleware", + ); + limiters + .get(scope) + .unwrap_or_else(|| panic!("Unable to find defined limiter with scope: {}", scope)) + .clone() + } else { + let a = req + .app_data::>() + .expect("web::Data should be set in app data for RateLimiter middleware"); + // Deref to get the Limiter + (***a).clone() + }; let key = (limiter.get_key_fn)(&req); let service = Rc::clone(&self.service);