1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-24 00:21:08 +01:00

simplify client::connection::Connection trait (#1998)

This commit is contained in:
fakeshadow 2021-02-16 06:10:22 -08:00 committed by GitHub
parent 3e0a9b99ff
commit 117025a96b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 50 additions and 60 deletions

View File

@ -1,4 +1,3 @@
use std::future::Future;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -8,7 +7,6 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use actix_rt::task::JoinHandle; use actix_rt::task::JoinHandle;
use bytes::Bytes; use bytes::Bytes;
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::future::{err, Either, FutureExt, Ready};
use h2::client::SendRequest; use h2::client::SendRequest;
use pin_project::pin_project; use pin_project::pin_project;
@ -75,7 +73,6 @@ impl DerefMut for H2Connection {
pub trait Connection { pub trait Connection {
type Io: AsyncRead + AsyncWrite + Unpin; type Io: AsyncRead + AsyncWrite + Unpin;
type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol; fn protocol(&self) -> Protocol;
@ -84,14 +81,16 @@ pub trait Connection {
self, self,
head: H, head: H,
body: B, body: B,
) -> Self::Future; ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
type TunnelFuture: Future<
Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture; fn open_tunnel<H: Into<RequestHeadType> + 'static>(
self,
head: H,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>;
} }
pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static {
@ -154,8 +153,6 @@ where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Io = T; type Io = T;
type Future =
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self.io { match self.io {
@ -169,33 +166,35 @@ where
mut self, mut self,
head: H, head: H,
body: B, body: B,
) -> Self::Future { ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => { ConnectionType::H1(io) => Box::pin(h1proto::send_request(
h1proto::send_request(io, head.into(), body, self.created, self.pool) io,
.boxed_local() head.into(),
} body,
ConnectionType::H2(io) => { self.created,
h2proto::send_request(io, head.into(), body, self.created, self.pool) self.pool,
.boxed_local() )),
} ConnectionType::H2(io) => Box::pin(h2proto::send_request(
io,
head.into(),
body,
self.created,
self.pool,
)),
} }
} }
type TunnelFuture = Either<
LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>,
Ready<Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>>,
>;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture { fn open_tunnel<H: Into<RequestHeadType>>(
mut self,
head: H,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
> {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => { ConnectionType::H1(io) => Box::pin(h1proto::open_tunnel(io, head.into())),
Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local())
}
ConnectionType::H2(io) => { ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
pool.release(IoConnection::new( pool.release(IoConnection::new(
@ -204,7 +203,7 @@ where
None, None,
)); ));
} }
Either::Right(err(SendRequestError::TunnelNotSupported)) Box::pin(async { Err(SendRequestError::TunnelNotSupported) })
} }
} }
} }
@ -226,8 +225,6 @@ where
B: AsyncRead + AsyncWrite + Unpin + 'static, B: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Io = EitherIo<A, B>; type Io = EitherIo<A, B>;
type Future =
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self { match self {
@ -240,33 +237,30 @@ where
self, self,
head: H, head: H,
body: RB, body: RB,
) -> Self::Future { ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> {
match self { match self {
EitherConnection::A(con) => con.send_request(head, body), EitherConnection::A(con) => con.send_request(head, body),
EitherConnection::B(con) => con.send_request(head, body), EitherConnection::B(con) => con.send_request(head, body),
} }
} }
type TunnelFuture = LocalBoxFuture< /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType> + 'static>(
self,
head: H,
) -> LocalBoxFuture<
'static, 'static,
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>; > {
/// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture {
match self { match self {
EitherConnection::A(con) => con EitherConnection::A(con) => Box::pin(async {
.open_tunnel(head) let (head, framed) = con.open_tunnel(head).await?;
.map(|res| { Ok((head, framed.into_map_io(EitherIo::A)))
res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::A))) }),
}) EitherConnection::B(con) => Box::pin(async {
.boxed_local(), let (head, framed) = con.open_tunnel(head).await?;
EitherConnection::B(con) => con Ok((head, framed.into_map_io(EitherIo::B)))
.open_tunnel(head) }),
.map(|res| {
res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::B)))
})
.boxed_local(),
} }
} }
} }

View File

@ -8,7 +8,7 @@ use bytes::buf::BufMut;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_core::Stream; use futures_core::Stream;
use futures_util::future::poll_fn; use futures_util::future::poll_fn;
use futures_util::{pin_mut, SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::h1; use crate::h1;
@ -127,7 +127,7 @@ where
T: ConnectionLifetime + Unpin, T: ConnectionLifetime + Unpin,
B: MessageBody, B: MessageBody,
{ {
pin_mut!(body); actix_rt::pin!(body);
let mut eof = false; let mut eof = false;
while !eof { while !eof {

View File

@ -5,7 +5,6 @@ use std::time;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use bytes::Bytes; use bytes::Bytes;
use futures_util::future::poll_fn; use futures_util::future::poll_fn;
use futures_util::pin_mut;
use h2::{ use h2::{
client::{Builder, Connection, SendRequest}, client::{Builder, Connection, SendRequest},
SendStream, SendStream,
@ -131,7 +130,7 @@ async fn send_body<B: MessageBody>(
mut send: SendStream<Bytes>, mut send: SendStream<Bytes>,
) -> Result<(), SendRequestError> { ) -> Result<(), SendRequestError> {
let mut buf = None; let mut buf = None;
pin_mut!(body); actix_rt::pin!(body);
loop { loop {
if buf.is_none() { if buf.is_none() {
match poll_fn(|cx| body.as_mut().poll_next(cx)).await { match poll_fn(|cx| body.as_mut().poll_next(cx)).await {

View File

@ -52,7 +52,6 @@ impl ClientBuilder {
where where
T: Service<HttpConnect, Error = ConnectError> + 'static, T: Service<HttpConnect, Error = ConnectError> + 'static,
T::Response: Connection, T::Response: Connection,
<T::Response as Connection>::Future: 'static,
T::Future: 'static, T::Future: 'static,
{ {
self.connector = Some(Box::new(ConnectorWrapper(connector))); self.connector = Some(Box::new(ConnectorWrapper(connector)));

View File

@ -41,8 +41,6 @@ where
T: Service<ClientConnect, Error = ConnectError>, T: Service<ClientConnect, Error = ConnectError>,
T::Response: Connection, T::Response: Connection,
<T::Response as Connection>::Io: 'static, <T::Response as Connection>::Io: 'static,
<T::Response as Connection>::Future: 'static,
<T::Response as Connection>::TunnelFuture: 'static,
T::Future: 'static, T::Future: 'static,
{ {
fn send_request( fn send_request(