use std::io; use std::marker::PhantomData; use actix_service::{NewService, Service}; use futures::{future::ok, future::FutureResult, Async, Future, Poll}; use native_tls::{self, Error, HandshakeError, TlsAcceptor}; use tokio_io::{AsyncRead, AsyncWrite}; use crate::counter::{Counter, CounterGuard}; use crate::ssl::MAX_CONN_COUNTER; /// Support `SSL` connections via native-tls package /// /// `tls` feature enables `NativeTlsAcceptor` type pub struct NativeTlsAcceptor { acceptor: TlsAcceptor, io: PhantomData, } impl NativeTlsAcceptor { /// Create `NativeTlsAcceptor` instance pub fn new(acceptor: TlsAcceptor) -> Self { NativeTlsAcceptor { acceptor, io: PhantomData, } } } impl Clone for NativeTlsAcceptor { fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), io: PhantomData, } } } impl NewService for NativeTlsAcceptor { type Response = TlsStream; type Error = Error; type Service = NativeTlsAcceptorService; type InitError = (); type Future = FutureResult; fn new_service(&self, _: &()) -> Self::Future { MAX_CONN_COUNTER.with(|conns| { ok(NativeTlsAcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), io: PhantomData, }) }) } } pub struct NativeTlsAcceptorService { acceptor: TlsAcceptor, io: PhantomData, conns: Counter, } impl Service for NativeTlsAcceptorService { type Response = TlsStream; type Error = Error; type Future = Accept; fn poll_ready(&mut self) -> Poll<(), Self::Error> { if self.conns.available() { Ok(Async::Ready(())) } else { Ok(Async::NotReady) } } fn call(&mut self, req: T) -> Self::Future { Accept { _guard: self.conns.get(), inner: Some(self.acceptor.accept(req)), } } } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. /// /// A `TlsStream` represents a handshake that has been completed successfully /// and both the server and the client are ready for receiving and sending /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written /// to a `TlsStream` are encrypted when passing through to `S`. #[derive(Debug)] pub struct TlsStream { inner: native_tls::TlsStream, } /// Future returned from `NativeTlsAcceptor::accept` which will resolve /// once the accept handshake has finished. pub struct Accept { inner: Option, HandshakeError>>, _guard: CounterGuard, } impl Future for Accept { type Item = TlsStream; type Error = Error; fn poll(&mut self) -> Poll { match self.inner.take().expect("cannot poll MidHandshake twice") { Ok(stream) => Ok(TlsStream { inner: stream }.into()), Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::WouldBlock(s)) => match s.handshake() { Ok(stream) => Ok(TlsStream { inner: stream }.into()), Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::WouldBlock(s)) => { self.inner = Some(Err(HandshakeError::WouldBlock(s))); Ok(Async::NotReady) } }, } } } impl TlsStream { /// Get access to the internal `native_tls::TlsStream` stream which also /// transitively allows access to `S`. pub fn get_ref(&self) -> &native_tls::TlsStream { &self.inner } /// Get mutable access to the internal `native_tls::TlsStream` stream which /// also transitively allows mutable access to `S`. pub fn get_mut(&mut self) -> &mut native_tls::TlsStream { &mut self.inner } } impl io::Read for TlsStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.inner.read(buf) } } impl io::Write for TlsStream { fn write(&mut self, buf: &[u8]) -> io::Result { self.inner.write(buf) } fn flush(&mut self) -> io::Result<()> { self.inner.flush() } } impl AsyncRead for TlsStream {} impl AsyncWrite for TlsStream { fn shutdown(&mut self) -> Poll<(), io::Error> { match self.inner.shutdown() { Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) => return Err(e), } self.inner.get_mut().shutdown() } }