From 58904b9ebc69d57511ae4d718b21ef6dcb3b2b7c Mon Sep 17 00:00:00 2001 From: joshbenaron <73971531+joshbenaron@users.noreply.github.com> Date: Fri, 2 Apr 2021 14:50:43 +0100 Subject: [PATCH] Fixed invalid tests --- awc/src/middleware/retry.rs | 228 ++++++++++++++++++++---------------- 1 file changed, 128 insertions(+), 100 deletions(-) diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs index 27bb3a1c3..32b54ffcb 100644 --- a/awc/src/middleware/retry.rs +++ b/awc/src/middleware/retry.rs @@ -1,15 +1,15 @@ use super::Transform; -use std::rc::Rc; -use actix_http::RequestHeadType; -use actix_http::http::{StatusCode, HeaderMap}; -use std::ops::Deref; -use crate::{ConnectRequest, ConnectResponse}; -use actix_service::Service; -use actix_http::client::SendRequestError; -use std::task::{Context, Poll}; use crate::RequestHead; -use futures_core::future::LocalBoxFuture; +use crate::{ConnectRequest, ConnectResponse}; use actix_http::body::Body; +use actix_http::client::SendRequestError; +use actix_http::http::{HeaderMap, StatusCode}; +use actix_http::RequestHeadType; +use actix_service::Service; +use futures_core::future::LocalBoxFuture; +use std::ops::Deref; +use std::rc::Rc; +use std::task::{Context, Poll}; pub struct Retry(Inner); @@ -60,7 +60,8 @@ impl Retry { /// .finish(); ///``` pub fn policy(mut self, p: T) -> Self - where T: IntoRetryPolicy + where + T: IntoRetryPolicy, { self.0.policies.push(p.into_policy()); self @@ -78,7 +79,8 @@ pub trait IntoRetryPolicy { } impl IntoRetryPolicy for T - where T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static +where + T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, { fn into_policy(self) -> RetryPolicy { RetryPolicy::Custom(Box::new(self)) @@ -92,8 +94,8 @@ impl IntoRetryPolicy for Vec { } impl Transform for Retry - where - S: Service + 'static, +where + S: Service + 'static, { type Transform = RetryService; @@ -114,8 +116,8 @@ pub struct RetryService { } impl Service for RetryService - where - S: Service + 'static, +where + S: Service + 'static, { type Response = S::Response; type Error = S::Error; @@ -139,13 +141,23 @@ impl Service for RetryService loop { let h = clone_request_head_type(&head); - match connector.call(ConnectRequest::Client(h, Body::Bytes(b.clone()), addr)).await + match connector + .call(ConnectRequest::Client( + h, + Body::Bytes(b.clone()), + addr, + )) + .await { Ok(res) => { // ConnectResponse match &res { ConnectResponse::Client(ref r) => { - if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { return Ok(res); } @@ -156,7 +168,11 @@ impl Service for RetryService tries += 1; } ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response(policies.as_ref(), head.status, head.headers()) { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { return Ok(res); } @@ -184,13 +200,19 @@ impl Service for RetryService loop { let h = clone_request_head_type(&head); - match connector.call(ConnectRequest::Client(h, Body::Empty, addr)).await + match connector + .call(ConnectRequest::Client(h, Body::Empty, addr)) + .await { Ok(res) => { // ConnectResponse match &res { ConnectResponse::Client(ref r) => { - if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { return Ok(res); } @@ -202,11 +224,17 @@ impl Service for RetryService tries += 1; } ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response(policies.as_ref(), head.status, head.headers()) { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { return Ok(res); } else { if tries == max_retry { - log::debug!("Request max retry reached"); + log::debug!( + "Request max retry reached" + ); return Ok(res); } @@ -228,17 +256,25 @@ impl Service for RetryService } } _ => { - log::debug!("Non cloneable body type given - defaulting to `Body::None`"); + log::debug!( + "Non cloneable body type given - defaulting to `Body::None`" + ); loop { let h = clone_request_head_type(&head); - match connector.call(ConnectRequest::Client(h, Body::None, addr)).await + match connector + .call(ConnectRequest::Client(h, Body::None, addr)) + .await { Ok(res) => { // ConnectResponse match &res { ConnectResponse::Client(ref r) => { - if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { return Ok(res); } @@ -250,11 +286,17 @@ impl Service for RetryService tries += 1; } ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response(policies.as_ref(), head.status, head.headers()) { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { return Ok(res); } else { if tries == max_retry { - log::debug!("Request max retry reached"); + log::debug!( + "Request max retry reached" + ); return Ok(res); } @@ -277,50 +319,46 @@ impl Service for RetryService } } } - ConnectRequest::Tunnel(head, addr) => { - loop { - let h = clone_request_head(&head); + ConnectRequest::Tunnel(head, addr) => loop { + let h = clone_request_head(&head); - match connector.call(ConnectRequest::Tunnel(h, addr)).await { - Ok(res) => { - match &res { - ConnectResponse::Client(r) => { - if is_valid_response(&policies, r.status(), r.headers()) { - return Ok(res) - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res) - } - - tries += 1; - } - ConnectResponse::Tunnel(head, _) => { - if is_valid_response(&policies, head.status, head.headers()) { - return Ok(res) - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res) - } - - tries += 1; - } + match connector.call(ConnectRequest::Tunnel(h, addr)).await { + Ok(res) => match &res { + ConnectResponse::Client(r) => { + if is_valid_response(&policies, r.status(), r.headers()) { + return Ok(res); } - }, - Err(e) => { + if tries == max_retry { log::debug!("Request max retry reached"); - return Err(e) + return Ok(res); } tries += 1; } + ConnectResponse::Tunnel(head, _) => { + if is_valid_response(&policies, head.status, head.headers()) { + return Ok(res); + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + }, + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; } } - } + }, } }) } @@ -361,7 +399,11 @@ fn clone_request_head(head: &RequestHead) -> RequestHead { #[doc(hidden)] /// Checks whether the response matches the policies -fn is_valid_response(policies: &[RetryPolicy], status_code: StatusCode, headers: &HeaderMap) -> bool { +fn is_valid_response( + policies: &[RetryPolicy], + status_code: StatusCode, + headers: &HeaderMap, +) -> bool { policies.iter().all(|policy| { match policy { RetryPolicy::Status(v) => { @@ -390,19 +432,13 @@ mod tests { async fn test_basic_policy() { let client = ClientBuilder::new() .disable_redirects() - .wrap(Retry::new(3) - .policy(vec![StatusCode::INTERNAL_SERVER_ERROR]) - ) + .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR])) .finish(); let srv = actix_test::start(|| { - App::new() - .service(web::resource("/test").route(web::to(|| async { - Ok::<_, Error>( - HttpResponse::InternalServerError() - .finish(), - ) - }))) + App::new().service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::InternalServerError().finish()) + }))) }); let res = client.get(srv.url("/test")).send().await.unwrap(); @@ -412,27 +448,23 @@ mod tests { #[actix_rt::test] async fn test_header_policy() { - std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); - env_logger::init(); - let client = ClientBuilder::new() .disable_redirects() - .wrap(Retry::new(3) - .policy(|code: StatusCode, headers: &HeaderMap| { + .wrap( + Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| { code.is_success() && headers.contains_key("SOME_HEADER") - }) + }), ) .finish(); let srv = actix_test::start(|| { - App::new() - .service(web::resource("/test").route(web::to(|| async { - Ok::<_, Error>( - HttpResponse::Ok() - .insert_header(("SOME_HEADER", "test")) - .finish(), - ) - }))) + App::new().service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) }); let res = client.get(srv.url("/test")).send().await.unwrap(); @@ -442,27 +474,23 @@ mod tests { #[actix_rt::test] async fn test_bad_header_policy() { - std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); - env_logger::init(); - let client = ClientBuilder::new() .disable_redirects() - .wrap(Retry::new(3) - .policy(|code: StatusCode, headers: &HeaderMap| { + .wrap( + Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| { code.is_success() && headers.contains_key("WRONG_HEADER") - }) + }), ) .finish(); let srv = actix_test::start(|| { - App::new() - .service(web::resource("/test").route(web::to(|| async { - Ok::<_, Error>( - HttpResponse::Ok() - .insert_header(("SOME_HEADER", "test")) - .finish(), - ) - }))) + App::new().service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) }); let res = client.get(srv.url("/test")).send().await.unwrap();