From bb7d33c9d41acef3b8d57dafc167c4077875374c Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 25 May 2021 10:21:20 +0800 Subject: [PATCH] refactor h2 dispatcher to async/await.reduce duplicate code (#2211) --- actix-http/src/h2/dispatcher.rs | 520 ++++++++++++-------------------- 1 file changed, 195 insertions(+), 325 deletions(-) diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index 5be172aaf..baff20e51 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -1,20 +1,26 @@ -use std::task::{Context, Poll}; -use std::{cmp, future::Future, marker::PhantomData, net, pin::Pin, rc::Rc}; +use std::{ + cmp, + future::Future, + marker::PhantomData, + net, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::Service; +use actix_utils::future::poll_fn; use bytes::{Bytes, BytesMut}; use futures_core::ready; -use h2::{ - server::{Connection, SendResponse}, - SendStream, -}; +use h2::server::{Connection, SendResponse}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use log::{error, trace}; +use pin_project_lite::pin_project; -use crate::body::{Body, BodySize, MessageBody}; +use crate::body::{BodySize, MessageBody}; use crate::config::ServiceConfig; -use crate::error::{DispatchError, Error}; +use crate::error::Error; use crate::message::ResponseHead; use crate::payload::Payload; use crate::request::Request; @@ -24,30 +30,19 @@ use crate::OnConnectData; const CHUNK_SIZE: usize = 16_384; -/// Dispatcher for HTTP/2 protocol. -#[pin_project::pin_project] -pub struct Dispatcher -where - T: AsyncRead + AsyncWrite + Unpin, - S: Service, - B: MessageBody, -{ - flow: Rc>, - connection: Connection, - on_connect_data: OnConnectData, - config: ServiceConfig, - peer_addr: Option, - _phantom: PhantomData, +pin_project! { + /// Dispatcher for HTTP/2 protocol. + pub struct Dispatcher { + flow: Rc>, + connection: Connection, + on_connect_data: OnConnectData, + config: ServiceConfig, + peer_addr: Option, + _phantom: PhantomData, + } } -impl Dispatcher -where - T: AsyncRead + AsyncWrite + Unpin, - S: Service, - S::Error: Into, - S::Response: Into>, - B: MessageBody, -{ +impl Dispatcher { pub(crate) fn new( flow: Rc>, connection: Connection, @@ -55,7 +50,7 @@ where config: ServiceConfig, peer_addr: Option, ) -> Self { - Dispatcher { + Self { flow, config, peer_addr, @@ -71,331 +66,206 @@ where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into + 'static, + S::Error: Into, S::Future: 'static, - S::Response: Into> + 'static, + S::Response: Into>, - B: MessageBody + 'static, + B: MessageBody, B::Error: Into, { - type Output = Result<(), DispatchError>; + type Output = Result<(), crate::error::DispatchError>; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - loop { - match ready!(Pin::new(&mut this.connection).poll_accept(cx)) { - None => return Poll::Ready(Ok(())), + while let Some((req, tx)) = + ready!(Pin::new(&mut this.connection).poll_accept(cx)?) + { + let (parts, body) = req.into_parts(); + let pl = crate::h2::Payload::new(body); + let pl = Payload::::H2(pl); + let mut req = Request::with_payload(pl); - Some(Err(err)) => return Poll::Ready(Err(err.into())), + let head = req.head_mut(); + head.uri = parts.uri; + head.method = parts.method; + head.version = parts.version; + head.headers = parts.headers.into(); + head.peer_addr = this.peer_addr; - Some(Ok((req, res))) => { - let (parts, body) = req.into_parts(); - let pl = crate::h2::Payload::new(body); - let pl = Payload::::H2(pl); - let mut req = Request::with_payload(pl); + // merge on_connect_ext data into request extensions + this.on_connect_data.merge_into(&mut req); - let head = req.head_mut(); - head.uri = parts.uri; - head.method = parts.method; - head.version = parts.version; - head.headers = parts.headers.into(); - head.peer_addr = this.peer_addr; + let fut = this.flow.service.call(req); + let config = this.config.clone(); - // merge on_connect_ext data into request extensions - this.on_connect_data.merge_into(&mut req); + // multiplex request handling with spawn task + actix_rt::spawn(async move { + // resolve service call and send response. + let res = match fut.await { + Ok(res) => handle_response(res.into(), tx, config).await, + Err(err) => { + let res = Response::from_error(err.into()); + handle_response(res, tx, config).await + } + }; - let svc = ServiceResponse { - state: ServiceResponseState::ServiceCall( - this.flow.service.call(req), - Some(res), - ), - config: this.config.clone(), - buffer: None, - _phantom: PhantomData, - }; + // log error. + if let Err(err) = res { + match err { + DispatchError::SendResponse(err) => { + trace!("Error sending HTTP/2 response: {:?}", err) + } + DispatchError::SendData(err) => warn!("{:?}", err), + DispatchError::ResponseBody(err) => { + error!("Response payload stream error: {:?}", err) + } + } + } + }); + } - actix_rt::spawn(svc); + Poll::Ready(Ok(())) + } +} + +enum DispatchError { + SendResponse(h2::Error), + SendData(h2::Error), + ResponseBody(Error), +} + +async fn handle_response( + res: Response, + mut tx: SendResponse, + config: ServiceConfig, +) -> Result<(), DispatchError> +where + B: MessageBody, + B::Error: Into, +{ + let (res, body) = res.replace_body(()); + + // prepare response. + let mut size = body.size(); + let res = prepare_response(config, res.head(), &mut size); + let eof = size.is_eof(); + + // send response head and return on eof. + let mut stream = tx + .send_response(res, eof) + .map_err(DispatchError::SendResponse)?; + + if eof { + return Ok(()); + } + + // poll response body and send chunks to client. + actix_rt::pin!(body); + + while let Some(res) = poll_fn(|cx| body.as_mut().poll_next(cx)).await { + let mut chunk = res.map_err(|err| DispatchError::ResponseBody(err.into()))?; + + 'send: loop { + // reserve enough space and wait for stream ready. + stream.reserve_capacity(cmp::min(chunk.len(), CHUNK_SIZE)); + + match poll_fn(|cx| stream.poll_capacity(cx)).await { + // No capacity left. drop body and return. + None => return Ok(()), + Some(res) => { + // Split chuck to writeable size and send to client. + let cap = res.map_err(DispatchError::SendData)?; + + let len = chunk.len(); + let bytes = chunk.split_to(cmp::min(cap, len)); + + stream + .send_data(bytes, false) + .map_err(DispatchError::SendData)?; + + // Current chuck completely sent. break send loop and poll next one. + if chunk.is_empty() { + break 'send; + } } } } } + + // response body streaming finished. send end of stream and return. + stream + .send_data(Bytes::new(), true) + .map_err(DispatchError::SendData)?; + + Ok(()) } -#[pin_project::pin_project] -struct ServiceResponse { - #[pin] - state: ServiceResponseState, +fn prepare_response( config: ServiceConfig, - buffer: Option, - _phantom: PhantomData<(I, E)>, -} + head: &ResponseHead, + size: &mut BodySize, +) -> http::Response<()> { + let mut has_date = false; + let mut skip_len = size != &BodySize::Stream; -#[pin_project::pin_project(project = ServiceResponseStateProj)] -enum ServiceResponseState { - ServiceCall(#[pin] F, Option>), - SendPayload(SendStream, #[pin] B), - SendErrorPayload(SendStream, #[pin] Body), -} + let mut res = http::Response::new(()); + *res.status_mut() = head.status; + *res.version_mut() = http::Version::HTTP_2; -impl ServiceResponse -where - F: Future>, - E: Into, - I: Into>, + // Content length + match head.status { + http::StatusCode::NO_CONTENT + | http::StatusCode::CONTINUE + | http::StatusCode::PROCESSING => *size = BodySize::None, + http::StatusCode::SWITCHING_PROTOCOLS => { + skip_len = true; + *size = BodySize::Stream; + } + _ => {} + } - B: MessageBody, - B::Error: Into, -{ - fn prepare_response( - &self, - head: &ResponseHead, - size: &mut BodySize, - ) -> http::Response<()> { - let mut has_date = false; - let mut skip_len = size != &BodySize::Stream; + let _ = match size { + BodySize::None | BodySize::Stream => None, + BodySize::Empty => res + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => { + let mut buf = itoa::Buffer::new(); - let mut res = http::Response::new(()); - *res.status_mut() = head.status; - *res.version_mut() = http::Version::HTTP_2; + res.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::from_str(buf.format(*len)).unwrap(), + ) + } + }; - // Content length - match head.status { - http::StatusCode::NO_CONTENT - | http::StatusCode::CONTINUE - | http::StatusCode::PROCESSING => *size = BodySize::None, - http::StatusCode::SWITCHING_PROTOCOLS => { - skip_len = true; - *size = BodySize::Stream; - } + // copy headers + for (key, value) in head.headers.iter() { + match *key { + // TODO: consider skipping other headers according to: + // https://tools.ietf.org/html/rfc7540#section-8.1.2.2 + // omit HTTP/1.x only headers + CONNECTION | TRANSFER_ENCODING => continue, + CONTENT_LENGTH if skip_len => continue, + DATE => has_date = true, _ => {} } - let _ = match size { - BodySize::None | BodySize::Stream => None, - BodySize::Empty => res - .headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), - BodySize::Sized(len) => { - let mut buf = itoa::Buffer::new(); - - res.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::from_str(buf.format(*len)).unwrap(), - ) - } - }; - - // copy headers - for (key, value) in head.headers.iter() { - match *key { - // TODO: consider skipping other headers according to: - // https://tools.ietf.org/html/rfc7540#section-8.1.2.2 - // omit HTTP/1.x only headers - CONNECTION | TRANSFER_ENCODING => continue, - CONTENT_LENGTH if skip_len => continue, - DATE => has_date = true, - _ => {} - } - - res.headers_mut().append(key, value.clone()); - } - - // set date header - if !has_date { - let mut bytes = BytesMut::with_capacity(29); - self.config.set_date_header(&mut bytes); - res.headers_mut().insert( - DATE, - // SAFETY: serialized date-times are known ASCII strings - unsafe { HeaderValue::from_maybe_shared_unchecked(bytes.freeze()) }, - ); - } - - res + res.headers_mut().append(key, value.clone()); } -} -impl Future for ServiceResponse -where - F: Future>, - E: Into, - I: Into>, - - B: MessageBody, - B::Error: Into, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().project(); - - match this.state.project() { - ServiceResponseStateProj::ServiceCall(call, send) => { - match ready!(call.poll(cx)) { - Ok(res) => { - let (res, body) = res.into().replace_body(()); - - let mut send = send.take().unwrap(); - let mut size = body.size(); - let h2_res = - self.as_mut().prepare_response(res.head(), &mut size); - this = self.as_mut().project(); - - let stream = match send.send_response(h2_res, size.is_eof()) { - Err(e) => { - trace!("Error sending HTTP/2 response: {:?}", e); - return Poll::Ready(()); - } - Ok(stream) => stream, - }; - - if size.is_eof() { - Poll::Ready(()) - } else { - this.state - .set(ServiceResponseState::SendPayload(stream, body)); - self.poll(cx) - } - } - - Err(err) => { - let res = Response::from_error(err.into()); - let (res, body) = res.replace_body(()); - - let mut send = send.take().unwrap(); - let mut size = body.size(); - let h2_res = - self.as_mut().prepare_response(res.head(), &mut size); - this = self.as_mut().project(); - - let stream = match send.send_response(h2_res, size.is_eof()) { - Err(e) => { - trace!("Error sending HTTP/2 response: {:?}", e); - return Poll::Ready(()); - } - Ok(stream) => stream, - }; - - if size.is_eof() { - Poll::Ready(()) - } else { - this.state.set(ServiceResponseState::SendErrorPayload( - stream, body, - )); - self.poll(cx) - } - } - } - } - - ServiceResponseStateProj::SendPayload(ref mut stream, ref mut body) => { - loop { - match this.buffer { - Some(ref mut buffer) => match ready!(stream.poll_capacity(cx)) { - None => return Poll::Ready(()), - - Some(Ok(cap)) => { - let len = buffer.len(); - let bytes = buffer.split_to(cmp::min(cap, len)); - - if let Err(e) = stream.send_data(bytes, false) { - warn!("{:?}", e); - return Poll::Ready(()); - } else if !buffer.is_empty() { - let cap = cmp::min(buffer.len(), CHUNK_SIZE); - stream.reserve_capacity(cap); - } else { - this.buffer.take(); - } - } - - Some(Err(e)) => { - warn!("{:?}", e); - return Poll::Ready(()); - } - }, - - None => match ready!(body.as_mut().poll_next(cx)) { - None => { - if let Err(e) = stream.send_data(Bytes::new(), true) { - warn!("{:?}", e); - } - return Poll::Ready(()); - } - - Some(Ok(chunk)) => { - stream - .reserve_capacity(cmp::min(chunk.len(), CHUNK_SIZE)); - *this.buffer = Some(chunk); - } - - Some(Err(err)) => { - error!( - "Response payload stream error: {:?}", - err.into() - ); - - return Poll::Ready(()); - } - }, - } - } - } - - ServiceResponseStateProj::SendErrorPayload(ref mut stream, ref mut body) => { - // TODO: de-dupe impl with SendPayload - - loop { - match this.buffer { - Some(ref mut buffer) => match ready!(stream.poll_capacity(cx)) { - None => return Poll::Ready(()), - - Some(Ok(cap)) => { - let len = buffer.len(); - let bytes = buffer.split_to(cmp::min(cap, len)); - - if let Err(e) = stream.send_data(bytes, false) { - warn!("{:?}", e); - return Poll::Ready(()); - } else if !buffer.is_empty() { - let cap = cmp::min(buffer.len(), CHUNK_SIZE); - stream.reserve_capacity(cap); - } else { - this.buffer.take(); - } - } - - Some(Err(e)) => { - warn!("{:?}", e); - return Poll::Ready(()); - } - }, - - None => match ready!(body.as_mut().poll_next(cx)) { - None => { - if let Err(e) = stream.send_data(Bytes::new(), true) { - warn!("{:?}", e); - } - return Poll::Ready(()); - } - - Some(Ok(chunk)) => { - stream - .reserve_capacity(cmp::min(chunk.len(), CHUNK_SIZE)); - *this.buffer = Some(chunk); - } - - Some(Err(err)) => { - error!("Response payload stream error: {:?}", err); - - return Poll::Ready(()); - } - }, - } - } - } - } + // set date header + if !has_date { + let mut bytes = BytesMut::with_capacity(29); + config.set_date_header(&mut bytes); + res.headers_mut().insert( + DATE, + // SAFETY: serialized date-times are known ASCII strings + unsafe { HeaderValue::from_maybe_shared_unchecked(bytes.freeze()) }, + ); } + + res }