use std::fmt; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; pub use rust_tls::Session; pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::{Service, ServiceFactory}; use futures_util::future::{ok, Ready}; use tokio_rustls::{Connect, TlsConnector}; use webpki::DNSNameRef; use crate::{Address, Connection}; /// Rustls connector factory pub struct RustlsConnector { connector: Arc, _t: PhantomData<(T, U)>, } impl RustlsConnector { pub fn new(connector: Arc) -> Self { RustlsConnector { connector, _t: PhantomData, } } } impl RustlsConnector where T: Address, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, { pub fn service(connector: Arc) -> RustlsConnectorService { RustlsConnectorService { connector, _t: PhantomData, } } } impl Clone for RustlsConnector { fn clone(&self) -> Self { Self { connector: self.connector.clone(), _t: PhantomData, } } } impl ServiceFactory for RustlsConnector where U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, { type Request = Connection; type Response = Connection>; type Error = std::io::Error; type Config = (); type Service = RustlsConnectorService; type InitError = (); type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { ok(RustlsConnectorService { connector: self.connector.clone(), _t: PhantomData, }) } } pub struct RustlsConnectorService { connector: Arc, _t: PhantomData<(T, U)>, } impl Clone for RustlsConnectorService { fn clone(&self) -> Self { Self { connector: self.connector.clone(), _t: PhantomData, } } } impl Service for RustlsConnectorService where U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, { type Request = Connection; type Response = Connection>; type Error = std::io::Error; type Future = ConnectAsyncExt; actix_service::always_ready!(); fn call(&mut self, stream: Connection) -> Self::Future { trace!("SSL Handshake start for: {:?}", stream.host()); let (io, stream) = stream.replace(()); let host = DNSNameRef::try_from_ascii_str(stream.host()) .expect("rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54"); ConnectAsyncExt { fut: TlsConnector::from(self.connector.clone()).connect(host, io), stream: Some(stream), } } } pub struct ConnectAsyncExt { fut: Connect, stream: Option>, } impl Future for ConnectAsyncExt where U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, { type Output = Result>, std::io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); Poll::Ready( futures_util::ready!(Pin::new(&mut this.fut).poll(cx)).map(|stream| { let s = this.stream.take().unwrap(); trace!("SSL Handshake success: {:?}", s.host()); s.replace(stream).1 }), ) } }