1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-19 06:04:40 +01:00

refactor h2 dispatcher to async/await.reduce duplicate code (#2211)

This commit is contained in:
fakeshadow 2021-05-25 10:21:20 +08:00 committed by GitHub
parent 4598a7c0cc
commit bb7d33c9d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,20 +1,26 @@
use std::task::{Context, Poll}; use std::{
use std::{cmp, future::Future, marker::PhantomData, net, pin::Pin, rc::Rc}; cmp,
future::Future,
marker::PhantomData,
net,
pin::Pin,
rc::Rc,
task::{Context, Poll},
};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::Service; use actix_service::Service;
use actix_utils::future::poll_fn;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_core::ready; use futures_core::ready;
use h2::{ use h2::server::{Connection, SendResponse};
server::{Connection, SendResponse},
SendStream,
};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use log::{error, trace}; use log::{error, trace};
use pin_project_lite::pin_project;
use crate::body::{Body, BodySize, MessageBody}; use crate::body::{BodySize, MessageBody};
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::Error;
use crate::message::ResponseHead; use crate::message::ResponseHead;
use crate::payload::Payload; use crate::payload::Payload;
use crate::request::Request; use crate::request::Request;
@ -24,14 +30,9 @@ use crate::OnConnectData;
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
pin_project! {
/// Dispatcher for HTTP/2 protocol. /// Dispatcher for HTTP/2 protocol.
#[pin_project::pin_project] pub struct Dispatcher<T, S, B, X, U> {
pub struct Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
B: MessageBody,
{
flow: Rc<HttpFlow<S, X, U>>, flow: Rc<HttpFlow<S, X, U>>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
on_connect_data: OnConnectData, on_connect_data: OnConnectData,
@ -39,15 +40,9 @@ where
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
_phantom: PhantomData<B>, _phantom: PhantomData<B>,
} }
}
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Dispatcher<T, S, B, X, U> {
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
B: MessageBody,
{
pub(crate) fn new( pub(crate) fn new(
flow: Rc<HttpFlow<S, X, U>>, flow: Rc<HttpFlow<S, X, U>>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
@ -55,7 +50,7 @@ where
config: ServiceConfig, config: ServiceConfig,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
) -> Self { ) -> Self {
Dispatcher { Self {
flow, flow,
config, config,
peer_addr, peer_addr,
@ -71,26 +66,22 @@ where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody,
B::Error: Into<Error>, B::Error: Into<Error>,
{ {
type Output = Result<(), DispatchError>; type Output = Result<(), crate::error::DispatchError>;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
loop { while let Some((req, tx)) =
match ready!(Pin::new(&mut this.connection).poll_accept(cx)) { ready!(Pin::new(&mut this.connection).poll_accept(cx)?)
None => return Poll::Ready(Ok(())), {
Some(Err(err)) => return Poll::Ready(Err(err.into())),
Some(Ok((req, res))) => {
let (parts, body) = req.into_parts(); let (parts, body) = req.into_parts();
let pl = crate::h2::Payload::new(body); let pl = crate::h2::Payload::new(body);
let pl = Payload::<crate::payload::PayloadStream>::H2(pl); let pl = Payload::<crate::payload::PayloadStream>::H2(pl);
@ -106,50 +97,113 @@ where
// merge on_connect_ext data into request extensions // merge on_connect_ext data into request extensions
this.on_connect_data.merge_into(&mut req); this.on_connect_data.merge_into(&mut req);
let svc = ServiceResponse { let fut = this.flow.service.call(req);
state: ServiceResponseState::ServiceCall( let config = this.config.clone();
this.flow.service.call(req),
Some(res), // multiplex request handling with spawn task
), actix_rt::spawn(async move {
config: this.config.clone(), // resolve service call and send response.
buffer: None, let res = match fut.await {
_phantom: PhantomData, Ok(res) => handle_response(res.into(), tx, config).await,
Err(err) => {
let res = Response::from_error(err.into());
handle_response(res, tx, config).await
}
}; };
actix_rt::spawn(svc); // log error.
if let Err(err) = res {
match err {
DispatchError::SendResponse(err) => {
trace!("Error sending HTTP/2 response: {:?}", err)
}
DispatchError::SendData(err) => warn!("{:?}", err),
DispatchError::ResponseBody(err) => {
error!("Response payload stream error: {:?}", err)
} }
} }
} }
});
}
Poll::Ready(Ok(()))
} }
} }
#[pin_project::pin_project] enum DispatchError {
struct ServiceResponse<F, I, E, B> { SendResponse(h2::Error),
#[pin] SendData(h2::Error),
state: ServiceResponseState<F, B>, ResponseBody(Error),
}
async fn handle_response<B>(
res: Response<B>,
mut tx: SendResponse<Bytes>,
config: ServiceConfig, config: ServiceConfig,
buffer: Option<Bytes>, ) -> Result<(), DispatchError>
_phantom: PhantomData<(I, E)>,
}
#[pin_project::pin_project(project = ServiceResponseStateProj)]
enum ServiceResponseState<F, B> {
ServiceCall(#[pin] F, Option<SendResponse<Bytes>>),
SendPayload(SendStream<Bytes>, #[pin] B),
SendErrorPayload(SendStream<Bytes>, #[pin] Body),
}
impl<F, I, E, B> ServiceResponse<F, I, E, B>
where where
F: Future<Output = Result<I, E>>,
E: Into<Error>,
I: Into<Response<B>>,
B: MessageBody, B: MessageBody,
B::Error: Into<Error>, B::Error: Into<Error>,
{ {
let (res, body) = res.replace_body(());
// prepare response.
let mut size = body.size();
let res = prepare_response(config, res.head(), &mut size);
let eof = size.is_eof();
// send response head and return on eof.
let mut stream = tx
.send_response(res, eof)
.map_err(DispatchError::SendResponse)?;
if eof {
return Ok(());
}
// poll response body and send chunks to client.
actix_rt::pin!(body);
while let Some(res) = poll_fn(|cx| body.as_mut().poll_next(cx)).await {
let mut chunk = res.map_err(|err| DispatchError::ResponseBody(err.into()))?;
'send: loop {
// reserve enough space and wait for stream ready.
stream.reserve_capacity(cmp::min(chunk.len(), CHUNK_SIZE));
match poll_fn(|cx| stream.poll_capacity(cx)).await {
// No capacity left. drop body and return.
None => return Ok(()),
Some(res) => {
// Split chuck to writeable size and send to client.
let cap = res.map_err(DispatchError::SendData)?;
let len = chunk.len();
let bytes = chunk.split_to(cmp::min(cap, len));
stream
.send_data(bytes, false)
.map_err(DispatchError::SendData)?;
// Current chuck completely sent. break send loop and poll next one.
if chunk.is_empty() {
break 'send;
}
}
}
}
}
// response body streaming finished. send end of stream and return.
stream
.send_data(Bytes::new(), true)
.map_err(DispatchError::SendData)?;
Ok(())
}
fn prepare_response( fn prepare_response(
&self, config: ServiceConfig,
head: &ResponseHead, head: &ResponseHead,
size: &mut BodySize, size: &mut BodySize,
) -> http::Response<()> { ) -> http::Response<()> {
@ -205,7 +259,7 @@ where
// set date header // set date header
if !has_date { if !has_date {
let mut bytes = BytesMut::with_capacity(29); let mut bytes = BytesMut::with_capacity(29);
self.config.set_date_header(&mut bytes); config.set_date_header(&mut bytes);
res.headers_mut().insert( res.headers_mut().insert(
DATE, DATE,
// SAFETY: serialized date-times are known ASCII strings // SAFETY: serialized date-times are known ASCII strings
@ -215,187 +269,3 @@ where
res res
} }
}
impl<F, I, E, B> Future for ServiceResponse<F, I, E, B>
where
F: Future<Output = Result<I, E>>,
E: Into<Error>,
I: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Error>,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
match this.state.project() {
ServiceResponseStateProj::ServiceCall(call, send) => {
match ready!(call.poll(cx)) {
Ok(res) => {
let (res, body) = res.into().replace_body(());
let mut send = send.take().unwrap();
let mut size = body.size();
let h2_res =
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = match send.send_response(h2_res, size.is_eof()) {
Err(e) => {
trace!("Error sending HTTP/2 response: {:?}", e);
return Poll::Ready(());
}
Ok(stream) => stream,
};
if size.is_eof() {
Poll::Ready(())
} else {
this.state
.set(ServiceResponseState::SendPayload(stream, body));
self.poll(cx)
}
}
Err(err) => {
let res = Response::from_error(err.into());
let (res, body) = res.replace_body(());
let mut send = send.take().unwrap();
let mut size = body.size();
let h2_res =
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = match send.send_response(h2_res, size.is_eof()) {
Err(e) => {
trace!("Error sending HTTP/2 response: {:?}", e);
return Poll::Ready(());
}
Ok(stream) => stream,
};
if size.is_eof() {
Poll::Ready(())
} else {
this.state.set(ServiceResponseState::SendErrorPayload(
stream, body,
));
self.poll(cx)
}
}
}
}
ServiceResponseStateProj::SendPayload(ref mut stream, ref mut body) => {
loop {
match this.buffer {
Some(ref mut buffer) => match ready!(stream.poll_capacity(cx)) {
None => return Poll::Ready(()),
Some(Ok(cap)) => {
let len = buffer.len();
let bytes = buffer.split_to(cmp::min(cap, len));
if let Err(e) = stream.send_data(bytes, false) {
warn!("{:?}", e);
return Poll::Ready(());
} else if !buffer.is_empty() {
let cap = cmp::min(buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap);
} else {
this.buffer.take();
}
}
Some(Err(e)) => {
warn!("{:?}", e);
return Poll::Ready(());
}
},
None => match ready!(body.as_mut().poll_next(cx)) {
None => {
if let Err(e) = stream.send_data(Bytes::new(), true) {
warn!("{:?}", e);
}
return Poll::Ready(());
}
Some(Ok(chunk)) => {
stream
.reserve_capacity(cmp::min(chunk.len(), CHUNK_SIZE));
*this.buffer = Some(chunk);
}
Some(Err(err)) => {
error!(
"Response payload stream error: {:?}",
err.into()
);
return Poll::Ready(());
}
},
}
}
}
ServiceResponseStateProj::SendErrorPayload(ref mut stream, ref mut body) => {
// TODO: de-dupe impl with SendPayload
loop {
match this.buffer {
Some(ref mut buffer) => match ready!(stream.poll_capacity(cx)) {
None => return Poll::Ready(()),
Some(Ok(cap)) => {
let len = buffer.len();
let bytes = buffer.split_to(cmp::min(cap, len));
if let Err(e) = stream.send_data(bytes, false) {
warn!("{:?}", e);
return Poll::Ready(());
} else if !buffer.is_empty() {
let cap = cmp::min(buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap);
} else {
this.buffer.take();
}
}
Some(Err(e)) => {
warn!("{:?}", e);
return Poll::Ready(());
}
},
None => match ready!(body.as_mut().poll_next(cx)) {
None => {
if let Err(e) = stream.send_data(Bytes::new(), true) {
warn!("{:?}", e);
}
return Poll::Ready(());
}
Some(Ok(chunk)) => {
stream
.reserve_capacity(cmp::min(chunk.len(), CHUNK_SIZE));
*this.buffer = Some(chunk);
}
Some(Err(err)) => {
error!("Response payload stream error: {:?}", err);
return Poll::Ready(());
}
},
}
}
}
}
}
}