diff --git a/actix-tls/Cargo.toml b/actix-tls/Cargo.toml index c3ed89b4..baf1615a 100755 --- a/actix-tls/Cargo.toml +++ b/actix-tls/Cargo.toml @@ -115,6 +115,7 @@ actix-server = "2" bytes = "1" env_logger = "0.10" futures-util = { version = "0.3.17", default-features = false, features = ["sink"] } +itertools = "0.12" rcgen = "0.11" rustls-pemfile = "2" tokio-rustls-025 = { package = "tokio-rustls", version = "0.25" } @@ -122,4 +123,4 @@ trust-dns-resolver = "0.23" [[example]] name = "accept-rustls" -required-features = ["accept", "rustls-0_22"] +required-features = ["accept", "rustls-0_22-webpki-roots"] diff --git a/actix-tls/examples/accept-rustls.rs b/actix-tls/examples/accept-rustls.rs index 40f51753..6e1e267a 100644 --- a/actix-tls/examples/accept-rustls.rs +++ b/actix-tls/examples/accept-rustls.rs @@ -15,11 +15,8 @@ //! http --verify=false https://127.0.0.1:8443 //! ``` -#[rustfmt::skip] // this `use` is only exists because of how we have organised the crate // it is not necessary for your actual code; you should import from `rustls` normally -use tokio_rustls_024::rustls; - use std::{ fs::File, io::{self, BufReader}, @@ -33,10 +30,13 @@ use std::{ use actix_rt::net::TcpStream; use actix_server::Server; use actix_service::ServiceFactoryExt as _; -use actix_tls::accept::rustls_0_21::{Acceptor as RustlsAcceptor, TlsStream}; +use actix_tls::accept::rustls_0_22::{Acceptor as RustlsAcceptor, TlsStream}; use futures_util::future::ok; -use rustls::{server::ServerConfig, Certificate, PrivateKey}; +use itertools::Itertools as _; +use rustls::server::ServerConfig; use rustls_pemfile::{certs, rsa_private_keys}; +use rustls_pki_types_1::PrivateKeyDer; +use tokio_rustls_025::rustls; use tracing::info; #[actix_rt::main] @@ -54,17 +54,15 @@ async fn main() -> io::Result<()> { let cert_file = &mut BufReader::new(File::open(cert_path).unwrap()); let key_file = &mut BufReader::new(File::open(key_path).unwrap()); - let cert_chain = certs(cert_file) - .unwrap() - .into_iter() - .map(Certificate) - .collect(); - let mut keys = rsa_private_keys(key_file).unwrap(); + let cert_chain = certs(cert_file); + let mut keys = rsa_private_keys(key_file); let tls_config = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .with_single_cert( + cert_chain.try_collect::<_, Vec<_>, _>()?, + PrivateKeyDer::Pkcs1(keys.next().unwrap()?), + ) .unwrap(); let tls_acceptor = RustlsAcceptor::new(tls_config); diff --git a/actix-tls/src/accept/mod.rs b/actix-tls/src/accept/mod.rs index 18585970..8d89f006 100644 --- a/actix-tls/src/accept/mod.rs +++ b/actix-tls/src/accept/mod.rs @@ -22,6 +22,12 @@ pub use rustls_0_20 as rustls; #[cfg(feature = "rustls-0_21")] pub mod rustls_0_21; +#[cfg(any( + feature = "rustls-0_22-webpki-roots", + feature = "rustls-0_22-native-roots", +))] +pub mod rustls_0_22; + #[cfg(feature = "native-tls")] pub mod native_tls; diff --git a/actix-tls/src/accept/rustls_0_22.rs b/actix-tls/src/accept/rustls_0_22.rs new file mode 100644 index 00000000..46b4c03e --- /dev/null +++ b/actix-tls/src/accept/rustls_0_22.rs @@ -0,0 +1,198 @@ +//! `rustls` v0.22 based TLS connection acceptor service. +//! +//! See [`Acceptor`] for main service factory docs. + +use std::{ + convert::Infallible, + future::Future, + io::{self, IoSlice}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use actix_rt::{ + net::{ActixStream, Ready}, + time::{sleep, Sleep}, +}; +use actix_service::{Service, ServiceFactory}; +use actix_utils::{ + counter::{Counter, CounterGuard}, + future::{ready, Ready as FutReady}, +}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::{Accept, TlsAcceptor}; +use tokio_rustls_025 as tokio_rustls; + +use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; + +pub mod reexports { + //! Re-exports from `rustls` that are useful for acceptors. + + pub use tokio_rustls_025::rustls::ServerConfig; +} + +/// Wraps a `rustls` based async TLS stream in order to implement [`ActixStream`]. +pub struct TlsStream(tokio_rustls::server::TlsStream); + +impl_more::impl_from!( in tokio_rustls::server::TlsStream => TlsStream); +impl_more::impl_deref_and_mut!( in TlsStream => tokio_rustls::server::TlsStream); + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_read(cx, buf) + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } +} + +impl ActixStream for TlsStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + IO::poll_read_ready((**self).get_ref().0, cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + IO::poll_write_ready((**self).get_ref().0, cx) + } +} + +/// Accept TLS connections via the `rustls` crate. +pub struct Acceptor { + config: Arc, + handshake_timeout: Duration, +} + +impl Acceptor { + /// Constructs `rustls` based acceptor service factory. + pub fn new(config: reexports::ServerConfig) -> Self { + Acceptor { + config: Arc::new(config), + handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT, + } + } + + /// Limit the amount of time that the acceptor will wait for a TLS handshake to complete. + /// + /// Default timeout is 3 seconds. + pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self { + self.handshake_timeout = handshake_timeout; + self + } +} + +impl Clone for Acceptor { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + handshake_timeout: self.handshake_timeout, + } + } +} + +impl ServiceFactory for Acceptor { + type Response = TlsStream; + type Error = TlsError; + type Config = (); + type Service = AcceptorService; + type InitError = (); + type Future = FutReady>; + + fn new_service(&self, _: ()) -> Self::Future { + let res = MAX_CONN_COUNTER.with(|conns| { + Ok(AcceptorService { + acceptor: self.config.clone().into(), + conns: conns.clone(), + handshake_timeout: self.handshake_timeout, + }) + }); + + ready(res) + } +} + +/// Rustls based acceptor service. +pub struct AcceptorService { + acceptor: TlsAcceptor, + conns: Counter, + handshake_timeout: Duration, +} + +impl Service for AcceptorService { + type Response = TlsStream; + type Error = TlsError; + type Future = AcceptFut; + + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + if self.conns.available(cx) { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn call(&self, req: IO) -> Self::Future { + AcceptFut { + fut: self.acceptor.accept(req), + timeout: sleep(self.handshake_timeout), + _guard: self.conns.get(), + } + } +} + +pin_project! { + /// Accept future for Rustls service. + #[doc(hidden)] + pub struct AcceptFut { + fut: Accept, + #[pin] + timeout: Sleep, + _guard: CounterGuard, + } +} + +impl Future for AcceptFut { + type Output = Result, TlsError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + match Pin::new(&mut this.fut).poll(cx) { + Poll::Ready(Ok(stream)) => Poll::Ready(Ok(TlsStream(stream))), + Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))), + Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)), + } + } +}