use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, io, mem, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{Buf, Bytes}; use futures_util::future::{err, Either, FutureExt, LocalBoxFuture, Ready}; use h2::client::SendRequest; use pin_project::pin_project; use crate::body::MessageBody; use crate::h1::ClientCodec; use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; use super::error::SendRequestError; use super::pool::{Acquired, Protocol}; use super::{h1proto, h2proto}; pub(crate) enum ConnectionType { H1(Io), H2(SendRequest), } pub trait Connection { type Io: AsyncRead + AsyncWrite + Unpin; type Future: Future>; fn protocol(&self) -> Protocol; /// Send request and body fn send_request>( self, head: H, body: B, ) -> Self::Future; type TunnelFuture: Future< Output = Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed fn open_tunnel>(self, head: H) -> Self::TunnelFuture; } pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { /// Close connection fn close(&mut self); /// Release connection to the connection pool fn release(&mut self); } #[doc(hidden)] /// HTTP client connection pub struct IoConnection { io: Option>, created: time::Instant, pool: Option>, } impl fmt::Debug for IoConnection where T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.io { Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), None => write!(f, "Connection(Empty)"), } } } impl IoConnection { pub(crate) fn new( io: ConnectionType, created: time::Instant, pool: Option>, ) -> Self { IoConnection { pool, created, io: Some(io), } } pub(crate) fn into_inner(self) -> (ConnectionType, time::Instant) { (self.io.unwrap(), self.created) } } impl Connection for IoConnection where T: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = T; type Future = LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self.io { Some(ConnectionType::H1(_)) => Protocol::Http1, Some(ConnectionType::H2(_)) => Protocol::Http2, None => Protocol::Http1, } } fn send_request>( mut self, head: H, body: B, ) -> Self::Future { match self.io.take().unwrap() { ConnectionType::H1(io) => { h1proto::send_request(io, head.into(), body, self.created, self.pool) .boxed_local() } ConnectionType::H2(io) => { h2proto::send_request(io, head.into(), body, self.created, self.pool) .boxed_local() } } } type TunnelFuture = Either< LocalBoxFuture< 'static, Result<(ResponseHead, Framed), SendRequestError>, >, Ready), SendRequestError>>, >; /// Send request, returns Response and Framed fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { match self.io.take().unwrap() { ConnectionType::H1(io) => { Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) } ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { pool.release(IoConnection::new( ConnectionType::H2(io), self.created, None, )); } Either::Right(err(SendRequestError::TunnelNotSupported)) } } } } #[allow(dead_code)] pub(crate) enum EitherConnection { A(IoConnection), B(IoConnection), } impl Connection for EitherConnection where A: AsyncRead + AsyncWrite + Unpin + 'static, B: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = EitherIo; type Future = LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self { EitherConnection::A(con) => con.protocol(), EitherConnection::B(con) => con.protocol(), } } fn send_request>( self, head: H, body: RB, ) -> Self::Future { match self { EitherConnection::A(con) => con.send_request(head, body), EitherConnection::B(con) => con.send_request(head, body), } } type TunnelFuture = LocalBoxFuture< 'static, Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed fn open_tunnel>(self, head: H) -> Self::TunnelFuture { match self { EitherConnection::A(con) => con .open_tunnel(head) .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A)))) .boxed_local(), EitherConnection::B(con) => con .open_tunnel(head) .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B)))) .boxed_local(), } } } #[pin_project(project = EitherIoProj)] pub enum EitherIo { A(#[pin] A), B(#[pin] B), } impl AsyncRead for EitherIo where A: AsyncRead, B: AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_read(cx, buf), EitherIoProj::B(val) => val.poll_read(cx, buf), } } unsafe fn prepare_uninitialized_buffer( &self, buf: &mut [mem::MaybeUninit], ) -> bool { match self { EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf), } } } impl AsyncWrite for EitherIo where A: AsyncWrite, B: AsyncWrite, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_write(cx, buf), EitherIoProj::B(val) => val.poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_flush(cx), EitherIoProj::B(val) => val.poll_flush(cx), } } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_shutdown(cx), EitherIoProj::B(val) => val.poll_shutdown(cx), } } fn poll_write_buf( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut U, ) -> Poll> where Self: Sized, { match self.project() { EitherIoProj::A(val) => val.poll_write_buf(cx, buf), EitherIoProj::B(val) => val.poll_write_buf(cx, buf), } } }