1
0
mirror of https://github.com/fafhrd91/actix-net synced 2025-08-31 09:37:00 +02:00

add timeout for accepting tls connections (#393)

Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
fakeshadow
2021-11-16 08:22:24 +08:00
committed by GitHub
parent ce8ec15eaa
commit 7e7df2f931
10 changed files with 382 additions and 50 deletions

View File

@@ -1,22 +1,28 @@
use std::{
convert::Infallible,
future::Future,
io::{self, IoSlice},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
use actix_rt::net::{ActixStream, Ready};
use actix_rt::{
net::{ActixStream, Ready},
time::{sleep, Sleep},
};
use actix_service::{Service, ServiceFactory};
use actix_utils::counter::{Counter, CounterGuard};
use futures_core::{future::LocalBoxFuture, ready};
use futures_core::future::LocalBoxFuture;
pub use openssl::ssl::{
AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder,
};
use pin_project_lite::pin_project;
use super::MAX_CONN_COUNTER;
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait.
pub struct TlsStream<T>(tokio_openssl::SslStream<T>);
@@ -96,13 +102,25 @@ impl<T: ActixStream> ActixStream for TlsStream<T> {
/// `openssl` feature enables this `Acceptor` type.
pub struct Acceptor {
acceptor: SslAcceptor,
handshake_timeout: Duration,
}
impl Acceptor {
/// Create OpenSSL based `Acceptor` service factory.
#[inline]
pub fn new(acceptor: SslAcceptor) -> Self {
Acceptor { acceptor }
Acceptor {
acceptor,
handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
}
}
/// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
///
/// Default timeout is 3 seconds.
pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
self.handshake_timeout = handshake_timeout;
self
}
}
@@ -111,13 +129,14 @@ impl Clone for Acceptor {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
handshake_timeout: self.handshake_timeout,
}
}
}
impl<T: ActixStream> ServiceFactory<T> for Acceptor {
type Response = TlsStream<T>;
type Error = SslError;
type Error = TlsError<SslError, Infallible>;
type Config = ();
type Service = AcceptorService;
type InitError = ();
@@ -128,8 +147,10 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
Ok(AcceptorService {
acceptor: self.acceptor.clone(),
conns: conns.clone(),
handshake_timeout: self.handshake_timeout,
})
});
Box::pin(async { res })
}
}
@@ -137,11 +158,12 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
pub struct AcceptorService {
acceptor: SslAcceptor,
conns: Counter,
handshake_timeout: Duration,
}
impl<T: ActixStream> Service<T> for AcceptorService {
type Response = TlsStream<T>;
type Error = SslError;
type Error = TlsError<SslError, Infallible>;
type Future = AcceptorServiceResponse<T>;
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@@ -155,27 +177,38 @@ impl<T: ActixStream> Service<T> for AcceptorService {
fn call(&self, io: T) -> Self::Future {
let ssl_ctx = self.acceptor.context();
let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid.");
AcceptorServiceResponse {
_guard: self.conns.get(),
timeout: sleep(self.handshake_timeout),
stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()),
}
}
}
pub struct AcceptorServiceResponse<T: ActixStream> {
stream: Option<tokio_openssl::SslStream<T>>,
_guard: CounterGuard,
pin_project! {
pub struct AcceptorServiceResponse<T: ActixStream> {
stream: Option<tokio_openssl::SslStream<T>>,
#[pin]
timeout: Sleep,
_guard: CounterGuard,
}
}
impl<T: ActixStream> Future for AcceptorServiceResponse<T> {
type Output = Result<TlsStream<T>, SslError>;
type Output = Result<TlsStream<T>, TlsError<SslError, Infallible>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
ready!(Pin::new(self.stream.as_mut().unwrap()).poll_accept(cx))?;
Poll::Ready(Ok(self
.stream
.take()
.expect("SSL connect has resolved.")
.into()))
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match Pin::new(this.stream.as_mut().unwrap()).poll_accept(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(this
.stream
.take()
.expect("Acceptor should not be polled after it has completed.")
.into())),
Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))),
Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)),
}
}
}