diff --git a/actix-rt/src/lib.rs b/actix-rt/src/lib.rs index a7e9f309..1649a8b0 100644 --- a/actix-rt/src/lib.rs +++ b/actix-rt/src/lib.rs @@ -77,6 +77,44 @@ pub mod net { #[cfg(unix)] pub use tokio::net::{UnixDatagram, UnixListener, UnixStream}; + + use std::task::{Context, Poll}; + + use tokio::io::{AsyncRead, AsyncWrite}; + + /// Trait for generic over tokio stream types and various wrapper types around them. + pub trait ActixStream: AsyncRead + AsyncWrite + Unpin + 'static { + /// poll stream and check read readiness of Self. + /// + /// See [tokio::net::TcpStream::poll_read_ready] for detail + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll>; + + /// poll stream and check write readiness of Self. + /// + /// See [tokio::net::TcpStream::poll_write_ready] for detail + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll>; + } + + impl ActixStream for TcpStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + TcpStream::poll_read_ready(self, cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + TcpStream::poll_write_ready(self, cx) + } + } + + #[cfg(unix)] + impl ActixStream for UnixStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + UnixStream::poll_read_ready(self, cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + UnixStream::poll_write_ready(self, cx) + } + } } pub mod time { diff --git a/actix-tls/examples/basic.rs b/actix-tls/examples/basic.rs index d1762b08..d0c20428 100644 --- a/actix-tls/examples/basic.rs +++ b/actix-tls/examples/basic.rs @@ -29,9 +29,10 @@ use std::{ }, }; +use actix_rt::net::TcpStream; use actix_server::Server; use actix_service::pipeline_factory; -use actix_tls::accept::rustls::Acceptor as RustlsAcceptor; +use actix_tls::accept::rustls::{Acceptor as RustlsAcceptor, TlsStream}; use futures_util::future::ok; use log::info; use rustls::{ @@ -74,9 +75,9 @@ async fn main() -> io::Result<()> { // Set up TLS service factory pipeline_factory(tls_acceptor.clone()) .map_err(|err| println!("Rustls error: {:?}", err)) - .and_then(move |stream| { + .and_then(move |stream: TlsStream| { let num = count.fetch_add(1, Ordering::Relaxed); - info!("[{}] Got TLS connection: {:?}", num, stream); + info!("[{}] Got TLS connection: {:?}", num, &*stream); ok(()) }) })? diff --git a/actix-tls/src/accept/nativetls.rs b/actix-tls/src/accept/nativetls.rs index 236ce973..15e43c99 100644 --- a/actix-tls/src/accept/nativetls.rs +++ b/actix-tls/src/accept/nativetls.rs @@ -1,15 +1,94 @@ -use std::task::{Context, Poll}; +use std::{ + io::{self, IoSlice}, + ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, +}; -use actix_codec::{AsyncRead, AsyncWrite}; +use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; +use actix_rt::net::ActixStream; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::Counter; use futures_core::future::LocalBoxFuture; pub use tokio_native_tls::native_tls::Error; -pub use tokio_native_tls::{TlsAcceptor, TlsStream}; +pub use tokio_native_tls::TlsAcceptor; use super::MAX_CONN_COUNTER; +/// wrapper type for `tokio_native_tls::TlsStream` in order to impl `ActixStream` trait. +pub struct TlsStream(tokio_native_tls::TlsStream); + +impl From> for TlsStream { + fn from(stream: tokio_native_tls::TlsStream) -> Self { + Self(stream) + } +} + +impl Deref for TlsStream { + type Target = tokio_native_tls::TlsStream; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for TlsStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_read(cx, buf) + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (&**self).is_write_vectored() + } +} + +impl ActixStream for TlsStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + T::poll_read_ready((&**self).get_ref().get_ref().get_ref(), cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + T::poll_write_ready((&**self).get_ref().get_ref().get_ref(), cx) + } +} + /// Accept TLS connections via `native-tls` package. /// /// `native-tls` feature enables this `Acceptor` type. @@ -34,10 +113,7 @@ impl Clone for Acceptor { } } -impl ServiceFactory for Acceptor -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl ServiceFactory for Acceptor { type Response = TlsStream; type Error = Error; type Config = (); @@ -71,10 +147,7 @@ impl Clone for NativeTlsAcceptorService { } } -impl Service for NativeTlsAcceptorService -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Service for NativeTlsAcceptorService { type Response = TlsStream; type Error = Error; type Future = LocalBoxFuture<'static, Result, Error>>; @@ -93,7 +166,7 @@ where Box::pin(async move { let io = this.acceptor.accept(io).await; drop(guard); - io + io.map(Into::into) }) } } diff --git a/actix-tls/src/accept/openssl.rs b/actix-tls/src/accept/openssl.rs index 8ca88578..b490888a 100644 --- a/actix-tls/src/accept/openssl.rs +++ b/actix-tls/src/accept/openssl.rs @@ -1,10 +1,13 @@ use std::{ future::Future, + io::{self, IoSlice}, + ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, }; -use actix_codec::{AsyncRead, AsyncWrite}; +use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; +use actix_rt::net::ActixStream; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::{Counter, CounterGuard}; use futures_core::{future::LocalBoxFuture, ready}; @@ -12,10 +15,82 @@ use futures_core::{future::LocalBoxFuture, ready}; pub use openssl::ssl::{ AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder, }; -pub use tokio_openssl::SslStream; use super::MAX_CONN_COUNTER; +/// wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait. +pub struct SslStream(tokio_openssl::SslStream); + +impl From> for SslStream { + fn from(stream: tokio_openssl::SslStream) -> Self { + Self(stream) + } +} + +impl Deref for SslStream { + type Target = tokio_openssl::SslStream; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SslStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl AsyncRead for SslStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_read(cx, buf) + } +} + +impl AsyncWrite for SslStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (&**self).is_write_vectored() + } +} + +impl ActixStream for SslStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + T::poll_read_ready((&**self).get_ref(), cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + T::poll_write_ready((&**self).get_ref(), cx) + } +} + /// Accept TLS connections via `openssl` package. /// /// `openssl` feature enables this `Acceptor` type. @@ -40,10 +115,7 @@ impl Clone for Acceptor { } } -impl ServiceFactory for Acceptor -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl ServiceFactory for Acceptor { type Response = SslStream; type Error = SslError; type Config = (); @@ -67,10 +139,7 @@ pub struct AcceptorService { conns: Counter, } -impl Service for AcceptorService -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Service for AcceptorService { type Response = SslStream; type Error = SslError; type Future = AcceptorServiceResponse; @@ -88,24 +157,25 @@ where let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid."); AcceptorServiceResponse { _guard: self.conns.get(), - stream: Some(SslStream::new(ssl, io).unwrap()), + stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()), } } } -pub struct AcceptorServiceResponse -where - T: AsyncRead + AsyncWrite, -{ - stream: Option>, +pub struct AcceptorServiceResponse { + stream: Option>, _guard: CounterGuard, } -impl Future for AcceptorServiceResponse { +impl Future for AcceptorServiceResponse { type Output = Result, SslError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { ready!(Pin::new(self.stream.as_mut().unwrap()).poll_accept(cx))?; - Poll::Ready(Ok(self.stream.take().expect("SSL connect has resolved."))) + Poll::Ready(Ok(self + .stream + .take() + .expect("SSL connect has resolved.") + .into())) } } diff --git a/actix-tls/src/accept/rustls.rs b/actix-tls/src/accept/rustls.rs index ff5cf3e5..c19a6cec 100644 --- a/actix-tls/src/accept/rustls.rs +++ b/actix-tls/src/accept/rustls.rs @@ -1,22 +1,96 @@ use std::{ future::Future, - io, + io::{self, IoSlice}, + ops::{Deref, DerefMut}, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use actix_codec::{AsyncRead, AsyncWrite}; +use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; +use actix_rt::net::ActixStream; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::{Counter, CounterGuard}; use futures_core::future::LocalBoxFuture; use tokio_rustls::{Accept, TlsAcceptor}; pub use tokio_rustls::rustls::{ServerConfig, Session}; -pub use tokio_rustls::server::TlsStream; use super::MAX_CONN_COUNTER; +/// wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait. +pub struct TlsStream(tokio_rustls::server::TlsStream); + +impl From> for TlsStream { + fn from(stream: tokio_rustls::server::TlsStream) -> Self { + Self(stream) + } +} + +impl Deref for TlsStream { + type Target = tokio_rustls::server::TlsStream; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for TlsStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_read(cx, buf) + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (&**self).is_write_vectored() + } +} + +impl ActixStream for TlsStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + T::poll_read_ready((&**self).get_ref().0, cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + T::poll_write_ready((&**self).get_ref().0, cx) + } +} + /// Accept TLS connections via `rustls` package. /// /// `rustls` feature enables this `Acceptor` type. @@ -43,10 +117,7 @@ impl Clone for Acceptor { } } -impl ServiceFactory for Acceptor -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl ServiceFactory for Acceptor { type Response = TlsStream; type Error = io::Error; type Config = (); @@ -72,10 +143,7 @@ pub struct AcceptorService { conns: Counter, } -impl Service for AcceptorService -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl Service for AcceptorService { type Response = TlsStream; type Error = io::Error; type Future = AcceptorServiceFut; @@ -96,22 +164,16 @@ where } } -pub struct AcceptorServiceFut -where - T: AsyncRead + AsyncWrite + Unpin, -{ +pub struct AcceptorServiceFut { fut: Accept, _guard: CounterGuard, } -impl Future for AcceptorServiceFut -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl Future for AcceptorServiceFut { type Output = Result, io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - Pin::new(&mut this.fut).poll(cx) + Pin::new(&mut this.fut).poll(cx).map_ok(TlsStream) } }