From 65fd23c648c7e8599ce16c971529a5775602d9f8 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 27 Sep 2018 20:23:30 -0700 Subject: [PATCH] add native-tls support --- .travis.yml | 6 +- src/ssl/mod.rs | 10 +-- src/ssl/nativetls.rs | 166 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 8 deletions(-) create mode 100644 src/ssl/nativetls.rs diff --git a/.travis.yml b/.travis.yml index 54d57dec..48a1414a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -32,12 +32,12 @@ script: - | if [[ "$TRAVIS_RUST_VERSION" != "nightly" ]]; then cargo clean - cargo test --features="ssl" -- --nocapture + cargo test --features="ssl,tls" -- --nocapture fi - | if [[ "$TRAVIS_RUST_VERSION" == "nightly" ]]; then RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install -f cargo-tarpaulin - cargo tarpaulin --features="ssl" --out Xml + cargo tarpaulin --features="ssl,tls" --out Xml bash <(curl -s https://codecov.io/bash) echo "Uploaded code coverage" fi @@ -46,7 +46,7 @@ script: after_success: - | if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "beta" ]]; then - cargo doc --features "ssl" --no-deps && + cargo doc --features "ssl,tls" --no-deps && echo "" > target/doc/index.html && git clone https://github.com/davisp/ghp-import.git && ./ghp-import/ghp_import.py -n -p -f -m "Documentation upload" -r https://"$GH_TOKEN"@github.com/"$TRAVIS_REPO_SLUG.git" target/doc && diff --git a/src/ssl/mod.rs b/src/ssl/mod.rs index f512ab29..81a364df 100644 --- a/src/ssl/mod.rs +++ b/src/ssl/mod.rs @@ -8,6 +8,11 @@ mod openssl; #[cfg(feature = "ssl")] pub use self::openssl::{OpensslAcceptor, OpensslConnector}; +#[cfg(feature = "tls")] +mod nativetls; +#[cfg(feature = "tls")] +pub use self::nativetls::{NativeTlsAcceptor, TlsStream}; + pub(crate) const MAX_CONN: AtomicUsize = AtomicUsize::new(256); /// Sets the maximum per-worker concurrent ssl connection establish process. @@ -24,11 +29,6 @@ thread_local! { static MAX_CONN_COUNTER: Counter = Counter::new(MAX_CONN.load(Ordering::Relaxed)); } -// #[cfg(feature = "tls")] -// mod nativetls; -// #[cfg(feature = "tls")] -// pub use self::nativetls::{NativeTlsAcceptor, TlsStream}; - // #[cfg(feature = "rust-tls")] // mod rustls; // #[cfg(feature = "rust-tls")] diff --git a/src/ssl/nativetls.rs b/src/ssl/nativetls.rs new file mode 100644 index 00000000..7b9e2cd9 --- /dev/null +++ b/src/ssl/nativetls.rs @@ -0,0 +1,166 @@ +use std::io; +use std::marker::PhantomData; + +use futures::{future::ok, future::FutureResult, Async, Future, Poll}; +use native_tls::{self, Error, HandshakeError, TlsAcceptor}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use super::MAX_CONN_COUNTER; +use counter::{Counter, CounterGuard}; +use service::{NewService, Service}; + +/// Support `SSL` connections via native-tls package +/// +/// `tls` feature enables `NativeTlsAcceptor` type +pub struct NativeTlsAcceptor { + acceptor: TlsAcceptor, + io: PhantomData, +} + +impl NativeTlsAcceptor { + /// Create `NativeTlsAcceptor` instance + pub fn new(acceptor: TlsAcceptor) -> Self { + NativeTlsAcceptor { + acceptor: acceptor.into(), + io: PhantomData, + } + } +} + +impl Clone for NativeTlsAcceptor { + fn clone(&self) -> Self { + Self { + acceptor: self.acceptor.clone(), + io: PhantomData, + } + } +} + +impl NewService for NativeTlsAcceptor { + type Request = T; + type Response = TlsStream; + type Error = Error; + type Service = NativeTlsAcceptorService; + type InitError = (); + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + MAX_CONN_COUNTER.with(|conns| { + ok(NativeTlsAcceptorService { + acceptor: self.acceptor.clone(), + conns: conns.clone(), + io: PhantomData, + }) + }) + } +} + +pub struct NativeTlsAcceptorService { + acceptor: TlsAcceptor, + io: PhantomData, + conns: Counter, +} + +impl Service for NativeTlsAcceptorService { + type Request = T; + type Response = TlsStream; + type Error = Error; + type Future = Accept; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + if self.conns.available() { + Ok(Async::Ready(())) + } else { + Ok(Async::NotReady) + } + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + Accept { + _guard: self.conns.get(), + inner: Some(self.acceptor.accept(req)), + } + } +} + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +/// +/// A `TlsStream` represents a handshake that has been completed successfully +/// and both the server and the client are ready for receiving and sending +/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written +/// to a `TlsStream` are encrypted when passing through to `S`. +#[derive(Debug)] +pub struct TlsStream { + inner: native_tls::TlsStream, +} + +/// Future returned from `NativeTlsAcceptor::accept` which will resolve +/// once the accept handshake has finished. +pub struct Accept { + inner: Option, HandshakeError>>, + _guard: CounterGuard, +} + +impl Future for Accept { + type Item = TlsStream; + type Error = Error; + + fn poll(&mut self) -> Poll { + match self.inner.take().expect("cannot poll MidHandshake twice") { + Ok(stream) => Ok(TlsStream { inner: stream }.into()), + Err(HandshakeError::Failure(e)) => Err(e), + Err(HandshakeError::WouldBlock(s)) => match s.handshake() { + Ok(stream) => Ok(TlsStream { inner: stream }.into()), + Err(HandshakeError::Failure(e)) => Err(e), + Err(HandshakeError::WouldBlock(s)) => { + self.inner = Some(Err(HandshakeError::WouldBlock(s))); + Ok(Async::NotReady) + } + }, + } + } +} + +impl TlsStream { + /// Get access to the internal `native_tls::TlsStream` stream which also + /// transitively allows access to `S`. + pub fn get_ref(&self) -> &native_tls::TlsStream { + &self.inner + } + + /// Get mutable access to the internal `native_tls::TlsStream` stream which + /// also transitively allows mutable access to `S`. + pub fn get_mut(&mut self) -> &mut native_tls::TlsStream { + &mut self.inner + } +} + +impl io::Read for TlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.inner.read(buf) + } +} + +impl io::Write for TlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() + } +} + +impl AsyncRead for TlsStream {} + +impl AsyncWrite for TlsStream { + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.inner.shutdown() { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e), + } + self.inner.get_mut().shutdown() + } +}