mirror of
https://github.com/actix/actix-extras.git
synced 2025-01-22 23:05:56 +01:00
httpauth: Refactor out Mutex (#69)
This commit is contained in:
parent
1d32248844
commit
70df190e0b
@ -1,15 +1,16 @@
|
||||
//! HTTP Authentication middleware.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_service::{Service, Transform};
|
||||
use actix_web::dev::{ServiceRequest, ServiceResponse};
|
||||
use actix_web::Error;
|
||||
use futures_util::future::{self, FutureExt, LocalBoxFuture, TryFutureExt};
|
||||
use futures_util::lock::Mutex;
|
||||
use futures_util::task::{Context, Poll};
|
||||
|
||||
use crate::extractors::{basic, bearer, AuthExtractor};
|
||||
@ -142,7 +143,7 @@ where
|
||||
|
||||
fn new_transform(&self, service: S) -> Self::Future {
|
||||
future::ok(AuthenticationMiddleware {
|
||||
service: Arc::new(Mutex::new(service)),
|
||||
service: Rc::new(RefCell::new(service)),
|
||||
process_fn: self.process_fn.clone(),
|
||||
_extractor: PhantomData,
|
||||
})
|
||||
@ -154,7 +155,7 @@ pub struct AuthenticationMiddleware<S, F, T>
|
||||
where
|
||||
T: AuthExtractor,
|
||||
{
|
||||
service: Arc<Mutex<S>>,
|
||||
service: Rc<RefCell<S>>,
|
||||
process_fn: Arc<F>,
|
||||
_extractor: PhantomData<T>,
|
||||
}
|
||||
@ -181,21 +182,22 @@ where
|
||||
ctx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.service
|
||||
.try_lock()
|
||||
.expect("AuthenticationMiddleware was called already")
|
||||
.borrow_mut()
|
||||
.poll_ready(ctx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Self::Request) -> Self::Future {
|
||||
let process_fn = self.process_fn.clone();
|
||||
// Note: cloning the mutex, not the service itself
|
||||
let inner = self.service.clone();
|
||||
|
||||
let service = Rc::clone(&self.service);
|
||||
|
||||
async move {
|
||||
let (req, credentials) = Extract::<T>::new(req).await?;
|
||||
let req = process_fn(req, credentials).await?;
|
||||
let mut service = inner.lock().await;
|
||||
service.call(req).await
|
||||
// It is important that `borrow_mut()` and `.await` are on
|
||||
// separate lines, or else a panic occurs.
|
||||
let fut = service.borrow_mut().call(req);
|
||||
fut.await
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
@ -246,3 +248,73 @@ where
|
||||
Poll::Ready(Ok((req, credentials)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use actix_web::test::TestRequest;
|
||||
use actix_service::{into_service, Service};
|
||||
use futures_util::join;
|
||||
use crate::extractors::bearer::BearerAuth;
|
||||
use actix_web::error;
|
||||
|
||||
|
||||
/// This is a test for https://github.com/actix/actix-extras/issues/10
|
||||
#[actix_rt::test]
|
||||
async fn test_middleware_panic() {
|
||||
let mut middleware = AuthenticationMiddleware {
|
||||
service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| {
|
||||
async move {
|
||||
actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await;
|
||||
Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
|
||||
}}))),
|
||||
process_fn: Arc::new(|req, _: BearerAuth| async {
|
||||
Ok(req) }),
|
||||
_extractor: PhantomData,
|
||||
};
|
||||
|
||||
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();
|
||||
|
||||
let f = middleware.call(req);
|
||||
|
||||
let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx) );
|
||||
|
||||
|
||||
assert!(join!(f, res).0.is_err());
|
||||
}
|
||||
|
||||
/// This is a test for https://github.com/actix/actix-extras/issues/10
|
||||
#[actix_rt::test]
|
||||
async fn test_middleware_panic_several_orders() {
|
||||
let mut middleware = AuthenticationMiddleware {
|
||||
service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| {
|
||||
async move {
|
||||
actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await;
|
||||
Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
|
||||
}}))),
|
||||
process_fn: Arc::new(|req, _: BearerAuth| async {
|
||||
Ok(req) }),
|
||||
_extractor: PhantomData,
|
||||
};
|
||||
|
||||
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();
|
||||
|
||||
let f1 = middleware.call(req);
|
||||
|
||||
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();
|
||||
|
||||
let f2 = middleware.call(req);
|
||||
|
||||
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();
|
||||
|
||||
let f3 = middleware.call(req);
|
||||
|
||||
let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx));
|
||||
|
||||
let result = join!(f1, f2, f3, res);
|
||||
|
||||
assert!(result.0.is_err());
|
||||
assert!(result.1.is_err());
|
||||
assert!(result.2.is_err());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user