1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-12-02 19:32:24 +01:00
actix-web/awc/src/response.rs

558 lines
16 KiB
Rust
Raw Normal View History

2021-02-13 16:08:43 +01:00
use std::{
cell::{Ref, RefMut},
2021-02-17 12:55:11 +01:00
fmt,
future::Future,
io,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
time::{Duration, Instant},
2021-02-13 16:08:43 +01:00
};
2021-02-17 12:55:11 +01:00
use actix_http::{
error::PayloadError,
http::{header, HeaderMap, StatusCode, Version},
Extensions, HttpMessage, Payload, PayloadStream, ResponseHead,
};
use actix_rt::time::{sleep, Sleep};
2019-03-27 04:45:00 +01:00
use bytes::{Bytes, BytesMut};
2020-05-18 04:46:57 +02:00
use futures_core::{ready, Stream};
2019-04-01 20:51:18 +02:00
use serde::de::DeserializeOwned;
2021-02-13 16:08:43 +01:00
#[cfg(feature = "cookies")]
use crate::cookie::{Cookie, ParseError as CookieParseError};
2019-04-01 20:51:18 +02:00
use crate::error::JsonPayloadError;
2019-03-27 05:31:18 +01:00
/// Client Response
2019-03-27 04:45:00 +01:00
pub struct ClientResponse<S = PayloadStream> {
pub(crate) head: ResponseHead,
2019-03-27 04:45:00 +01:00
pub(crate) payload: Payload<S>,
2021-02-17 12:55:11 +01:00
pub(crate) timeout: ResponseTimeout,
}
/// helper enum with reusable sleep passed from `SendClientResponse`.
/// See `ClientResponse::_timeout` for reason.
pub(crate) enum ResponseTimeout {
Disabled(Option<Pin<Box<Sleep>>>),
Enabled(Pin<Box<Sleep>>),
}
impl Default for ResponseTimeout {
fn default() -> Self {
Self::Disabled(None)
}
}
impl ResponseTimeout {
fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> {
match *self {
Self::Enabled(ref mut timeout) => {
if timeout.as_mut().poll(cx).is_ready() {
Err(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Response Payload IO timed out",
)))
} else {
Ok(())
}
}
Self::Disabled(_) => Ok(()),
}
}
}
2019-03-27 04:45:00 +01:00
impl<S> HttpMessage for ClientResponse<S> {
type Stream = S;
fn headers(&self) -> &HeaderMap {
&self.head.headers
}
2021-02-17 12:55:11 +01:00
fn take_payload(&mut self) -> Payload<S> {
std::mem::replace(&mut self.payload, Payload::None)
}
2019-12-08 07:31:16 +01:00
fn extensions(&self) -> Ref<'_, Extensions> {
self.head.extensions()
}
2019-12-08 07:31:16 +01:00
fn extensions_mut(&self) -> RefMut<'_, Extensions> {
self.head.extensions_mut()
}
}
2019-03-27 04:45:00 +01:00
impl<S> ClientResponse<S> {
/// Create new Request instance
2019-03-27 04:45:00 +01:00
pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self {
2021-02-17 12:55:11 +01:00
ClientResponse {
head,
payload,
timeout: ResponseTimeout::default(),
}
}
#[inline]
pub(crate) fn head(&self) -> &ResponseHead {
&self.head
}
/// Read the Request Version.
#[inline]
pub fn version(&self) -> Version {
self.head().version
}
/// Get the status from the server.
#[inline]
pub fn status(&self) -> StatusCode {
self.head().status
}
#[inline]
2019-04-02 22:35:01 +02:00
/// Returns request's headers.
pub fn headers(&self) -> &HeaderMap {
&self.head().headers
}
2019-03-27 04:45:00 +01:00
/// Set a body and return previous body value
pub fn map_body<F, U>(mut self, f: F) -> ClientResponse<U>
where
F: FnOnce(&mut ResponseHead, Payload<S>) -> Payload<U>,
{
let payload = f(&mut self.head, self.payload);
ClientResponse {
payload,
head: self.head,
2021-02-17 12:55:11 +01:00
timeout: self.timeout,
2019-03-27 04:45:00 +01:00
}
}
2021-02-17 12:55:11 +01:00
/// Set a timeout duration for [`ClientResponse`](self::ClientResponse).
///
/// This duration covers the duration of processing the response body stream
/// and would end it as timeout error when deadline met.
///
/// Disabled by default.
pub fn timeout(self, dur: Duration) -> Self {
let timeout = match self.timeout {
ResponseTimeout::Disabled(Some(mut timeout))
| ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) {
Some(deadline) => {
timeout.as_mut().reset(deadline.into());
ResponseTimeout::Enabled(timeout)
}
None => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
},
_ => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
};
Self {
payload: self.payload,
head: self.head,
timeout,
}
}
/// This method does not enable timeout. It's used to pass the boxed `Sleep` from
/// `SendClientRequest` and reuse it's heap allocation together with it's slot in
/// timer wheel.
pub(crate) fn _timeout(mut self, timeout: Option<Pin<Box<Sleep>>>) -> Self {
self.timeout = ResponseTimeout::Disabled(timeout);
self
}
/// Load request cookies.
#[cfg(feature = "cookies")]
pub fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
struct Cookies(Vec<Cookie<'static>>);
if self.extensions().get::<Cookies>().is_none() {
let mut cookies = Vec::new();
for hdr in self.headers().get_all(&header::SET_COOKIE) {
let s = std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?;
cookies.push(Cookie::parse_encoded(s)?.into_owned());
}
self.extensions_mut().insert(Cookies(cookies));
}
Ok(Ref::map(self.extensions(), |ext| {
&ext.get::<Cookies>().unwrap().0
}))
}
/// Return request cookie.
#[cfg(feature = "cookies")]
pub fn cookie(&self, name: &str) -> Option<Cookie<'static>> {
if let Ok(cookies) = self.cookies() {
for cookie in cookies.iter() {
if cookie.name() == name {
return Some(cookie.to_owned());
}
}
}
None
}
}
2019-03-27 04:45:00 +01:00
impl<S> ClientResponse<S>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
2019-03-27 04:45:00 +01:00
{
2021-02-11 23:39:54 +01:00
/// Loads HTTP response's body.
2019-04-01 20:29:26 +02:00
pub fn body(&mut self) -> MessageBody<S> {
2019-03-27 04:45:00 +01:00
MessageBody::new(self)
}
2019-04-01 20:51:18 +02:00
/// Loads and parse `application/json` encoded body.
/// Return `JsonBody<T>` future. It resolves to a `T` value.
///
/// Returns error:
///
/// * content type is not `application/json`
/// * content length is greater than 256k
pub fn json<T: DeserializeOwned>(&mut self) -> JsonBody<S, T> {
JsonBody::new(self)
}
2019-03-27 04:45:00 +01:00
}
impl<S> Stream for ClientResponse<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
2019-03-27 04:45:00 +01:00
{
type Item = Result<Bytes, PayloadError>;
2021-02-12 00:03:17 +01:00
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
2021-02-17 12:55:11 +01:00
let this = self.get_mut();
this.timeout.poll_timeout(cx)?;
Pin::new(&mut this.payload).poll_next(cx)
}
}
2019-03-27 04:45:00 +01:00
impl<S> fmt::Debug for ClientResponse<S> {
2019-12-08 07:31:16 +01:00
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?;
writeln!(f, " headers:")?;
for (key, val) in self.headers().iter() {
writeln!(f, " {:?}: {:?}", key, val)?;
}
Ok(())
}
}
2019-03-27 04:45:00 +01:00
const DEFAULT_BODY_LIMIT: usize = 2 * 1024 * 1024;
2021-02-11 23:39:54 +01:00
/// Future that resolves to a complete HTTP message body.
2019-03-27 04:45:00 +01:00
pub struct MessageBody<S> {
length: Option<usize>,
2021-02-17 12:55:11 +01:00
timeout: ResponseTimeout,
body: Result<ReadBody<S>, Option<PayloadError>>,
2019-03-27 04:45:00 +01:00
}
impl<S> MessageBody<S>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
2019-03-27 04:45:00 +01:00
{
/// Create `MessageBody` for request.
2019-04-01 20:29:26 +02:00
pub fn new(res: &mut ClientResponse<S>) -> MessageBody<S> {
let length = match res.headers().get(&header::CONTENT_LENGTH) {
Some(value) => {
let len = value.to_str().ok().and_then(|s| s.parse::<usize>().ok());
match len {
None => return Self::err(PayloadError::UnknownLength),
len => len,
2019-03-27 04:45:00 +01:00
}
}
None => None,
};
2019-03-27 04:45:00 +01:00
MessageBody {
length,
2021-02-17 12:55:11 +01:00
timeout: std::mem::take(&mut res.timeout),
body: Ok(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)),
2019-03-27 04:45:00 +01:00
}
}
/// Change max size of payload. By default max size is 2048kB
2019-03-27 04:45:00 +01:00
pub fn limit(mut self, limit: usize) -> Self {
if let Ok(ref mut body) = self.body {
body.limit = limit;
2019-04-02 19:53:44 +02:00
}
2019-03-27 04:45:00 +01:00
self
}
fn err(e: PayloadError) -> Self {
MessageBody {
length: None,
2021-02-17 12:55:11 +01:00
timeout: ResponseTimeout::default(),
body: Err(Some(e)),
2019-03-27 04:45:00 +01:00
}
}
}
impl<S> Future for MessageBody<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
2019-03-27 04:45:00 +01:00
{
type Output = Result<Bytes, PayloadError>;
2019-03-27 04:45:00 +01:00
2019-12-08 07:31:16 +01:00
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this.body {
Err(ref mut err) => Poll::Ready(Err(err.take().unwrap())),
Ok(ref mut body) => {
if let Some(len) = this.length.take() {
if len > body.limit {
return Poll::Ready(Err(PayloadError::Overflow));
}
}
2019-03-27 04:45:00 +01:00
this.timeout.poll_timeout(cx)?;
Pin::new(body).poll(cx)
2019-03-27 04:45:00 +01:00
}
}
}
}
2019-03-27 05:54:57 +01:00
2019-04-01 20:51:18 +02:00
/// Response's payload json parser, it resolves to a deserialized `T` value.
///
/// Returns error:
///
/// * content type is not `application/json`
/// * content length is greater than 64k
pub struct JsonBody<S, U> {
length: Option<usize>,
err: Option<JsonPayloadError>,
2021-02-17 12:55:11 +01:00
timeout: ResponseTimeout,
2019-04-02 19:53:44 +02:00
fut: Option<ReadBody<S>>,
2021-01-04 01:49:02 +01:00
_phantom: PhantomData<U>,
2019-04-01 20:51:18 +02:00
}
impl<S, U> JsonBody<S, U>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
2019-04-01 20:51:18 +02:00
U: DeserializeOwned,
{
/// Create `JsonBody` for request.
2021-02-17 12:55:11 +01:00
pub fn new(res: &mut ClientResponse<S>) -> Self {
2019-04-01 20:51:18 +02:00
// check content-type
2021-02-17 12:55:11 +01:00
let json = if let Ok(Some(mime)) = res.mime_type() {
2019-04-01 20:51:18 +02:00
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
} else {
false
};
if !json {
return JsonBody {
length: None,
fut: None,
2021-02-17 12:55:11 +01:00
timeout: ResponseTimeout::default(),
2019-04-01 20:51:18 +02:00
err: Some(JsonPayloadError::ContentType),
2021-01-04 01:49:02 +01:00
_phantom: PhantomData,
2019-04-01 20:51:18 +02:00
};
}
let mut len = None;
2021-02-17 12:55:11 +01:00
if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) {
2019-04-01 20:51:18 +02:00
if let Ok(s) = l.to_str() {
if let Ok(l) = s.parse::<usize>() {
len = Some(l)
}
}
}
JsonBody {
length: len,
err: None,
2021-02-17 12:55:11 +01:00
timeout: std::mem::take(&mut res.timeout),
fut: Some(ReadBody::new(res.take_payload(), 65536)),
2021-01-04 01:49:02 +01:00
_phantom: PhantomData,
2019-04-01 20:51:18 +02:00
}
}
/// Change max size of payload. By default max size is 64kB
2019-04-01 20:51:18 +02:00
pub fn limit(mut self, limit: usize) -> Self {
2019-04-02 19:53:44 +02:00
if let Some(ref mut fut) = self.fut {
fut.limit = limit;
}
2019-04-01 20:51:18 +02:00
self
}
}
impl<T, U> Unpin for JsonBody<T, U>
where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned,
{
}
2019-04-01 20:51:18 +02:00
impl<T, U> Future for JsonBody<T, U>
where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
2019-04-04 19:59:34 +02:00
U: DeserializeOwned,
2019-04-01 20:51:18 +02:00
{
type Output = Result<U, JsonPayloadError>;
2019-04-01 20:51:18 +02:00
2019-12-08 07:31:16 +01:00
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
2019-04-01 20:51:18 +02:00
if let Some(err) = self.err.take() {
return Poll::Ready(Err(err));
2019-04-01 20:51:18 +02:00
}
if let Some(len) = self.length.take() {
2019-04-02 19:53:44 +02:00
if len > self.fut.as_ref().unwrap().limit {
2021-02-12 00:03:17 +01:00
return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow)));
2019-04-01 20:51:18 +02:00
}
}
2021-02-17 12:55:11 +01:00
self.timeout
.poll_timeout(cx)
.map_err(JsonPayloadError::Payload)?;
let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?;
Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
2019-04-02 19:53:44 +02:00
}
}
struct ReadBody<S> {
stream: Payload<S>,
buf: BytesMut,
limit: usize,
}
impl<S> ReadBody<S> {
fn new(stream: Payload<S>, limit: usize) -> Self {
Self {
stream,
buf: BytesMut::new(),
2019-04-02 19:53:44 +02:00
limit,
}
}
}
impl<S> Future for ReadBody<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
2019-04-02 19:53:44 +02:00
{
type Output = Result<Bytes, PayloadError>;
2019-12-08 07:31:16 +01:00
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
2019-04-02 19:53:44 +02:00
while let Some(chunk) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) {
if (this.buf.len() + chunk.len()) > this.limit {
return Poll::Ready(Err(PayloadError::Overflow));
}
this.buf.extend_from_slice(&chunk);
2019-04-02 19:53:44 +02:00
}
Poll::Ready(Ok(this.buf.split().freeze()))
2019-04-01 20:51:18 +02:00
}
}
2019-03-27 05:54:57 +01:00
#[cfg(test)]
mod tests {
use super::*;
2019-04-01 20:51:18 +02:00
use serde::{Deserialize, Serialize};
2019-03-27 05:54:57 +01:00
2019-05-12 17:34:51 +02:00
use crate::{http::header, test::TestResponse};
2019-03-27 05:54:57 +01:00
2019-11-26 06:25:50 +01:00
#[actix_rt::test]
async fn test_body() {
let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "xxxx")).finish();
2019-11-26 06:25:50 +01:00
match req.body().await.err().unwrap() {
2021-01-04 02:01:35 +01:00
PayloadError::UnknownLength => {}
2019-11-26 06:25:50 +01:00
_ => unreachable!("error"),
}
2019-03-27 05:54:57 +01:00
let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "10000000")).finish();
2019-11-26 06:25:50 +01:00
match req.body().await.err().unwrap() {
2021-01-04 02:01:35 +01:00
PayloadError::Overflow => {}
2019-11-26 06:25:50 +01:00
_ => unreachable!("error"),
}
2019-03-27 05:54:57 +01:00
2019-11-26 06:25:50 +01:00
let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"test"))
.finish();
assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test"));
let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"11111111111111"))
.finish();
match req.body().limit(5).await.err().unwrap() {
2021-01-04 02:01:35 +01:00
PayloadError::Overflow => {}
2019-11-26 06:25:50 +01:00
_ => unreachable!("error"),
}
2019-03-27 05:54:57 +01:00
}
2019-04-01 20:51:18 +02:00
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct MyObject {
name: String,
}
fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
match err {
JsonPayloadError::Payload(PayloadError::Overflow) => {
matches!(other, JsonPayloadError::Payload(PayloadError::Overflow))
}
2021-02-28 19:17:08 +01:00
JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
2019-04-01 20:51:18 +02:00
_ => false,
}
}
2019-11-26 06:25:50 +01:00
#[actix_rt::test]
async fn test_json_body() {
let mut req = TestResponse::default().finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default()
.insert_header((
2019-11-26 06:25:50 +01:00
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"),
))
2019-11-26 06:25:50 +01:00
.finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default()
.insert_header((
2019-11-26 06:25:50 +01:00
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
))
.insert_header((
2019-11-26 06:25:50 +01:00
header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"),
))
2019-11-26 06:25:50 +01:00
.finish();
let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await;
assert!(json_eq(
json.err().unwrap(),
JsonPayloadError::Payload(PayloadError::Overflow)
));
let mut req = TestResponse::default()
.insert_header((
2019-11-26 06:25:50 +01:00
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
))
.insert_header((
2019-11-26 06:25:50 +01:00
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
))
2019-11-26 06:25:50 +01:00
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert_eq!(
json.ok().unwrap(),
MyObject {
name: "test".to_owned()
}
);
2019-04-01 20:51:18 +02:00
}
2019-03-27 05:54:57 +01:00
}