diff --git a/actix-web-httpauth/CHANGES.md b/actix-web-httpauth/CHANGES.md index 9d216f8dd..2b6eef1c2 100644 --- a/actix-web-httpauth/CHANGES.md +++ b/actix-web-httpauth/CHANGES.md @@ -1,6 +1,9 @@ # Changes ## Unreleased - 2021-xx-xx +* impl `AuthExtractor` trait for `Option` and `Result`. [#205] + +[#205]: https://github.com/actix/actix-extras/pull/205 ## 0.6.0-beta.3 - 2021-10-21 diff --git a/actix-web-httpauth/Cargo.toml b/actix-web-httpauth/Cargo.toml index 33c3eee5b..06b16f172 100644 --- a/actix-web-httpauth/Cargo.toml +++ b/actix-web-httpauth/Cargo.toml @@ -22,6 +22,7 @@ actix-web = { version = "4.0.0-beta.10", default_features = false } actix-service = "2.0.0" base64 = "0.13" futures-util = { version = "0.3.7", default-features = false } +pin-project-lite = "0.2.7" [dev-dependencies] actix-cors = "0.6.0-beta.3" diff --git a/actix-web-httpauth/src/extractors/mod.rs b/actix-web-httpauth/src/extractors/mod.rs index 1fe9b7d9a..497cc983d 100644 --- a/actix-web-httpauth/src/extractors/mod.rs +++ b/actix-web-httpauth/src/extractors/mod.rs @@ -1,8 +1,15 @@ //! Type-safe authentication information extractors +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + use actix_web::dev::ServiceRequest; use actix_web::Error; -use std::future::Future; +use futures_util::ready; +use pin_project_lite::pin_project; pub mod basic; pub mod bearer; @@ -31,3 +38,66 @@ pub trait AuthExtractor: Sized { /// Parse the authentication credentials from the actix' `ServiceRequest`. fn from_service_request(req: &ServiceRequest) -> Self::Future; } + +impl AuthExtractor for Option { + type Error = T::Error; + + type Future = AuthExtractorOptFut; + + fn from_service_request(req: &ServiceRequest) -> Self::Future { + let fut = T::from_service_request(req); + AuthExtractorOptFut { fut } + } +} + +pin_project! { + #[doc(hidden)] + pub struct AuthExtractorOptFut { + #[pin] + fut: F + } +} + +impl Future for AuthExtractorOptFut +where + F: Future>, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = ready!(self.project().fut.poll(cx)); + Poll::Ready(Ok(res.ok())) + } +} + +impl AuthExtractor for Result { + type Error = T::Error; + + type Future = AuthExtractorResFut; + + fn from_service_request(req: &ServiceRequest) -> Self::Future { + AuthExtractorResFut { + fut: T::from_service_request(req), + } + } +} + +pin_project! { + #[doc(hidden)] + pub struct AuthExtractorResFut { + #[pin] + fut: F + } +} + +impl Future for AuthExtractorResFut +where + F: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = ready!(self.project().fut.poll(cx)); + Poll::Ready(Ok(res)) + } +} diff --git a/actix-web-httpauth/src/middleware.rs b/actix-web-httpauth/src/middleware.rs index 983b9ff0a..828d3401d 100644 --- a/actix-web-httpauth/src/middleware.rs +++ b/actix-web-httpauth/src/middleware.rs @@ -247,8 +247,8 @@ mod tests { use super::*; use crate::extractors::bearer::BearerAuth; use actix_service::{into_service, Service}; - use actix_web::error; use actix_web::test::TestRequest; + use actix_web::{error, HttpResponse}; /// This is a test for https://github.com/actix/actix-extras/issues/10 #[actix_rt::test] @@ -309,4 +309,54 @@ mod tests { assert!(f2.is_err()); assert!(f3.is_err()); } + + #[actix_rt::test] + async fn test_middleware_opt_extractor() { + let middleware = AuthenticationMiddleware { + service: Rc::new(into_service(|req: ServiceRequest| async move { + Ok::(req.into_response(HttpResponse::Ok().finish())) + })), + process_fn: Arc::new(|req, auth: Option| { + assert!(auth.is_none()); + async { Ok(req) } + }), + _extractor: PhantomData, + }; + + let req = TestRequest::get() + .append_header(("Authorization996", "Bearer 1")) + .to_srv_request(); + + let f = middleware.call(req).await; + + let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await; + + assert!(f.is_ok()); + } + + #[actix_rt::test] + async fn test_middleware_res_extractor() { + let middleware = AuthenticationMiddleware { + service: Rc::new(into_service(|req: ServiceRequest| async move { + Ok::(req.into_response(HttpResponse::Ok().finish())) + })), + process_fn: Arc::new( + |req, auth: Result::Error>| { + assert!(auth.is_err()); + async { Ok(req) } + }, + ), + _extractor: PhantomData, + }; + + let req = TestRequest::get() + .append_header(("Authorization", "BearerLOL")) + .to_srv_request(); + + let f = middleware.call(req).await; + + let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await; + + assert!(f.is_ok()); + } }