mirror of
https://github.com/fafhrd91/actix-net
synced 2025-08-31 09:37:00 +02:00
add timeout for accepting tls connections (#393)
Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
@@ -1,22 +1,28 @@
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::Future,
|
||||
io::{self, IoSlice},
|
||||
ops::{Deref, DerefMut},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use actix_rt::net::{ActixStream, Ready};
|
||||
use actix_rt::{
|
||||
net::{ActixStream, Ready},
|
||||
time::{sleep, Sleep},
|
||||
};
|
||||
use actix_service::{Service, ServiceFactory};
|
||||
use actix_utils::counter::{Counter, CounterGuard};
|
||||
use futures_core::{future::LocalBoxFuture, ready};
|
||||
use futures_core::future::LocalBoxFuture;
|
||||
|
||||
pub use openssl::ssl::{
|
||||
AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use super::MAX_CONN_COUNTER;
|
||||
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
|
||||
|
||||
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait.
|
||||
pub struct TlsStream<T>(tokio_openssl::SslStream<T>);
|
||||
@@ -96,13 +102,25 @@ impl<T: ActixStream> ActixStream for TlsStream<T> {
|
||||
/// `openssl` feature enables this `Acceptor` type.
|
||||
pub struct Acceptor {
|
||||
acceptor: SslAcceptor,
|
||||
handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Acceptor {
|
||||
/// Create OpenSSL based `Acceptor` service factory.
|
||||
#[inline]
|
||||
pub fn new(acceptor: SslAcceptor) -> Self {
|
||||
Acceptor { acceptor }
|
||||
Acceptor {
|
||||
acceptor,
|
||||
handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
|
||||
}
|
||||
}
|
||||
|
||||
/// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
|
||||
///
|
||||
/// Default timeout is 3 seconds.
|
||||
pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
|
||||
self.handshake_timeout = handshake_timeout;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,13 +129,14 @@ impl Clone for Acceptor {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
acceptor: self.acceptor.clone(),
|
||||
handshake_timeout: self.handshake_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
||||
type Response = TlsStream<T>;
|
||||
type Error = SslError;
|
||||
type Error = TlsError<SslError, Infallible>;
|
||||
type Config = ();
|
||||
type Service = AcceptorService;
|
||||
type InitError = ();
|
||||
@@ -128,8 +147,10 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
||||
Ok(AcceptorService {
|
||||
acceptor: self.acceptor.clone(),
|
||||
conns: conns.clone(),
|
||||
handshake_timeout: self.handshake_timeout,
|
||||
})
|
||||
});
|
||||
|
||||
Box::pin(async { res })
|
||||
}
|
||||
}
|
||||
@@ -137,11 +158,12 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
|
||||
pub struct AcceptorService {
|
||||
acceptor: SslAcceptor,
|
||||
conns: Counter,
|
||||
handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
impl<T: ActixStream> Service<T> for AcceptorService {
|
||||
type Response = TlsStream<T>;
|
||||
type Error = SslError;
|
||||
type Error = TlsError<SslError, Infallible>;
|
||||
type Future = AcceptorServiceResponse<T>;
|
||||
|
||||
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
@@ -155,27 +177,38 @@ impl<T: ActixStream> Service<T> for AcceptorService {
|
||||
fn call(&self, io: T) -> Self::Future {
|
||||
let ssl_ctx = self.acceptor.context();
|
||||
let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid.");
|
||||
|
||||
AcceptorServiceResponse {
|
||||
_guard: self.conns.get(),
|
||||
timeout: sleep(self.handshake_timeout),
|
||||
stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AcceptorServiceResponse<T: ActixStream> {
|
||||
stream: Option<tokio_openssl::SslStream<T>>,
|
||||
_guard: CounterGuard,
|
||||
pin_project! {
|
||||
pub struct AcceptorServiceResponse<T: ActixStream> {
|
||||
stream: Option<tokio_openssl::SslStream<T>>,
|
||||
#[pin]
|
||||
timeout: Sleep,
|
||||
_guard: CounterGuard,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ActixStream> Future for AcceptorServiceResponse<T> {
|
||||
type Output = Result<TlsStream<T>, SslError>;
|
||||
type Output = Result<TlsStream<T>, TlsError<SslError, Infallible>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
ready!(Pin::new(self.stream.as_mut().unwrap()).poll_accept(cx))?;
|
||||
Poll::Ready(Ok(self
|
||||
.stream
|
||||
.take()
|
||||
.expect("SSL connect has resolved.")
|
||||
.into()))
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
match Pin::new(this.stream.as_mut().unwrap()).poll_accept(cx) {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(Ok(this
|
||||
.stream
|
||||
.take()
|
||||
.expect("Acceptor should not be polled after it has completed.")
|
||||
.into())),
|
||||
Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))),
|
||||
Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user