2022-10-03 18:26:08 +02:00
|
|
|
use std::{collections::HashMap, future::Future, pin::Pin, rc::Rc};
|
2022-03-18 17:00:33 +00:00
|
|
|
|
|
|
|
use actix_utils::future::{ok, Ready};
|
|
|
|
use actix_web::{
|
|
|
|
body::EitherBody,
|
|
|
|
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
|
2022-07-31 03:03:43 +01:00
|
|
|
http::StatusCode,
|
2022-03-18 17:00:33 +00:00
|
|
|
web, Error, HttpResponse,
|
|
|
|
};
|
|
|
|
|
2022-08-28 21:49:14 +02:00
|
|
|
use crate::{Error as LimitationError, Limiter};
|
2022-03-18 17:00:33 +00:00
|
|
|
|
2022-03-20 00:40:34 +00:00
|
|
|
/// Rate limit middleware.
|
2022-10-03 18:26:08 +02:00
|
|
|
///
|
|
|
|
/// Use the `scope` variable to define multiple limiter
|
2022-07-31 03:03:43 +01:00
|
|
|
#[derive(Debug, Default)]
|
|
|
|
#[non_exhaustive]
|
2022-10-03 18:26:08 +02:00
|
|
|
pub struct RateLimiter {
|
|
|
|
/// Used to define multiple limiter, with different configurations
|
|
|
|
///
|
|
|
|
/// WARNING: When used (not None) the middleware will expect a `HashMap<Limiter>` in the actix-web `app_data`
|
|
|
|
pub scope: Option<&'static str>,
|
|
|
|
}
|
2022-03-18 17:00:33 +00:00
|
|
|
|
|
|
|
impl<S, B> Transform<S, ServiceRequest> for RateLimiter
|
|
|
|
where
|
|
|
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
|
|
|
|
S::Future: 'static,
|
|
|
|
B: 'static,
|
|
|
|
{
|
|
|
|
type Response = ServiceResponse<EitherBody<B>>;
|
|
|
|
type Error = Error;
|
|
|
|
type Transform = RateLimiterMiddleware<S>;
|
2022-03-20 00:40:34 +00:00
|
|
|
type InitError = ();
|
2022-03-18 17:00:33 +00:00
|
|
|
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
|
|
|
|
|
|
|
fn new_transform(&self, service: S) -> Self::Future {
|
|
|
|
ok(RateLimiterMiddleware {
|
|
|
|
service: Rc::new(service),
|
2022-10-03 18:26:08 +02:00
|
|
|
scope: self.scope,
|
2022-03-18 17:00:33 +00:00
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-20 00:40:34 +00:00
|
|
|
/// Rate limit middleware service.
|
|
|
|
#[derive(Debug)]
|
2022-03-18 17:00:33 +00:00
|
|
|
pub struct RateLimiterMiddleware<S> {
|
|
|
|
service: Rc<S>,
|
2022-10-03 18:26:08 +02:00
|
|
|
scope: Option<&'static str>,
|
2022-03-18 17:00:33 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
|
|
|
|
where
|
|
|
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
|
|
|
|
S::Future: 'static,
|
|
|
|
B: 'static,
|
|
|
|
{
|
|
|
|
type Response = ServiceResponse<EitherBody<B>>;
|
|
|
|
type Error = Error;
|
|
|
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
|
|
|
|
|
|
|
forward_ready!(service);
|
|
|
|
|
|
|
|
fn call(&self, req: ServiceRequest) -> Self::Future {
|
|
|
|
// A mis-configuration of the Actix App will result in a **runtime** failure, so the expect
|
|
|
|
// method description is important context for the developer.
|
2022-10-03 18:26:08 +02:00
|
|
|
let limiter = if let Some(scope) = self.scope {
|
|
|
|
let limiters = req.app_data::<web::Data<HashMap<&str, Limiter>>>().expect(
|
|
|
|
"web::Data<HashMap<Limiter>> should be set in app data for RateLimiter middleware",
|
|
|
|
);
|
|
|
|
limiters
|
|
|
|
.get(scope)
|
|
|
|
.unwrap_or_else(|| panic!("Unable to find defined limiter with scope: {}", scope))
|
|
|
|
.clone()
|
|
|
|
} else {
|
2022-10-04 16:38:47 +02:00
|
|
|
let limiter = req
|
2022-10-03 18:26:08 +02:00
|
|
|
.app_data::<web::Data<Limiter>>()
|
|
|
|
.expect("web::Data<Limiter> should be set in app data for RateLimiter middleware");
|
|
|
|
// Deref to get the Limiter
|
2022-10-04 16:38:47 +02:00
|
|
|
(***limiter).clone()
|
2022-10-03 18:26:08 +02:00
|
|
|
};
|
2022-03-18 17:00:33 +00:00
|
|
|
|
2022-09-11 01:02:54 +02:00
|
|
|
let key = (limiter.get_key_fn)(&req);
|
2022-03-18 17:00:33 +00:00
|
|
|
let service = Rc::clone(&self.service);
|
2022-03-20 00:40:34 +00:00
|
|
|
|
2022-03-18 17:00:33 +00:00
|
|
|
let key = match key {
|
|
|
|
Some(key) => key,
|
2022-07-31 03:03:43 +01:00
|
|
|
None => {
|
2022-09-11 01:02:54 +02:00
|
|
|
return Box::pin(async move {
|
|
|
|
service
|
|
|
|
.call(req)
|
|
|
|
.await
|
|
|
|
.map(ServiceResponse::map_into_left_body)
|
|
|
|
});
|
2022-07-31 03:03:43 +01:00
|
|
|
}
|
2022-03-18 17:00:33 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
Box::pin(async move {
|
|
|
|
let status = limiter.count(key.to_string()).await;
|
2022-03-20 00:40:34 +00:00
|
|
|
|
2022-08-28 21:49:14 +02:00
|
|
|
if let Err(err) = status {
|
|
|
|
match err {
|
|
|
|
LimitationError::LimitExceeded(_) => {
|
|
|
|
log::warn!("Rate limit exceed error for {}", key);
|
2022-03-20 00:40:34 +00:00
|
|
|
|
2022-08-28 21:49:14 +02:00
|
|
|
Ok(req.into_response(
|
|
|
|
HttpResponse::new(StatusCode::TOO_MANY_REQUESTS).map_into_right_body(),
|
|
|
|
))
|
|
|
|
}
|
|
|
|
LimitationError::Client(e) => {
|
|
|
|
log::error!("Client request failed, redis error: {}", e);
|
|
|
|
|
|
|
|
Ok(req.into_response(
|
|
|
|
HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
|
|
|
|
.map_into_right_body(),
|
|
|
|
))
|
|
|
|
}
|
|
|
|
_ => {
|
|
|
|
log::error!("Count failed: {}", err);
|
|
|
|
|
|
|
|
Ok(req.into_response(
|
|
|
|
HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
|
|
|
|
.map_into_right_body(),
|
|
|
|
))
|
|
|
|
}
|
|
|
|
}
|
2022-03-18 17:00:33 +00:00
|
|
|
} else {
|
|
|
|
service
|
|
|
|
.call(req)
|
|
|
|
.await
|
|
|
|
.map(ServiceResponse::map_into_left_body)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|