mirror of
https://github.com/fafhrd91/actix-web
synced 2025-01-18 22:01:50 +01:00
fix h1 client for handling expect header request (#2049)
This commit is contained in:
parent
78384c3ff5
commit
880b863f95
@ -7,13 +7,15 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
|
|||||||
use bytes::buf::BufMut;
|
use bytes::buf::BufMut;
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use futures_core::Stream;
|
use futures_core::Stream;
|
||||||
use futures_util::future::poll_fn;
|
use futures_util::{future::poll_fn, SinkExt, StreamExt};
|
||||||
use futures_util::{SinkExt, StreamExt};
|
|
||||||
|
|
||||||
use crate::error::PayloadError;
|
use crate::error::PayloadError;
|
||||||
use crate::h1;
|
use crate::h1;
|
||||||
use crate::header::HeaderMap;
|
use crate::header::HeaderMap;
|
||||||
use crate::http::header::{IntoHeaderValue, HOST};
|
use crate::http::{
|
||||||
|
header::{IntoHeaderValue, EXPECT, HOST},
|
||||||
|
StatusCode,
|
||||||
|
};
|
||||||
use crate::message::{RequestHeadType, ResponseHead};
|
use crate::message::{RequestHeadType, ResponseHead};
|
||||||
use crate::payload::{Payload, PayloadStream};
|
use crate::payload::{Payload, PayloadStream};
|
||||||
|
|
||||||
@ -66,33 +68,72 @@ where
|
|||||||
io: Some(io),
|
io: Some(io),
|
||||||
};
|
};
|
||||||
|
|
||||||
// create Framed and send request
|
// create Framed and prepare sending request
|
||||||
let mut framed_inner = Framed::new(io, h1::ClientCodec::default());
|
let mut framed = Framed::new(io, h1::ClientCodec::default());
|
||||||
framed_inner.send((head, body.size()).into()).await?;
|
|
||||||
|
|
||||||
// send request body
|
// Check EXPECT header and enable expect handle flag accordingly.
|
||||||
match body.size() {
|
//
|
||||||
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {}
|
// RFC: https://tools.ietf.org/html/rfc7231#section-5.1.1
|
||||||
_ => send_body(body, Pin::new(&mut framed_inner)).await?,
|
let is_expect = if head.as_ref().headers.contains_key(EXPECT) {
|
||||||
};
|
match body.size() {
|
||||||
|
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
|
||||||
|
let pin_framed = Pin::new(&mut framed);
|
||||||
|
|
||||||
// read response and init read body
|
let force_close = !pin_framed.codec_ref().keepalive();
|
||||||
let res = Pin::new(&mut framed_inner).into_future().await;
|
release_connection(pin_framed, force_close);
|
||||||
let (head, framed) = if let (Some(result), framed) = res {
|
|
||||||
let item = result.map_err(SendRequestError::from)?;
|
// TODO: use a new variant or a new type better describing error violate
|
||||||
(item, framed)
|
// `Requirements for clients` session of above RFC
|
||||||
|
return Err(SendRequestError::Connect(ConnectError::Disconnected));
|
||||||
|
}
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(SendRequestError::from(ConnectError::Disconnected));
|
false
|
||||||
};
|
};
|
||||||
|
|
||||||
match framed.codec_ref().message_type() {
|
framed.send((head, body.size()).into()).await?;
|
||||||
|
|
||||||
|
let mut pin_framed = Pin::new(&mut framed);
|
||||||
|
|
||||||
|
// special handle for EXPECT request.
|
||||||
|
let (do_send, mut res_head) = if is_expect {
|
||||||
|
let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
|
||||||
|
.await
|
||||||
|
.ok_or(ConnectError::Disconnected)??;
|
||||||
|
|
||||||
|
// return response head in case status code is not continue
|
||||||
|
// and current head would be used as final response head.
|
||||||
|
(head.status == StatusCode::CONTINUE, Some(head))
|
||||||
|
} else {
|
||||||
|
(true, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
if do_send {
|
||||||
|
// send request body
|
||||||
|
match body.size() {
|
||||||
|
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {}
|
||||||
|
_ => send_body(body, pin_framed.as_mut()).await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
// read response and init read body
|
||||||
|
let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
|
||||||
|
.await
|
||||||
|
.ok_or(ConnectError::Disconnected)??;
|
||||||
|
|
||||||
|
res_head = Some(head);
|
||||||
|
}
|
||||||
|
|
||||||
|
let head = res_head.unwrap();
|
||||||
|
|
||||||
|
match pin_framed.codec_ref().message_type() {
|
||||||
h1::MessageType::None => {
|
h1::MessageType::None => {
|
||||||
let force_close = !framed.codec_ref().keepalive();
|
let force_close = !pin_framed.codec_ref().keepalive();
|
||||||
release_connection(framed, force_close);
|
release_connection(pin_framed, force_close);
|
||||||
Ok((head, Payload::None))
|
Ok((head, Payload::None))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let pl: PayloadStream = PlStream::new(framed_inner).boxed_local();
|
let pl: PayloadStream = Box::pin(PlStream::new(framed));
|
||||||
Ok((head, pl.into()))
|
Ok((head, pl.into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
use actix_http::{http, HttpService, Request, Response};
|
use actix_http::{
|
||||||
|
error, http, http::StatusCode, HttpMessage, HttpService, Request, Response,
|
||||||
|
};
|
||||||
use actix_http_test::test_server;
|
use actix_http_test::test_server;
|
||||||
use actix_service::ServiceFactoryExt;
|
use actix_service::ServiceFactoryExt;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::future::{self, ok};
|
use futures_util::{
|
||||||
|
future::{self, ok},
|
||||||
|
StreamExt,
|
||||||
|
};
|
||||||
|
|
||||||
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
|
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
|
||||||
Hello World Hello World Hello World Hello World Hello World \
|
Hello World Hello World Hello World Hello World Hello World \
|
||||||
@ -88,3 +93,55 @@ async fn test_with_query_parameter() {
|
|||||||
let response = request.send().await.unwrap();
|
let response = request.send().await.unwrap();
|
||||||
assert!(response.status().is_success());
|
assert!(response.status().is_success());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn test_h1_expect() {
|
||||||
|
let srv = test_server(move || {
|
||||||
|
HttpService::build()
|
||||||
|
.expect(|req: Request| async {
|
||||||
|
if req.headers().contains_key("AUTH") {
|
||||||
|
Ok(req)
|
||||||
|
} else {
|
||||||
|
Err(error::ErrorExpectationFailed("expect failed"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.h1(|req: Request| async move {
|
||||||
|
let (_, mut body) = req.into_parts();
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
while let Some(Ok(chunk)) = body.next().await {
|
||||||
|
buf.extend_from_slice(&chunk);
|
||||||
|
}
|
||||||
|
let str = std::str::from_utf8(&buf).unwrap();
|
||||||
|
assert_eq!(str, "expect body");
|
||||||
|
|
||||||
|
Ok::<_, ()>(Response::Ok().finish())
|
||||||
|
})
|
||||||
|
.tcp()
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// test expect without payload.
|
||||||
|
let request = srv
|
||||||
|
.request(http::Method::GET, srv.url("/"))
|
||||||
|
.insert_header(("Expect", "100-continue"));
|
||||||
|
|
||||||
|
let response = request.send().await;
|
||||||
|
assert!(response.is_err());
|
||||||
|
|
||||||
|
// test expect would fail to continue
|
||||||
|
let request = srv
|
||||||
|
.request(http::Method::GET, srv.url("/"))
|
||||||
|
.insert_header(("Expect", "100-continue"));
|
||||||
|
|
||||||
|
let response = request.send_body("expect body").await.unwrap();
|
||||||
|
assert_eq!(response.status(), StatusCode::EXPECTATION_FAILED);
|
||||||
|
|
||||||
|
// test exepct would continue
|
||||||
|
let request = srv
|
||||||
|
.request(http::Method::GET, srv.url("/"))
|
||||||
|
.insert_header(("Expect", "100-continue"))
|
||||||
|
.insert_header(("AUTH", "996"));
|
||||||
|
|
||||||
|
let response = request.send_body("expect body").await.unwrap();
|
||||||
|
assert!(response.status().is_success());
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user