1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-18 13:51:50 +01:00

add ClientResponse::timeout (#1931)

This commit is contained in:
fakeshadow 2021-02-17 03:55:11 -08:00 committed by GitHub
parent dfa795ff9d
commit 5efea652e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 202 additions and 40 deletions

View File

@ -1,9 +1,13 @@
# Changes
## Unreleased - 2021-xx-xx
### Added
* `ClientResponse::timeout` for set the timeout of collecting response body. [#1931]
### Changed
* Feature `cookies` is now optional and enabled by default. [#1981]
[#1931]: https://github.com/actix/actix-web/pull/1931
[#1981]: https://github.com/actix/actix-web/pull/1981

View File

@ -1,20 +1,22 @@
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{
cell::{Ref, RefMut},
mem,
fmt,
future::Future,
io,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
time::{Duration, Instant},
};
use actix_http::{
error::PayloadError,
http::{header, HeaderMap, StatusCode, Version},
Extensions, HttpMessage, Payload, PayloadStream, ResponseHead,
};
use actix_rt::time::{sleep, Sleep};
use bytes::{Bytes, BytesMut};
use futures_core::{ready, Stream};
use actix_http::error::PayloadError;
use actix_http::http::header;
use actix_http::http::{HeaderMap, StatusCode, Version};
use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead};
use serde::de::DeserializeOwned;
#[cfg(feature = "cookies")]
@ -26,6 +28,38 @@ use crate::error::JsonPayloadError;
pub struct ClientResponse<S = PayloadStream> {
pub(crate) head: ResponseHead,
pub(crate) payload: Payload<S>,
pub(crate) timeout: ResponseTimeout,
}
/// helper enum with reusable sleep passed from `SendClientResponse`.
/// See `ClientResponse::_timeout` for reason.
pub(crate) enum ResponseTimeout {
Disabled(Option<Pin<Box<Sleep>>>),
Enabled(Pin<Box<Sleep>>),
}
impl Default for ResponseTimeout {
fn default() -> Self {
Self::Disabled(None)
}
}
impl ResponseTimeout {
fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> {
match *self {
Self::Enabled(ref mut timeout) => {
if timeout.as_mut().poll(cx).is_ready() {
Err(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Response Payload IO timed out",
)))
} else {
Ok(())
}
}
Self::Disabled(_) => Ok(()),
}
}
}
impl<S> HttpMessage for ClientResponse<S> {
@ -35,6 +69,10 @@ impl<S> HttpMessage for ClientResponse<S> {
&self.head.headers
}
fn take_payload(&mut self) -> Payload<S> {
std::mem::replace(&mut self.payload, Payload::None)
}
fn extensions(&self) -> Ref<'_, Extensions> {
self.head.extensions()
}
@ -43,10 +81,6 @@ impl<S> HttpMessage for ClientResponse<S> {
self.head.extensions_mut()
}
fn take_payload(&mut self) -> Payload<S> {
mem::replace(&mut self.payload, Payload::None)
}
/// Load request cookies.
#[cfg(feature = "cookies")]
fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
@ -69,7 +103,11 @@ impl<S> HttpMessage for ClientResponse<S> {
impl<S> ClientResponse<S> {
/// Create new Request instance
pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self {
ClientResponse { head, payload }
ClientResponse {
head,
payload,
timeout: ResponseTimeout::default(),
}
}
#[inline]
@ -105,8 +143,43 @@ impl<S> ClientResponse<S> {
ClientResponse {
payload,
head: self.head,
timeout: self.timeout,
}
}
/// Set a timeout duration for [`ClientResponse`](self::ClientResponse).
///
/// This duration covers the duration of processing the response body stream
/// and would end it as timeout error when deadline met.
///
/// Disabled by default.
pub fn timeout(self, dur: Duration) -> Self {
let timeout = match self.timeout {
ResponseTimeout::Disabled(Some(mut timeout))
| ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) {
Some(deadline) => {
timeout.as_mut().reset(deadline.into());
ResponseTimeout::Enabled(timeout)
}
None => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
},
_ => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
};
Self {
payload: self.payload,
head: self.head,
timeout,
}
}
/// This method does not enable timeout. It's used to pass the boxed `Sleep` from
/// `SendClientRequest` and reuse it's heap allocation together with it's slot in
/// timer wheel.
pub(crate) fn _timeout(mut self, timeout: Option<Pin<Box<Sleep>>>) -> Self {
self.timeout = ResponseTimeout::Disabled(timeout);
self
}
}
impl<S> ClientResponse<S>
@ -137,7 +210,10 @@ where
type Item = Result<Bytes, PayloadError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.get_mut().payload).poll_next(cx)
let this = self.get_mut();
this.timeout.poll_timeout(cx)?;
Pin::new(&mut this.payload).poll_next(cx)
}
}
@ -156,6 +232,7 @@ impl<S> fmt::Debug for ClientResponse<S> {
pub struct MessageBody<S> {
length: Option<usize>,
err: Option<PayloadError>,
timeout: ResponseTimeout,
fut: Option<ReadBody<S>>,
}
@ -181,6 +258,7 @@ where
MessageBody {
length: len,
err: None,
timeout: std::mem::take(&mut res.timeout),
fut: Some(ReadBody::new(res.take_payload(), 262_144)),
}
}
@ -198,6 +276,7 @@ where
fut: None,
err: Some(e),
length: None,
timeout: ResponseTimeout::default(),
}
}
}
@ -221,6 +300,8 @@ where
}
}
this.timeout.poll_timeout(cx)?;
Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
}
}
@ -234,6 +315,7 @@ where
pub struct JsonBody<S, U> {
length: Option<usize>,
err: Option<JsonPayloadError>,
timeout: ResponseTimeout,
fut: Option<ReadBody<S>>,
_phantom: PhantomData<U>,
}
@ -244,9 +326,9 @@ where
U: DeserializeOwned,
{
/// Create `JsonBody` for request.
pub fn new(req: &mut ClientResponse<S>) -> Self {
pub fn new(res: &mut ClientResponse<S>) -> Self {
// check content-type
let json = if let Ok(Some(mime)) = req.mime_type() {
let json = if let Ok(Some(mime)) = res.mime_type() {
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
} else {
false
@ -255,13 +337,15 @@ where
return JsonBody {
length: None,
fut: None,
timeout: ResponseTimeout::default(),
err: Some(JsonPayloadError::ContentType),
_phantom: PhantomData,
};
}
let mut len = None;
if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) {
if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) {
if let Ok(s) = l.to_str() {
if let Ok(l) = s.parse::<usize>() {
len = Some(l)
@ -272,7 +356,8 @@ where
JsonBody {
length: len,
err: None,
fut: Some(ReadBody::new(req.take_payload(), 65536)),
timeout: std::mem::take(&mut res.timeout),
fut: Some(ReadBody::new(res.take_payload(), 65536)),
_phantom: PhantomData,
}
}
@ -311,6 +396,10 @@ where
}
}
self.timeout
.poll_timeout(cx)
.map_err(JsonPayloadError::Payload)?;
let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?;
Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
}

View File

@ -18,15 +18,11 @@ use actix_http::{
use actix_rt::time::{sleep, Sleep};
use bytes::Bytes;
use derive_more::From;
use futures_core::Stream;
use futures_core::{ready, Stream};
use serde::Serialize;
#[cfg(feature = "compress")]
use actix_http::encoding::Decoder;
#[cfg(feature = "compress")]
use actix_http::http::header::ContentEncoding;
#[cfg(feature = "compress")]
use actix_http::{Payload, PayloadStream};
use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream};
use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError};
use crate::response::ClientResponse;
@ -61,7 +57,6 @@ impl From<PrepForSendingError> for SendRequestError {
pub enum SendClientRequest {
Fut(
Pin<Box<dyn Future<Output = Result<ClientResponse, SendRequestError>>>>,
// FIXME: use a pinned Sleep instead of box.
Option<Pin<Box<Sleep>>>,
bool,
),
@ -88,15 +83,14 @@ impl Future for SendClientRequest {
match this {
SendClientRequest::Fut(send, delay, response_decompress) => {
if delay.is_some() {
match Pin::new(delay.as_mut().unwrap()).poll(cx) {
Poll::Pending => {}
_ => return Poll::Ready(Err(SendRequestError::Timeout)),
if let Some(delay) = delay {
if delay.as_mut().poll(cx).is_ready() {
return Poll::Ready(Err(SendRequestError::Timeout));
}
}
let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| {
res.map_body(|head, payload| {
let res = ready!(send.as_mut().poll(cx)).map(|res| {
res._timeout(delay.take()).map_body(|head, payload| {
if *response_decompress {
Payload::Stream(Decoder::from_headers(payload, &head.headers))
} else {
@ -123,13 +117,15 @@ impl Future for SendClientRequest {
let this = self.get_mut();
match this {
SendClientRequest::Fut(send, delay, _) => {
if delay.is_some() {
match Pin::new(delay.as_mut().unwrap()).poll(cx) {
Poll::Pending => {}
_ => return Poll::Ready(Err(SendRequestError::Timeout)),
if let Some(delay) = delay {
if delay.as_mut().poll(cx).is_ready() {
return Poll::Ready(Err(SendRequestError::Timeout));
}
}
Pin::new(send).poll(cx)
send.as_mut()
.poll(cx)
.map_ok(|res| res._timeout(delay.take()))
}
SendClientRequest::Err(ref mut e) => match e.take() {
Some(e) => Poll::Ready(Err(e)),

View File

@ -24,7 +24,7 @@ use actix_web::{
middleware::Compress,
test, web, App, Error, HttpMessage, HttpRequest, HttpResponse,
};
use awc::error::SendRequestError;
use awc::error::{JsonPayloadError, PayloadError, SendRequestError};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
@ -157,6 +157,79 @@ async fn test_timeout_override() {
}
}
#[actix_rt::test]
async fn test_response_timeout() {
use futures_util::stream::{once, StreamExt};
let srv = test::start(|| {
App::new().service(web::resource("/").route(web::to(|| async {
Ok::<_, Error>(
HttpResponse::Ok()
.content_type("application/json")
.streaming(Box::pin(once(async {
actix_rt::time::sleep(Duration::from_millis(200)).await;
Ok::<_, Error>(Bytes::from(STR))
}))),
)
})))
});
let client = awc::Client::new();
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(500))
.body()
.await
.unwrap();
assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR);
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(100))
.next()
.await
.unwrap();
match res {
Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut),
_ => panic!("Response error type is not matched"),
}
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(100))
.body()
.await;
match res {
Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut),
_ => panic!("Response error type is not matched"),
}
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(100))
.json::<HashMap<String, String>>()
.await;
match res {
Err(JsonPayloadError::Payload(PayloadError::Io(e))) => {
assert_eq!(e.kind(), std::io::ErrorKind::TimedOut)
}
_ => panic!("Response error type is not matched"),
}
}
#[actix_rt::test]
async fn test_connection_reuse() {
let num = Arc::new(AtomicUsize::new(0));