mirror of
https://github.com/fafhrd91/actix-web
synced 2025-06-25 22:49:21 +02:00
reduce duplicate code (#2020)
This commit is contained in:
@ -1,5 +1,8 @@
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_codec::{AsyncRead, AsyncWrite};
|
||||
@ -12,7 +15,7 @@ use actix_utils::timeout::{TimeoutError, TimeoutService};
|
||||
use http::Uri;
|
||||
|
||||
use super::config::ConnectorConfig;
|
||||
use super::connection::Connection;
|
||||
use super::connection::{Connection, EitherIoConnection};
|
||||
use super::error::ConnectError;
|
||||
use super::pool::{ConnectionPool, Protocol};
|
||||
use super::Connect;
|
||||
@ -55,7 +58,7 @@ pub struct Connector<T, U> {
|
||||
_phantom: PhantomData<U>,
|
||||
}
|
||||
|
||||
trait Io: AsyncRead + AsyncWrite + Unpin {}
|
||||
pub trait Io: AsyncRead + AsyncWrite + Unpin {}
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
|
||||
|
||||
impl Connector<(), ()> {
|
||||
@ -244,28 +247,43 @@ where
|
||||
self,
|
||||
) -> impl Service<Connect, Response = impl Connection, Error = ConnectError> + Clone
|
||||
{
|
||||
let tcp_service = TimeoutService::new(
|
||||
self.config.timeout,
|
||||
apply_fn(self.connector.clone(), |msg: Connect, srv| {
|
||||
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
|
||||
})
|
||||
.map_err(ConnectError::from)
|
||||
.map(|stream| (stream.into_parts().0, Protocol::Http1)),
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
TimeoutError::Service(e) => e,
|
||||
TimeoutError::Timeout => ConnectError::Timeout,
|
||||
});
|
||||
|
||||
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
|
||||
{
|
||||
let connector = TimeoutService::new(
|
||||
self.config.timeout,
|
||||
apply_fn(self.connector, |msg: Connect, srv| {
|
||||
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
|
||||
})
|
||||
.map_err(ConnectError::from)
|
||||
.map(|stream| (stream.into_parts().0, Protocol::Http1)),
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
TimeoutError::Service(e) => e,
|
||||
TimeoutError::Timeout => ConnectError::Timeout,
|
||||
});
|
||||
// A dummy service for annotate tls pool's type signature.
|
||||
pub type DummyService = Box<
|
||||
dyn Service<
|
||||
Connect,
|
||||
Response = (Box<dyn Io>, Protocol),
|
||||
Error = ConnectError,
|
||||
Future = futures_core::future::LocalBoxFuture<
|
||||
'static,
|
||||
Result<(Box<dyn Io>, Protocol), ConnectError>,
|
||||
>,
|
||||
>,
|
||||
>;
|
||||
|
||||
connect_impl::InnerConnector {
|
||||
InnerConnector::<_, DummyService, _, Box<dyn Io>> {
|
||||
tcp_pool: ConnectionPool::new(
|
||||
connector,
|
||||
tcp_service,
|
||||
self.config.no_disconnect_timeout(),
|
||||
),
|
||||
tls_pool: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "openssl", feature = "rustls"))]
|
||||
{
|
||||
const H2: &[u8] = b"h2";
|
||||
@ -328,172 +346,97 @@ where
|
||||
TimeoutError::Timeout => ConnectError::Timeout,
|
||||
});
|
||||
|
||||
let tcp_service = TimeoutService::new(
|
||||
self.config.timeout,
|
||||
apply_fn(self.connector, |msg: Connect, srv| {
|
||||
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
|
||||
})
|
||||
.map_err(ConnectError::from)
|
||||
.map(|stream| (stream.into_parts().0, Protocol::Http1)),
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
TimeoutError::Service(e) => e,
|
||||
TimeoutError::Timeout => ConnectError::Timeout,
|
||||
});
|
||||
|
||||
connect_impl::InnerConnector {
|
||||
InnerConnector {
|
||||
tcp_pool: ConnectionPool::new(
|
||||
tcp_service,
|
||||
self.config.no_disconnect_timeout(),
|
||||
),
|
||||
ssl_pool: ConnectionPool::new(ssl_service, self.config),
|
||||
tls_pool: Some(ConnectionPool::new(ssl_service, self.config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
|
||||
mod connect_impl {
|
||||
use std::task::{Context, Poll};
|
||||
struct InnerConnector<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
tcp_pool: ConnectionPool<S1, Io1>,
|
||||
tls_pool: Option<ConnectionPool<S2, Io2>>,
|
||||
}
|
||||
|
||||
use futures_core::future::LocalBoxFuture;
|
||||
|
||||
use super::*;
|
||||
use crate::client::connection::IoConnection;
|
||||
|
||||
pub(crate) struct InnerConnector<T, Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<Connect, Response = (Io, Protocol), Error = ConnectError> + 'static,
|
||||
{
|
||||
pub(crate) tcp_pool: ConnectionPool<T, Io>,
|
||||
}
|
||||
|
||||
impl<T, Io> Clone for InnerConnector<T, Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<Connect, Response = (Io, Protocol), Error = ConnectError> + 'static,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
InnerConnector {
|
||||
tcp_pool: self.tcp_pool.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Io> Service<Connect> for InnerConnector<T, Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<Connect, Response = (Io, Protocol), Error = ConnectError> + 'static,
|
||||
{
|
||||
type Response = IoConnection<Io>;
|
||||
type Error = ConnectError;
|
||||
type Future = LocalBoxFuture<'static, Result<IoConnection<Io>, ConnectError>>;
|
||||
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.tcp_pool.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&self, req: Connect) -> Self::Future {
|
||||
match req.uri.scheme_str() {
|
||||
Some("https") | Some("wss") => {
|
||||
Box::pin(async { Err(ConnectError::SslIsNotSupported) })
|
||||
}
|
||||
_ => self.tcp_pool.call(req),
|
||||
}
|
||||
impl<S1, S2, Io1, Io2> Clone for InnerConnector<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
InnerConnector {
|
||||
tcp_pool: self.tcp_pool.clone(),
|
||||
tls_pool: self.tls_pool.as_ref().cloned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "openssl", feature = "rustls"))]
|
||||
mod connect_impl {
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
impl<S1, S2, Io1, Io2> Service<Connect> for InnerConnector<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
type Response = EitherIoConnection<Io1, Io2>;
|
||||
type Error = ConnectError;
|
||||
type Future = InnerConnectorResponse<S1, S2, Io1, Io2>;
|
||||
|
||||
use super::*;
|
||||
use crate::client::connection::EitherIoConnection;
|
||||
|
||||
pub(crate) struct InnerConnector<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
pub(crate) tcp_pool: ConnectionPool<S1, Io1>,
|
||||
pub(crate) ssl_pool: ConnectionPool<S2, Io2>,
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.tcp_pool.poll_ready(cx)
|
||||
}
|
||||
|
||||
impl<S1, S2, Io1, Io2> Clone for InnerConnector<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
InnerConnector {
|
||||
tcp_pool: self.tcp_pool.clone(),
|
||||
ssl_pool: self.ssl_pool.clone(),
|
||||
}
|
||||
fn call(&self, req: Connect) -> Self::Future {
|
||||
match req.uri.scheme_str() {
|
||||
Some("https") | Some("wss") => match self.tls_pool {
|
||||
None => InnerConnectorResponse::SslIsNotSupported,
|
||||
Some(ref pool) => InnerConnectorResponse::Io2(pool.call(req)),
|
||||
},
|
||||
_ => InnerConnectorResponse::Io1(self.tcp_pool.call(req)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S1, S2, Io1, Io2> Service<Connect> for InnerConnector<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
type Response = EitherIoConnection<Io1, Io2>;
|
||||
type Error = ConnectError;
|
||||
type Future = InnerConnectorResponse<S1, S2, Io1, Io2>;
|
||||
#[pin_project::pin_project(project = InnerConnectorProj)]
|
||||
enum InnerConnectorResponse<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
Io1(#[pin] <ConnectionPool<S1, Io1> as Service<Connect>>::Future),
|
||||
Io2(#[pin] <ConnectionPool<S2, Io2> as Service<Connect>>::Future),
|
||||
SslIsNotSupported,
|
||||
}
|
||||
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.tcp_pool.poll_ready(cx)
|
||||
}
|
||||
impl<S1, S2, Io1, Io2> Future for InnerConnectorResponse<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
type Output = Result<EitherIoConnection<Io1, Io2>, ConnectError>;
|
||||
|
||||
fn call(&self, req: Connect) -> Self::Future {
|
||||
match req.uri.scheme_str() {
|
||||
Some("https") | Some("wss") => {
|
||||
InnerConnectorResponse::Io2(self.ssl_pool.call(req))
|
||||
}
|
||||
_ => InnerConnectorResponse::Io1(self.tcp_pool.call(req)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project::pin_project(project = InnerConnectorProj)]
|
||||
pub(crate) enum InnerConnectorResponse<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
Io1(#[pin] <ConnectionPool<S1, Io1> as Service<Connect>>::Future),
|
||||
Io2(#[pin] <ConnectionPool<S2, Io2> as Service<Connect>>::Future),
|
||||
}
|
||||
|
||||
impl<S1, S2, Io1, Io2> Future for InnerConnectorResponse<S1, S2, Io1, Io2>
|
||||
where
|
||||
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
|
||||
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
|
||||
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
type Output = Result<EitherIoConnection<Io1, Io2>, ConnectError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.project() {
|
||||
InnerConnectorProj::Io1(fut) => {
|
||||
fut.poll(cx).map_ok(EitherIoConnection::A)
|
||||
}
|
||||
InnerConnectorProj::Io2(fut) => {
|
||||
fut.poll(cx).map_ok(EitherIoConnection::B)
|
||||
}
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.project() {
|
||||
InnerConnectorProj::Io1(fut) => fut.poll(cx).map_ok(EitherIoConnection::A),
|
||||
InnerConnectorProj::Io2(fut) => fut.poll(cx).map_ok(EitherIoConnection::B),
|
||||
InnerConnectorProj::SslIsNotSupported => {
|
||||
Poll::Ready(Err(ConnectError::SslIsNotSupported))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user