From 70df190e0beed29f1b6939be7bae52f9d7104387 Mon Sep 17 00:00:00 2001 From: Boyd Johnson Date: Thu, 11 Jun 2020 10:10:18 -0500 Subject: [PATCH] httpauth: Refactor out Mutex (#69) --- actix-web-httpauth/src/middleware.rs | 90 +++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 9 deletions(-) diff --git a/actix-web-httpauth/src/middleware.rs b/actix-web-httpauth/src/middleware.rs index 15eb25885..31f08fcf6 100644 --- a/actix-web-httpauth/src/middleware.rs +++ b/actix-web-httpauth/src/middleware.rs @@ -1,15 +1,16 @@ //! HTTP Authentication middleware. +use std::cell::RefCell; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; +use std::rc::Rc; use std::sync::Arc; use actix_service::{Service, Transform}; use actix_web::dev::{ServiceRequest, ServiceResponse}; use actix_web::Error; use futures_util::future::{self, FutureExt, LocalBoxFuture, TryFutureExt}; -use futures_util::lock::Mutex; use futures_util::task::{Context, Poll}; use crate::extractors::{basic, bearer, AuthExtractor}; @@ -142,7 +143,7 @@ where fn new_transform(&self, service: S) -> Self::Future { future::ok(AuthenticationMiddleware { - service: Arc::new(Mutex::new(service)), + service: Rc::new(RefCell::new(service)), process_fn: self.process_fn.clone(), _extractor: PhantomData, }) @@ -154,7 +155,7 @@ pub struct AuthenticationMiddleware where T: AuthExtractor, { - service: Arc>, + service: Rc>, process_fn: Arc, _extractor: PhantomData, } @@ -181,21 +182,22 @@ where ctx: &mut Context<'_>, ) -> Poll> { self.service - .try_lock() - .expect("AuthenticationMiddleware was called already") + .borrow_mut() .poll_ready(ctx) } fn call(&mut self, req: Self::Request) -> Self::Future { let process_fn = self.process_fn.clone(); - // Note: cloning the mutex, not the service itself - let inner = self.service.clone(); + + let service = Rc::clone(&self.service); async move { let (req, credentials) = Extract::::new(req).await?; let req = process_fn(req, credentials).await?; - let mut service = inner.lock().await; - service.call(req).await + // It is important that `borrow_mut()` and `.await` are on + // separate lines, or else a panic occurs. + let fut = service.borrow_mut().call(req); + fut.await } .boxed_local() } @@ -246,3 +248,73 @@ where Poll::Ready(Ok((req, credentials))) } } + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::test::TestRequest; + use actix_service::{into_service, Service}; + use futures_util::join; + use crate::extractors::bearer::BearerAuth; + use actix_web::error; + + + /// This is a test for https://github.com/actix/actix-extras/issues/10 + #[actix_rt::test] + async fn test_middleware_panic() { + let mut middleware = AuthenticationMiddleware { + service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| { + async move { + actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await; + Err::(error::ErrorBadRequest("error")) + }}))), + process_fn: Arc::new(|req, _: BearerAuth| async { + Ok(req) }), + _extractor: PhantomData, + }; + + let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); + + let f = middleware.call(req); + + let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx) ); + + + assert!(join!(f, res).0.is_err()); + } + + /// This is a test for https://github.com/actix/actix-extras/issues/10 + #[actix_rt::test] + async fn test_middleware_panic_several_orders() { + let mut middleware = AuthenticationMiddleware { + service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| { + async move { + actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await; + Err::(error::ErrorBadRequest("error")) + }}))), + process_fn: Arc::new(|req, _: BearerAuth| async { + Ok(req) }), + _extractor: PhantomData, + }; + + let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); + + let f1 = middleware.call(req); + + let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); + + let f2 = middleware.call(req); + + let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); + + let f3 = middleware.call(req); + + let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)); + + let result = join!(f1, f2, f3, res); + + assert!(result.0.is_err()); + assert!(result.1.is_err()); + assert!(result.2.is_err()); + } +}