mirror of
https://github.com/fafhrd91/actix-net
synced 2025-02-17 14:43:31 +01:00
add timeout for accepting tls connections (#393)
Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
parent
ce8ec15eaa
commit
7e7df2f931
@ -14,7 +14,7 @@ ci-check = "hack --workspace --feature-powerset --exclude-features=io-uring chec
|
|||||||
ci-check-linux = "hack --workspace --feature-powerset check --tests --examples"
|
ci-check-linux = "hack --workspace --feature-powerset check --tests --examples"
|
||||||
|
|
||||||
# tests avoiding io-uring feature
|
# tests avoiding io-uring feature
|
||||||
ci-test = "hack test --workspace --exclude=actix-rt --exclude=actix-server --all-features --lib --tests --no-fail-fast -- --nocapture"
|
ci-test = " hack --feature-powerset --exclude=actix-rt --exclude=actix-server --exclude-features=io-uring test --workspace --lib --tests --no-fail-fast -- --nocapture"
|
||||||
ci-test-rt = " hack --feature-powerset --exclude-features=io-uring test --package=actix-rt --lib --tests --no-fail-fast -- --nocapture"
|
ci-test-rt = " hack --feature-powerset --exclude-features=io-uring test --package=actix-rt --lib --tests --no-fail-fast -- --nocapture"
|
||||||
ci-test-server = "hack --feature-powerset --exclude-features=io-uring test --package=actix-server --lib --tests --no-fail-fast -- --nocapture"
|
ci-test-server = "hack --feature-powerset --exclude-features=io-uring test --package=actix-server --lib --tests --no-fail-fast -- --nocapture"
|
||||||
|
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
## Unreleased - 2021-xx-xx
|
## Unreleased - 2021-xx-xx
|
||||||
|
* Add configurable timeout for accepting TLS connection. [#393]
|
||||||
|
* Added `TlsError::Timeout` variant. [#393]
|
||||||
|
* All TLS acceptor services now use `TlsError` for their error types. [#393]
|
||||||
|
|
||||||
|
[#393]: https://github.com/actix/actix-net/pull/393
|
||||||
|
|
||||||
|
|
||||||
## 3.0.0-beta.8 - 2021-11-15
|
## 3.0.0-beta.8 - 2021-11-15
|
||||||
|
@ -47,6 +47,7 @@ derive_more = "0.99.5"
|
|||||||
futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] }
|
futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] }
|
||||||
http = { version = "0.2.3", optional = true }
|
http = { version = "0.2.3", optional = true }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
|
pin-project-lite = "0.2.7"
|
||||||
tokio-util = { version = "0.6.3", default-features = false }
|
tokio-util = { version = "0.6.3", default-features = false }
|
||||||
|
|
||||||
# openssl
|
# openssl
|
||||||
@ -67,7 +68,9 @@ bytes = "1"
|
|||||||
env_logger = "0.9"
|
env_logger = "0.9"
|
||||||
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
|
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
|
rcgen = "0.8"
|
||||||
rustls-pemfile = "0.2.1"
|
rustls-pemfile = "0.2.1"
|
||||||
|
tokio-rustls = { version = "0.23", features = ["dangerous_configuration"] }
|
||||||
trust-dns-resolver = "0.20.0"
|
trust-dns-resolver = "0.20.0"
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
|
@ -1,9 +1,4 @@
|
|||||||
//! TLS acceptor services for Actix ecosystem.
|
//! TLS acceptor services.
|
||||||
//!
|
|
||||||
//! ## Crate Features
|
|
||||||
//! * `openssl` - TLS acceptor using the `openssl` crate.
|
|
||||||
//! * `rustls` - TLS acceptor using the `rustls` crate.
|
|
||||||
//! * `native-tls` - TLS acceptor using the `native-tls` crate.
|
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
@ -20,6 +15,10 @@ pub mod native_tls;
|
|||||||
|
|
||||||
pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256);
|
pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256);
|
||||||
|
|
||||||
|
#[cfg(any(feature = "openssl", feature = "rustls", feature = "native-tls"))]
|
||||||
|
pub(crate) const DEFAULT_TLS_HANDSHAKE_TIMEOUT: std::time::Duration =
|
||||||
|
std::time::Duration::from_secs(3);
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static MAX_CONN_COUNTER: Counter = Counter::new(MAX_CONN.load(Ordering::Relaxed));
|
static MAX_CONN_COUNTER: Counter = Counter::new(MAX_CONN.load(Ordering::Relaxed));
|
||||||
}
|
}
|
||||||
@ -36,7 +35,8 @@ pub fn max_concurrent_tls_connect(num: usize) {
|
|||||||
|
|
||||||
/// TLS error combined with service error.
|
/// TLS error combined with service error.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum TlsError<E1, E2> {
|
pub enum TlsError<TlsErr, SvcErr> {
|
||||||
Tls(E1),
|
Tls(TlsErr),
|
||||||
Service(E2),
|
Timeout,
|
||||||
|
Service(SvcErr),
|
||||||
}
|
}
|
||||||
|
@ -1,20 +1,24 @@
|
|||||||
use std::{
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
io::{self, IoSlice},
|
io::{self, IoSlice},
|
||||||
ops::{Deref, DerefMut},
|
ops::{Deref, DerefMut},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
|
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use actix_rt::net::{ActixStream, Ready};
|
use actix_rt::{
|
||||||
|
net::{ActixStream, Ready},
|
||||||
|
time::timeout,
|
||||||
|
};
|
||||||
use actix_service::{Service, ServiceFactory};
|
use actix_service::{Service, ServiceFactory};
|
||||||
use actix_utils::counter::Counter;
|
use actix_utils::counter::Counter;
|
||||||
use futures_core::future::LocalBoxFuture;
|
use futures_core::future::LocalBoxFuture;
|
||||||
|
|
||||||
pub use tokio_native_tls::native_tls::Error;
|
pub use tokio_native_tls::{native_tls::Error, TlsAcceptor};
|
||||||
pub use tokio_native_tls::TlsAcceptor;
|
|
||||||
|
|
||||||
use super::MAX_CONN_COUNTER;
|
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
|
||||||
|
|
||||||
/// Wrapper type for `tokio_native_tls::TlsStream` in order to impl `ActixStream` trait.
|
/// Wrapper type for `tokio_native_tls::TlsStream` in order to impl `ActixStream` trait.
|
||||||
pub struct TlsStream<T>(tokio_native_tls::TlsStream<T>);
|
pub struct TlsStream<T>(tokio_native_tls::TlsStream<T>);
|
||||||
@ -94,13 +98,25 @@ impl<T: ActixStream> ActixStream for TlsStream<T> {
|
|||||||
/// `native-tls` feature enables this `Acceptor` type.
|
/// `native-tls` feature enables this `Acceptor` type.
|
||||||
pub struct Acceptor {
|
pub struct Acceptor {
|
||||||
acceptor: TlsAcceptor,
|
acceptor: TlsAcceptor,
|
||||||
|
handshake_timeout: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Acceptor {
|
impl Acceptor {
|
||||||
/// Create `native-tls` based `Acceptor` service factory.
|
/// Create `native-tls` based `Acceptor` service factory.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new(acceptor: TlsAcceptor) -> Self {
|
pub fn new(acceptor: TlsAcceptor) -> Self {
|
||||||
Acceptor { acceptor }
|
Acceptor {
|
||||||
|
acceptor,
|
||||||
|
handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
|
||||||
|
///
|
||||||
|
/// Default timeout is 3 seconds.
|
||||||
|
pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
|
||||||
|
self.handshake_timeout = handshake_timeout;
|
||||||
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,13 +125,14 @@ impl Clone for Acceptor {
|
|||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
acceptor: self.acceptor.clone(),
|
acceptor: self.acceptor.clone(),
|
||||||
|
handshake_timeout: self.handshake_timeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor {
|
impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor {
|
||||||
type Response = TlsStream<T>;
|
type Response = TlsStream<T>;
|
||||||
type Error = Error;
|
type Error = TlsError<Error, Infallible>;
|
||||||
type Config = ();
|
type Config = ();
|
||||||
|
|
||||||
type Service = NativeTlsAcceptorService;
|
type Service = NativeTlsAcceptorService;
|
||||||
@ -127,8 +144,10 @@ impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor {
|
|||||||
Ok(NativeTlsAcceptorService {
|
Ok(NativeTlsAcceptorService {
|
||||||
acceptor: self.acceptor.clone(),
|
acceptor: self.acceptor.clone(),
|
||||||
conns: conns.clone(),
|
conns: conns.clone(),
|
||||||
|
handshake_timeout: self.handshake_timeout,
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
Box::pin(async { res })
|
Box::pin(async { res })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,12 +155,13 @@ impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor {
|
|||||||
pub struct NativeTlsAcceptorService {
|
pub struct NativeTlsAcceptorService {
|
||||||
acceptor: TlsAcceptor,
|
acceptor: TlsAcceptor,
|
||||||
conns: Counter,
|
conns: Counter,
|
||||||
|
handshake_timeout: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService {
|
impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService {
|
||||||
type Response = TlsStream<T>;
|
type Response = TlsStream<T>;
|
||||||
type Error = Error;
|
type Error = TlsError<Error, Infallible>;
|
||||||
type Future = LocalBoxFuture<'static, Result<TlsStream<T>, Error>>;
|
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||||
|
|
||||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
if self.conns.available(cx) {
|
if self.conns.available(cx) {
|
||||||
@ -154,10 +174,18 @@ impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService {
|
|||||||
fn call(&self, io: T) -> Self::Future {
|
fn call(&self, io: T) -> Self::Future {
|
||||||
let guard = self.conns.get();
|
let guard = self.conns.get();
|
||||||
let acceptor = self.acceptor.clone();
|
let acceptor = self.acceptor.clone();
|
||||||
|
|
||||||
|
let dur = self.handshake_timeout;
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let io = acceptor.accept(io).await;
|
match timeout(dur, acceptor.accept(io)).await {
|
||||||
drop(guard);
|
Ok(Ok(io)) => {
|
||||||
io.map(Into::into)
|
drop(guard);
|
||||||
|
Ok(TlsStream(io))
|
||||||
|
}
|
||||||
|
Ok(Err(err)) => Err(TlsError::Tls(err)),
|
||||||
|
Err(_timeout) => Err(TlsError::Timeout),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,22 +1,28 @@
|
|||||||
use std::{
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
future::Future,
|
future::Future,
|
||||||
io::{self, IoSlice},
|
io::{self, IoSlice},
|
||||||
ops::{Deref, DerefMut},
|
ops::{Deref, DerefMut},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
|
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use actix_rt::net::{ActixStream, Ready};
|
use actix_rt::{
|
||||||
|
net::{ActixStream, Ready},
|
||||||
|
time::{sleep, Sleep},
|
||||||
|
};
|
||||||
use actix_service::{Service, ServiceFactory};
|
use actix_service::{Service, ServiceFactory};
|
||||||
use actix_utils::counter::{Counter, CounterGuard};
|
use actix_utils::counter::{Counter, CounterGuard};
|
||||||
use futures_core::{future::LocalBoxFuture, ready};
|
use futures_core::future::LocalBoxFuture;
|
||||||
|
|
||||||
pub use openssl::ssl::{
|
pub use openssl::ssl::{
|
||||||
AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder,
|
AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder,
|
||||||
};
|
};
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
use super::MAX_CONN_COUNTER;
|
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
|
||||||
|
|
||||||
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait.
|
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait.
|
||||||
pub struct TlsStream<T>(tokio_openssl::SslStream<T>);
|
pub struct TlsStream<T>(tokio_openssl::SslStream<T>);
|
||||||
@ -96,13 +102,25 @@ impl<T: ActixStream> ActixStream for TlsStream<T> {
|
|||||||
/// `openssl` feature enables this `Acceptor` type.
|
/// `openssl` feature enables this `Acceptor` type.
|
||||||
pub struct Acceptor {
|
pub struct Acceptor {
|
||||||
acceptor: SslAcceptor,
|
acceptor: SslAcceptor,
|
||||||
|
handshake_timeout: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Acceptor {
|
impl Acceptor {
|
||||||
/// Create OpenSSL based `Acceptor` service factory.
|
/// Create OpenSSL based `Acceptor` service factory.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new(acceptor: SslAcceptor) -> Self {
|
pub fn new(acceptor: SslAcceptor) -> Self {
|
||||||
Acceptor { acceptor }
|
Acceptor {
|
||||||
|
acceptor,
|
||||||
|
handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
|
||||||
|
///
|
||||||
|
/// Default timeout is 3 seconds.
|
||||||
|
pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
|
||||||
|
self.handshake_timeout = handshake_timeout;
|
||||||
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,13 +129,14 @@ impl Clone for Acceptor {
|
|||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
acceptor: self.acceptor.clone(),
|
acceptor: self.acceptor.clone(),
|
||||||
|
handshake_timeout: self.handshake_timeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
||||||
type Response = TlsStream<T>;
|
type Response = TlsStream<T>;
|
||||||
type Error = SslError;
|
type Error = TlsError<SslError, Infallible>;
|
||||||
type Config = ();
|
type Config = ();
|
||||||
type Service = AcceptorService;
|
type Service = AcceptorService;
|
||||||
type InitError = ();
|
type InitError = ();
|
||||||
@ -128,8 +147,10 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
|||||||
Ok(AcceptorService {
|
Ok(AcceptorService {
|
||||||
acceptor: self.acceptor.clone(),
|
acceptor: self.acceptor.clone(),
|
||||||
conns: conns.clone(),
|
conns: conns.clone(),
|
||||||
|
handshake_timeout: self.handshake_timeout,
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
Box::pin(async { res })
|
Box::pin(async { res })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -137,11 +158,12 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
|||||||
pub struct AcceptorService {
|
pub struct AcceptorService {
|
||||||
acceptor: SslAcceptor,
|
acceptor: SslAcceptor,
|
||||||
conns: Counter,
|
conns: Counter,
|
||||||
|
handshake_timeout: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream> Service<T> for AcceptorService {
|
impl<T: ActixStream> Service<T> for AcceptorService {
|
||||||
type Response = TlsStream<T>;
|
type Response = TlsStream<T>;
|
||||||
type Error = SslError;
|
type Error = TlsError<SslError, Infallible>;
|
||||||
type Future = AcceptorServiceResponse<T>;
|
type Future = AcceptorServiceResponse<T>;
|
||||||
|
|
||||||
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
@ -155,27 +177,38 @@ impl<T: ActixStream> Service<T> for AcceptorService {
|
|||||||
fn call(&self, io: T) -> Self::Future {
|
fn call(&self, io: T) -> Self::Future {
|
||||||
let ssl_ctx = self.acceptor.context();
|
let ssl_ctx = self.acceptor.context();
|
||||||
let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid.");
|
let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid.");
|
||||||
|
|
||||||
AcceptorServiceResponse {
|
AcceptorServiceResponse {
|
||||||
_guard: self.conns.get(),
|
_guard: self.conns.get(),
|
||||||
|
timeout: sleep(self.handshake_timeout),
|
||||||
stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()),
|
stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AcceptorServiceResponse<T: ActixStream> {
|
pin_project! {
|
||||||
stream: Option<tokio_openssl::SslStream<T>>,
|
pub struct AcceptorServiceResponse<T: ActixStream> {
|
||||||
_guard: CounterGuard,
|
stream: Option<tokio_openssl::SslStream<T>>,
|
||||||
|
#[pin]
|
||||||
|
timeout: Sleep,
|
||||||
|
_guard: CounterGuard,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream> Future for AcceptorServiceResponse<T> {
|
impl<T: ActixStream> Future for AcceptorServiceResponse<T> {
|
||||||
type Output = Result<TlsStream<T>, SslError>;
|
type Output = Result<TlsStream<T>, TlsError<SslError, Infallible>>;
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
ready!(Pin::new(self.stream.as_mut().unwrap()).poll_accept(cx))?;
|
let this = self.project();
|
||||||
Poll::Ready(Ok(self
|
|
||||||
.stream
|
match Pin::new(this.stream.as_mut().unwrap()).poll_accept(cx) {
|
||||||
.take()
|
Poll::Ready(Ok(())) => Poll::Ready(Ok(this
|
||||||
.expect("SSL connect has resolved.")
|
.stream
|
||||||
.into()))
|
.take()
|
||||||
|
.expect("Acceptor should not be polled after it has completed.")
|
||||||
|
.into())),
|
||||||
|
Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))),
|
||||||
|
Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,22 +1,28 @@
|
|||||||
use std::{
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
future::Future,
|
future::Future,
|
||||||
io::{self, IoSlice},
|
io::{self, IoSlice},
|
||||||
ops::{Deref, DerefMut},
|
ops::{Deref, DerefMut},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
|
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use actix_rt::net::{ActixStream, Ready};
|
use actix_rt::{
|
||||||
|
net::{ActixStream, Ready},
|
||||||
|
time::{sleep, Sleep},
|
||||||
|
};
|
||||||
use actix_service::{Service, ServiceFactory};
|
use actix_service::{Service, ServiceFactory};
|
||||||
use actix_utils::counter::{Counter, CounterGuard};
|
use actix_utils::counter::{Counter, CounterGuard};
|
||||||
use futures_core::future::LocalBoxFuture;
|
use futures_core::future::LocalBoxFuture;
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
use tokio_rustls::{Accept, TlsAcceptor};
|
use tokio_rustls::{Accept, TlsAcceptor};
|
||||||
|
|
||||||
pub use tokio_rustls::rustls::ServerConfig;
|
pub use tokio_rustls::rustls::ServerConfig;
|
||||||
|
|
||||||
use super::MAX_CONN_COUNTER;
|
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
|
||||||
|
|
||||||
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait.
|
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait.
|
||||||
pub struct TlsStream<T>(tokio_rustls::server::TlsStream<T>);
|
pub struct TlsStream<T>(tokio_rustls::server::TlsStream<T>);
|
||||||
@ -96,6 +102,7 @@ impl<T: ActixStream> ActixStream for TlsStream<T> {
|
|||||||
/// `rustls` feature enables this `Acceptor` type.
|
/// `rustls` feature enables this `Acceptor` type.
|
||||||
pub struct Acceptor {
|
pub struct Acceptor {
|
||||||
config: Arc<ServerConfig>,
|
config: Arc<ServerConfig>,
|
||||||
|
handshake_timeout: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Acceptor {
|
impl Acceptor {
|
||||||
@ -104,8 +111,17 @@ impl Acceptor {
|
|||||||
pub fn new(config: ServerConfig) -> Self {
|
pub fn new(config: ServerConfig) -> Self {
|
||||||
Acceptor {
|
Acceptor {
|
||||||
config: Arc::new(config),
|
config: Arc::new(config),
|
||||||
|
handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
|
||||||
|
///
|
||||||
|
/// Default timeout is 3 seconds.
|
||||||
|
pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
|
||||||
|
self.handshake_timeout = handshake_timeout;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for Acceptor {
|
impl Clone for Acceptor {
|
||||||
@ -113,13 +129,14 @@ impl Clone for Acceptor {
|
|||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
config: self.config.clone(),
|
config: self.config.clone(),
|
||||||
|
handshake_timeout: self.handshake_timeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
||||||
type Response = TlsStream<T>;
|
type Response = TlsStream<T>;
|
||||||
type Error = io::Error;
|
type Error = TlsError<io::Error, Infallible>;
|
||||||
type Config = ();
|
type Config = ();
|
||||||
|
|
||||||
type Service = AcceptorService;
|
type Service = AcceptorService;
|
||||||
@ -131,8 +148,10 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
|||||||
Ok(AcceptorService {
|
Ok(AcceptorService {
|
||||||
acceptor: self.config.clone().into(),
|
acceptor: self.config.clone().into(),
|
||||||
conns: conns.clone(),
|
conns: conns.clone(),
|
||||||
|
handshake_timeout: self.handshake_timeout,
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
Box::pin(async { res })
|
Box::pin(async { res })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -141,11 +160,12 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
|||||||
pub struct AcceptorService {
|
pub struct AcceptorService {
|
||||||
acceptor: TlsAcceptor,
|
acceptor: TlsAcceptor,
|
||||||
conns: Counter,
|
conns: Counter,
|
||||||
|
handshake_timeout: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream> Service<T> for AcceptorService {
|
impl<T: ActixStream> Service<T> for AcceptorService {
|
||||||
type Response = TlsStream<T>;
|
type Response = TlsStream<T>;
|
||||||
type Error = io::Error;
|
type Error = TlsError<io::Error, Infallible>;
|
||||||
type Future = AcceptorServiceFut<T>;
|
type Future = AcceptorServiceFut<T>;
|
||||||
|
|
||||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
@ -158,22 +178,31 @@ impl<T: ActixStream> Service<T> for AcceptorService {
|
|||||||
|
|
||||||
fn call(&self, req: T) -> Self::Future {
|
fn call(&self, req: T) -> Self::Future {
|
||||||
AcceptorServiceFut {
|
AcceptorServiceFut {
|
||||||
_guard: self.conns.get(),
|
|
||||||
fut: self.acceptor.accept(req),
|
fut: self.acceptor.accept(req),
|
||||||
|
timeout: sleep(self.handshake_timeout),
|
||||||
|
_guard: self.conns.get(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AcceptorServiceFut<T: ActixStream> {
|
pin_project! {
|
||||||
fut: Accept<T>,
|
pub struct AcceptorServiceFut<T: ActixStream> {
|
||||||
_guard: CounterGuard,
|
fut: Accept<T>,
|
||||||
|
#[pin]
|
||||||
|
timeout: Sleep,
|
||||||
|
_guard: CounterGuard,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ActixStream> Future for AcceptorServiceFut<T> {
|
impl<T: ActixStream> Future for AcceptorServiceFut<T> {
|
||||||
type Output = Result<TlsStream<T>, io::Error>;
|
type Output = Result<TlsStream<T>, TlsError<io::Error, Infallible>>;
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
let this = self.get_mut();
|
let mut this = self.project();
|
||||||
Pin::new(&mut this.fut).poll(cx).map_ok(TlsStream)
|
match Pin::new(&mut this.fut).poll(cx) {
|
||||||
|
Poll::Ready(Ok(stream)) => Poll::Ready(Ok(TlsStream(stream))),
|
||||||
|
Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))),
|
||||||
|
Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
130
actix-tls/tests/accept-openssl.rs
Normal file
130
actix-tls/tests/accept-openssl.rs
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
//! Use Rustls connector to test OpenSSL acceptor.
|
||||||
|
|
||||||
|
#![cfg(all(
|
||||||
|
feature = "accept",
|
||||||
|
feature = "connect",
|
||||||
|
feature = "rustls",
|
||||||
|
feature = "openssl"
|
||||||
|
))]
|
||||||
|
|
||||||
|
use std::{convert::TryFrom, io::Write, sync::Arc};
|
||||||
|
|
||||||
|
use actix_rt::net::TcpStream;
|
||||||
|
use actix_server::TestServer;
|
||||||
|
use actix_service::ServiceFactoryExt as _;
|
||||||
|
use actix_tls::accept::openssl::{Acceptor, TlsStream};
|
||||||
|
use actix_utils::future::ok;
|
||||||
|
use tokio_rustls::rustls::{Certificate, ClientConfig, RootCertStore, ServerName};
|
||||||
|
|
||||||
|
fn new_cert_and_key() -> (String, String) {
|
||||||
|
let cert = rcgen::generate_simple_self_signed(vec![
|
||||||
|
"127.0.0.1".to_owned(),
|
||||||
|
"localhost".to_owned(),
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let key = cert.serialize_private_key_pem();
|
||||||
|
let cert = cert.serialize_pem().unwrap();
|
||||||
|
|
||||||
|
(cert, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn openssl_acceptor(cert: String, key: String) -> tls_openssl::ssl::SslAcceptor {
|
||||||
|
use tls_openssl::{
|
||||||
|
pkey::PKey,
|
||||||
|
ssl::{SslAcceptor, SslMethod},
|
||||||
|
x509::X509,
|
||||||
|
};
|
||||||
|
|
||||||
|
let cert = X509::from_pem(cert.as_bytes()).unwrap();
|
||||||
|
let key = PKey::private_key_from_pem(key.as_bytes()).unwrap();
|
||||||
|
|
||||||
|
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
|
||||||
|
builder.set_certificate(&cert).unwrap();
|
||||||
|
builder.set_private_key(&key).unwrap();
|
||||||
|
builder.set_alpn_select_callback(|_, _protocols| Ok(b"http/1.1"));
|
||||||
|
builder.set_alpn_protos(b"\x08http/1.1").unwrap();
|
||||||
|
builder.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
mod danger {
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use tokio_rustls::rustls::{
|
||||||
|
self,
|
||||||
|
client::{ServerCertVerified, ServerCertVerifier},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct NoCertificateVerification;
|
||||||
|
|
||||||
|
impl ServerCertVerifier for NoCertificateVerification {
|
||||||
|
fn verify_server_cert(
|
||||||
|
&self,
|
||||||
|
_end_entity: &Certificate,
|
||||||
|
_intermediates: &[Certificate],
|
||||||
|
_server_name: &ServerName,
|
||||||
|
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||||
|
_ocsp_response: &[u8],
|
||||||
|
_now: SystemTime,
|
||||||
|
) -> Result<ServerCertVerified, rustls::Error> {
|
||||||
|
Ok(ServerCertVerified::assertion())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn rustls_connector(_cert: String, _key: String) -> ClientConfig {
|
||||||
|
let mut config = ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(RootCertStore::empty())
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
config
|
||||||
|
.dangerous()
|
||||||
|
.set_certificate_verifier(Arc::new(danger::NoCertificateVerification));
|
||||||
|
|
||||||
|
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn accepts_connections() {
|
||||||
|
let (cert, key) = new_cert_and_key();
|
||||||
|
|
||||||
|
let srv = TestServer::with({
|
||||||
|
let cert = cert.clone();
|
||||||
|
let key = key.clone();
|
||||||
|
|
||||||
|
move || {
|
||||||
|
let openssl_acceptor = openssl_acceptor(cert.clone(), key.clone());
|
||||||
|
let tls_acceptor = Acceptor::new(openssl_acceptor);
|
||||||
|
|
||||||
|
tls_acceptor
|
||||||
|
.map_err(|err| println!("OpenSSL error: {:?}", err))
|
||||||
|
.and_then(move |_stream: TlsStream<TcpStream>| ok(()))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut sock = srv
|
||||||
|
.connect()
|
||||||
|
.expect("cannot connect to test server")
|
||||||
|
.into_std()
|
||||||
|
.unwrap();
|
||||||
|
sock.set_nonblocking(false).unwrap();
|
||||||
|
|
||||||
|
let config = rustls_connector(cert, key);
|
||||||
|
let config = Arc::new(config);
|
||||||
|
|
||||||
|
let mut conn = tokio_rustls::rustls::ClientConnection::new(
|
||||||
|
config,
|
||||||
|
ServerName::try_from("localhost").unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut stream = tokio_rustls::rustls::Stream::new(&mut conn, &mut sock);
|
||||||
|
|
||||||
|
stream.flush().expect("TLS handshake failed");
|
||||||
|
}
|
104
actix-tls/tests/accept-rustls.rs
Normal file
104
actix-tls/tests/accept-rustls.rs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
//! Use OpenSSL connector to test Rustls acceptor.
|
||||||
|
|
||||||
|
#![cfg(all(
|
||||||
|
feature = "accept",
|
||||||
|
feature = "connect",
|
||||||
|
feature = "rustls",
|
||||||
|
feature = "openssl"
|
||||||
|
))]
|
||||||
|
|
||||||
|
use std::io::{BufReader, Write};
|
||||||
|
|
||||||
|
use actix_rt::net::TcpStream;
|
||||||
|
use actix_server::TestServer;
|
||||||
|
use actix_service::ServiceFactoryExt as _;
|
||||||
|
use actix_tls::accept::rustls::{Acceptor, TlsStream};
|
||||||
|
use actix_tls::connect::tls::openssl::SslConnector;
|
||||||
|
use actix_utils::future::ok;
|
||||||
|
use rustls_pemfile::{certs, pkcs8_private_keys};
|
||||||
|
use tls_openssl::ssl::SslVerifyMode;
|
||||||
|
use tokio_rustls::rustls::{self, Certificate, PrivateKey, ServerConfig};
|
||||||
|
|
||||||
|
fn new_cert_and_key() -> (String, String) {
|
||||||
|
let cert = rcgen::generate_simple_self_signed(vec![
|
||||||
|
"127.0.0.1".to_owned(),
|
||||||
|
"localhost".to_owned(),
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let key = cert.serialize_private_key_pem();
|
||||||
|
let cert = cert.serialize_pem().unwrap();
|
||||||
|
|
||||||
|
(cert, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rustls_server_config(cert: String, key: String) -> rustls::ServerConfig {
|
||||||
|
// Load TLS key and cert files
|
||||||
|
|
||||||
|
let cert = &mut BufReader::new(cert.as_bytes());
|
||||||
|
let key = &mut BufReader::new(key.as_bytes());
|
||||||
|
|
||||||
|
let cert_chain = certs(cert).unwrap().into_iter().map(Certificate).collect();
|
||||||
|
let mut keys = pkcs8_private_keys(key).unwrap();
|
||||||
|
|
||||||
|
let mut config = ServerConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(cert_chain, PrivateKey(keys.remove(0)))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||||
|
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
fn openssl_connector(cert: String, key: String) -> SslConnector {
|
||||||
|
use actix_tls::connect::tls::openssl::{SslConnector as OpensslConnector, SslMethod};
|
||||||
|
use tls_openssl::{pkey::PKey, x509::X509};
|
||||||
|
|
||||||
|
let cert = X509::from_pem(cert.as_bytes()).unwrap();
|
||||||
|
let key = PKey::private_key_from_pem(key.as_bytes()).unwrap();
|
||||||
|
|
||||||
|
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
|
||||||
|
ssl.set_verify(SslVerifyMode::NONE);
|
||||||
|
ssl.set_certificate(&cert).unwrap();
|
||||||
|
ssl.set_private_key(&key).unwrap();
|
||||||
|
ssl.set_alpn_protos(b"\x08http/1.1").unwrap();
|
||||||
|
|
||||||
|
ssl.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn accepts_connections() {
|
||||||
|
let (cert, key) = new_cert_and_key();
|
||||||
|
|
||||||
|
let srv = TestServer::with({
|
||||||
|
let cert = cert.clone();
|
||||||
|
let key = key.clone();
|
||||||
|
|
||||||
|
move || {
|
||||||
|
let tls_acceptor = Acceptor::new(rustls_server_config(cert.clone(), key.clone()));
|
||||||
|
|
||||||
|
tls_acceptor
|
||||||
|
.map_err(|err| println!("Rustls error: {:?}", err))
|
||||||
|
.and_then(move |_stream: TlsStream<TcpStream>| ok(()))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let sock = srv
|
||||||
|
.connect()
|
||||||
|
.expect("cannot connect to test server")
|
||||||
|
.into_std()
|
||||||
|
.unwrap();
|
||||||
|
sock.set_nonblocking(false).unwrap();
|
||||||
|
|
||||||
|
let connector = openssl_connector(cert, key);
|
||||||
|
|
||||||
|
let mut stream = connector
|
||||||
|
.connect("localhost", sock)
|
||||||
|
.expect("TLS handshake failed");
|
||||||
|
|
||||||
|
stream.do_handshake().expect("TLS handshake failed");
|
||||||
|
|
||||||
|
stream.flush().expect("TLS handshake failed");
|
||||||
|
}
|
@ -26,7 +26,7 @@ impl Counter {
|
|||||||
CounterGuard::new(self.0.clone())
|
CounterGuard::new(self.0.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Notify current task and return true if counter is at capacity.
|
/// Returns true if counter is below capacity. Otherwise, register to wake task when it is.
|
||||||
pub fn available(&self, cx: &mut task::Context<'_>) -> bool {
|
pub fn available(&self, cx: &mut task::Context<'_>) -> bool {
|
||||||
self.0.available(cx)
|
self.0.available(cx)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user