use std::{future::Future, pin::Pin, rc::Rc}; use actix_session::UserSession; use actix_utils::future::{ok, Ready}; use actix_web::{ body::EitherBody, cookie::Cookie, dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, http::header::COOKIE, web, Error, HttpResponse, }; use crate::Limiter; pub struct RateLimiter; impl Transform for RateLimiter where S: Service, Error = Error> + 'static, S::Future: 'static, B: 'static, { type Response = ServiceResponse>; type Error = Error; type InitError = (); type Transform = RateLimiterMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(RateLimiterMiddleware { service: Rc::new(service), }) } } pub struct RateLimiterMiddleware { service: Rc, } impl Service for RateLimiterMiddleware where S: Service, Error = Error> + 'static, S::Future: 'static, B: 'static, { type Response = ServiceResponse>; type Error = Error; type Future = Pin>>>; forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { // A mis-configuration of the Actix App will result in a **runtime** failure, so the expect // method description is important context for the developer. let limiter = req .app_data::>() .expect("web::Data should be set in app data for RateLimiter middleware") .clone(); let forbidden = HttpResponse::Forbidden().finish().map_into_right_body(); let (key, fallback) = key(&req, limiter.clone()); let service = Rc::clone(&self.service); let key = match key { Some(key) => key, None => match fallback { Some(key) => key, None => { return Box::pin(async move { service .call(req) .await .map(ServiceResponse::map_into_left_body) }); } }, }; let service = Rc::clone(&self.service); Box::pin(async move { let status = limiter.count(key.to_string()).await; if status.is_err() { warn!("403. Rate limit exceed error for {}", key); Ok(req.into_response(forbidden)) } else { service .call(req) .await .map(ServiceResponse::map_into_left_body) } }) } } fn key(req: &ServiceRequest, limiter: web::Data) -> (Option, Option) { let session = req.get_session(); let result: Option = session.get(&limiter.session_key).unwrap_or(None); let cookies = req.headers().get_all(COOKIE); let cookie = cookies .filter_map(|i| i.to_str().ok()) .find(|i| i.contains(&limiter.cookie_name)); let fallback = match cookie { Some(value) => Cookie::parse(value).ok().map(|i| i.to_string()), None => None, }; (result, fallback) }