1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-12-04 20:11:55 +01:00

use forward_ready for service definitions

This commit is contained in:
Rob Ede 2021-03-22 05:18:59 +00:00
parent b0854ed144
commit 2254a429d4
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
9 changed files with 38 additions and 68 deletions

View File

@ -20,6 +20,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.4", default-features = false } actix-web = { version = "4.0.0-beta.4", default-features = false }
actix-service = "2.0.0-beta.5"
derive_more = "0.99.5" derive_more = "0.99.5"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
log = "0.4" log = "0.4"

View File

@ -1,8 +1,4 @@
use std::{ use std::{convert::TryInto, rc::Rc};
convert::TryInto,
rc::Rc,
task::{Context, Poll},
};
use actix_web::{ use actix_web::{
dev::{Service, ServiceRequest, ServiceResponse}, dev::{Service, ServiceRequest, ServiceResponse},
@ -131,9 +127,7 @@ where
type Error = Error; type Error = Error;
type Future = CorsMiddlewareServiceFuture<B>; type Future = CorsMiddlewareServiceFuture<B>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { actix_service::forward_ready!(service);
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
if self.inner.preflight && req.method() == Method::OPTIONS { if self.inner.preflight && req.method() == Method::OPTIONS {

View File

@ -18,7 +18,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.4", default-features = false, features = ["secure-cookies"] } actix-web = { version = "4.0.0-beta.4", default-features = false, features = ["secure-cookies"] }
actix-service = "2.0.0-beta.5" actix-service = "2.0.0-beta.5"
futures-util = { version = "0.3", default-features = false } futures-util = { version = "0.3.7", default-features = false }
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = { version = "0.2.7", default-features = false, features = ["std"] } time = { version = "0.2.7", default-features = false, features = ["std"] }

View File

@ -12,7 +12,7 @@
//! To access current request identity //! To access current request identity
//! [**Identity**](struct.Identity.html) extractor should be used. //! [**Identity**](struct.Identity.html) extractor should be used.
//! //!
//! ```rust //! ```
//! use actix_web::*; //! use actix_web::*;
//! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; //! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService};
//! //!
@ -49,11 +49,7 @@
#![deny(rust_2018_idioms)] #![deny(rust_2018_idioms)]
use std::cell::RefCell; use std::{future::Future, rc::Rc, time::SystemTime};
use std::future::Future;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::SystemTime;
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready};
@ -239,22 +235,22 @@ where
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(IdentityServiceMiddleware { ok(IdentityServiceMiddleware {
backend: self.backend.clone(), backend: self.backend.clone(),
service: Rc::new(RefCell::new(service)), service: Rc::new(service),
}) })
} }
} }
#[doc(hidden)] #[doc(hidden)]
pub struct IdentityServiceMiddleware<S, T> { pub struct IdentityServiceMiddleware<S, T> {
service: Rc<S>,
backend: Rc<T>, backend: Rc<T>,
service: Rc<RefCell<S>>,
} }
impl<S, T> Clone for IdentityServiceMiddleware<S, T> { impl<S, T> Clone for IdentityServiceMiddleware<S, T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
backend: self.backend.clone(), backend: Rc::clone(&self.backend),
service: self.service.clone(), service: Rc::clone(&self.service),
} }
} }
} }
@ -270,13 +266,11 @@ where
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { actix_service::forward_ready!(service);
self.service.borrow_mut().poll_ready(cx)
}
fn call(&self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
let srv = self.service.clone(); let srv = Rc::clone(&self.service);
let backend = self.backend.clone(); let backend = Rc::clone(&self.backend);
let fut = self.backend.from_request(&mut req); let fut = self.backend.from_request(&mut req);
async move { async move {
@ -285,9 +279,7 @@ where
req.extensions_mut() req.extensions_mut()
.insert(IdentityItem { id, changed: false }); .insert(IdentityItem { id, changed: false });
// https://github.com/actix/actix-web/issues/1263 let mut res = srv.call(req).await?;
let fut = srv.borrow_mut().call(req);
let mut res = fut.await?;
let id = res.request().extensions_mut().remove::<IdentityItem>(); let id = res.request().extensions_mut().remove::<IdentityItem>();
if let Some(id) = id { if let Some(id) = id {
@ -1132,12 +1124,10 @@ mod tests {
let srv = IdentityServiceMiddleware { let srv = IdentityServiceMiddleware {
backend: Rc::new(Ident), backend: Rc::new(Ident),
service: Rc::new(RefCell::new(into_service( service: Rc::new(into_service(|_: ServiceRequest| async move {
|_: ServiceRequest| async move {
actix_rt::time::sleep(std::time::Duration::from_secs(100)).await; actix_rt::time::sleep(std::time::Duration::from_secs(100)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error")) Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}, })),
))),
}; };
let srv2 = srv.clone(); let srv2 = srv.clone();

View File

@ -21,7 +21,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.4", default_features = false } actix-web = { version = "4.0.0-beta.4", default_features = false }
actix-rt = "2" actix-rt = "2"
futures-util = { version = "0.3.5", default-features = false } futures-util = { version = "0.3.7", default-features = false }
derive_more = "0.99" derive_more = "0.99"
prost = "0.7" prost = "0.7"

View File

@ -152,8 +152,8 @@ where
actix_service::forward_ready!(service); actix_service::forward_ready!(service);
fn call(&self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
let srv = self.service.clone(); let srv = Rc::clone(&self.service);
let inner = self.inner.clone(); let inner = Rc::clone(&self.inner);
Box::pin(async move { Box::pin(async move {
let state = inner.load(&req).await?; let state = inner.load(&req).await?;

View File

@ -1,8 +1,6 @@
//! Cookie based sessions. See docs for [`CookieSession`]. //! Cookie based sessions. See docs for [`CookieSession`].
use std::collections::HashMap; use std::{collections::HashMap, rc::Rc};
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
@ -326,9 +324,7 @@ where
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { actix_service::forward_ready!(service);
self.service.poll_ready(cx)
}
/// On first request, a new session cookie is returned in response, regardless /// On first request, a new session cookie is returned in response, regardless
/// of whether any session state is set. With subsequent requests, if the /// of whether any session state is set. With subsequent requests, if the

View File

@ -21,6 +21,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.4", default_features = false } actix-web = { version = "4.0.0-beta.4", default_features = false }
actix-service = "2.0.0-beta.5"
base64 = "0.13" base64 = "0.13"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }

View File

@ -1,11 +1,6 @@
//! HTTP Authentication middleware. //! HTTP Authentication middleware.
use std::cell::RefCell; use std::{future::Future, marker::PhantomData, pin::Pin, rc::Rc, sync::Arc};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use actix_web::{ use actix_web::{
dev::{Service, ServiceRequest, ServiceResponse, Transform}, dev::{Service, ServiceRequest, ServiceResponse, Transform},
@ -134,7 +129,7 @@ where
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
future::ok(AuthenticationMiddleware { future::ok(AuthenticationMiddleware {
service: Rc::new(RefCell::new(service)), service: Rc::new(service),
process_fn: self.process_fn.clone(), process_fn: self.process_fn.clone(),
_extractor: PhantomData, _extractor: PhantomData,
}) })
@ -146,7 +141,7 @@ pub struct AuthenticationMiddleware<S, F, T>
where where
T: AuthExtractor, T: AuthExtractor,
{ {
service: Rc<RefCell<S>>, service: Rc<S>,
process_fn: Arc<F>, process_fn: Arc<F>,
_extractor: PhantomData<T>, _extractor: PhantomData<T>,
} }
@ -163,9 +158,7 @@ where
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>; type Future = LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>;
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { actix_service::forward_ready!(service);
self.service.borrow_mut().poll_ready(ctx)
}
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
let process_fn = Arc::clone(&self.process_fn); let process_fn = Arc::clone(&self.process_fn);
@ -183,9 +176,8 @@ where
// TODO: alter to remove ? operator; an error response is required for downstream // TODO: alter to remove ? operator; an error response is required for downstream
// middleware to do their thing (eg. cors adding headers) // middleware to do their thing (eg. cors adding headers)
let req = process_fn(req, credentials).await?; let req = process_fn(req, credentials).await?;
// Ensure `borrow_mut()` and `.await` are on separate lines or else a panic occurs.
let fut = service.borrow_mut().call(req); service.call(req).await
fut.await
} }
.boxed_local() .boxed_local()
} }
@ -252,12 +244,10 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_middleware_panic() { async fn test_middleware_panic() {
let middleware = AuthenticationMiddleware { let middleware = AuthenticationMiddleware {
service: Rc::new(RefCell::new(into_service( service: Rc::new(into_service(|_: ServiceRequest| async move {
|_: ServiceRequest| async move {
actix_rt::time::sleep(std::time::Duration::from_secs(1)).await; actix_rt::time::sleep(std::time::Duration::from_secs(1)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error")) Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}, })),
))),
process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }), process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }),
_extractor: PhantomData, _extractor: PhantomData,
}; };
@ -277,12 +267,10 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_middleware_panic_several_orders() { async fn test_middleware_panic_several_orders() {
let middleware = AuthenticationMiddleware { let middleware = AuthenticationMiddleware {
service: Rc::new(RefCell::new(into_service( service: Rc::new(into_service(|_: ServiceRequest| async move {
|_: ServiceRequest| async move {
actix_rt::time::sleep(std::time::Duration::from_secs(1)).await; actix_rt::time::sleep(std::time::Duration::from_secs(1)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error")) Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}, })),
))),
process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }), process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }),
_extractor: PhantomData, _extractor: PhantomData,
}; };