use std::fmt; use std::marker::PhantomData; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::{NewService, Service}; use futures::{future::ok, future::FutureResult, Async, Future, Poll}; use std::sync::Arc; use tokio_rustls::{client::TlsStream, rustls::ClientConfig, Connect, TlsConnector}; use webpki::DNSNameRef; use crate::{Address, Connection}; /// Rustls connector factory pub struct RustlsConnector { connector: Arc, _t: PhantomData<(T, U)>, } impl RustlsConnector { pub fn new(connector: Arc) -> Self { RustlsConnector { connector, _t: PhantomData, } } } impl RustlsConnector where T: Address, U: AsyncRead + AsyncWrite + fmt::Debug, { pub fn service( connector: Arc, ) -> impl Service< Request = Connection, Response = Connection>, Error = std::io::Error, > { RustlsConnectorService { connector: connector, _t: PhantomData, } } } impl Clone for RustlsConnector { fn clone(&self) -> Self { Self { connector: self.connector.clone(), _t: PhantomData, } } } impl NewService for RustlsConnector where U: AsyncRead + AsyncWrite + fmt::Debug, { type Request = Connection; type Response = Connection>; type Error = std::io::Error; type Config = (); type Service = RustlsConnectorService; type InitError = (); type Future = FutureResult; fn new_service(&self, _: &()) -> Self::Future { ok(RustlsConnectorService { connector: self.connector.clone(), _t: PhantomData, }) } } pub struct RustlsConnectorService { connector: Arc, _t: PhantomData<(T, U)>, } impl Service for RustlsConnectorService where U: AsyncRead + AsyncWrite + fmt::Debug, { type Request = Connection; type Response = Connection>; type Error = std::io::Error; type Future = ConnectAsyncExt; fn poll_ready(&mut self) -> Poll<(), Self::Error> { Ok(Async::Ready(())) } fn call(&mut self, stream: Connection) -> Self::Future { trace!("SSL Handshake start for: {:?}", stream.host()); let (io, stream) = stream.replace(()); let host = DNSNameRef::try_from_ascii_str(stream.host()) .expect("rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54"); ConnectAsyncExt { fut: TlsConnector::from(self.connector.clone()).connect(host, io), stream: Some(stream), } } } pub struct ConnectAsyncExt { fut: Connect, stream: Option>, } impl Future for ConnectAsyncExt where U: AsyncRead + AsyncWrite + fmt::Debug, { type Item = Connection>; type Error = std::io::Error; fn poll(&mut self) -> Poll { match self.fut.poll().map_err(|e| { trace!("SSL Handshake error: {:?}", e); e })? { Async::Ready(stream) => { let s = self.stream.take().unwrap(); trace!("SSL Handshake success: {:?}", s.host()); Ok(Async::Ready(s.replace(stream).1)) } Async::NotReady => Ok(Async::NotReady), } } }