use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, io}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; use futures::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; use open_ssl::ssl::SslConnector; use tokio_openssl::{HandshakeError, SslStream}; use trust_dns_resolver::AsyncResolver; use crate::{ Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection, }; /// Openssl connector factory pub struct OpensslConnector { connector: SslConnector, _t: PhantomData<(T, U)>, } impl OpensslConnector { pub fn new(connector: SslConnector) -> Self { OpensslConnector { connector, _t: PhantomData, } } } impl OpensslConnector where T: Address + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { pub fn service(connector: SslConnector) -> OpensslConnectorService { OpensslConnectorService { connector, _t: PhantomData, } } } impl Clone for OpensslConnector { fn clone(&self) -> Self { Self { connector: self.connector.clone(), _t: PhantomData, } } } impl ServiceFactory for OpensslConnector where T: Address + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { type Request = Connection; type Response = Connection>; type Error = io::Error; type Config = (); type Service = OpensslConnectorService; type InitError = (); type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { ok(OpensslConnectorService { connector: self.connector.clone(), _t: PhantomData, }) } } pub struct OpensslConnectorService { connector: SslConnector, _t: PhantomData<(T, U)>, } impl Clone for OpensslConnectorService { fn clone(&self) -> Self { Self { connector: self.connector.clone(), _t: PhantomData, } } } impl Service for OpensslConnectorService where T: Address + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { type Request = Connection; type Response = Connection>; type Error = io::Error; type Future = Either, Ready>>; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, stream: Connection) -> Self::Future { trace!("SSL Handshake start for: {:?}", stream.host()); let (io, stream) = stream.replace(()); let host = stream.host().to_string(); match self.connector.configure() { Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))), Ok(config) => Either::Left(ConnectAsyncExt { fut: async move { tokio_openssl::connect(config, &host, io).await } .boxed_local(), stream: Some(stream), _t: PhantomData, }), } } } pub struct ConnectAsyncExt { fut: LocalBoxFuture<'static, Result, HandshakeError>>, stream: Option>, _t: PhantomData, } impl Future for ConnectAsyncExt where U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { type Output = Result>, io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match Pin::new(&mut this.fut).poll(cx) { Poll::Ready(Ok(stream)) => { let s = this.stream.take().unwrap(); trace!("SSL Handshake success: {:?}", s.host()); Poll::Ready(Ok(s.replace(stream).1)) } Poll::Ready(Err(e)) => { trace!("SSL Handshake error: {:?}", e); Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) } Poll::Pending => Poll::Pending, } } } pub struct OpensslConnectServiceFactory { tcp: ConnectServiceFactory, openssl: OpensslConnector, } impl OpensslConnectServiceFactory { /// Construct new OpensslConnectService factory pub fn new(connector: SslConnector) -> Self { OpensslConnectServiceFactory { tcp: ConnectServiceFactory::default(), openssl: OpensslConnector::new(connector), } } /// Construct new connect service with custom dns resolver pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self { OpensslConnectServiceFactory { tcp: ConnectServiceFactory::with_resolver(resolver), openssl: OpensslConnector::new(connector), } } /// Construct openssl connect service pub fn service(&self) -> OpensslConnectService { OpensslConnectService { tcp: self.tcp.service(), openssl: OpensslConnectorService { connector: self.openssl.connector.clone(), _t: PhantomData, }, } } } impl Clone for OpensslConnectServiceFactory { fn clone(&self) -> Self { OpensslConnectServiceFactory { tcp: self.tcp.clone(), openssl: self.openssl.clone(), } } } impl ServiceFactory for OpensslConnectServiceFactory { type Request = Connect; type Response = SslStream; type Error = ConnectError; type Config = (); type Service = OpensslConnectService; type InitError = (); type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { ok(self.service()) } } #[derive(Clone)] pub struct OpensslConnectService { tcp: ConnectService, openssl: OpensslConnectorService, } impl Service for OpensslConnectService { type Request = Connect; type Response = SslStream; type Error = ConnectError; type Future = OpensslConnectServiceResponse; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Connect) -> Self::Future { OpensslConnectServiceResponse { fut1: Some(self.tcp.call(req)), fut2: None, openssl: self.openssl.clone(), } } } pub struct OpensslConnectServiceResponse { fut1: Option< as Service>::Future>, fut2: Option< as Service>::Future>, openssl: OpensslConnectorService, } impl Future for OpensslConnectServiceResponse { type Output = Result, ConnectError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(ref mut fut) = self.fut1 { match futures::ready!(Pin::new(fut).poll(cx)) { Ok(res) => { let _ = self.fut1.take(); self.fut2 = Some(self.openssl.call(res)); } Err(e) => return Poll::Ready(Err(e)), } } if let Some(ref mut fut) = self.fut2 { match futures::ready!(Pin::new(fut).poll(cx)) { Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)), Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new( io::ErrorKind::Other, e, )))), } } else { Poll::Pending } } }