use std::convert::Infallible; use std::marker::PhantomData; use std::task::{Context, Poll}; use actix_service::{Service, ServiceFactory}; use futures::future::{self, FutureExt as _, LocalBoxFuture, TryFutureExt as _}; use native_tls::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_tls::{TlsAcceptor, TlsStream}; use crate::counter::Counter; use crate::ssl::MAX_CONN_COUNTER; use crate::{Io, ServerConfig}; /// Support `SSL` connections via native-tls package /// /// `tls` feature enables `NativeTlsAcceptor` type pub struct NativeTlsAcceptor { acceptor: TlsAcceptor, io: PhantomData<(T, P)>, } impl NativeTlsAcceptor where T: AsyncRead + AsyncWrite + Unpin, { /// Create `NativeTlsAcceptor` instance #[inline] pub fn new(acceptor: TlsAcceptor) -> Self { NativeTlsAcceptor { acceptor, io: PhantomData, } } } impl Clone for NativeTlsAcceptor { #[inline] fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), io: PhantomData, } } } impl ServiceFactory for NativeTlsAcceptor where T: AsyncRead + AsyncWrite + Unpin + 'static, P: 'static, { type Request = Io; type Response = Io, P>; type Error = Error; type Config = ServerConfig; type Service = NativeTlsAcceptorService; type InitError = Infallible; type Future = future::Ready>; fn new_service(&self, cfg: &ServerConfig) -> Self::Future { cfg.set_secure(); MAX_CONN_COUNTER.with(|conns| { future::ok(NativeTlsAcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), io: PhantomData, }) }) } } pub struct NativeTlsAcceptorService { acceptor: TlsAcceptor, io: PhantomData<(T, P)>, conns: Counter, } impl Clone for NativeTlsAcceptorService { fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), io: PhantomData, conns: self.conns.clone(), } } } impl Service for NativeTlsAcceptorService where T: AsyncRead + AsyncWrite + Unpin + 'static, P: 'static, { type Request = Io; type Response = Io, P>; type Error = Error; type Future = LocalBoxFuture<'static, Result, P>, Error>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { if self.conns.available(cx) { Poll::Ready(Ok(())) } else { Poll::Pending } } fn call(&mut self, req: Self::Request) -> Self::Future { let guard = self.conns.get(); let this = self.clone(); let (io, params, proto) = req.into_parts(); async move { this.acceptor.accept(io).await } .map_ok(move |stream| Io::from_parts(stream, params, proto)) .map_ok(move |io| { // Required to preserve `CounterGuard` until `Self::Future` // is completely resolved. let _ = guard; io }) .boxed_local() } }