1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use std::{future::Future, pin::Pin, rc::Rc};
use actix_session::SessionExt as _;
use actix_utils::future::{ok, Ready};
use actix_web::{
body::EitherBody,
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
http::StatusCode,
web, Error, HttpResponse,
};
use crate::{Error as LimitationError, Limiter};
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct RateLimiter;
impl<S, B> Transform<S, ServiceRequest> for RateLimiter
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Transform = RateLimiterMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(RateLimiterMiddleware {
service: Rc::new(service),
})
}
}
#[derive(Debug)]
pub struct RateLimiterMiddleware<S> {
service: Rc<S>,
}
impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let limiter = req
.app_data::<web::Data<Limiter>>()
.expect("web::Data<Limiter> should be set in app data for RateLimiter middleware")
.clone();
let key = req.get_session().get(&limiter.session_key).unwrap_or(None);
let service = Rc::clone(&self.service);
let key = match key {
Some(key) => key,
None => {
let fallback = req.cookie(&limiter.cookie_name).map(|c| c.to_string());
match fallback {
Some(key) => key,
None => {
return Box::pin(async move {
service
.call(req)
.await
.map(ServiceResponse::map_into_left_body)
});
}
}
}
};
Box::pin(async move {
let status = limiter.count(key.to_string()).await;
if let Err(err) = status {
match err {
LimitationError::LimitExceeded(_) => {
log::warn!("Rate limit exceed error for {}", key);
Ok(req.into_response(
HttpResponse::new(StatusCode::TOO_MANY_REQUESTS).map_into_right_body(),
))
}
LimitationError::Client(e) => {
log::error!("Client request failed, redis error: {}", e);
Ok(req.into_response(
HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
.map_into_right_body(),
))
}
_ => {
log::error!("Count failed: {}", err);
Ok(req.into_response(
HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
.map_into_right_body(),
))
}
}
} else {
service
.call(req)
.await
.map(ServiceResponse::map_into_left_body)
}
})
}
}