diff --git a/actix-connect/CHANGES.md b/actix-connect/CHANGES.md index b5b3419c..a97c59c8 100644 --- a/actix-connect/CHANGES.md +++ b/actix-connect/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.2] - 2019-07-24 + +* Add `rustls` support + ## [0.2.1] - 2019-07-17 ### Added diff --git a/actix-connect/Cargo.toml b/actix-connect/Cargo.toml index 97b2007a..85230eeb 100644 --- a/actix-connect/Cargo.toml +++ b/actix-connect/Cargo.toml @@ -26,6 +26,9 @@ default = ["uri"] # openssl ssl = ["openssl", "tokio-openssl"] +#rustls +rust-tls = ["rustls", "tokio-rustls", "webpki"] + # support http::Uri as connect address uri = ["http"] @@ -46,6 +49,11 @@ trust-dns-resolver = { version="0.11.0", default-features = false } openssl = { version="0.10", optional = true } tokio-openssl = { version="0.3", optional = true } +#rustls +rustls = { version = "0.15.2", optional = true } +tokio-rustls = { version = "0.9.1", optional = true } +webpki = { version = "0.19", optional = true } + [dev-dependencies] bytes = "0.4" actix-test-server = { version="0.2.2", features=["ssl"] } diff --git a/actix-connect/src/ssl/mod.rs b/actix-connect/src/ssl/mod.rs index ecbd0a88..2110d01d 100644 --- a/actix-connect/src/ssl/mod.rs +++ b/actix-connect/src/ssl/mod.rs @@ -4,3 +4,7 @@ mod openssl; #[cfg(feature = "ssl")] pub use self::openssl::OpensslConnector; +#[cfg(feature = "rust-tls")] +mod rustls; +#[cfg(feature = "rust-tls")] +pub use self::rustls::RustlsConnector; diff --git a/actix-connect/src/ssl/rustls.rs b/actix-connect/src/ssl/rustls.rs new file mode 100644 index 00000000..2358af07 --- /dev/null +++ b/actix-connect/src/ssl/rustls.rs @@ -0,0 +1,130 @@ +use std::fmt; +use std::marker::PhantomData; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_service::{NewService, Service}; +use futures::{future::ok, future::FutureResult, Async, Future, Poll}; +use webpki::DNSNameRef; +use tokio_rustls::{TlsConnector, TlsStream, Connect, rustls::{ClientConfig, ClientSession}}; +use std::sync::Arc; + +use crate::{Address, Connection}; + +/// Rustls connector factory +pub struct RustlsConnector { + connector: Arc, + _t: PhantomData<(T, U)>, +} + +impl RustlsConnector { + pub fn new(connector: Arc) -> Self { + RustlsConnector { + connector, + _t: PhantomData, + } + } +} + +impl RustlsConnector +where + T: Address, + U: AsyncRead + AsyncWrite + fmt::Debug, +{ + pub fn service( + connector: Arc, + ) -> impl Service< + Request = Connection, + Response = Connection>, + Error = std::io::Error, + > { + RustlsConnectorService { + connector: connector, + _t: PhantomData, + } + } +} + +impl Clone for RustlsConnector { + fn clone(&self) -> Self { + Self { + connector: self.connector.clone(), + _t: PhantomData, + } + } +} + +impl NewService for RustlsConnector +where + U: AsyncRead + AsyncWrite + fmt::Debug, +{ + type Request = Connection; + type Response = Connection>; + type Error = std::io::Error; + type Config = (); + type Service = RustlsConnectorService; + type InitError = (); + type Future = FutureResult; + + fn new_service(&self, _: &()) -> Self::Future { + ok(RustlsConnectorService { + connector: self.connector.clone(), + _t: PhantomData, + }) + } +} + +pub struct RustlsConnectorService { + connector: Arc, + _t: PhantomData<(T, U)>, +} + +impl Service for RustlsConnectorService +where + U: AsyncRead + AsyncWrite + fmt::Debug, +{ + type Request = Connection; + type Response = Connection>; + type Error = std::io::Error; + type Future = ConnectAsyncExt; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, stream: Connection) -> Self::Future { + trace!("SSL Handshake start for: {:?}", stream.host()); + let (io, stream) = stream.replace(()); + let host = DNSNameRef::try_from_ascii_str(stream.host()).unwrap(); + ConnectAsyncExt { + fut: TlsConnector::from(self.connector.clone()).connect(host, io), + stream: Some(stream), + } + } +} + +pub struct ConnectAsyncExt { + fut: Connect, + stream: Option>, +} + +impl Future for ConnectAsyncExt +where + U: AsyncRead + AsyncWrite + fmt::Debug, +{ + type Item = Connection>; + type Error = std::io::Error; + + fn poll(&mut self) -> Poll { + match self.fut.poll().map_err(|e| { + trace!("SSL Handshake error: {:?}", e); + e + })? { + Async::Ready(stream) => { + let s = self.stream.take().unwrap(); + trace!("SSL Handshake success: {:?}", s.host()); + Ok(Async::Ready(s.replace(stream).1)) + } + Async::NotReady => Ok(Async::NotReady), + } + } +} diff --git a/actix-connect/tests/test_connect.rs b/actix-connect/tests/test_connect.rs index 481457d3..738cfc42 100644 --- a/actix-connect/tests/test_connect.rs +++ b/actix-connect/tests/test_connect.rs @@ -26,6 +26,22 @@ fn test_string() { assert_eq!(con.peer_addr().unwrap(), srv.addr()); } +#[cfg(feature = "rust-tls")] +#[test] +fn test_rustls_string() { + let mut srv = TestServer::with(|| { + service_fn(|io: Io| { + Framed::new(io.into_parts().0, BytesCodec) + .send(Bytes::from_static(b"test")) + .then(|_| Ok::<_, ()>(())) + }) + }); + + let mut conn = default_connector(); + let addr = format!("localhost:{}", srv.port()); + let con = srv.run_on(move || conn.call(addr.into())).unwrap(); + assert_eq!(con.peer_addr().unwrap(), srv.addr()); +} #[test] fn test_static_str() { let mut srv = TestServer::with(|| { @@ -107,3 +123,20 @@ fn test_uri() { let con = srv.run_on(move || conn.call(addr.into())).unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); } + +#[cfg(feature = "rust-tls")] +#[test] +fn test_rustls_uri() { + let mut srv = TestServer::with(|| { + service_fn(|io: Io| { + Framed::new(io.into_parts().0, BytesCodec) + .send(Bytes::from_static(b"test")) + .then(|_| Ok::<_, ()>(())) + }) + }); + + let mut conn = default_connector(); + let addr = Uri::try_from(format!("https://localhost:{}", srv.port())).unwrap(); + let con = srv.run_on(move || conn.call(addr.into())).unwrap(); + assert_eq!(con.peer_addr().unwrap(), srv.addr()); +}