1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-23 23:51:06 +01:00

Update dependencies (Tokio 1.0) (#144)

This commit is contained in:
Andrey Kutejko 2021-03-21 23:50:26 +01:00 committed by GitHub
parent 86ff1302ad
commit ca85f6b245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 429 additions and 503 deletions

View File

@ -13,7 +13,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
version: version:
- 1.42.0 - 1.46.0
name: ${{ matrix.version }} - x86_64-unknown-linux-gnu name: ${{ matrix.version }} - x86_64-unknown-linux-gnu
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -1,7 +1,7 @@
# Changes # Changes
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
* Minimum supported Rust version (MSRV) is now 1.46.0.
## 0.5.4 - 2020-12-31 ## 0.5.4 - 2020-12-31

View File

@ -19,7 +19,7 @@ name = "actix_cors"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "3.0.0", default-features = false } actix-web = { version = "4.0.0-beta.4", default-features = false }
derive_more = "0.99.5" derive_more = "0.99.5"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
log = "0.4" log = "0.4"
@ -27,7 +27,7 @@ once_cell = "1"
tinyvec = { version = "1", features = ["alloc"] } tinyvec = { version = "1", features = ["alloc"] }
[dev-dependencies] [dev-dependencies]
actix-service = "1" actix-service = "2.0.0-beta.5"
actix-rt = "1" actix-rt = "2"
pretty_env_logger = "0.4" pretty_env_logger = "0.4"
regex = "1.4" regex = "1.4"

View File

@ -12,4 +12,4 @@
- [API Documentation](https://docs.rs/actix-cors) - [API Documentation](https://docs.rs/actix-cors)
- [Example Project](https://github.com/actix/examples/tree/master/security/web-cors) - [Example Project](https://github.com/actix/examples/tree/master/security/web-cors)
- [Chat on Gitter](https://gitter.im/actix/actix-web) - [Chat on Gitter](https://gitter.im/actix/actix-web)
- Minimum Supported Rust Version (MSRV): 1.42.0 - Minimum Supported Rust Version (MSRV): 1.46.0

View File

@ -145,7 +145,7 @@ impl Cors {
match TryInto::<Uri>::try_into(origin) { match TryInto::<Uri>::try_into(origin) {
Ok(_) if origin == "*" => { Ok(_) if origin == "*" => {
error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`."); error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
self.error = Some(Either::B(CorsError::WildcardOrigin)); self.error = Some(Either::Right(CorsError::WildcardOrigin));
} }
Ok(_) => { Ok(_) => {
@ -162,7 +162,7 @@ impl Cors {
} }
Err(err) => { Err(err) => {
self.error = Some(Either::A(err.into())); self.error = Some(Either::Left(err.into()));
} }
} }
} }
@ -224,7 +224,7 @@ impl Cors {
} }
Err(err) => { Err(err) => {
self.error = Some(Either::A(err.into())); self.error = Some(Either::Left(err.into()));
break; break;
} }
} }
@ -266,7 +266,7 @@ impl Cors {
} }
} }
Err(err) => self.error = Some(Either::A(err.into())), Err(err) => self.error = Some(Either::Left(err.into())),
} }
} }
@ -303,7 +303,7 @@ impl Cors {
} }
} }
Err(err) => { Err(err) => {
self.error = Some(Either::A(err.into())); self.error = Some(Either::Left(err.into()));
break; break;
} }
} }
@ -351,7 +351,7 @@ impl Cors {
} }
} }
Err(err) => { Err(err) => {
self.error = Some(Either::A(err.into())); self.error = Some(Either::Left(err.into()));
break; break;
} }
} }
@ -483,13 +483,12 @@ impl Default for Cors {
} }
} }
impl<S, B> Transform<S> for Cors impl<S, B> Transform<S, ServiceRequest> for Cors
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
@ -499,8 +498,8 @@ where
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
if let Some(ref err) = self.error { if let Some(ref err) = self.error {
match err { match err {
Either::A(err) => error!("{}", err), Either::Left(err) => error!("{}", err),
Either::B(err) => error!("{}", err), Either::Right(err) => error!("{}", err),
} }
return future::err(()); return future::err(());
@ -592,15 +591,16 @@ mod test {
#[actix_rt::test] #[actix_rt::test]
async fn restrictive_defaults() { async fn restrictive_defaults() {
let mut cors = Cors::default() let cors = Cors::default()
.new_transform(test::ok_service()) .new_transform(test::ok_service())
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
} }

View File

@ -235,8 +235,8 @@ mod test {
.unwrap(); .unwrap();
let req = TestRequest::get() let req = TestRequest::get()
.header(header::ORIGIN, "https://www.unknown.com") .insert_header((header::ORIGIN, "https://www.unknown.com"))
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "DNT") .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "DNT"))
.to_srv_request(); .to_srv_request();
assert!(cors.inner.validate_origin(req.head()).is_err()); assert!(cors.inner.validate_origin(req.head()).is_err());
@ -257,34 +257,37 @@ mod test {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.method(Method::OPTIONS) .method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") .insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed"))
.to_srv_request(); .to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "put"))
.to_srv_request(); .to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); assert!(cors.inner.validate_allowed_headers(req.head()).is_ok());
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .method(Method::OPTIONS)
.header( .insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.insert_header((
header::ACCESS_CONTROL_REQUEST_HEADERS, header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT", "AUTHORIZATION,ACCEPT",
) ))
.method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"*"[..]), Some(&b"*"[..]),
resp.headers() resp.headers()
@ -319,16 +322,17 @@ mod test {
Rc::get_mut(&mut cors.inner).unwrap().preflight = false; Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .method(Method::OPTIONS)
.header( .insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.insert_header((
header::ACCESS_CONTROL_REQUEST_HEADERS, header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT", "AUTHORIZATION,ACCEPT",
) ))
.method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
} }

View File

@ -42,34 +42,34 @@ impl<S> CorsMiddleware<S> {
let mut res = HttpResponse::Ok(); let mut res = HttpResponse::Ok();
if let Some(origin) = inner.access_control_allow_origin(req.head()) { if let Some(origin) = inner.access_control_allow_origin(req.head()) {
res.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); res.insert_header((header::ACCESS_CONTROL_ALLOW_ORIGIN, origin));
} }
if let Some(ref allowed_methods) = inner.allowed_methods_baked { if let Some(ref allowed_methods) = inner.allowed_methods_baked {
res.header( res.insert_header((
header::ACCESS_CONTROL_ALLOW_METHODS, header::ACCESS_CONTROL_ALLOW_METHODS,
allowed_methods.clone(), allowed_methods.clone(),
); ));
} }
if let Some(ref headers) = inner.allowed_headers_baked { if let Some(ref headers) = inner.allowed_headers_baked {
res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()); res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
} else if let Some(headers) = } else if let Some(headers) =
req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS)
{ {
// all headers allowed, return // all headers allowed, return
res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()); res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
} }
if inner.supports_credentials { if inner.supports_credentials {
res.header( res.insert_header((
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"), HeaderValue::from_static("true"),
); ));
} }
if let Some(max_age) = inner.max_age { if let Some(max_age) = inner.max_age {
res.header(header::ACCESS_CONTROL_MAX_AGE, max_age.to_string()); res.insert_header((header::ACCESS_CONTROL_MAX_AGE, max_age.to_string()));
} }
let res = res.finish(); let res = res.finish();
@ -121,22 +121,21 @@ type CorsMiddlewareServiceFuture<B> = Either<
LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>, LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>,
>; >;
impl<S, B> Service for CorsMiddleware<S> impl<S, B> Service<ServiceRequest> for CorsMiddleware<S>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = CorsMiddlewareServiceFuture<B>; type Future = CorsMiddlewareServiceFuture<B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx) self.service.poll_ready(cx)
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
if self.inner.preflight && req.method() == Method::OPTIONS { if self.inner.preflight && req.method() == Method::OPTIONS {
let inner = Rc::clone(&self.inner); let inner = Rc::clone(&self.inner);
let res = Self::handle_preflight(&inner, req); let res = Self::handle_preflight(&inner, req);
@ -187,7 +186,7 @@ mod tests {
// Tests case where allowed_origins is All but there are validate functions to run incase. // Tests case where allowed_origins is All but there are validate functions to run incase.
// In this case, origins are only allowed when the DNT header is sent. // In this case, origins are only allowed when the DNT header is sent.
let mut cors = Cors::default() let cors = Cors::default()
.allow_any_origin() .allow_any_origin()
.allowed_origin_fn(|origin, req_head| { .allowed_origin_fn(|origin, req_head| {
assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap()); assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap());
@ -199,7 +198,7 @@ mod tests {
.unwrap(); .unwrap();
let req = TestRequest::get() let req = TestRequest::get()
.header(header::ORIGIN, "http://example.com") .insert_header((header::ORIGIN, "http://example.com"))
.to_srv_request(); .to_srv_request();
let res = cors.call(req).await.unwrap(); let res = cors.call(req).await.unwrap();
assert_eq!( assert_eq!(
@ -210,8 +209,8 @@ mod tests {
); );
let req = TestRequest::get() let req = TestRequest::get()
.header(header::ORIGIN, "http://example.com") .insert_header((header::ORIGIN, "http://example.com"))
.header(header::DNT, "1") .insert_header((header::DNT, "1"))
.to_srv_request(); .to_srv_request();
let res = cors.call(req).await.unwrap(); let res = cors.call(req).await.unwrap();
assert_eq!( assert_eq!(

View File

@ -26,7 +26,7 @@ async fn test_wildcard_origin() {
#[actix_rt::test] #[actix_rt::test]
async fn test_not_allowed_origin_fn() { async fn test_not_allowed_origin_fn() {
let mut cors = Cors::default() let cors = Cors::default()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allowed_origin_fn(|origin, req| { .allowed_origin_fn(|origin, req| {
assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap());
@ -42,11 +42,11 @@ async fn test_not_allowed_origin_fn() {
.unwrap(); .unwrap();
{ {
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://www.example.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://www.example.com"[..]), Some(&b"https://www.example.com"[..]),
@ -57,11 +57,11 @@ async fn test_not_allowed_origin_fn() {
} }
{ {
let req = TestRequest::with_header("Origin", "https://www.known.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://www.known.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
None, None,
@ -72,7 +72,7 @@ async fn test_not_allowed_origin_fn() {
#[actix_rt::test] #[actix_rt::test]
async fn test_allowed_origin_fn() { async fn test_allowed_origin_fn() {
let mut cors = Cors::default() let cors = Cors::default()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allowed_origin_fn(|origin, req| { .allowed_origin_fn(|origin, req| {
assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap());
@ -87,11 +87,11 @@ async fn test_allowed_origin_fn() {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://www.example.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
"https://www.example.com", "https://www.example.com",
@ -101,11 +101,11 @@ async fn test_allowed_origin_fn() {
.unwrap() .unwrap()
); );
let req = TestRequest::with_header("Origin", "https://www.unknown.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://www.unknown.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://www.unknown.com"[..]), Some(&b"https://www.unknown.com"[..]),
@ -119,7 +119,7 @@ async fn test_allowed_origin_fn() {
async fn test_allowed_origin_fn_with_environment() { async fn test_allowed_origin_fn_with_environment() {
let regex = Regex::new("https:.+\\.unknown\\.com").unwrap(); let regex = Regex::new("https:.+\\.unknown\\.com").unwrap();
let mut cors = Cors::default() let cors = Cors::default()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allowed_origin_fn(move |origin, req| { .allowed_origin_fn(move |origin, req| {
assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap());
@ -134,11 +134,11 @@ async fn test_allowed_origin_fn_with_environment() {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://www.example.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
"https://www.example.com", "https://www.example.com",
@ -148,11 +148,11 @@ async fn test_allowed_origin_fn_with_environment() {
.unwrap() .unwrap()
); );
let req = TestRequest::with_header("Origin", "https://www.unknown.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://www.unknown.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://www.unknown.com"[..]), Some(&b"https://www.unknown.com"[..]),
@ -164,7 +164,7 @@ async fn test_allowed_origin_fn_with_environment() {
#[actix_rt::test] #[actix_rt::test]
async fn test_multiple_origins_preflight() { async fn test_multiple_origins_preflight() {
let mut cors = Cors::default() let cors = Cors::default()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
@ -172,12 +172,13 @@ async fn test_multiple_origins_preflight() {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::default()
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .insert_header(("Origin", "https://example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET"))
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://example.com"[..]), Some(&b"https://example.com"[..]),
resp.headers() resp.headers()
@ -185,12 +186,13 @@ async fn test_multiple_origins_preflight() {
.map(HeaderValue::as_bytes) .map(HeaderValue::as_bytes)
); );
let req = TestRequest::with_header("Origin", "https://example.org") let req = TestRequest::default()
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .insert_header(("Origin", "https://example.org"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET"))
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://example.org"[..]), Some(&b"https://example.org"[..]),
resp.headers() resp.headers()
@ -201,7 +203,7 @@ async fn test_multiple_origins_preflight() {
#[actix_rt::test] #[actix_rt::test]
async fn test_multiple_origins() { async fn test_multiple_origins() {
let mut cors = Cors::default() let cors = Cors::default()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
@ -209,11 +211,11 @@ async fn test_multiple_origins() {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://example.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://example.com"[..]), Some(&b"https://example.com"[..]),
resp.headers() resp.headers()
@ -221,11 +223,11 @@ async fn test_multiple_origins() {
.map(HeaderValue::as_bytes) .map(HeaderValue::as_bytes)
); );
let req = TestRequest::with_header("Origin", "https://example.org") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://example.org"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://example.org"[..]), Some(&b"https://example.org"[..]),
resp.headers() resp.headers()
@ -237,7 +239,7 @@ async fn test_multiple_origins() {
#[actix_rt::test] #[actix_rt::test]
async fn test_response() { async fn test_response() {
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::default() let cors = Cors::default()
.allow_any_origin() .allow_any_origin()
.send_wildcard() .send_wildcard()
.disable_preflight() .disable_preflight()
@ -250,10 +252,11 @@ async fn test_response() {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"*"[..]), Some(&b"*"[..]),
resp.headers() resp.headers()
@ -283,7 +286,7 @@ async fn test_response() {
} }
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::default() let cors = Cors::default()
.allow_any_origin() .allow_any_origin()
.send_wildcard() .send_wildcard()
.disable_preflight() .disable_preflight()
@ -294,22 +297,25 @@ async fn test_response() {
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.new_transform(fn_service(|req: ServiceRequest| { .new_transform(fn_service(|req: ServiceRequest| {
ok(req.into_response({ ok(req.into_response({
HttpResponse::Ok().header(header::VARY, "Accept").finish() HttpResponse::Ok()
.insert_header((header::VARY, "Accept"))
.finish()
})) }))
})) }))
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"Accept, Origin"[..]), Some(&b"Accept, Origin"[..]),
resp.headers().get(header::VARY).map(HeaderValue::as_bytes) resp.headers().get(header::VARY).map(HeaderValue::as_bytes)
); );
let mut cors = Cors::default() let cors = Cors::default()
.disable_vary_header() .disable_vary_header()
.allowed_methods(vec!["POST"]) .allowed_methods(vec!["POST"])
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
@ -318,11 +324,12 @@ async fn test_response() {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS) .method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
let origins_str = resp let origins_str = resp
.headers() .headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
@ -332,39 +339,40 @@ async fn test_response() {
#[actix_rt::test] #[actix_rt::test]
async fn test_validate_origin() { async fn test_validate_origin() {
let mut cors = Cors::default() let cors = Cors::default()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.new_transform(test::ok_service()) .new_transform(test::ok_service())
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::get()
.method(Method::GET) .insert_header(("Origin", "https://www.example.com"))
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[actix_rt::test]
async fn test_no_origin_response() { async fn test_no_origin_response() {
let mut cors = Cors::permissive() let cors = Cors::permissive()
.disable_preflight() .disable_preflight()
.new_transform(test::ok_service()) .new_transform(test::ok_service())
.await .await
.unwrap(); .unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request(); let req = TestRequest::default().method(Method::GET).to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert!(resp assert!(resp
.headers() .headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none()); .is_none());
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
Some(&b"https://www.example.com"[..]), Some(&b"https://www.example.com"[..]),
resp.headers() resp.headers()
@ -375,21 +383,22 @@ async fn test_no_origin_response() {
#[actix_rt::test] #[actix_rt::test]
async fn validate_origin_allows_all_origins() { async fn validate_origin_allows_all_origins() {
let mut cors = Cors::permissive() let cors = Cors::permissive()
.new_transform(test::ok_service()) .new_transform(test::ok_service())
.await .await
.unwrap(); .unwrap();
let req = let req = TestRequest::default()
TestRequest::with_header("Origin", "https://www.example.com").to_srv_request(); .insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[actix_rt::test]
async fn test_allow_any_origin_any_method_any_header() { async fn test_allow_any_origin_any_method_any_header() {
let mut cors = Cors::default() let cors = Cors::default()
.allow_any_origin() .allow_any_origin()
.allow_any_method() .allow_any_method()
.allow_any_header() .allow_any_header()
@ -397,12 +406,13 @@ async fn test_allow_any_origin_any_method_any_header() {
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") let req = TestRequest::default()
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type") .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.header(header::ORIGIN, "https://www.example.com") .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type"))
.insert_header((header::ORIGIN, "https://www.example.com"))
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }

View File

@ -1,6 +1,7 @@
# Changes # Changes
## Unreleased - 2020-xx-xx ## Unreleased - 2020-xx-xx
* Minimum supported Rust version (MSRV) is now 1.46.0.
## 0.3.1 - 2020-09-20 ## 0.3.1 - 2020-09-20

View File

@ -16,13 +16,13 @@ name = "actix_identity"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "3.0.0", default-features = false, features = ["secure-cookies"] } actix-web = { version = "4.0.0-beta.4", default-features = false, features = ["secure-cookies"] }
actix-service = "1.0.6" actix-service = "2.0.0-beta.5"
futures-util = { version = "0.3.4", default-features = false } futures-util = { version = "0.3", default-features = false }
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = { version = "0.2.7", default-features = false, features = ["std"] } time = { version = "0.2.7", default-features = false, features = ["std"] }
[dev-dependencies] [dev-dependencies]
actix-rt = "1.1.1" actix-rt = "2"
actix-http = "2.0.0" actix-http = "3.0.0-beta.4"

View File

@ -223,15 +223,13 @@ impl<T> IdentityService<T> {
} }
} }
impl<S, T, B> Transform<S> for IdentityService<T> impl<S, T, B> Transform<S, ServiceRequest> for IdentityService<T>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error> S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+ 'static,
S::Future: 'static, S::Future: 'static,
T: IdentityPolicy, T: IdentityPolicy,
B: 'static, B: 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
@ -261,24 +259,22 @@ impl<S, T> Clone for IdentityServiceMiddleware<S, T> {
} }
} }
impl<S, T, B> Service for IdentityServiceMiddleware<S, T> impl<S, T, B> Service<ServiceRequest> for IdentityServiceMiddleware<S, T>
where where
B: 'static, B: 'static,
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error> S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+ 'static,
S::Future: 'static, S::Future: 'static,
T: IdentityPolicy, T: IdentityPolicy,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.borrow_mut().poll_ready(cx) self.service.borrow_mut().poll_ready(cx)
} }
fn call(&mut self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
let srv = self.service.clone(); let srv = self.service.clone();
let backend = self.backend.clone(); let backend = self.backend.clone();
let fut = self.backend.from_request(&mut req); let fut = self.backend.from_request(&mut req);
@ -637,7 +633,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity() { async fn test_identity() {
let mut srv = test::init_service( let srv = test::init_service(
App::new() App::new()
.wrap(IdentityService::new( .wrap(IdentityService::new(
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
@ -668,18 +664,16 @@ mod tests {
) )
.await; .await;
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()) test::call_service(&srv, TestRequest::with_uri("/index").to_request()).await;
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) test::call_service(&srv, TestRequest::with_uri("/login").to_request()).await;
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let c = resp.response().cookies().next().unwrap().to_owned(); let c = resp.response().cookies().next().unwrap().to_owned();
let resp = test::call_service( let resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/index") TestRequest::with_uri("/index")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
@ -688,7 +682,7 @@ mod tests {
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
let resp = test::call_service( let resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/logout") TestRequest::with_uri("/logout")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
@ -701,7 +695,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_max_age_time() { async fn test_identity_max_age_time() {
let duration = Duration::days(1); let duration = Duration::days(1);
let mut srv = test::init_service( let srv = test::init_service(
App::new() App::new()
.wrap(IdentityService::new( .wrap(IdentityService::new(
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
@ -718,8 +712,7 @@ mod tests {
) )
.await; .await;
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) test::call_service(&srv, TestRequest::with_uri("/login").to_request()).await;
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE)); assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned(); let c = resp.response().cookies().next().unwrap().to_owned();
@ -728,7 +721,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_http_only_same_site() { async fn test_http_only_same_site() {
let mut srv = test::init_service( let srv = test::init_service(
App::new() App::new()
.wrap(IdentityService::new( .wrap(IdentityService::new(
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
@ -746,8 +739,7 @@ mod tests {
.await; .await;
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) test::call_service(&srv, TestRequest::with_uri("/login").to_request()).await;
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE)); assert!(resp.headers().contains_key(header::SET_COOKIE));
@ -760,7 +752,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_max_age() { async fn test_identity_max_age() {
let seconds = 60; let seconds = 60;
let mut srv = test::init_service( let srv = test::init_service(
App::new() App::new()
.wrap(IdentityService::new( .wrap(IdentityService::new(
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
@ -777,8 +769,7 @@ mod tests {
) )
.await; .await;
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) test::call_service(&srv, TestRequest::with_uri("/login").to_request()).await;
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE)); assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned(); let c = resp.response().cookies().next().unwrap().to_owned();
@ -790,7 +781,7 @@ mod tests {
>( >(
f: F, f: F,
) -> impl actix_service::Service< ) -> impl actix_service::Service<
Request = actix_http::Request, actix_http::Request,
Response = ServiceResponse<actix_web::body::Body>, Response = ServiceResponse<actix_web::body::Body>,
Error = Error, Error = Error,
> { > {
@ -925,19 +916,19 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_legacy_cookie_is_set() { async fn test_identity_legacy_cookie_is_set() {
let mut srv = create_identity_server(|c| c).await; let srv = create_identity_server(|c| c).await;
let mut resp = let mut resp =
test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await; test::call_service(&srv, TestRequest::with_uri("/").to_request()).await;
assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN);
assert_logged_in(resp, None).await; assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_legacy_cookie_works() { async fn test_identity_legacy_cookie_works() {
let mut srv = create_identity_server(|c| c).await; let srv = create_identity_server(|c| c).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN); let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -949,11 +940,10 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { async fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = let srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN); let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -970,11 +960,10 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { async fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() {
let mut srv = let srv = create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN); let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -991,11 +980,10 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_cookie_rejected_if_login_timestamp_needed() { async fn test_identity_cookie_rejected_if_login_timestamp_needed() {
let mut srv = let srv = create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now()));
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -1012,11 +1000,10 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_cookie_rejected_if_visit_timestamp_needed() { async fn test_identity_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = let srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -1033,15 +1020,14 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_cookie_rejected_if_login_timestamp_too_old() { async fn test_identity_cookie_rejected_if_login_timestamp_too_old() {
let mut srv = let srv = create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie( let cookie = login_cookie(
COOKIE_LOGIN, COOKIE_LOGIN,
Some(SystemTime::now() - Duration::days(180)), Some(SystemTime::now() - Duration::days(180)),
None, None,
); );
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -1058,15 +1044,14 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { async fn test_identity_cookie_rejected_if_visit_timestamp_too_old() {
let mut srv = let srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = login_cookie( let cookie = login_cookie(
COOKIE_LOGIN, COOKIE_LOGIN,
None, None,
Some(SystemTime::now() - Duration::days(180)), Some(SystemTime::now() - Duration::days(180)),
); );
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -1083,11 +1068,10 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_cookie_not_updated_on_login_deadline() { async fn test_identity_cookie_not_updated_on_login_deadline() {
let mut srv = let srv = create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -1100,7 +1084,7 @@ mod tests {
// https://github.com/actix/actix-web/issues/1263 // https://github.com/actix/actix-web/issues/1263
#[actix_rt::test] #[actix_rt::test]
async fn test_identity_cookie_updated_on_visit_deadline() { async fn test_identity_cookie_updated_on_visit_deadline() {
let mut srv = create_identity_server(|c| { let srv = create_identity_server(|c| {
c.visit_deadline(Duration::days(90)) c.visit_deadline(Duration::days(90))
.login_deadline(Duration::days(90)) .login_deadline(Duration::days(90))
}) })
@ -1108,7 +1092,7 @@ mod tests {
let timestamp = SystemTime::now() - Duration::days(1); let timestamp = SystemTime::now() - Duration::days(1);
let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp));
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
@ -1146,22 +1130,22 @@ mod tests {
} }
} }
let mut srv = IdentityServiceMiddleware { let srv = IdentityServiceMiddleware {
backend: Rc::new(Ident), backend: Rc::new(Ident),
service: Rc::new(RefCell::new(into_service( service: Rc::new(RefCell::new(into_service(
|_: ServiceRequest| async move { |_: ServiceRequest| async move {
actix_rt::time::delay_for(std::time::Duration::from_secs(100)).await; actix_rt::time::sleep(std::time::Duration::from_secs(100)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error")) Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}, },
))), ))),
}; };
let mut srv2 = srv.clone(); let srv2 = srv.clone();
let req = TestRequest::default().to_srv_request(); let req = TestRequest::default().to_srv_request();
actix_rt::spawn(async move { actix_rt::spawn(async move {
let _ = srv2.call(req).await; let _ = srv2.call(req).await;
}); });
actix_rt::time::delay_for(std::time::Duration::from_millis(50)).await; actix_rt::time::sleep(std::time::Duration::from_millis(50)).await;
let _ = lazy(|cx| srv.poll_ready(cx)).await; let _ = lazy(|cx| srv.poll_ready(cx)).await;
} }

View File

@ -1,6 +1,7 @@
# Changes # Changes
## Unreleased - 2020-xx-xx ## Unreleased - 2020-xx-xx
* Minimum supported Rust version (MSRV) is now 1.46.0.
## 0.6.0 - 2020-09-11 ## 0.6.0 - 2020-09-11

View File

@ -19,11 +19,11 @@ name = "actix_protobuf"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "3.0.0", default_features = false } actix-web = { version = "4.0.0-beta.4", default_features = false }
actix-rt = "1.1.1" actix-rt = "2"
futures-util = { version = "0.3.5", default-features = false } futures-util = { version = "0.3.5", default-features = false }
derive_more = "0.99" derive_more = "0.99"
prost = "0.6.0" prost = "0.7"
[dev-dependencies] [dev-dependencies]
prost-derive = "0.6.0" prost-derive = "0.7"

View File

@ -8,9 +8,9 @@ authors = [
] ]
[dependencies] [dependencies]
actix-web = "3.0.0" actix-web = "4.0.0-beta.4"
actix-protobuf = { path = "../../" } actix-protobuf = { path = "../../" }
env_logger = "0.7" env_logger = "0.8"
prost = "0.6.0" prost = "0.7"
prost-derive = "0.6.0" prost-derive = "0.7"

View File

@ -17,7 +17,7 @@ use actix_web::error::{Error, PayloadError, ResponseError};
use actix_web::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; use actix_web::http::header::{CONTENT_LENGTH, CONTENT_TYPE};
use actix_web::web::BytesMut; use actix_web::web::BytesMut;
use actix_web::{FromRequest, HttpMessage, HttpRequest, HttpResponse, Responder}; use actix_web::{FromRequest, HttpMessage, HttpRequest, HttpResponse, Responder};
use futures_util::future::{ready, FutureExt, LocalBoxFuture, Ready}; use futures_util::future::{FutureExt, LocalBoxFuture};
use futures_util::StreamExt; use futures_util::StreamExt;
#[derive(Debug, Display)] #[derive(Debug, Display)]
@ -137,21 +137,16 @@ where
} }
impl<T: Message + Default> Responder for ProtoBuf<T> { impl<T: Message + Default> Responder for ProtoBuf<T> {
type Error = Error; fn respond_to(self, _: &HttpRequest) -> HttpResponse {
type Future = Ready<Result<HttpResponse, Error>>;
fn respond_to(self, _: &HttpRequest) -> Self::Future {
let mut buf = Vec::new(); let mut buf = Vec::new();
ready( match self.0.encode(&mut buf) {
self.0 Ok(()) => HttpResponse::Ok()
.encode(&mut buf) .content_type("application/protobuf")
.map_err(|e| Error::from(ProtoBufPayloadError::Serialize(e))) .body(buf),
.map(|()| { Err(err) => HttpResponse::from_error(Error::from(
HttpResponse::Ok() ProtoBufPayloadError::Serialize(err),
.content_type("application/protobuf") )),
.body(buf) }
}),
)
} }
} }
@ -255,7 +250,7 @@ pub trait ProtoBufResponseBuilder {
impl ProtoBufResponseBuilder for HttpResponseBuilder { impl ProtoBufResponseBuilder for HttpResponseBuilder {
fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error> { fn protobuf<T: Message>(&mut self, value: T) -> Result<HttpResponse, Error> {
self.header(CONTENT_TYPE, "application/protobuf"); self.append_header((CONTENT_TYPE, "application/protobuf"));
let mut body = Vec::new(); let mut body = Vec::new();
value value
@ -313,16 +308,16 @@ mod tests {
let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await; let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType); assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
let (req, mut pl) = let (req, mut pl) = TestRequest::get()
TestRequest::with_header(header::CONTENT_TYPE, "application/text") .append_header((header::CONTENT_TYPE, "application/text"))
.to_http_parts(); .to_http_parts();
let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await; let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl).await;
assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType); assert_eq!(protobuf.err().unwrap(), ProtoBufPayloadError::ContentType);
let (req, mut pl) = let (req, mut pl) = TestRequest::get()
TestRequest::with_header(header::CONTENT_TYPE, "application/protobuf") .append_header((header::CONTENT_TYPE, "application/protobuf"))
.header(header::CONTENT_LENGTH, "10000") .append_header((header::CONTENT_LENGTH, "10000"))
.to_http_parts(); .to_http_parts();
let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl) let protobuf = ProtoBufMessage::<MyObject>::new(&req, &mut pl)
.limit(100) .limit(100)
.await; .await;

View File

@ -22,7 +22,6 @@ default = ["web"]
# actix-web integration # actix-web integration
web = [ web = [
"actix-http/actors",
"actix-service", "actix-service",
"actix-web", "actix-web",
"actix-session/cookie-session", "actix-session/cookie-session",
@ -32,28 +31,27 @@ web = [
] ]
[dependencies] [dependencies]
actix = "0.10.0"
actix-utils = "2.0.0"
log = "0.4.6" log = "0.4.6"
backoff = "0.2.1"
derive_more = "0.99.2" derive_more = "0.99.2"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
redis-async = "0.6.3" futures-channel = { version = "0.3.5", default-features = false }
actix-rt = "1.1.1" redis-async = "0.9"
time = "0.2.23" time = "0.2.23"
tokio = "0.2.6" actix-rt = "2"
tokio-util = "0.3.0" tokio = "1"
tokio-util = "0.6"
trust-dns-resolver = { version = "0.20.0", default-features = false, features = ["tokio-runtime", "system-config"] }
# actix-session # actix-session
actix-web = { version = "3.0.0", default_features = false, optional = true } actix-web = { version = "4.0.0-beta.4", default_features = false, optional = true }
actix-http = { version = "2.0.0", optional = true } actix-http = { version = "3.0.0-beta.4", optional = true }
actix-service = { version = "1.0.6", optional = true } actix-service = { version = "2.0.0-beta.5", optional = true }
actix-session = { version = "0.4.0", optional = true } actix-session = { version = "0.4.0", optional = true }
rand = { version = "0.7.0", optional = true } rand = { version = "0.8", optional = true }
serde = { version = "1.0.101", optional = true } serde = { version = "1.0.101", optional = true }
serde_json = { version = "1.0.40", optional = true } serde_json = { version = "1.0.40", optional = true }
[dev-dependencies] [dev-dependencies]
env_logger = "0.7" env_logger = "0.8"
serde_derive = "1.0" serde_derive = "1.0"

View File

@ -19,7 +19,7 @@ struct User {
} }
impl User { impl User {
fn authenticate(credentials: Credentials) -> Result<Self, actix_http::Response> { fn authenticate(credentials: Credentials) -> Result<Self, HttpResponse> {
// TODO: figure out why I keep getting hacked // TODO: figure out why I keep getting hacked
if &credentials.password != "hunter2" { if &credentials.password != "hunter2" {
return Err(HttpResponse::Unauthorized().json("Unauthorized")); return Err(HttpResponse::Unauthorized().json("Unauthorized"));
@ -33,7 +33,7 @@ impl User {
} }
} }
pub fn validate_session(session: &Session) -> Result<i64, actix_http::Response> { pub fn validate_session(session: &Session) -> Result<i64, HttpResponse> {
let user_id: Option<i64> = session.get("user_id").unwrap_or(None); let user_id: Option<i64> = session.get("user_id").unwrap_or(None);
match user_id { match user_id {
@ -49,7 +49,7 @@ pub fn validate_session(session: &Session) -> Result<i64, actix_http::Response>
async fn login( async fn login(
credentials: web::Json<Credentials>, credentials: web::Json<Credentials>,
session: Session, session: Session,
) -> Result<impl Responder, actix_http::Response> { ) -> Result<impl Responder, HttpResponse> {
let credentials = credentials.into_inner(); let credentials = credentials.into_inner();
match User::authenticate(credentials) { match User::authenticate(credentials) {

View File

@ -3,7 +3,7 @@
#![deny(rust_2018_idioms)] #![deny(rust_2018_idioms)]
mod redis; mod redis;
pub use redis::{Command, RedisActor}; pub use redis::RedisClient;
use derive_more::{Display, Error, From}; use derive_more::{Display, Error, From};
@ -25,6 +25,12 @@ pub enum Error {
/// Cancel all waters when connection get dropped /// Cancel all waters when connection get dropped
#[display(fmt = "Redis: Disconnected")] #[display(fmt = "Redis: Disconnected")]
Disconnected, Disconnected,
/// Invalid address
#[display(fmt = "Redis: Invalid address")]
InvalidAddress,
/// DNS resolve error
#[display(fmt = "Redis: DNS resolve error")]
ResolveError,
} }
#[cfg(feature = "web")] #[cfg(feature = "web")]

View File

@ -1,150 +1,98 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::io; use std::net::SocketAddr;
use actix::actors::resolver::{Connect, Resolver}; use redis_async::client::{paired_connect, PairedConnection};
use actix::prelude::*; use redis_async::resp::RespValue;
use actix_utils::oneshot; use tokio::sync::Mutex;
use backoff::backoff::Backoff; use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use backoff::ExponentialBackoff; use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
use futures_util::FutureExt;
use log::{error, info, warn};
use redis_async::error::Error as RespError;
use redis_async::resp::{RespCodec, RespValue};
use tokio::io::{split, WriteHalf};
use tokio::net::TcpStream;
use tokio_util::codec::FramedRead;
use crate::Error; use crate::Error;
/// Command for send data to Redis pub struct RedisClient {
#[derive(Debug)]
pub struct Command(pub RespValue);
impl Message for Command {
type Result = Result<RespValue, Error>;
}
/// Redis comminucation actor
pub struct RedisActor {
addr: String, addr: String,
backoff: ExponentialBackoff, connection: Mutex<Option<PairedConnection>>,
cell: Option<actix::io::FramedWrite<RespValue, WriteHalf<TcpStream>, RespCodec>>,
queue: VecDeque<oneshot::Sender<Result<RespValue, Error>>>,
} }
impl RedisActor { impl RedisClient {
/// Start new `Supervisor` with `RedisActor`. pub fn new(addr: impl Into<String>) -> Self {
pub fn start<S: Into<String>>(addr: S) -> Addr<RedisActor> { Self {
let addr = addr.into(); addr: addr.into(),
connection: Mutex::new(None),
}
}
let backoff = ExponentialBackoff { async fn get_connection(&self) -> Result<PairedConnection, Error> {
max_elapsed_time: None, let mut connection = self.connection.lock().await;
..Default::default() if let Some(ref connection) = *connection {
}; return Ok(connection.clone());
}
Supervisor::start(|_| RedisActor { let mut addrs = resolve(&self.addr).await?;
addr, loop {
cell: None, // try to connect
backoff, let socket_addr = addrs.pop_front().ok_or_else(|| {
queue: VecDeque::new(), log::warn!("Cannot connect to {}.", self.addr);
Error::NotConnected
})?;
match paired_connect(socket_addr).await {
Ok(conn) => {
*connection = Some(conn.clone());
return Ok(conn);
}
Err(err) => log::warn!(
"Attempt to connect to {} as {} failed: {}.",
self.addr,
socket_addr,
err
),
}
}
}
pub async fn send(&self, req: RespValue) -> Result<RespValue, Error> {
let res = self.get_connection().await?.send(req).await?;
Ok(res)
}
}
fn parse_addr(addr: &str, default_port: u16) -> Option<(&str, u16)> {
// split the string by ':' and convert the second part to u16
let mut parts_iter = addr.splitn(2, ':');
let host = parts_iter.next()?;
let port_str = parts_iter.next().unwrap_or("");
let port: u16 = port_str.parse().unwrap_or(default_port);
Some((host, port))
}
async fn resolve(addr: &str) -> Result<VecDeque<SocketAddr>, Error> {
// try to parse as a regular SocketAddr first
if let Ok(addr) = addr.parse::<SocketAddr>() {
let mut addrs = VecDeque::new();
addrs.push_back(addr);
return Ok(addrs);
}
let (host, port) = parse_addr(addr, 6379).ok_or(Error::InvalidAddress)?;
// we need to do dns resolution
let resolver = AsyncResolver::tokio_from_system_conf()
.or_else(|err| {
log::warn!("Cannot create system DNS resolver: {}", err);
AsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default())
}) })
} .map_err(|err| {
} log::error!("Cannot create DNS resolver: {}", err);
Error::ResolveError
impl Actor for RedisActor { })?;
type Context = Context<Self>;
let addrs = resolver
fn started(&mut self, ctx: &mut Context<Self>) { .lookup_ip(host)
Resolver::from_registry() .await
.send(Connect::host(self.addr.as_str())) .map_err(|_| Error::ResolveError)?
.into_actor(self) .into_iter()
.map(|res, act, ctx| match res { .map(|ip| SocketAddr::new(ip, port))
Ok(res) => match res { .collect();
Ok(stream) => {
info!("Connected to redis server: {}", act.addr); Ok(addrs)
let (r, w) = split(stream);
// configure write side of the connection
let framed = actix::io::FramedWrite::new(w, RespCodec, ctx);
act.cell = Some(framed);
// read side of the connection
ctx.add_stream(FramedRead::new(r, RespCodec));
act.backoff.reset();
}
Err(err) => {
error!("Can not connect to redis server: {}", err);
// re-connect with backoff time.
// we stop current context, supervisor will restart it.
if let Some(timeout) = act.backoff.next_backoff() {
ctx.run_later(timeout, |_, ctx| ctx.stop());
}
}
},
Err(err) => {
error!("Can not connect to redis server: {}", err);
// re-connect with backoff time.
// we stop current context, supervisor will restart it.
if let Some(timeout) = act.backoff.next_backoff() {
ctx.run_later(timeout, |_, ctx| ctx.stop());
}
}
})
.wait(ctx);
}
}
impl Supervised for RedisActor {
fn restarting(&mut self, _: &mut Self::Context) {
self.cell.take();
for tx in self.queue.drain(..) {
let _ = tx.send(Err(Error::Disconnected));
}
}
}
impl actix::io::WriteHandler<io::Error> for RedisActor {
fn error(&mut self, err: io::Error, _: &mut Self::Context) -> Running {
warn!("Redis connection dropped: {} error: {}", self.addr, err);
Running::Stop
}
}
impl StreamHandler<Result<RespValue, RespError>> for RedisActor {
fn handle(&mut self, msg: Result<RespValue, RespError>, ctx: &mut Self::Context) {
match msg {
Err(e) => {
if let Some(tx) = self.queue.pop_front() {
let _ = tx.send(Err(e.into()));
}
ctx.stop();
}
Ok(val) => {
if let Some(tx) = self.queue.pop_front() {
let _ = tx.send(Ok(val));
}
}
}
}
}
impl Handler<Command> for RedisActor {
type Result = ResponseFuture<Result<RespValue, Error>>;
fn handle(&mut self, msg: Command, _: &mut Self::Context) -> Self::Result {
let (tx, rx) = oneshot::channel();
if let Some(ref mut cell) = self.cell {
self.queue.push_back(tx);
cell.write(msg.0);
} else {
let _ = tx.send(Err(Error::NotConnected));
}
Box::pin(rx.map(|res| match res {
Ok(res) => res,
Err(_) => Err(Error::Disconnected),
}))
}
} }

View File

@ -3,7 +3,6 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{collections::HashMap, iter, rc::Rc}; use std::{collections::HashMap, iter, rc::Rc};
use actix::prelude::*;
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use actix_session::{Session, SessionStatus}; use actix_session::{Session, SessionStatus};
use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
@ -16,7 +15,7 @@ use redis_async::resp::RespValue;
use redis_async::resp_array; use redis_async::resp_array;
use time::{self, Duration, OffsetDateTime}; use time::{self, Duration, OffsetDateTime};
use crate::redis::{Command, RedisActor}; use crate::redis::RedisClient;
/// Use redis as session storage. /// Use redis as session storage.
/// ///
@ -36,7 +35,7 @@ impl RedisSession {
key: Key::derive_from(key), key: Key::derive_from(key),
cache_keygen: Box::new(|key: &str| format!("session:{}", &key)), cache_keygen: Box::new(|key: &str| format!("session:{}", &key)),
ttl: "7200".to_owned(), ttl: "7200".to_owned(),
addr: RedisActor::start(addr), redis_client: RedisClient::new(addr),
name: "actix-session".to_owned(), name: "actix-session".to_owned(),
path: "/".to_owned(), path: "/".to_owned(),
domain: None, domain: None,
@ -113,14 +112,12 @@ impl RedisSession {
} }
} }
impl<S, B> Transform<S> for RedisSession impl<S, B> Transform<S, ServiceRequest> for RedisSession
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error> S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+ 'static,
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = S::Error; type Error = S::Error;
type InitError = (); type InitError = ();
@ -141,25 +138,23 @@ pub struct RedisSessionMiddleware<S: 'static> {
inner: Rc<Inner>, inner: Rc<Inner>,
} }
impl<S, B> Service for RedisSessionMiddleware<S> impl<S, B> Service<ServiceRequest> for RedisSessionMiddleware<S>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error> S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+ 'static,
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.borrow_mut().poll_ready(cx) self.service.borrow_mut().poll_ready(cx)
} }
fn call(&mut self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
let mut srv = self.service.clone(); let srv = self.service.clone();
let inner = self.inner.clone(); let inner = self.inner.clone();
Box::pin(async move { Box::pin(async move {
@ -215,7 +210,7 @@ struct Inner {
key: Key, key: Key,
cache_keygen: Box<dyn Fn(&str) -> String>, cache_keygen: Box<dyn Fn(&str) -> String>,
ttl: String, ttl: String,
addr: Addr<RedisActor>, redis_client: RedisClient,
name: String, name: String,
path: String, path: String,
domain: Option<String>, domain: Option<String>,
@ -256,13 +251,11 @@ impl Inner {
} }
}; };
let res = self let val = self
.addr .redis_client
.send(Command(resp_array!["GET", cache_key])) .send(resp_array!["GET", cache_key])
.await?; .await?;
let val = res.map_err(error::ErrorInternalServerError)?;
match val { match val {
RespValue::Error(err) => { RespValue::Error(err) => {
return Err(error::ErrorInternalServerError(err)); return Err(error::ErrorInternalServerError(err));
@ -294,6 +287,7 @@ impl Inner {
} else { } else {
let value: String = iter::repeat(()) let value: String = iter::repeat(())
.map(|()| OsRng.sample(Alphanumeric)) .map(|()| OsRng.sample(Alphanumeric))
.map(char::from)
.take(32) .take(32)
.collect(); .collect();
@ -331,12 +325,9 @@ impl Inner {
Ok(body) => body, Ok(body) => body,
}; };
let cmd = Command(resp_array!["SET", cache_key, body, "EX", &self.ttl]); self.redis_client
.send(resp_array!["SET", cache_key, body, "EX", &self.ttl])
self.addr .await?;
.send(cmd)
.await?
.map_err(error::ErrorInternalServerError)?;
if let Some(jar) = jar { if let Some(jar) = jar {
for cookie in jar.delta() { for cookie in jar.delta() {
@ -352,17 +343,16 @@ impl Inner {
async fn clear_cache(&self, key: String) -> Result<(), Error> { async fn clear_cache(&self, key: String) -> Result<(), Error> {
let cache_key = (self.cache_keygen)(&key); let cache_key = (self.cache_keygen)(&key);
match self.addr.send(Command(resp_array!["DEL", cache_key])).await { match self
Err(e) => Err(Error::from(e)), .redis_client
Ok(res) => { .send(resp_array!["DEL", cache_key])
match res { .await?
// redis responds with number of deleted records {
Ok(RespValue::Integer(x)) if x > 0 => Ok(()), // redis responds with number of deleted records
_ => Err(error::ErrorInternalServerError( RespValue::Integer(x) if x > 0 => Ok(()),
"failed to remove session from cache", _ => Err(error::ErrorInternalServerError(
)), "failed to remove session from cache",
} )),
}
} }
} }
@ -406,7 +396,7 @@ mod test {
.unwrap_or(Some(0)) .unwrap_or(Some(0))
.unwrap_or(0); .unwrap_or(0);
Ok(HttpResponse::Ok().json(IndexResponse { user_id, counter })) Ok(HttpResponse::Ok().json(&IndexResponse { user_id, counter }))
} }
async fn do_something(session: Session) -> Result<HttpResponse> { async fn do_something(session: Session) -> Result<HttpResponse> {
@ -417,7 +407,7 @@ mod test {
.map_or(1, |inner| inner + 1); .map_or(1, |inner| inner + 1);
session.set("counter", counter)?; session.set("counter", counter)?;
Ok(HttpResponse::Ok().json(IndexResponse { user_id, counter })) Ok(HttpResponse::Ok().json(&IndexResponse { user_id, counter }))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -438,7 +428,7 @@ mod test {
.unwrap_or(Some(0)) .unwrap_or(Some(0))
.unwrap_or(0); .unwrap_or(0);
Ok(HttpResponse::Ok().json(IndexResponse { Ok(HttpResponse::Ok().json(&IndexResponse {
user_id: Some(id), user_id: Some(id),
counter, counter,
})) }))

View File

@ -1,42 +1,31 @@
#[macro_use] #[macro_use]
extern crate redis_async; extern crate redis_async;
use actix_redis::{Command, Error, RedisActor, RespValue}; use actix_redis::{Error, RedisClient, RespValue};
#[actix_rt::test] #[actix_rt::test]
async fn test_error_connect() { async fn test_error_connect() {
let addr = RedisActor::start("localhost:54000"); let addr = RedisClient::new("localhost:54000");
let _addr2 = addr.clone();
let res = addr.send(Command(resp_array!["GET", "test"])).await; let res = addr.send(resp_array!["GET", "test"]).await;
match res { match res {
Ok(Err(Error::NotConnected)) => (), Err(Error::NotConnected) => (),
_ => panic!("Should not happen {:?}", res), _ => panic!("Should not happen {:?}", res),
} }
} }
#[actix_rt::test] #[actix_rt::test]
async fn test_redis() { async fn test_redis() -> Result<(), Error> {
env_logger::init(); env_logger::init();
let addr = RedisActor::start("127.0.0.1:6379"); let addr = RedisClient::new("127.0.0.1:6379");
let res = addr
.send(Command(resp_array!["SET", "test", "value"]))
.await;
match res { let resp = addr.send(resp_array!["SET", "test", "value"]).await?;
Ok(Ok(resp)) => {
assert_eq!(resp, RespValue::SimpleString("OK".to_owned()));
let res = addr.send(Command(resp_array!["GET", "test"])).await; assert_eq!(resp, RespValue::SimpleString("OK".to_owned()));
match res {
Ok(Ok(resp)) => { let resp = addr.send(resp_array!["GET", "test"]).await?;
println!("RESP: {:?}", resp); println!("RESP: {:?}", resp);
assert_eq!(resp, RespValue::BulkString((&b"value"[..]).into())); assert_eq!(resp, RespValue::BulkString((&b"value"[..]).into()));
} Ok(())
_ => panic!("Should not happen {:?}", res),
}
}
_ => panic!("Should not happen {:?}", res),
}
} }

View File

@ -6,6 +6,7 @@
## 0.4.1 - 2020-03-21 ## 0.4.1 - 2020-03-21
* `Session::set_session` takes a `IntoIterator` instead of `Iterator`. [#105] * `Session::set_session` takes a `IntoIterator` instead of `Iterator`. [#105]
* Fix calls to `session.purge()` from paths other than the one specified in the cookie. [#129] * Fix calls to `session.purge()` from paths other than the one specified in the cookie. [#129]
* Minimum supported Rust version (MSRV) is now 1.46.0.
[#105]: https://github.com/actix/actix-extras/pull/105 [#105]: https://github.com/actix/actix-extras/pull/105
[#129]: https://github.com/actix/actix-extras/pull/129 [#129]: https://github.com/actix/actix-extras/pull/129

View File

@ -20,8 +20,8 @@ default = ["cookie-session"]
cookie-session = ["actix-web/secure-cookies"] cookie-session = ["actix-web/secure-cookies"]
[dependencies] [dependencies]
actix-web = { version = "3.0.0", default_features = false } actix-web = { version = "4.0.0-beta.4", default_features = false, features = ["cookies"] }
actix-service = "1.0.6" actix-service = "2.0.0-beta.5"
derive_more = "0.99.2" derive_more = "0.99.2"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
serde = "1.0" serde = "1.0"
@ -29,4 +29,4 @@ serde_json = "1.0"
time = { version = "0.2.23", default-features = false, features = ["std"] } time = { version = "0.2.23", default-features = false, features = ["std"] }
[dev-dependencies] [dev-dependencies]
actix-rt = "1" actix-rt = "2"

View File

@ -9,7 +9,7 @@ use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
use actix_web::dev::{ServiceRequest, ServiceResponse}; use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::http::{header::SET_COOKIE, HeaderValue}; use actix_web::http::{header::SET_COOKIE, HeaderValue};
use actix_web::{Error, HttpMessage, ResponseError}; use actix_web::{Error, HttpMessage, ResponseError};
use derive_more::{Display, From}; use derive_more::Display;
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready};
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
use time::{Duration, OffsetDateTime}; use time::{Duration, OffsetDateTime};
@ -17,7 +17,7 @@ use time::{Duration, OffsetDateTime};
use crate::{Session, SessionStatus}; use crate::{Session, SessionStatus};
/// Errors that can occur during handling cookie session /// Errors that can occur during handling cookie session
#[derive(Debug, From, Display)] #[derive(Debug, Display)]
pub enum CookieSessionError { pub enum CookieSessionError {
/// Size of the serialized session is greater than 4000 bytes. /// Size of the serialized session is greater than 4000 bytes.
#[display(fmt = "Size of the serialized session is greater than 4000 bytes.")] #[display(fmt = "Size of the serialized session is greater than 4000 bytes.")]
@ -290,13 +290,12 @@ impl CookieSession {
} }
} }
impl<S, B: 'static> Transform<S> for CookieSession impl<S, B: 'static> Transform<S, ServiceRequest> for CookieSession
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>>, S: Service<ServiceRequest, Response = ServiceResponse<B>>,
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = S::Error; type Error = S::Error;
type InitError = (); type InitError = ();
@ -317,18 +316,17 @@ pub struct CookieSessionMiddleware<S> {
inner: Rc<CookieSessionInner>, inner: Rc<CookieSessionInner>,
} }
impl<S, B: 'static> Service for CookieSessionMiddleware<S> impl<S, B: 'static> Service<ServiceRequest> for CookieSessionMiddleware<S>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>>, S: Service<ServiceRequest, Response = ServiceResponse<B>>,
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx) self.service.poll_ready(cx)
} }
@ -337,7 +335,7 @@ where
/// session state changes, then set-cookie is returned in response. As /// session state changes, then set-cookie is returned in response. As
/// a user logs out, call session.purge() to set SessionStatus accordingly /// a user logs out, call session.purge() to set SessionStatus accordingly
/// and this will trigger removal of the session cookie in the response. /// and this will trigger removal of the session cookie in the response.
fn call(&mut self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();
let (is_new, state) = self.inner.load(&req); let (is_new, state) = self.inner.load(&req);
let prolong_expiration = self.inner.expires_in.is_some(); let prolong_expiration = self.inner.expires_in.is_some();
@ -387,7 +385,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn cookie_session() { async fn cookie_session() {
let mut app = test::init_service( let app = test::init_service(
App::new() App::new()
.wrap(CookieSession::signed(&[0; 32]).secure(false)) .wrap(CookieSession::signed(&[0; 32]).secure(false))
.service(web::resource("/").to(|ses: Session| async move { .service(web::resource("/").to(|ses: Session| async move {
@ -407,7 +405,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn private_cookie() { async fn private_cookie() {
let mut app = test::init_service( let app = test::init_service(
App::new() App::new()
.wrap(CookieSession::private(&[0; 32]).secure(false)) .wrap(CookieSession::private(&[0; 32]).secure(false))
.service(web::resource("/").to(|ses: Session| async move { .service(web::resource("/").to(|ses: Session| async move {
@ -427,7 +425,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn lazy_cookie() { async fn lazy_cookie() {
let mut app = test::init_service( let app = test::init_service(
App::new() App::new()
.wrap(CookieSession::signed(&[0; 32]).secure(false).lazy(true)) .wrap(CookieSession::signed(&[0; 32]).secure(false).lazy(true))
.service(web::resource("/count").to(|ses: Session| async move { .service(web::resource("/count").to(|ses: Session| async move {
@ -453,7 +451,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn cookie_session_extractor() { async fn cookie_session_extractor() {
let mut app = test::init_service( let app = test::init_service(
App::new() App::new()
.wrap(CookieSession::signed(&[0; 32]).secure(false)) .wrap(CookieSession::signed(&[0; 32]).secure(false))
.service(web::resource("/").to(|ses: Session| async move { .service(web::resource("/").to(|ses: Session| async move {
@ -473,7 +471,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn basics() { async fn basics() {
let mut app = test::init_service( let app = test::init_service(
App::new() App::new()
.wrap( .wrap(
CookieSession::signed(&[0; 32]) CookieSession::signed(&[0; 32])
@ -508,13 +506,13 @@ mod tests {
let request = test::TestRequest::with_uri("/test/") let request = test::TestRequest::with_uri("/test/")
.cookie(cookie) .cookie(cookie)
.to_request(); .to_request();
let body = test::read_response(&mut app, request).await; let body = test::read_response(&app, request).await;
assert_eq!(body, Bytes::from_static(b"counter: 100")); assert_eq!(body, Bytes::from_static(b"counter: 100"));
} }
#[actix_rt::test] #[actix_rt::test]
async fn prolong_expiration() { async fn prolong_expiration() {
let mut app = test::init_service( let app = test::init_service(
App::new() App::new()
.wrap(CookieSession::signed(&[0; 32]).secure(false).expires_in(60)) .wrap(CookieSession::signed(&[0; 32]).secure(false).expires_in(60))
.service(web::resource("/").to(|ses: Session| async move { .service(web::resource("/").to(|ses: Session| async move {
@ -538,7 +536,7 @@ mod tests {
.expires() .expires()
.expect("Expiration is set"); .expect("Expiration is set");
actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await; actix_rt::time::sleep(std::time::Duration::from_secs(1)).await;
let request = test::TestRequest::with_uri("/test/").to_request(); let request = test::TestRequest::with_uri("/test/").to_request();
let response = app.call(request).await.unwrap(); let response = app.call(request).await.unwrap();

View File

@ -5,6 +5,7 @@
## 0.5.1 - 2020-03-21 ## 0.5.1 - 2020-03-21
* Correct error handling when extracting auth details from request. [#128] * Correct error handling when extracting auth details from request. [#128]
* Minimum supported Rust version (MSRV) is now 1.46.0.
[#128]: https://github.com/actix/actix-extras/pull/128 [#128]: https://github.com/actix/actix-extras/pull/128

View File

@ -20,11 +20,11 @@ name = "actix_web_httpauth"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "3.0.0", default_features = false } actix-web = { version = "4.0.0-beta.4", default_features = false }
base64 = "0.13" base64 = "0.13"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
[dev-dependencies] [dev-dependencies]
actix-cors = "0.5" actix-cors = "0.5"
actix-rt = "1.1.1" actix-rt = "2"
actix-service = "1.0.6" actix-service = "2.0.0-beta.5"

View File

@ -54,7 +54,7 @@ impl<C: 'static + Challenge> ResponseError for AuthenticationError<C> {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code) HttpResponse::build(self.status_code)
// TODO: Get rid of the `.clone()` // TODO: Get rid of the `.clone()`
.set(WwwAuthenticate(self.challenge.clone())) .insert_header(WwwAuthenticate(self.challenge.clone()))
.finish() .finish()
} }

View File

@ -91,8 +91,8 @@ impl<S: Scheme> Header for Authorization<S> {
impl<S: Scheme> IntoHeaderValue for Authorization<S> { impl<S: Scheme> IntoHeaderValue for Authorization<S> {
type Error = <S as IntoHeaderValue>::Error; type Error = <S as IntoHeaderValue>::Error;
fn try_into(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> { fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
self.0.try_into() self.0.try_into_value()
} }
} }

View File

@ -97,7 +97,7 @@ impl fmt::Display for Basic {
impl IntoHeaderValue for Basic { impl IntoHeaderValue for Basic {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValue;
fn try_into(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> { fn try_into_value(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> {
let mut credentials = BytesMut::with_capacity( let mut credentials = BytesMut::with_capacity(
self.user_id.len() self.user_id.len()
+ 1 // ':' + 1 // ':'
@ -187,7 +187,7 @@ mod tests {
password: Some("open sesame".into()), password: Some("open sesame".into()),
}; };
let result = basic.try_into(); let result = basic.try_into_value();
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),

View File

@ -76,7 +76,7 @@ impl fmt::Display for Bearer {
impl IntoHeaderValue for Bearer { impl IntoHeaderValue for Bearer {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValue;
fn try_into(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> { fn try_into_value(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> {
let mut buffer = BytesMut::with_capacity(7 + self.token.len()); let mut buffer = BytesMut::with_capacity(7 + self.token.len());
buffer.put(&b"Bearer "[..]); buffer.put(&b"Bearer "[..]);
buffer.extend_from_slice(self.token.as_bytes()); buffer.extend_from_slice(self.token.as_bytes());
@ -128,7 +128,7 @@ mod tests {
fn test_into_header_value() { fn test_into_header_value() {
let bearer = Bearer::new("mF_9.B5f-4.1JqM"); let bearer = Bearer::new("mF_9.B5f-4.1JqM");
let result = bearer.try_into(); let result = bearer.try_into_value();
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),

View File

@ -25,7 +25,7 @@ use crate::utils;
/// let challenge = Basic::with_realm("Restricted area"); /// let challenge = Basic::with_realm("Restricted area");
/// ///
/// HttpResponse::Unauthorized() /// HttpResponse::Unauthorized()
/// .set(WwwAuthenticate(challenge)) /// .insert_header(WwwAuthenticate(challenge))
/// .finish() /// .finish()
/// } /// }
/// ``` /// ```
@ -106,7 +106,7 @@ impl fmt::Display for Basic {
impl IntoHeaderValue for Basic { impl IntoHeaderValue for Basic {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValue;
fn try_into(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> { fn try_into_value(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> {
HeaderValue::from_maybe_shared(self.to_bytes()) HeaderValue::from_maybe_shared(self.to_bytes())
} }
} }
@ -120,7 +120,7 @@ mod tests {
fn test_plain_into_header_value() { fn test_plain_into_header_value() {
let challenge = Basic { realm: None }; let challenge = Basic { realm: None };
let value = challenge.try_into(); let value = challenge.try_into_value();
assert!(value.is_ok()); assert!(value.is_ok());
let value = value.unwrap(); let value = value.unwrap();
assert_eq!(value, "Basic"); assert_eq!(value, "Basic");
@ -132,7 +132,7 @@ mod tests {
realm: Some("Restricted area".into()), realm: Some("Restricted area".into()),
}; };
let value = challenge.try_into(); let value = challenge.try_into_value();
assert!(value.is_ok()); assert!(value.is_ok());
let value = value.unwrap(); let value = value.unwrap();
assert_eq!(value, "Basic realm=\"Restricted area\""); assert_eq!(value, "Basic realm=\"Restricted area\"");

View File

@ -31,7 +31,7 @@ use crate::utils;
/// .finish(); /// .finish();
/// ///
/// HttpResponse::Unauthorized() /// HttpResponse::Unauthorized()
/// .set(WwwAuthenticate(challenge)) /// .insert_header(WwwAuthenticate(challenge))
/// .finish() /// .finish()
/// } /// }
/// ``` /// ```
@ -133,7 +133,7 @@ impl fmt::Display for Bearer {
impl IntoHeaderValue for Bearer { impl IntoHeaderValue for Bearer {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValue;
fn try_into(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> { fn try_into_value(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> {
HeaderValue::from_maybe_shared(self.to_bytes()) HeaderValue::from_maybe_shared(self.to_bytes())
} }
} }

View File

@ -27,7 +27,7 @@ impl<C: Challenge> Header for WwwAuthenticate<C> {
impl<C: Challenge> IntoHeaderValue for WwwAuthenticate<C> { impl<C: Challenge> IntoHeaderValue for WwwAuthenticate<C> {
type Error = <C as IntoHeaderValue>::Error; type Error = <C as IntoHeaderValue>::Error;
fn try_into(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> { fn try_into_value(self) -> Result<HeaderValue, <Self as IntoHeaderValue>::Error> {
self.0.try_into() self.0.try_into_value()
} }
} }

View File

@ -118,16 +118,14 @@ where
} }
} }
impl<S, B, T, F, O> Transform<S> for HttpAuthentication<T, F> impl<S, B, T, F, O> Transform<S, ServiceRequest> for HttpAuthentication<T, F>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error> S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+ 'static,
S::Future: 'static, S::Future: 'static,
F: Fn(ServiceRequest, T) -> O + 'static, F: Fn(ServiceRequest, T) -> O + 'static,
O: Future<Output = Result<ServiceRequest, Error>> + 'static, O: Future<Output = Result<ServiceRequest, Error>> + 'static,
T: AuthExtractor + 'static, T: AuthExtractor + 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Transform = AuthenticationMiddleware<S, F, T>; type Transform = AuthenticationMiddleware<S, F, T>;
@ -153,25 +151,23 @@ where
_extractor: PhantomData<T>, _extractor: PhantomData<T>,
} }
impl<S, B, F, T, O> Service for AuthenticationMiddleware<S, F, T> impl<S, B, F, T, O> Service<ServiceRequest> for AuthenticationMiddleware<S, F, T>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error> S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+ 'static,
S::Future: 'static, S::Future: 'static,
F: Fn(ServiceRequest, T) -> O + 'static, F: Fn(ServiceRequest, T) -> O + 'static,
O: Future<Output = Result<ServiceRequest, Error>> + 'static, O: Future<Output = Result<ServiceRequest, Error>> + 'static,
T: AuthExtractor + 'static, T: AuthExtractor + 'static,
{ {
type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>; type Future = LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>;
fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.borrow_mut().poll_ready(ctx) self.service.borrow_mut().poll_ready(ctx)
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
let process_fn = Arc::clone(&self.process_fn); let process_fn = Arc::clone(&self.process_fn);
let service = Rc::clone(&self.service); let service = Rc::clone(&self.service);
@ -251,15 +247,14 @@ mod tests {
use actix_service::{into_service, Service}; use actix_service::{into_service, Service};
use actix_web::error; use actix_web::error;
use actix_web::test::TestRequest; use actix_web::test::TestRequest;
use futures_util::join;
/// This is a test for https://github.com/actix/actix-extras/issues/10 /// This is a test for https://github.com/actix/actix-extras/issues/10
#[actix_rt::test] #[actix_rt::test]
async fn test_middleware_panic() { async fn test_middleware_panic() {
let mut middleware = AuthenticationMiddleware { let middleware = AuthenticationMiddleware {
service: Rc::new(RefCell::new(into_service( service: Rc::new(RefCell::new(into_service(
|_: ServiceRequest| async move { |_: ServiceRequest| async move {
actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await; actix_rt::time::sleep(std::time::Duration::from_secs(1)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error")) Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}, },
))), ))),
@ -267,22 +262,24 @@ mod tests {
_extractor: PhantomData, _extractor: PhantomData,
}; };
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); let req = TestRequest::get()
.append_header(("Authorization", "Bearer 1"))
.to_srv_request();
let f = middleware.call(req); let f = middleware.call(req).await;
let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)); let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
assert!(join!(f, res).0.is_err()); assert!(f.is_err());
} }
/// This is a test for https://github.com/actix/actix-extras/issues/10 /// This is a test for https://github.com/actix/actix-extras/issues/10
#[actix_rt::test] #[actix_rt::test]
async fn test_middleware_panic_several_orders() { async fn test_middleware_panic_several_orders() {
let mut middleware = AuthenticationMiddleware { let middleware = AuthenticationMiddleware {
service: Rc::new(RefCell::new(into_service( service: Rc::new(RefCell::new(into_service(
|_: ServiceRequest| async move { |_: ServiceRequest| async move {
actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await; actix_rt::time::sleep(std::time::Duration::from_secs(1)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error")) Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}, },
))), ))),
@ -290,24 +287,28 @@ mod tests {
_extractor: PhantomData, _extractor: PhantomData,
}; };
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); let req = TestRequest::get()
.append_header(("Authorization", "Bearer 1"))
.to_srv_request();
let f1 = middleware.call(req); let f1 = middleware.call(req).await;
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); let req = TestRequest::get()
.append_header(("Authorization", "Bearer 1"))
.to_srv_request();
let f2 = middleware.call(req); let f2 = middleware.call(req).await;
let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request(); let req = TestRequest::get()
.append_header(("Authorization", "Bearer 1"))
.to_srv_request();
let f3 = middleware.call(req); let f3 = middleware.call(req).await;
let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)); let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
let result = join!(f1, f2, f3, res); assert!(f1.is_err());
assert!(f2.is_err());
assert!(result.0.is_err()); assert!(f3.is_err());
assert!(result.1.is_err());
assert!(result.2.is_err());
} }
} }