diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs index 32b54ffcb..6e5957339 100644 --- a/awc/src/middleware/retry.rs +++ b/awc/src/middleware/retry.rs @@ -60,8 +60,8 @@ impl Retry { /// .finish(); ///``` pub fn policy(mut self, p: T) -> Self - where - T: IntoRetryPolicy, + where + T: IntoRetryPolicy, { self.0.policies.push(p.into_policy()); self @@ -79,8 +79,8 @@ pub trait IntoRetryPolicy { } impl IntoRetryPolicy for T -where - T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, + where + T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, { fn into_policy(self) -> RetryPolicy { RetryPolicy::Custom(Box::new(self)) @@ -94,8 +94,8 @@ impl IntoRetryPolicy for Vec { } impl Transform for Retry -where - S: Service + 'static, + where + S: Service + 'static, { type Transform = RetryService; @@ -116,8 +116,8 @@ pub struct RetryService { } impl Service for RetryService -where - S: Service + 'static, + where + S: Service + 'static, { type Response = S::Response; type Error = S::Error; @@ -133,237 +133,87 @@ where let max_retry = self.max_retry; Box::pin(async move { - let mut tries = 0; match req { ConnectRequest::Client(head, body, addr) => { - match body { - Body::Bytes(b) => { - loop { - let h = clone_request_head_type(&head); + for _ in 1..max_retry { + let h = clone_request_head_type(&head); - match connector - .call(ConnectRequest::Client( - h, - Body::Bytes(b.clone()), - addr, - )) - .await - { - Ok(res) => { - // ConnectResponse - match &res { - ConnectResponse::Client(ref r) => { - if is_valid_response( - policies.as_ref(), - r.status(), - r.headers(), - ) { - return Ok(res); - } + let result = connector + .call(ConnectRequest::Client(h, body_to_retry_body(&body), addr)) + .await; - if tries == max_retry { - return Ok(res); - } - - tries += 1; - } - ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response( - policies.as_ref(), - head.status, - head.headers(), - ) { - return Ok(res); - } - - if tries == max_retry { - return Ok(res); - } - - tries += 1; - } - } - } - // SendRequestError - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; + if let Ok(res) = result { + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { + return Ok(res); } } - } - } - Body::Empty => { - loop { - let h = clone_request_head_type(&head); - - match connector - .call(ConnectRequest::Client(h, Body::Empty, addr)) - .await - { - Ok(res) => { - // ConnectResponse - match &res { - ConnectResponse::Client(ref r) => { - if is_valid_response( - policies.as_ref(), - r.status(), - r.headers(), - ) { - return Ok(res); - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); - } - - tries += 1; - } - ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response( - policies.as_ref(), - head.status, - head.headers(), - ) { - return Ok(res); - } else { - if tries == max_retry { - log::debug!( - "Request max retry reached" - ); - return Ok(res); - } - - tries += 1; - } - } - } - } - // SendRequestError - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; - } - } - } - } - _ => { - log::debug!( - "Non cloneable body type given - defaulting to `Body::None`" - ); - loop { - let h = clone_request_head_type(&head); - - match connector - .call(ConnectRequest::Client(h, Body::None, addr)) - .await - { - Ok(res) => { - // ConnectResponse - match &res { - ConnectResponse::Client(ref r) => { - if is_valid_response( - policies.as_ref(), - r.status(), - r.headers(), - ) { - return Ok(res); - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); - } - - tries += 1; - } - ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response( - policies.as_ref(), - head.status, - head.headers(), - ) { - return Ok(res); - } else { - if tries == max_retry { - log::debug!( - "Request max retry reached" - ); - return Ok(res); - } - - tries += 1; - } - } - } - } - // SendRequestError - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { + return Ok(res); } } } } } + + // Exceed max retry so just return whatever response is received + log::debug!("Request max retry reached"); + connector.call(ConnectRequest::Client( + head, + body, + addr, + )) + .await } - ConnectRequest::Tunnel(head, addr) => loop { - let h = clone_request_head(&head); + ConnectRequest::Tunnel(head, addr) => { + for _ in 1..max_retry { + let h = clone_request_head(&head); - match connector.call(ConnectRequest::Tunnel(h, addr)).await { - Ok(res) => match &res { - ConnectResponse::Client(r) => { - if is_valid_response(&policies, r.status(), r.headers()) { - return Ok(res); + let result = connector.call(ConnectRequest::Tunnel(h, addr)).await; + + if let Ok(res) = result { + match &res { + ConnectResponse::Client(r) => { + if is_valid_response(&policies, r.status(), r.headers()) { + return Ok(res); + } } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); + ConnectResponse::Tunnel(head, _) => { + if is_valid_response(&policies, head.status, head.headers()) { + return Ok(res); + } } - - tries += 1; } - ConnectResponse::Tunnel(head, _) => { - if is_valid_response(&policies, head.status, head.headers()) { - return Ok(res); - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); - } - - tries += 1; - } - }, - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; } - } - }, + }; + + // Exceed max retry so just return whatever response is received + log::debug!("Request max retry reached"); + connector.call(ConnectRequest::Tunnel(head, addr)).await + } } }) } } +fn body_to_retry_body(body: &Body) -> Body { + match body { + Body::Empty => Body::Empty, + Body::Bytes(b) => Body::Bytes(b.clone()), + _ => Body::None + } +} + #[doc(hidden)] /// Clones [RequestHeadType] except for the extensions (not required for this middleware) fn clone_request_head_type(head_type: &RequestHeadType) -> RequestHeadType { @@ -430,6 +280,9 @@ mod tests { #[actix_rt::test] async fn test_basic_policy() { + std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); + env_logger::init(); + let client = ClientBuilder::new() .disable_redirects() .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR]))