From f3f1e04853dbaf1ff7f014ff50319fc822e20240 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 3 Aug 2018 16:09:46 -0700 Subject: [PATCH] refactor ssl support --- CHANGES.md | 5 +- src/server/accept.rs | 60 ++-- src/server/h1.rs | 2 +- src/server/h2.rs | 2 +- src/server/mod.rs | 129 +++----- src/server/settings.rs | 24 +- src/server/srv.rs | 379 ++++++++-------------- src/server/ssl/mod.rs | 14 + src/server/ssl/nativetls.rs | 67 ++++ src/server/ssl/openssl.rs | 96 ++++++ src/server/ssl/rustls.rs | 92 ++++++ src/server/worker.rs | 622 ++++++++++++++++++++++-------------- 12 files changed, 879 insertions(+), 613 deletions(-) create mode 100644 src/server/ssl/mod.rs create mode 100644 src/server/ssl/nativetls.rs create mode 100644 src/server/ssl/openssl.rs create mode 100644 src/server/ssl/rustls.rs diff --git a/CHANGES.md b/CHANGES.md index c6e4a9436..4d1610c09 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,9 +4,12 @@ ### Added -* Added `HttpServer::max_connections()` and `HttpServer::max_sslrate()`, +* Added `HttpServer::maxconn()` and `HttpServer::maxconnrate()`, accept backpressure #250 +* Allow to customize connection handshake process via `HttpServer::listen_with()` + and `HttpServer::bind_with()` methods + ### Fixed * Use zlib instead of raw deflate for decoding and encoding payloads with diff --git a/src/server/accept.rs b/src/server/accept.rs index f846e4a40..e837852d3 100644 --- a/src/server/accept.rs +++ b/src/server/accept.rs @@ -9,8 +9,8 @@ use tokio_timer::Delay; use actix::{msgs::Execute, Arbiter, System}; -use super::srv::{ServerCommand, Socket}; -use super::worker::{Conn, WorkerClient}; +use super::srv::ServerCommand; +use super::worker::{Conn, Socket, Token, WorkerClient}; pub(crate) enum Command { Pause, @@ -21,7 +21,7 @@ pub(crate) enum Command { struct ServerSocketInfo { addr: net::SocketAddr, - token: usize, + token: Token, sock: mio::net::TcpListener, timeout: Option, } @@ -31,20 +31,24 @@ pub(crate) struct AcceptNotify { ready: mio::SetReadiness, maxconn: usize, maxconn_low: usize, - maxsslrate: usize, - maxsslrate_low: usize, + maxconnrate: usize, + maxconnrate_low: usize, } impl AcceptNotify { - pub fn new(ready: mio::SetReadiness, maxconn: usize, maxsslrate: usize) -> Self { + pub fn new(ready: mio::SetReadiness, maxconn: usize, maxconnrate: usize) -> Self { let maxconn_low = if maxconn > 10 { maxconn - 10 } else { 0 }; - let maxsslrate_low = if maxsslrate > 10 { maxsslrate - 10 } else { 0 }; + let maxconnrate_low = if maxconnrate > 10 { + maxconnrate - 10 + } else { + 0 + }; AcceptNotify { ready, maxconn, maxconn_low, - maxsslrate, - maxsslrate_low, + maxconnrate, + maxconnrate_low, } } @@ -53,8 +57,8 @@ impl AcceptNotify { let _ = self.ready.set_readiness(mio::Ready::readable()); } } - pub fn notify_maxsslrate(&self, sslrate: usize) { - if sslrate > self.maxsslrate_low && sslrate <= self.maxsslrate { + pub fn notify_maxconnrate(&self, connrate: usize) { + if connrate > self.maxconnrate_low && connrate <= self.maxconnrate { let _ = self.ready.set_readiness(mio::Ready::readable()); } } @@ -78,7 +82,7 @@ pub(crate) struct AcceptLoop { mpsc::UnboundedReceiver, )>, maxconn: usize, - maxsslrate: usize, + maxconnrate: usize, } impl AcceptLoop { @@ -94,7 +98,7 @@ impl AcceptLoop { notify_ready, notify_reg: Some(notify_reg), maxconn: 102_400, - maxsslrate: 256, + maxconnrate: 256, rx: Some(rx), srv: Some(mpsc::unbounded()), } @@ -106,19 +110,19 @@ impl AcceptLoop { } pub fn get_notify(&self) -> AcceptNotify { - AcceptNotify::new(self.notify_ready.clone(), self.maxconn, self.maxsslrate) + AcceptNotify::new(self.notify_ready.clone(), self.maxconn, self.maxconnrate) } - pub fn max_connections(&mut self, num: usize) { + pub fn maxconn(&mut self, num: usize) { self.maxconn = num; } - pub fn max_sslrate(&mut self, num: usize) { - self.maxsslrate = num; + pub fn maxconnrate(&mut self, num: usize) { + self.maxconnrate = num; } pub(crate) fn start( - &mut self, socks: Vec<(usize, Socket)>, workers: Vec, + &mut self, socks: Vec, workers: Vec, ) -> mpsc::UnboundedReceiver { let (tx, rx) = self.srv.take().expect("Can not re-use AcceptInfo"); @@ -127,7 +131,7 @@ impl AcceptLoop { self.cmd_reg.take().expect("Can not re-use AcceptInfo"), self.notify_reg.take().expect("Can not re-use AcceptInfo"), self.maxconn, - self.maxsslrate, + self.maxconnrate, socks, tx, workers, @@ -145,7 +149,7 @@ struct Accept { timer: (mio::Registration, mio::SetReadiness), next: usize, maxconn: usize, - maxsslrate: usize, + maxconnrate: usize, backpressure: bool, } @@ -171,8 +175,8 @@ impl Accept { #![cfg_attr(feature = "cargo-clippy", allow(too_many_arguments))] pub(crate) fn start( rx: sync_mpsc::Receiver, cmd_reg: mio::Registration, - notify_reg: mio::Registration, maxconn: usize, maxsslrate: usize, - socks: Vec<(usize, Socket)>, srv: mpsc::UnboundedSender, + notify_reg: mio::Registration, maxconn: usize, maxconnrate: usize, + socks: Vec, srv: mpsc::UnboundedSender, workers: Vec, ) { let sys = System::current(); @@ -184,7 +188,7 @@ impl Accept { System::set_current(sys); let mut accept = Accept::new(rx, socks, workers, srv); accept.maxconn = maxconn; - accept.maxsslrate = maxsslrate; + accept.maxconnrate = maxconnrate; // Start listening for incoming commands if let Err(err) = accept.poll.register( @@ -211,7 +215,7 @@ impl Accept { } fn new( - rx: sync_mpsc::Receiver, socks: Vec<(usize, Socket)>, + rx: sync_mpsc::Receiver, socks: Vec, workers: Vec, srv: mpsc::UnboundedSender, ) -> Accept { // Create a poll instance @@ -222,7 +226,7 @@ impl Accept { // Start accept let mut sockets = Slab::new(); - for (stoken, sock) in socks { + for sock in socks { let server = mio::net::TcpListener::from_std(sock.lst) .expect("Can not create mio::net::TcpListener"); @@ -240,7 +244,7 @@ impl Accept { } entry.insert(ServerSocketInfo { - token: stoken, + token: sock.token, addr: sock.addr, sock: server, timeout: None, @@ -264,7 +268,7 @@ impl Accept { next: 0, timer: (tm, tmr), maxconn: 102_400, - maxsslrate: 256, + maxconnrate: 256, backpressure: false, } } @@ -427,7 +431,7 @@ impl Accept { let mut idx = 0; while idx < self.workers.len() { idx += 1; - if self.workers[self.next].available(self.maxconn, self.maxsslrate) { + if self.workers[self.next].available(self.maxconn, self.maxconnrate) { match self.workers[self.next].send(msg) { Ok(_) => { self.next = (self.next + 1) % self.workers.len(); diff --git a/src/server/h1.rs b/src/server/h1.rs index 9f3bda28f..085cea005 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -375,7 +375,7 @@ where self.keepalive_timer.take(); // search handler for request - for h in self.settings.handlers().iter_mut() { + for h in self.settings.handlers().iter() { msg = match h.handle(msg) { Ok(mut pipe) => { if self.tasks.is_empty() { diff --git a/src/server/h2.rs b/src/server/h2.rs index e5355a1fd..cb5367c5e 100644 --- a/src/server/h2.rs +++ b/src/server/h2.rs @@ -347,7 +347,7 @@ impl Entry { // start request processing let mut task = None; - for h in settings.handlers().iter_mut() { + for h in settings.handlers().iter() { msg = match h.handle(msg) { Ok(t) => { task = Some(t); diff --git a/src/server/mod.rs b/src/server/mod.rs index 429e293f2..55de25db4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,10 +1,11 @@ //! Http server use std::net::Shutdown; -use std::{io, time}; +use std::{io, net, time}; use bytes::{BufMut, BytesMut}; -use futures::{Async, Poll}; +use futures::{Async, Future, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_reactor::Handle; use tokio_tcp::TcpStream; pub(crate) mod accept; @@ -21,11 +22,13 @@ pub(crate) mod message; pub(crate) mod output; pub(crate) mod settings; mod srv; +mod ssl; mod worker; pub use self::message::Request; pub use self::settings::ServerSettings; pub use self::srv::HttpServer; +pub use self::ssl::*; #[doc(hidden)] pub use self::helpers::write_content_length; @@ -72,6 +75,13 @@ where HttpServer::new(factory) } +bitflags! { + pub struct ServerFlags: u8 { + const HTTP1 = 0b0000_0001; + const HTTP2 = 0b0000_0010; + } +} + #[derive(Debug, PartialEq, Clone, Copy)] /// Server keep-alive setting pub enum KeepAlive { @@ -179,6 +189,34 @@ impl IntoHttpHandler for T { } } +pub(crate) trait IntoAsyncIo { + type Io: AsyncRead + AsyncWrite; + + fn into_async_io(self) -> Result; +} + +impl IntoAsyncIo for net::TcpStream { + type Io = TcpStream; + + fn into_async_io(self) -> Result { + TcpStream::from_std(self, &Handle::default()) + } +} + +/// Trait implemented by types that could accept incomming socket connections. +pub trait AcceptorService: Clone { + /// Established connection type + type Accepted: IoStream; + /// Future describes async accept process. + type Future: Future + 'static; + + /// Establish new connection + fn accept(&self, io: Io) -> Self::Future; + + /// Scheme + fn scheme(&self) -> &'static str; +} + #[doc(hidden)] #[derive(Debug)] pub enum WriterState { @@ -267,90 +305,3 @@ impl IoStream for TcpStream { TcpStream::set_linger(self, dur) } } - -#[cfg(feature = "alpn")] -use tokio_openssl::SslStream; - -#[cfg(feature = "alpn")] -impl IoStream for SslStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = self.get_mut().shutdown(); - Ok(()) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().get_mut().set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().get_mut().set_linger(dur) - } -} - -#[cfg(feature = "tls")] -use tokio_tls::TlsStream; - -#[cfg(feature = "tls")] -impl IoStream for TlsStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = self.get_mut().shutdown(); - Ok(()) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().get_mut().set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().get_mut().set_linger(dur) - } -} - -#[cfg(feature = "rust-tls")] -use rustls::{ClientSession, ServerSession}; -#[cfg(feature = "rust-tls")] -use tokio_rustls::TlsStream as RustlsStream; - -#[cfg(feature = "rust-tls")] -impl IoStream for RustlsStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = ::shutdown(self); - Ok(()) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().0.set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().0.set_linger(dur) - } -} - -#[cfg(feature = "rust-tls")] -impl IoStream for RustlsStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = ::shutdown(self); - Ok(()) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().0.set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().0.set_linger(dur) - } -} diff --git a/src/server/settings.rs b/src/server/settings.rs index 8e30646d9..508be67dd 100644 --- a/src/server/settings.rs +++ b/src/server/settings.rs @@ -132,7 +132,7 @@ impl ServerSettings { const DATE_VALUE_LENGTH: usize = 29; pub(crate) struct WorkerSettings { - h: RefCell>, + h: Vec, keep_alive: u64, ka_enabled: bool, bytes: Rc, @@ -140,14 +140,14 @@ pub(crate) struct WorkerSettings { channels: Arc, node: RefCell>, date: UnsafeCell, - sslrate: Arc, + connrate: Arc, notify: AcceptNotify, } impl WorkerSettings { pub(crate) fn new( h: Vec, keep_alive: KeepAlive, settings: ServerSettings, - notify: AcceptNotify, channels: Arc, sslrate: Arc, + notify: AcceptNotify, channels: Arc, connrate: Arc, ) -> WorkerSettings { let (keep_alive, ka_enabled) = match keep_alive { KeepAlive::Timeout(val) => (val as u64, true), @@ -156,7 +156,7 @@ impl WorkerSettings { }; WorkerSettings { - h: RefCell::new(h), + h, bytes: Rc::new(SharedBytesPool::new()), messages: RequestPool::pool(settings), node: RefCell::new(Node::head()), @@ -164,7 +164,7 @@ impl WorkerSettings { keep_alive, ka_enabled, channels, - sslrate, + connrate, notify, } } @@ -177,8 +177,8 @@ impl WorkerSettings { self.node.borrow_mut() } - pub fn handlers(&self) -> RefMut> { - self.h.borrow_mut() + pub fn handlers(&self) -> &Vec { + &self.h } pub fn keep_alive(&self) -> u64 { @@ -230,13 +230,13 @@ impl WorkerSettings { } #[allow(dead_code)] - pub(crate) fn ssl_conn_add(&self) { - self.sslrate.fetch_add(1, Ordering::Relaxed); + pub(crate) fn conn_rate_add(&self) { + self.connrate.fetch_add(1, Ordering::Relaxed); } #[allow(dead_code)] - pub(crate) fn ssl_conn_del(&self) { - let val = self.sslrate.fetch_sub(1, Ordering::Relaxed); - self.notify.notify_maxsslrate(val); + pub(crate) fn conn_rate_del(&self) { + let val = self.connrate.fetch_sub(1, Ordering::Relaxed); + self.notify.notify_maxconnrate(val); } } diff --git a/src/server/srv.rs b/src/server/srv.rs index b6bd21967..33c820aa7 100644 --- a/src/server/srv.rs +++ b/src/server/srv.rs @@ -10,15 +10,15 @@ use actix::{ use futures::sync::mpsc; use futures::{Future, Sink, Stream}; -use net2::TcpBuilder; use num_cpus; use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_tcp::TcpStream; #[cfg(feature = "tls")] use native_tls::TlsAcceptor; #[cfg(feature = "alpn")] -use openssl::ssl::{AlpnError, SslAcceptorBuilder}; +use openssl::ssl::SslAcceptorBuilder; #[cfg(feature = "rust-tls")] use rustls::ServerConfig; @@ -26,43 +26,25 @@ use rustls::ServerConfig; use super::accept::{AcceptLoop, AcceptNotify, Command}; use super::channel::{HttpChannel, WrapperStream}; use super::settings::{ServerSettings, WorkerSettings}; -use super::worker::{ - Conn, StopWorker, StreamHandlerType, Worker, WorkerClient, WorkersPool, -}; -use super::{IntoHttpHandler, IoStream, KeepAlive}; +use super::worker::{Conn, StopWorker, Token, Worker, WorkerClient, WorkerFactory}; +use super::{AcceptorService, IntoHttpHandler, IoStream, KeepAlive}; use super::{PauseServer, ResumeServer, StopServer}; -#[cfg(feature = "alpn")] -fn configure_alpn(builder: &mut SslAcceptorBuilder) -> io::Result<()> { - builder.set_alpn_protos(b"\x02h2\x08http/1.1")?; - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(AlpnError::NOACK) - } - }); - Ok(()) -} - /// An HTTP Server pub struct HttpServer where H: IntoHttpHandler + 'static, { - h: Option>>, threads: usize, - backlog: i32, - sockets: Vec, - pool: WorkersPool, - workers: Vec<(usize, Addr>)>, + factory: WorkerFactory, + workers: Vec<(usize, Addr)>, accept: AcceptLoop, exit: bool, shutdown_timeout: u16, signals: Option>, no_http2: bool, no_signals: bool, + settings: Option>>, } pub(crate) enum ServerCommand { @@ -76,12 +58,6 @@ where type Context = Context; } -pub(crate) struct Socket { - pub lst: net::TcpListener, - pub addr: net::SocketAddr, - pub tp: StreamHandlerType, -} - impl HttpServer where H: IntoHttpHandler + 'static, @@ -95,18 +71,16 @@ where let f = move || (factory)().into_iter().collect(); HttpServer { - h: None, threads: num_cpus::get(), - backlog: 2048, - pool: WorkersPool::new(f), + factory: WorkerFactory::new(f), workers: Vec::new(), - sockets: Vec::new(), accept: AcceptLoop::new(), exit: false, shutdown_timeout: 30, signals: None, no_http2: false, no_signals: false, + settings: None, } } @@ -130,7 +104,7 @@ where /// /// This method should be called before `bind()` method call. pub fn backlog(mut self, num: i32) -> Self { - self.backlog = num; + self.factory.backlog = num; self } @@ -140,20 +114,19 @@ where /// for each worker. /// /// By default max connections is set to a 100k. - pub fn max_connections(mut self, num: usize) -> Self { - self.accept.max_connections(num); + pub fn maxconn(mut self, num: usize) -> Self { + self.accept.maxconn(num); self } - /// Sets the maximum concurrent per-worker number of SSL handshakes. + /// Sets the maximum per-worker concurrent connection establish process. /// /// All listeners will stop accepting connections when this limit is reached. It - /// can be used to limit the global SSL CPU usage regardless of each worker - /// capacity. + /// can be used to limit the global SSL CPU usage. /// /// By default max connections is set to a 256. - pub fn max_sslrate(mut self, num: usize) -> Self { - self.accept.max_sslrate(num); + pub fn maxconnrate(mut self, num: usize) -> Self { + self.accept.maxconnrate(num); self } @@ -161,7 +134,7 @@ where /// /// By default keep alive is set to a `Os`. pub fn keep_alive>(mut self, val: T) -> Self { - self.pool.keep_alive = val.into(); + self.factory.keep_alive = val.into(); self } @@ -171,7 +144,7 @@ where /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. /// html#method.host) documentation for more information. pub fn server_hostname(mut self, val: String) -> Self { - self.pool.host = Some(val); + self.factory.host = Some(val); self } @@ -215,7 +188,7 @@ where /// Get addresses of bound sockets. pub fn addrs(&self) -> Vec { - self.sockets.iter().map(|s| s.addr).collect() + self.factory.addrs() } /// Get addresses of bound sockets and the scheme for it. @@ -225,10 +198,7 @@ where /// and the user should be presented with an enumeration of which /// socket requires which protocol. pub fn addrs_with_scheme(&self) -> Vec<(net::SocketAddr, &str)> { - self.sockets - .iter() - .map(|s| (s.addr, s.tp.scheme())) - .collect() + self.factory.addrs_with_scheme() } /// Use listener for accepting incoming connection requests @@ -236,175 +206,177 @@ where /// HttpServer does not change any configuration for TcpListener, /// it needs to be configured before passing it to listen() method. pub fn listen(mut self, lst: net::TcpListener) -> Self { - let addr = lst.local_addr().unwrap(); - self.sockets.push(Socket { - addr, - lst, - tp: StreamHandlerType::Normal, - }); + self.factory.listen(lst); self } + /// Use listener for accepting incoming connection requests + pub fn listen_with( + mut self, lst: net::TcpListener, acceptor: A, + ) -> io::Result + where + A: AcceptorService + Send + 'static, + { + self.factory.listen_with(lst, acceptor); + Ok(self) + } + #[cfg(feature = "tls")] + #[doc(hidden)] + #[deprecated( + since = "0.7.4", + note = "please use `actix_web::HttpServer::listen_with()` and `actix_web::server::NativeTlsAcceptor` instead" + )] /// Use listener for accepting incoming tls connection requests /// /// HttpServer does not change any configuration for TcpListener, /// it needs to be configured before passing it to listen() method. - pub fn listen_tls(mut self, lst: net::TcpListener, acceptor: TlsAcceptor) -> Self { - let addr = lst.local_addr().unwrap(); - self.sockets.push(Socket { - addr, - lst, - tp: StreamHandlerType::Tls(acceptor.clone()), - }); - self + pub fn listen_tls( + self, lst: net::TcpListener, acceptor: TlsAcceptor, + ) -> io::Result { + use super::NativeTlsAcceptor; + + self.listen_with(lst, NativeTlsAcceptor::new(acceptor)) } #[cfg(feature = "alpn")] + #[doc(hidden)] + #[deprecated( + since = "0.7.4", + note = "please use `actix_web::HttpServer::listen_with()` and `actix_web::server::OpensslAcceptor` instead" + )] /// Use listener for accepting incoming tls connection requests /// /// This method sets alpn protocols to "h2" and "http/1.1" pub fn listen_ssl( - mut self, lst: net::TcpListener, mut builder: SslAcceptorBuilder, + self, lst: net::TcpListener, builder: SslAcceptorBuilder, ) -> io::Result { + use super::{OpensslAcceptor, ServerFlags}; + // alpn support - if !self.no_http2 { - configure_alpn(&mut builder)?; - } - let acceptor = builder.build(); - let addr = lst.local_addr().unwrap(); - self.sockets.push(Socket { - addr, - lst, - tp: StreamHandlerType::Alpn(acceptor.clone()), - }); - Ok(self) + let flags = if !self.no_http2 { + ServerFlags::HTTP1 + } else { + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + self.listen_with(lst, OpensslAcceptor::with_flags(builder, flags)?) } #[cfg(feature = "rust-tls")] + #[doc(hidden)] + #[deprecated( + since = "0.7.4", + note = "please use `actix_web::HttpServer::listen_with()` and `actix_web::server::RustlsAcceptor` instead" + )] /// Use listener for accepting incoming tls connection requests /// /// This method sets alpn protocols to "h2" and "http/1.1" pub fn listen_rustls( - mut self, lst: net::TcpListener, mut builder: ServerConfig, + self, lst: net::TcpListener, builder: ServerConfig, ) -> io::Result { + use super::{RustlsAcceptor, ServerFlags}; + // alpn support - if !self.no_http2 { - builder.set_protocols(&vec!["h2".to_string(), "http/1.1".to_string()]); - } - let addr = lst.local_addr().unwrap(); - self.sockets.push(Socket { - addr, - lst, - tp: StreamHandlerType::Rustls(Arc::new(builder)), - }); - Ok(self) - } - - fn bind2(&mut self, addr: S) -> io::Result> { - let mut err = None; - let mut succ = false; - let mut sockets = Vec::new(); - for addr in addr.to_socket_addrs()? { - match create_tcp_listener(addr, self.backlog) { - Ok(lst) => { - succ = true; - let addr = lst.local_addr().unwrap(); - sockets.push(Socket { - lst, - addr, - tp: StreamHandlerType::Normal, - }); - } - Err(e) => err = Some(e), - } - } - - if !succ { - if let Some(e) = err.take() { - Err(e) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - "Can not bind to address.", - )) - } + let flags = if !self.no_http2 { + ServerFlags::HTTP1 } else { - Ok(sockets) - } + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + self.listen_with(lst, RustlsAcceptor::with_flags(builder, flags)) } /// The socket address to bind /// /// To bind multiple addresses this method can be called multiple times. pub fn bind(mut self, addr: S) -> io::Result { - let sockets = self.bind2(addr)?; - self.sockets.extend(sockets); + self.factory.bind(addr)?; + Ok(self) + } + + /// Start listening for incoming connections with supplied acceptor. + #[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))] + pub fn bind_with(mut self, addr: S, acceptor: A) -> io::Result + where + S: net::ToSocketAddrs, + A: AcceptorService + Send + 'static, + { + self.factory.bind_with(addr, &acceptor)?; Ok(self) } #[cfg(feature = "tls")] + #[doc(hidden)] + #[deprecated( + since = "0.7.4", + note = "please use `actix_web::HttpServer::bind_with()` and `actix_web::server::NativeTlsAcceptor` instead" + )] /// The ssl socket address to bind /// /// To bind multiple addresses this method can be called multiple times. pub fn bind_tls( - mut self, addr: S, acceptor: TlsAcceptor, + self, addr: S, acceptor: TlsAcceptor, ) -> io::Result { - let sockets = self.bind2(addr)?; - self.sockets.extend(sockets.into_iter().map(|mut s| { - s.tp = StreamHandlerType::Tls(acceptor.clone()); - s - })); - Ok(self) + use super::NativeTlsAcceptor; + + self.bind_with(addr, NativeTlsAcceptor::new(acceptor)) } #[cfg(feature = "alpn")] + #[doc(hidden)] + #[deprecated( + since = "0.7.4", + note = "please use `actix_web::HttpServer::bind_with()` and `actix_web::server::OpensslAcceptor` instead" + )] /// Start listening for incoming tls connections. /// /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn bind_ssl( - mut self, addr: S, mut builder: SslAcceptorBuilder, - ) -> io::Result { - // alpn support - if !self.no_http2 { - configure_alpn(&mut builder)?; - } + pub fn bind_ssl(self, addr: S, builder: SslAcceptorBuilder) -> io::Result + where + S: net::ToSocketAddrs, + { + use super::{OpensslAcceptor, ServerFlags}; - let acceptor = builder.build(); - let sockets = self.bind2(addr)?; - self.sockets.extend(sockets.into_iter().map(|mut s| { - s.tp = StreamHandlerType::Alpn(acceptor.clone()); - s - })); - Ok(self) + // alpn support + let flags = if !self.no_http2 { + ServerFlags::HTTP1 + } else { + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + self.bind_with(addr, OpensslAcceptor::with_flags(builder, flags)?) } #[cfg(feature = "rust-tls")] + #[doc(hidden)] + #[deprecated( + since = "0.7.4", + note = "please use `actix_web::HttpServer::bind_with()` and `actix_web::server::RustlsAcceptor` instead" + )] /// Start listening for incoming tls connections. /// /// This method sets alpn protocols to "h2" and "http/1.1" pub fn bind_rustls( - mut self, addr: S, mut builder: ServerConfig, + self, addr: S, builder: ServerConfig, ) -> io::Result { - // alpn support - if !self.no_http2 { - builder.set_protocols(&vec!["h2".to_string(), "http/1.1".to_string()]); - } + use super::{RustlsAcceptor, ServerFlags}; - let builder = Arc::new(builder); - let sockets = self.bind2(addr)?; - self.sockets.extend(sockets.into_iter().map(move |mut s| { - s.tp = StreamHandlerType::Rustls(builder.clone()); - s - })); - Ok(self) + // alpn support + let flags = if !self.no_http2 { + ServerFlags::HTTP1 + } else { + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + self.bind_with(addr, RustlsAcceptor::with_flags(builder, flags)) } fn start_workers(&mut self, notify: &AcceptNotify) -> Vec { // start workers let mut workers = Vec::new(); for idx in 0..self.threads { - let (worker, addr) = self.pool.start(idx, notify.clone()); + let (worker, addr) = self.factory.start(idx, notify.clone()); workers.push(worker); self.workers.push((idx, addr)); } @@ -453,23 +425,18 @@ impl HttpServer { /// } /// ``` pub fn start(mut self) -> Addr { - if self.sockets.is_empty() { + let sockets = self.factory.take_sockets(); + if sockets.is_empty() { panic!("HttpServer::bind() has to be called before start()"); } else { - let mut addrs: Vec<(usize, Socket)> = Vec::new(); - - for socket in self.sockets.drain(..) { - let token = self.pool.insert(socket.addr, socket.tp.clone()); - addrs.push((token, socket)); - } let notify = self.accept.get_notify(); let workers = self.start_workers(¬ify); // start accept thread - for (_, sock) in &addrs { + for sock in &sockets { info!("Starting server on http://{}", sock.addr); } - let rx = self.accept.start(addrs, workers.clone()); + let rx = self.accept.start(sockets, workers.clone()); // start http server actor let signals = self.subscribe_to_signals(); @@ -511,64 +478,6 @@ impl HttpServer { } } -#[doc(hidden)] -#[cfg(feature = "tls")] -#[deprecated( - since = "0.6.0", - note = "please use `actix_web::HttpServer::bind_tls` instead" -)] -impl HttpServer { - /// Start listening for incoming tls connections. - pub fn start_tls(mut self, acceptor: TlsAcceptor) -> io::Result> { - for sock in &mut self.sockets { - match sock.tp { - StreamHandlerType::Normal => (), - _ => continue, - } - sock.tp = StreamHandlerType::Tls(acceptor.clone()); - } - Ok(self.start()) - } -} - -#[doc(hidden)] -#[cfg(feature = "alpn")] -#[deprecated( - since = "0.6.0", - note = "please use `actix_web::HttpServer::bind_ssl` instead" -)] -impl HttpServer { - /// Start listening for incoming tls connections. - /// - /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn start_ssl( - mut self, mut builder: SslAcceptorBuilder, - ) -> io::Result> { - // alpn support - if !self.no_http2 { - builder.set_alpn_protos(b"\x02h2\x08http/1.1")?; - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(AlpnError::NOACK) - } - }); - } - - let acceptor = builder.build(); - for sock in &mut self.sockets { - match sock.tp { - StreamHandlerType::Normal => (), - _ => continue, - } - sock.tp = StreamHandlerType::Alpn(acceptor.clone()); - } - Ok(self.start()) - } -} - impl HttpServer { /// Start listening for incoming connections from a stream. /// @@ -580,14 +489,14 @@ impl HttpServer { { // set server settings let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap(); - let settings = ServerSettings::new(Some(addr), &self.pool.host, secure); - let apps: Vec<_> = (*self.pool.factory)() + let settings = ServerSettings::new(Some(addr), &self.factory.host, secure); + let apps: Vec<_> = (*self.factory.factory)() .into_iter() .map(|h| h.into_handler()) .collect(); - self.h = Some(Rc::new(WorkerSettings::new( + self.settings = Some(Rc::new(WorkerSettings::new( apps, - self.pool.keep_alive, + self.factory.keep_alive, settings, AcceptNotify::default(), Arc::new(AtomicUsize::new(0)), @@ -599,7 +508,7 @@ impl HttpServer { let addr = HttpServer::create(move |ctx| { ctx.add_message_stream(stream.map_err(|_| ()).map(move |t| Conn { io: WrapperStream::new(t), - token: 0, + token: Token::new(0), peer: None, http2: false, })); @@ -672,7 +581,7 @@ impl StreamHandler for HttpServer { } let (worker, addr) = - self.pool.start(new_idx, self.accept.get_notify()); + self.factory.start(new_idx, self.accept.get_notify()); self.workers.push((new_idx, addr)); self.accept.send(Command::Worker(worker)); } @@ -690,7 +599,7 @@ where fn handle(&mut self, msg: Conn, _: &mut Context) -> Self::Result { Arbiter::spawn(HttpChannel::new( - Rc::clone(self.h.as_ref().unwrap()), + Rc::clone(self.settings.as_ref().unwrap()), msg.io, msg.peer, msg.http2, @@ -766,15 +675,3 @@ impl Handler for HttpServer { } } } - -fn create_tcp_listener( - addr: net::SocketAddr, backlog: i32, -) -> io::Result { - let builder = match addr { - net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, - net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, - }; - builder.reuse_address(true)?; - builder.bind(addr)?; - Ok(builder.listen(backlog)?) -} diff --git a/src/server/ssl/mod.rs b/src/server/ssl/mod.rs new file mode 100644 index 000000000..d99c4a584 --- /dev/null +++ b/src/server/ssl/mod.rs @@ -0,0 +1,14 @@ +#[cfg(feature = "alpn")] +mod openssl; +#[cfg(feature = "alpn")] +pub use self::openssl::OpensslAcceptor; + +#[cfg(feature = "tls")] +mod nativetls; +#[cfg(feature = "tls")] +pub use self::nativetls::NativeTlsAcceptor; + +#[cfg(feature = "rust-tls")] +mod rustls; +#[cfg(feature = "rust-tls")] +pub use self::rustls::RustlsAcceptor; diff --git a/src/server/ssl/nativetls.rs b/src/server/ssl/nativetls.rs new file mode 100644 index 000000000..8749599e9 --- /dev/null +++ b/src/server/ssl/nativetls.rs @@ -0,0 +1,67 @@ +use std::net::Shutdown; +use std::{io, time}; + +use futures::{Future, Poll}; +use native_tls::TlsAcceptor; +use tokio_tls::{AcceptAsync, TlsAcceptorExt, TlsStream}; + +use server::{AcceptorService, IoStream}; + +#[derive(Clone)] +/// Support `SSL` connections via native-tls package +/// +/// `tls` feature enables `NativeTlsAcceptor` type +pub struct NativeTlsAcceptor { + acceptor: TlsAcceptor, +} + +impl NativeTlsAcceptor { + /// Create `NativeTlsAcceptor` instance + pub fn new(acceptor: TlsAcceptor) -> Self { + NativeTlsAcceptor { acceptor } + } +} + +pub struct AcceptorFut(AcceptAsync); + +impl Future for AcceptorFut { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0 + .poll() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } +} + +impl AcceptorService for NativeTlsAcceptor { + type Accepted = TlsStream; + type Future = AcceptorFut; + + fn scheme(&self) -> &'static str { + "https" + } + + fn accept(&self, io: Io) -> Self::Future { + AcceptorFut(TlsAcceptorExt::accept_async(&self.acceptor, io)) + } +} + +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = self.get_mut().shutdown(); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().get_mut().set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_linger(dur) + } +} diff --git a/src/server/ssl/openssl.rs b/src/server/ssl/openssl.rs new file mode 100644 index 000000000..996c510dc --- /dev/null +++ b/src/server/ssl/openssl.rs @@ -0,0 +1,96 @@ +use std::net::Shutdown; +use std::{io, time}; + +use futures::{Future, Poll}; +use openssl::ssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; +use tokio_openssl::{AcceptAsync, SslAcceptorExt, SslStream}; + +use server::{AcceptorService, IoStream, ServerFlags}; + +#[derive(Clone)] +/// Support `SSL` connections via openssl package +/// +/// `alpn` feature enables `OpensslAcceptor` type +pub struct OpensslAcceptor { + acceptor: SslAcceptor, +} + +impl OpensslAcceptor { + /// Create `OpensslAcceptor` with enabled `HTTP/2` and `HTTP1.1` support. + pub fn new(builder: SslAcceptorBuilder) -> io::Result { + OpensslAcceptor::with_flags(builder, ServerFlags::HTTP1 | ServerFlags::HTTP2) + } + + /// Create `OpensslAcceptor` with custom server flags. + pub fn with_flags( + mut builder: SslAcceptorBuilder, flags: ServerFlags, + ) -> io::Result { + let mut protos = Vec::new(); + if flags.contains(ServerFlags::HTTP1) { + protos.extend(b"\x08http/1.1"); + } + if flags.contains(ServerFlags::HTTP2) { + protos.extend(b"\x02h2"); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(AlpnError::NOACK) + } + }); + } + + if !protos.is_empty() { + builder.set_alpn_protos(&protos)?; + } + + Ok(OpensslAcceptor { + acceptor: builder.build(), + }) + } +} + +pub struct AcceptorFut(AcceptAsync); + +impl Future for AcceptorFut { + type Item = SslStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0 + .poll() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } +} + +impl AcceptorService for OpensslAcceptor { + type Accepted = SslStream; + type Future = AcceptorFut; + + fn scheme(&self) -> &'static str { + "https" + } + + fn accept(&self, io: Io) -> Self::Future { + AcceptorFut(SslAcceptorExt::accept_async(&self.acceptor, io)) + } +} + +impl IoStream for SslStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = self.get_mut().shutdown(); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().get_mut().set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_linger(dur) + } +} diff --git a/src/server/ssl/rustls.rs b/src/server/ssl/rustls.rs new file mode 100644 index 000000000..45cb61be7 --- /dev/null +++ b/src/server/ssl/rustls.rs @@ -0,0 +1,92 @@ +use std::net::Shutdown; +use std::sync::Arc; +use std::{io, time}; + +use rustls::{ClientSession, ServerConfig, ServerSession}; +use tokio_io::AsyncWrite; +use tokio_rustls::{AcceptAsync, ServerConfigExt, TlsStream}; + +use server::{AcceptorService, IoStream, ServerFlags}; + +#[derive(Clone)] +/// Support `SSL` connections via rustls package +/// +/// `rust-tls` feature enables `RustlsAcceptor` type +pub struct RustlsAcceptor { + config: Arc, +} + +impl RustlsAcceptor { + /// Create `OpensslAcceptor` with enabled `HTTP/2` and `HTTP1.1` support. + pub fn new(config: ServerConfig) -> Self { + RustlsAcceptor::with_flags(config, ServerFlags::HTTP1 | ServerFlags::HTTP2) + } + + /// Create `OpensslAcceptor` with custom server flags. + pub fn with_flags(mut config: ServerConfig, flags: ServerFlags) -> Self { + let mut protos = Vec::new(); + if flags.contains(ServerFlags::HTTP1) { + protos.push("http/1.1".to_string()); + } + if flags.contains(ServerFlags::HTTP2) { + protos.push("h2".to_string()); + } + + if !protos.is_empty() { + config.set_protocols(&protos); + } + + RustlsAcceptor { + config: Arc::new(config), + } + } +} + +impl AcceptorService for RustlsAcceptor { + type Accepted = TlsStream; + type Future = AcceptAsync; + + fn scheme(&self) -> &'static str { + "https" + } + + fn accept(&self, io: Io) -> Self::Future { + ServerConfigExt::accept_async(&self.config, io) + } +} + +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = ::shutdown(self); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().0.set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_linger(dur) + } +} + +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = ::shutdown(self); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().0.set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_linger(dur) + } +} diff --git a/src/server/worker.rs b/src/server/worker.rs index e9bf42250..3b8f426db 100644 --- a/src/server/worker.rs +++ b/src/server/worker.rs @@ -1,102 +1,195 @@ +use std::marker::PhantomData; use std::rc::Rc; use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc}; -use std::{net, time}; +use std::{io, mem, net, time}; use futures::sync::mpsc::{unbounded, SendError, UnboundedSender}; use futures::sync::oneshot; use futures::Future; -use net2::TcpStreamExt; -use slab::Slab; +use net2::{TcpBuilder, TcpStreamExt}; use tokio::executor::current_thread; -use tokio_reactor::Handle; use tokio_tcp::TcpStream; -#[cfg(any(feature = "tls", feature = "alpn", feature = "rust-tls"))] -use futures::future; - -#[cfg(feature = "tls")] -use native_tls::TlsAcceptor; -#[cfg(feature = "tls")] -use tokio_tls::TlsAcceptorExt; - -#[cfg(feature = "alpn")] -use openssl::ssl::SslAcceptor; -#[cfg(feature = "alpn")] -use tokio_openssl::SslAcceptorExt; - -#[cfg(feature = "rust-tls")] -use rustls::{ServerConfig, Session}; -#[cfg(feature = "rust-tls")] -use tokio_rustls::ServerConfigExt; - use actix::msgs::StopArbiter; use actix::{Actor, Addr, Arbiter, AsyncContext, Context, Handler, Message, Response}; use super::accept::AcceptNotify; use super::channel::HttpChannel; use super::settings::{ServerSettings, WorkerSettings}; -use super::{HttpHandler, IntoHttpHandler, KeepAlive}; +use super::{ + AcceptorService, HttpHandler, IntoAsyncIo, IntoHttpHandler, IoStream, KeepAlive, +}; #[derive(Message)] pub(crate) struct Conn { pub io: T, - pub token: usize, + pub token: Token, pub peer: Option, pub http2: bool, } -#[derive(Clone)] -pub(crate) struct SocketInfo { - pub addr: net::SocketAddr, - pub htype: StreamHandlerType, +#[derive(Clone, Copy)] +pub struct Token(usize); + +impl Token { + pub(crate) fn new(val: usize) -> Token { + Token(val) + } } -pub(crate) struct WorkersPool { - sockets: Slab, +pub(crate) struct Socket { + pub lst: net::TcpListener, + pub addr: net::SocketAddr, + pub token: Token, +} + +pub(crate) struct WorkerFactory { pub factory: Arc Vec + Send + Sync>, pub host: Option, pub keep_alive: KeepAlive, + pub backlog: i32, + sockets: Vec, + handlers: Vec>>, } -impl WorkersPool { +impl WorkerFactory { pub fn new(factory: F) -> Self where F: Fn() -> Vec + Send + Sync + 'static, { - WorkersPool { + WorkerFactory { factory: Arc::new(factory), host: None, + backlog: 2048, keep_alive: KeepAlive::Os, - sockets: Slab::new(), + sockets: Vec::new(), + handlers: Vec::new(), } } - pub fn insert(&mut self, addr: net::SocketAddr, htype: StreamHandlerType) -> usize { - let entry = self.sockets.vacant_entry(); - let token = entry.key(); - entry.insert(SocketInfo { addr, htype }); - token + pub fn addrs(&self) -> Vec { + self.sockets.iter().map(|s| s.addr).collect() + } + + pub fn addrs_with_scheme(&self) -> Vec<(net::SocketAddr, &str)> { + self.handlers + .iter() + .map(|s| (s.addr(), s.scheme())) + .collect() + } + + pub fn take_sockets(&mut self) -> Vec { + mem::replace(&mut self.sockets, Vec::new()) + } + + pub fn listen(&mut self, lst: net::TcpListener) { + let token = Token(self.handlers.len()); + let addr = lst.local_addr().unwrap(); + self.handlers + .push(Box::new(SimpleHandler::new(lst.local_addr().unwrap()))); + self.sockets.push(Socket { lst, addr, token }) + } + + pub fn listen_with(&mut self, lst: net::TcpListener, acceptor: A) + where + A: AcceptorService + Send + 'static, + { + let token = Token(self.handlers.len()); + let addr = lst.local_addr().unwrap(); + self.handlers.push(Box::new(StreamHandler::new( + lst.local_addr().unwrap(), + acceptor, + ))); + self.sockets.push(Socket { lst, addr, token }) + } + + pub fn bind(&mut self, addr: S) -> io::Result<()> + where + S: net::ToSocketAddrs, + { + let sockets = self.bind2(addr)?; + + for lst in sockets { + let token = Token(self.handlers.len()); + let addr = lst.local_addr().unwrap(); + self.handlers + .push(Box::new(SimpleHandler::new(lst.local_addr().unwrap()))); + self.sockets.push(Socket { lst, addr, token }) + } + Ok(()) + } + + pub fn bind_with(&mut self, addr: S, acceptor: &A) -> io::Result<()> + where + S: net::ToSocketAddrs, + A: AcceptorService + Send + 'static, + { + let sockets = self.bind2(addr)?; + + for lst in sockets { + let token = Token(self.handlers.len()); + let addr = lst.local_addr().unwrap(); + self.handlers.push(Box::new(StreamHandler::new( + lst.local_addr().unwrap(), + acceptor.clone(), + ))); + self.sockets.push(Socket { lst, addr, token }) + } + Ok(()) + } + + fn bind2( + &self, addr: S, + ) -> io::Result> { + let mut err = None; + let mut succ = false; + let mut sockets = Vec::new(); + for addr in addr.to_socket_addrs()? { + match create_tcp_listener(addr, self.backlog) { + Ok(lst) => { + succ = true; + sockets.push(lst); + } + Err(e) => err = Some(e), + } + } + + if !succ { + if let Some(e) = err.take() { + Err(e) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Can not bind to address.", + )) + } + } else { + Ok(sockets) + } } pub fn start( &mut self, idx: usize, notify: AcceptNotify, - ) -> (WorkerClient, Addr>) { + ) -> (WorkerClient, Addr) { let host = self.host.clone(); - let addr = self.sockets[0].addr; + let addr = self.handlers[0].addr(); let factory = Arc::clone(&self.factory); - let socks = self.sockets.clone(); let ka = self.keep_alive; let (tx, rx) = unbounded::>(); - let client = WorkerClient::new(idx, tx, self.sockets.clone()); + let client = WorkerClient::new(idx, tx); let conn = client.conn.clone(); let sslrate = client.sslrate.clone(); + let handlers: Vec<_> = self.handlers.iter().map(|v| v.clone()).collect(); let addr = Arbiter::start(move |ctx: &mut Context<_>| { let s = ServerSettings::new(Some(addr), &host, false); let apps: Vec<_> = (*factory)().into_iter().map(|h| h.into_handler()).collect(); ctx.add_message_stream(rx); - Worker::new(apps, socks, ka, s, conn, sslrate, notify) + let inner = WorkerInner::new(apps, handlers, ka, s, conn, sslrate, notify); + Worker { + inner: Box::new(inner), + } }); (client, addr) @@ -107,19 +200,15 @@ impl WorkersPool { pub(crate) struct WorkerClient { pub idx: usize, tx: UnboundedSender>, - info: Slab, pub conn: Arc, pub sslrate: Arc, } impl WorkerClient { - fn new( - idx: usize, tx: UnboundedSender>, info: Slab, - ) -> Self { + fn new(idx: usize, tx: UnboundedSender>) -> Self { WorkerClient { idx, tx, - info, conn: Arc::new(AtomicUsize::new(0)), sslrate: Arc::new(AtomicUsize::new(0)), } @@ -154,47 +243,30 @@ impl Message for StopWorker { /// /// Worker accepts Socket objects via unbounded channel and start requests /// processing. -pub(crate) struct Worker -where - H: HttpHandler + 'static, -{ - settings: Rc>, - socks: Slab, - tcp_ka: Option, +pub(crate) struct Worker { + inner: Box, } -impl Worker { - pub(crate) fn new( - h: Vec, socks: Slab, keep_alive: KeepAlive, - settings: ServerSettings, conn: Arc, sslrate: Arc, - notify: AcceptNotify, - ) -> Worker { - let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive { - Some(time::Duration::new(val as u64, 0)) - } else { - None - }; +impl Actor for Worker { + type Context = Context; - Worker { - settings: Rc::new(WorkerSettings::new( - h, keep_alive, settings, notify, conn, sslrate, - )), - socks, - tcp_ka, - } + fn started(&mut self, ctx: &mut Self::Context) { + self.update_date(ctx); } +} - fn update_time(&self, ctx: &mut Context) { - self.settings.update_date(); - ctx.run_later(time::Duration::new(1, 0), |slf, ctx| slf.update_time(ctx)); +impl Worker { + fn update_date(&self, ctx: &mut Context) { + self.inner.update_date(); + ctx.run_later(time::Duration::new(1, 0), |slf, ctx| slf.update_date(ctx)); } fn shutdown_timeout( - &self, ctx: &mut Context, tx: oneshot::Sender, dur: time::Duration, + &self, ctx: &mut Context, tx: oneshot::Sender, dur: time::Duration, ) { // sleep for 1 second and then check again ctx.run_later(time::Duration::new(1, 0), move |slf, ctx| { - let num = slf.settings.num_channels(); + let num = slf.inner.num_channels(); if num == 0 { let _ = tx.send(true); Arbiter::current().do_send(StopArbiter(0)); @@ -202,7 +274,7 @@ impl Worker { slf.shutdown_timeout(ctx, tx, d); } else { info!("Force shutdown http worker, {} connections", num); - slf.settings.head().traverse::(); + slf.inner.force_shutdown(); let _ = tx.send(false); Arbiter::current().do_send(StopArbiter(0)); } @@ -210,44 +282,20 @@ impl Worker { } } -impl Actor for Worker -where - H: HttpHandler + 'static, -{ - type Context = Context; - - fn started(&mut self, ctx: &mut Self::Context) { - self.update_time(ctx); - } -} - -impl Handler> for Worker -where - H: HttpHandler + 'static, -{ +impl Handler> for Worker { type Result = (); fn handle(&mut self, msg: Conn, _: &mut Context) { - if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() { - error!("Can not set socket keep-alive option"); - } - self.socks - .get_mut(msg.token) - .unwrap() - .htype - .handle(Rc::clone(&self.settings), msg); + self.inner.handle_connect(msg) } } /// `StopWorker` message handler -impl Handler for Worker -where - H: HttpHandler + 'static, -{ +impl Handler for Worker { type Result = Response; fn handle(&mut self, msg: StopWorker, ctx: &mut Context) -> Self::Result { - let num = self.settings.num_channels(); + let num = self.inner.num_channels(); if num == 0 { info!("Shutting down http worker, 0 connections"); Response::reply(Ok(true)) @@ -258,148 +306,242 @@ where Response::async(rx.map_err(|_| ())) } else { info!("Force shutdown http worker, {} connections", num); - self.settings.head().traverse::(); + self.inner.force_shutdown(); Response::reply(Ok(false)) } } } -#[derive(Clone)] -pub(crate) enum StreamHandlerType { - Normal, - #[cfg(feature = "tls")] - Tls(TlsAcceptor), - #[cfg(feature = "alpn")] - Alpn(SslAcceptor), - #[cfg(feature = "rust-tls")] - Rustls(Arc), +trait WorkerHandler { + fn update_date(&self); + + fn handle_connect(&mut self, Conn); + + fn force_shutdown(&self); + + fn num_channels(&self) -> usize; } -impl StreamHandlerType { - pub fn is_ssl(&self) -> bool { - match *self { - StreamHandlerType::Normal => false, - #[cfg(feature = "tls")] - StreamHandlerType::Tls(_) => true, - #[cfg(feature = "alpn")] - StreamHandlerType::Alpn(_) => true, - #[cfg(feature = "rust-tls")] - StreamHandlerType::Rustls(_) => true, - } - } +struct WorkerInner +where + H: HttpHandler + 'static, +{ + settings: Rc>, + socks: Vec>>, + tcp_ka: Option, +} - fn handle( - &mut self, h: Rc>, msg: Conn, - ) { - match *self { - StreamHandlerType::Normal => { - let _ = msg.io.set_nodelay(true); - let io = TcpStream::from_std(msg.io, &Handle::default()) - .expect("failed to associate TCP stream"); +impl WorkerInner { + pub(crate) fn new( + h: Vec, socks: Vec>>, + keep_alive: KeepAlive, settings: ServerSettings, conn: Arc, + sslrate: Arc, notify: AcceptNotify, + ) -> WorkerInner { + let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive { + Some(time::Duration::new(val as u64, 0)) + } else { + None + }; - current_thread::spawn(HttpChannel::new(h, io, msg.peer, msg.http2)); - } - #[cfg(feature = "tls")] - StreamHandlerType::Tls(ref acceptor) => { - let Conn { - io, peer, http2, .. - } = msg; - let _ = io.set_nodelay(true); - let io = TcpStream::from_std(io, &Handle::default()) - .expect("failed to associate TCP stream"); - h.ssl_conn_add(); - - current_thread::spawn(TlsAcceptorExt::accept_async(acceptor, io).then( - move |res| { - h.ssl_conn_del(); - match res { - Ok(io) => current_thread::spawn(HttpChannel::new( - h, io, peer, http2, - )), - Err(err) => { - trace!("Error during handling tls connection: {}", err) - } - }; - future::result(Ok(())) - }, - )); - } - #[cfg(feature = "alpn")] - StreamHandlerType::Alpn(ref acceptor) => { - let Conn { io, peer, .. } = msg; - let _ = io.set_nodelay(true); - let io = TcpStream::from_std(io, &Handle::default()) - .expect("failed to associate TCP stream"); - h.ssl_conn_add(); - - current_thread::spawn(SslAcceptorExt::accept_async(acceptor, io).then( - move |res| { - h.ssl_conn_del(); - match res { - Ok(io) => { - let http2 = if let Some(p) = - io.get_ref().ssl().selected_alpn_protocol() - { - p.len() == 2 && &p == b"h2" - } else { - false - }; - current_thread::spawn(HttpChannel::new( - h, io, peer, http2, - )); - } - Err(err) => { - trace!("Error during handling tls connection: {}", err) - } - }; - future::result(Ok(())) - }, - )); - } - #[cfg(feature = "rust-tls")] - StreamHandlerType::Rustls(ref acceptor) => { - let Conn { io, peer, .. } = msg; - let _ = io.set_nodelay(true); - let io = TcpStream::from_std(io, &Handle::default()) - .expect("failed to associate TCP stream"); - h.ssl_conn_add(); - - current_thread::spawn(ServerConfigExt::accept_async(acceptor, io).then( - move |res| { - h.ssl_conn_del(); - match res { - Ok(io) => { - let http2 = if let Some(p) = - io.get_ref().1.get_alpn_protocol() - { - p.len() == 2 && &p == &"h2" - } else { - false - }; - current_thread::spawn(HttpChannel::new( - h, io, peer, http2, - )); - } - Err(err) => { - trace!("Error during handling tls connection: {}", err) - } - }; - future::result(Ok(())) - }, - )); - } - } - } - - pub(crate) fn scheme(&self) -> &'static str { - match *self { - StreamHandlerType::Normal => "http", - #[cfg(feature = "tls")] - StreamHandlerType::Tls(_) => "https", - #[cfg(feature = "alpn")] - StreamHandlerType::Alpn(_) => "https", - #[cfg(feature = "rust-tls")] - StreamHandlerType::Rustls(_) => "https", + WorkerInner { + settings: Rc::new(WorkerSettings::new( + h, keep_alive, settings, notify, conn, sslrate, + )), + socks, + tcp_ka, } } } + +impl WorkerHandler for WorkerInner +where + H: HttpHandler + 'static, +{ + fn update_date(&self) { + self.settings.update_date(); + } + + fn handle_connect(&mut self, msg: Conn) { + if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() { + error!("Can not set socket keep-alive option"); + } + self.socks[msg.token.0].handle(Rc::clone(&self.settings), msg.io, msg.peer); + } + + fn num_channels(&self) -> usize { + self.settings.num_channels() + } + + fn force_shutdown(&self) { + self.settings.head().traverse::(); + } +} + +struct SimpleHandler { + addr: net::SocketAddr, + io: PhantomData, +} + +impl Clone for SimpleHandler { + fn clone(&self) -> Self { + SimpleHandler { + addr: self.addr, + io: PhantomData, + } + } +} + +impl SimpleHandler { + fn new(addr: net::SocketAddr) -> Self { + SimpleHandler { + addr, + io: PhantomData, + } + } +} + +impl IoStreamHandler for SimpleHandler +where + H: HttpHandler, + Io: IntoAsyncIo + Send + 'static, + Io::Io: IoStream, +{ + fn addr(&self) -> net::SocketAddr { + self.addr + } + + fn clone(&self) -> Box> { + Box::new(Clone::clone(self)) + } + + fn scheme(&self) -> &'static str { + "http" + } + + fn handle(&self, h: Rc>, io: Io, peer: Option) { + let mut io = match io.into_async_io() { + Ok(io) => io, + Err(err) => { + trace!("Failed to create async io: {}", err); + return; + } + }; + let _ = io.set_nodelay(true); + + current_thread::spawn(HttpChannel::new(h, io, peer, false)); + } +} + +struct StreamHandler { + acceptor: A, + addr: net::SocketAddr, + io: PhantomData, +} + +impl> StreamHandler { + fn new(addr: net::SocketAddr, acceptor: A) -> Self { + StreamHandler { + addr, + acceptor, + io: PhantomData, + } + } +} + +impl> Clone for StreamHandler { + fn clone(&self) -> Self { + StreamHandler { + addr: self.addr, + acceptor: self.acceptor.clone(), + io: PhantomData, + } + } +} + +impl IoStreamHandler for StreamHandler +where + H: HttpHandler, + Io: IntoAsyncIo + Send + 'static, + Io::Io: IoStream, + A: AcceptorService + Send + 'static, +{ + fn addr(&self) -> net::SocketAddr { + self.addr + } + + fn clone(&self) -> Box> { + Box::new(Clone::clone(self)) + } + + fn scheme(&self) -> &'static str { + self.acceptor.scheme() + } + + fn handle(&self, h: Rc>, io: Io, peer: Option) { + let mut io = match io.into_async_io() { + Ok(io) => io, + Err(err) => { + trace!("Failed to create async io: {}", err); + return; + } + }; + let _ = io.set_nodelay(true); + + h.conn_rate_add(); + current_thread::spawn(self.acceptor.accept(io).then(move |res| { + h.conn_rate_del(); + match res { + Ok(io) => current_thread::spawn(HttpChannel::new(h, io, peer, false)), + Err(err) => trace!("Can not establish connection: {}", err), + } + Ok(()) + })) + } +} + +impl IoStreamHandler for Box> +where + H: HttpHandler, + Io: IntoAsyncIo, +{ + fn addr(&self) -> net::SocketAddr { + self.as_ref().addr() + } + + fn clone(&self) -> Box> { + self.as_ref().clone() + } + + fn scheme(&self) -> &'static str { + self.as_ref().scheme() + } + + fn handle(&self, h: Rc>, io: Io, peer: Option) { + self.as_ref().handle(h, io, peer) + } +} + +pub(crate) trait IoStreamHandler: Send +where + H: HttpHandler, +{ + fn clone(&self) -> Box>; + + fn addr(&self) -> net::SocketAddr; + + fn scheme(&self) -> &'static str; + + fn handle(&self, h: Rc>, io: Io, peer: Option); +} + +fn create_tcp_listener( + addr: net::SocketAddr, backlog: i32, +) -> io::Result { + let builder = match addr { + net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, + net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, + }; + builder.reuse_address(true)?; + builder.bind(addr)?; + Ok(builder.listen(backlog)?) +}