From 5efea652e342479d6d0a20afc58bd880a691a712 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 17 Feb 2021 03:55:11 -0800 Subject: [PATCH] add ClientResponse::timeout (#1931) --- awc/CHANGES.md | 4 ++ awc/src/response.rs | 131 ++++++++++++++++++++++++++++++++------- awc/src/sender.rs | 32 +++++----- awc/tests/test_client.rs | 75 +++++++++++++++++++++- 4 files changed, 202 insertions(+), 40 deletions(-) diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 9224f414..c67f6556 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,9 +1,13 @@ # Changes ## Unreleased - 2021-xx-xx +### Added +* `ClientResponse::timeout` for set the timeout of collecting response body. [#1931] + ### Changed * Feature `cookies` is now optional and enabled by default. [#1981] +[#1931]: https://github.com/actix/actix-web/pull/1931 [#1981]: https://github.com/actix/actix-web/pull/1981 diff --git a/awc/src/response.rs b/awc/src/response.rs index cf687329..514b8a90 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -1,20 +1,22 @@ -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; use std::{ cell::{Ref, RefMut}, - mem, + fmt, + future::Future, + io, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, }; +use actix_http::{ + error::PayloadError, + http::{header, HeaderMap, StatusCode, Version}, + Extensions, HttpMessage, Payload, PayloadStream, ResponseHead, +}; +use actix_rt::time::{sleep, Sleep}; use bytes::{Bytes, BytesMut}; use futures_core::{ready, Stream}; - -use actix_http::error::PayloadError; -use actix_http::http::header; -use actix_http::http::{HeaderMap, StatusCode, Version}; -use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead}; use serde::de::DeserializeOwned; #[cfg(feature = "cookies")] @@ -26,6 +28,38 @@ use crate::error::JsonPayloadError; pub struct ClientResponse { pub(crate) head: ResponseHead, pub(crate) payload: Payload, + pub(crate) timeout: ResponseTimeout, +} + +/// helper enum with reusable sleep passed from `SendClientResponse`. +/// See `ClientResponse::_timeout` for reason. +pub(crate) enum ResponseTimeout { + Disabled(Option>>), + Enabled(Pin>), +} + +impl Default for ResponseTimeout { + fn default() -> Self { + Self::Disabled(None) + } +} + +impl ResponseTimeout { + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { + match *self { + Self::Enabled(ref mut timeout) => { + if timeout.as_mut().poll(cx).is_ready() { + Err(PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Response Payload IO timed out", + ))) + } else { + Ok(()) + } + } + Self::Disabled(_) => Ok(()), + } + } } impl HttpMessage for ClientResponse { @@ -35,6 +69,10 @@ impl HttpMessage for ClientResponse { &self.head.headers } + fn take_payload(&mut self) -> Payload { + std::mem::replace(&mut self.payload, Payload::None) + } + fn extensions(&self) -> Ref<'_, Extensions> { self.head.extensions() } @@ -43,10 +81,6 @@ impl HttpMessage for ClientResponse { self.head.extensions_mut() } - fn take_payload(&mut self) -> Payload { - mem::replace(&mut self.payload, Payload::None) - } - /// Load request cookies. #[cfg(feature = "cookies")] fn cookies(&self) -> Result>>, CookieParseError> { @@ -69,7 +103,11 @@ impl HttpMessage for ClientResponse { impl ClientResponse { /// Create new Request instance pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { - ClientResponse { head, payload } + ClientResponse { + head, + payload, + timeout: ResponseTimeout::default(), + } } #[inline] @@ -105,8 +143,43 @@ impl ClientResponse { ClientResponse { payload, head: self.head, + timeout: self.timeout, } } + + /// Set a timeout duration for [`ClientResponse`](self::ClientResponse). + /// + /// This duration covers the duration of processing the response body stream + /// and would end it as timeout error when deadline met. + /// + /// Disabled by default. + pub fn timeout(self, dur: Duration) -> Self { + let timeout = match self.timeout { + ResponseTimeout::Disabled(Some(mut timeout)) + | ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) { + Some(deadline) => { + timeout.as_mut().reset(deadline.into()); + ResponseTimeout::Enabled(timeout) + } + None => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }, + _ => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }; + + Self { + payload: self.payload, + head: self.head, + timeout, + } + } + + /// This method does not enable timeout. It's used to pass the boxed `Sleep` from + /// `SendClientRequest` and reuse it's heap allocation together with it's slot in + /// timer wheel. + pub(crate) fn _timeout(mut self, timeout: Option>>) -> Self { + self.timeout = ResponseTimeout::Disabled(timeout); + self + } } impl ClientResponse @@ -137,7 +210,10 @@ where type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().payload).poll_next(cx) + let this = self.get_mut(); + this.timeout.poll_timeout(cx)?; + + Pin::new(&mut this.payload).poll_next(cx) } } @@ -156,6 +232,7 @@ impl fmt::Debug for ClientResponse { pub struct MessageBody { length: Option, err: Option, + timeout: ResponseTimeout, fut: Option>, } @@ -181,6 +258,7 @@ where MessageBody { length: len, err: None, + timeout: std::mem::take(&mut res.timeout), fut: Some(ReadBody::new(res.take_payload(), 262_144)), } } @@ -198,6 +276,7 @@ where fut: None, err: Some(e), length: None, + timeout: ResponseTimeout::default(), } } } @@ -221,6 +300,8 @@ where } } + this.timeout.poll_timeout(cx)?; + Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } } @@ -234,6 +315,7 @@ where pub struct JsonBody { length: Option, err: Option, + timeout: ResponseTimeout, fut: Option>, _phantom: PhantomData, } @@ -244,9 +326,9 @@ where U: DeserializeOwned, { /// Create `JsonBody` for request. - pub fn new(req: &mut ClientResponse) -> Self { + pub fn new(res: &mut ClientResponse) -> Self { // check content-type - let json = if let Ok(Some(mime)) = req.mime_type() { + let json = if let Ok(Some(mime)) = res.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) } else { false @@ -255,13 +337,15 @@ where return JsonBody { length: None, fut: None, + timeout: ResponseTimeout::default(), err: Some(JsonPayloadError::ContentType), _phantom: PhantomData, }; } let mut len = None; - if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { + + if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { if let Ok(s) = l.to_str() { if let Ok(l) = s.parse::() { len = Some(l) @@ -272,7 +356,8 @@ where JsonBody { length: len, err: None, - fut: Some(ReadBody::new(req.take_payload(), 65536)), + timeout: std::mem::take(&mut res.timeout), + fut: Some(ReadBody::new(res.take_payload(), 65536)), _phantom: PhantomData, } } @@ -311,6 +396,10 @@ where } } + self.timeout + .poll_timeout(cx) + .map_err(JsonPayloadError::Payload)?; + let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index a72b129f..6bac401c 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -18,15 +18,11 @@ use actix_http::{ use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; use derive_more::From; -use futures_core::Stream; +use futures_core::{ready, Stream}; use serde::Serialize; #[cfg(feature = "compress")] -use actix_http::encoding::Decoder; -#[cfg(feature = "compress")] -use actix_http::http::header::ContentEncoding; -#[cfg(feature = "compress")] -use actix_http::{Payload, PayloadStream}; +use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream}; use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use crate::response::ClientResponse; @@ -61,7 +57,6 @@ impl From for SendRequestError { pub enum SendClientRequest { Fut( Pin>>>, - // FIXME: use a pinned Sleep instead of box. Option>>, bool, ), @@ -88,15 +83,14 @@ impl Future for SendClientRequest { match this { SendClientRequest::Fut(send, delay, response_decompress) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { - res.map_body(|head, payload| { + let res = ready!(send.as_mut().poll(cx)).map(|res| { + res._timeout(delay.take()).map_body(|head, payload| { if *response_decompress { Payload::Stream(Decoder::from_headers(payload, &head.headers)) } else { @@ -123,13 +117,15 @@ impl Future for SendClientRequest { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, _) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - Pin::new(send).poll(cx) + + send.as_mut() + .poll(cx) + .map_ok(|res| res._timeout(delay.take())) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index 7e74d226..bcbaf3f4 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -24,7 +24,7 @@ use actix_web::{ middleware::Compress, test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, }; -use awc::error::SendRequestError; +use awc::error::{JsonPayloadError, PayloadError, SendRequestError}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -157,6 +157,79 @@ async fn test_timeout_override() { } } +#[actix_rt::test] +async fn test_response_timeout() { + use futures_util::stream::{once, StreamExt}; + + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/json") + .streaming(Box::pin(once(async { + actix_rt::time::sleep(Duration::from_millis(200)).await; + Ok::<_, Error>(Bytes::from(STR)) + }))), + ) + }))) + }); + + let client = awc::Client::new(); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(500)) + .body() + .await + .unwrap(); + assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .next() + .await + .unwrap(); + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .body() + .await; + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .json::>() + .await; + match res { + Err(JsonPayloadError::Payload(PayloadError::Io(e))) => { + assert_eq!(e.kind(), std::io::ErrorKind::TimedOut) + } + _ => panic!("Response error type is not matched"), + } +} + #[actix_rt::test] async fn test_connection_reuse() { let num = Arc::new(AtomicUsize::new(0));