1
0
mirror of https://github.com/fafhrd91/actix-net synced 2024-11-27 20:12:58 +01:00

add io parameters

This commit is contained in:
Nikolay Kim 2019-03-11 12:01:55 -07:00
parent f696914038
commit 787255d030
9 changed files with 267 additions and 87 deletions

View File

@ -31,3 +31,67 @@ impl ServerConfig {
self.secure.as_ref().set(true) self.secure.as_ref().set(true)
} }
} }
#[derive(Copy, Clone, Debug)]
pub enum Protocol {
Unknown,
Http10,
Http11,
Http2,
Proto1,
Proto2,
Proto3,
Proto4,
Proto5,
Proto6,
}
pub struct Io<T, P = ()> {
io: T,
proto: Protocol,
params: P,
}
impl<T> Io<T, ()> {
pub fn new(io: T) -> Self {
Self {
io,
proto: Protocol::Unknown,
params: (),
}
}
}
impl<T, P> Io<T, P> {
pub fn from_parts(io: T, params: P, proto: Protocol) -> Self {
Self { io, params, proto }
}
pub fn into_parts(self) -> (T, P, Protocol) {
(self.io, self.params, self.proto)
}
pub fn io(&self) -> &T {
&self.io
}
pub fn io_mut(&mut self) -> &mut T {
&mut self.io
}
pub fn protocol(&self) -> Protocol {
self.proto
}
/// Maps an Io<_, P> to Io<_, U> by applying a function to a contained value.
pub fn map<U, F>(self, op: F) -> Io<T, U>
where
F: FnOnce(P) -> U,
{
Io {
io: self.io,
proto: self.proto,
params: op(self.params),
}
}
}

View File

@ -23,6 +23,7 @@ use crate::{ssl, Token};
pub struct ServerBuilder { pub struct ServerBuilder {
threads: usize, threads: usize,
token: Token, token: Token,
backlog: i32,
workers: Vec<(usize, WorkerClient)>, workers: Vec<(usize, WorkerClient)>,
services: Vec<Box<InternalServiceFactory>>, services: Vec<Box<InternalServiceFactory>>,
sockets: Vec<(Token, net::TcpListener)>, sockets: Vec<(Token, net::TcpListener)>,
@ -53,6 +54,7 @@ impl ServerBuilder {
services: Vec::new(), services: Vec::new(),
sockets: Vec::new(), sockets: Vec::new(),
accept: AcceptLoop::new(server.clone()), accept: AcceptLoop::new(server.clone()),
backlog: 2048,
exit: false, exit: false,
shutdown_timeout: Duration::from_secs(30), shutdown_timeout: Duration::from_secs(30),
no_signals: false, no_signals: false,
@ -70,6 +72,21 @@ impl ServerBuilder {
self self
} }
/// Set the maximum number of pending connections.
///
/// This refers to the number of clients that can be waiting to be served.
/// Exceeding this number results in the client getting an error when
/// attempting to connect. It should only affect servers under significant
/// load.
///
/// Generally set in the 64-2048 range. Default value is 2048.
///
/// This method should be called before `bind()` method call.
pub fn backlog(mut self, num: i32) -> Self {
self.backlog = num;
self
}
/// Sets the maximum per-worker number of concurrent connections. /// Sets the maximum per-worker number of concurrent connections.
/// ///
/// All socket listeners will stop accepting connections when this limit is /// All socket listeners will stop accepting connections when this limit is
@ -125,7 +142,7 @@ impl ServerBuilder {
where where
F: Fn(&mut ServiceConfig) -> io::Result<()>, F: Fn(&mut ServiceConfig) -> io::Result<()>,
{ {
let mut cfg = ServiceConfig::new(self.threads); let mut cfg = ServiceConfig::new(self.threads, self.backlog);
f(&mut cfg)?; f(&mut cfg)?;
@ -149,7 +166,7 @@ impl ServerBuilder {
F: ServiceFactory, F: ServiceFactory,
U: net::ToSocketAddrs, U: net::ToSocketAddrs,
{ {
let sockets = bind_addr(addr)?; let sockets = bind_addr(addr, self.backlog)?;
for lst in sockets { for lst in sockets {
let token = self.token.next(); let token = self.token.next();
@ -393,12 +410,15 @@ impl Future for ServerBuilder {
} }
} }
pub(super) fn bind_addr<S: net::ToSocketAddrs>(addr: S) -> io::Result<Vec<net::TcpListener>> { pub(super) fn bind_addr<S: net::ToSocketAddrs>(
addr: S,
backlog: i32,
) -> io::Result<Vec<net::TcpListener>> {
let mut err = None; let mut err = None;
let mut succ = false; let mut succ = false;
let mut sockets = Vec::new(); let mut sockets = Vec::new();
for addr in addr.to_socket_addrs()? { for addr in addr.to_socket_addrs()? {
match create_tcp_listener(addr) { match create_tcp_listener(addr, backlog) {
Ok(lst) => { Ok(lst) => {
succ = true; succ = true;
sockets.push(lst); sockets.push(lst);
@ -421,12 +441,12 @@ pub(super) fn bind_addr<S: net::ToSocketAddrs>(addr: S) -> io::Result<Vec<net::T
} }
} }
fn create_tcp_listener(addr: net::SocketAddr) -> io::Result<net::TcpListener> { fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result<net::TcpListener> {
let builder = match addr { let builder = match addr {
net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, net::SocketAddr::V4(_) => TcpBuilder::new_v4()?,
net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, net::SocketAddr::V6(_) => TcpBuilder::new_v6()?,
}; };
builder.reuse_address(true)?; builder.reuse_address(true)?;
builder.bind(addr)?; builder.bind(addr)?;
Ok(builder.listen(1024)?) Ok(builder.listen(backlog)?)
} }

View File

@ -10,7 +10,7 @@ mod signals;
pub mod ssl; pub mod ssl;
mod worker; mod worker;
pub use actix_server_config::ServerConfig; pub use actix_server_config::{Io, Protocol, ServerConfig};
pub use self::builder::ServerBuilder; pub use self::builder::ServerBuilder;
pub use self::server::Server; pub use self::server::Server;

View File

@ -1,6 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::{fmt, io, net}; use std::{fmt, io, net};
use actix_server_config::Io;
use actix_service::{IntoNewService, NewService}; use actix_service::{IntoNewService, NewService};
use futures::future::{join_all, Future}; use futures::future::{join_all, Future};
use log::error; use log::error;
@ -18,12 +19,14 @@ pub struct ServiceConfig {
pub(crate) services: Vec<(String, net::TcpListener)>, pub(crate) services: Vec<(String, net::TcpListener)>,
pub(crate) apply: Option<Box<ServiceRuntimeConfiguration>>, pub(crate) apply: Option<Box<ServiceRuntimeConfiguration>>,
pub(crate) threads: usize, pub(crate) threads: usize,
pub(crate) backlog: i32,
} }
impl ServiceConfig { impl ServiceConfig {
pub(super) fn new(threads: usize) -> ServiceConfig { pub(super) fn new(threads: usize, backlog: i32) -> ServiceConfig {
ServiceConfig { ServiceConfig {
threads, threads,
backlog,
services: Vec::new(), services: Vec::new(),
apply: None, apply: None,
} }
@ -42,7 +45,7 @@ impl ServiceConfig {
where where
U: net::ToSocketAddrs, U: net::ToSocketAddrs,
{ {
let sockets = bind_addr(addr)?; let sockets = bind_addr(addr, self.backlog)?;
for lst in sockets { for lst in sockets {
self.listen(name.as_ref(), lst); self.listen(name.as_ref(), lst);
@ -170,7 +173,7 @@ impl ServiceRuntime {
pub fn service<T, F>(&mut self, name: &str, service: F) pub fn service<T, F>(&mut self, name: &str, service: F)
where where
F: IntoNewService<T>, F: IntoNewService<T>,
T: NewService<Request = TcpStream, Response = ()> + 'static, T: NewService<Request = Io<TcpStream>, Response = ()> + 'static,
T::Future: 'static, T::Future: 'static,
T::Service: 'static, T::Service: 'static,
T::InitError: fmt::Debug, T::InitError: fmt::Debug,
@ -206,7 +209,7 @@ struct ServiceFactory<T> {
impl<T> NewService for ServiceFactory<T> impl<T> NewService for ServiceFactory<T>
where where
T: NewService<Request = TcpStream, Response = ()>, T: NewService<Request = Io<TcpStream>, Response = ()>,
T::Future: 'static, T::Future: 'static,
T::Service: 'static, T::Service: 'static,
T::Error: 'static, T::Error: 'static,

View File

@ -1,14 +1,14 @@
use std::net::{SocketAddr, TcpStream}; use std::net::{self, SocketAddr};
use std::time::Duration; use std::time::Duration;
use actix_rt::spawn; use actix_rt::spawn;
use actix_server_config::ServerConfig; use actix_server_config::{Io, ServerConfig};
use actix_service::{NewService, Service}; use actix_service::{NewService, Service};
use futures::future::{err, ok, FutureResult}; use futures::future::{err, ok, FutureResult};
use futures::{Future, Poll}; use futures::{Future, Poll};
use log::error; use log::error;
use tokio_reactor::Handle; use tokio_reactor::Handle;
use tokio_tcp::TcpStream as TokioTcpStream; use tokio_tcp::TcpStream;
use super::Token; use super::Token;
use crate::counter::CounterGuard; use crate::counter::CounterGuard;
@ -16,7 +16,7 @@ use crate::counter::CounterGuard;
/// Server message /// Server message
pub(crate) enum ServerMessage { pub(crate) enum ServerMessage {
/// New stream /// New stream
Connect(TcpStream), Connect(net::TcpStream),
/// Gracefull shutdown /// Gracefull shutdown
Shutdown(Duration), Shutdown(Duration),
/// Force shutdown /// Force shutdown
@ -24,7 +24,7 @@ pub(crate) enum ServerMessage {
} }
pub trait ServiceFactory: Send + Clone + 'static { pub trait ServiceFactory: Send + Clone + 'static {
type NewService: NewService<ServerConfig, Request = TokioTcpStream>; type NewService: NewService<ServerConfig, Request = Io<TcpStream>>;
fn create(&self) -> Self::NewService; fn create(&self) -> Self::NewService;
} }
@ -58,7 +58,7 @@ impl<T> StreamService<T> {
impl<T> Service for StreamService<T> impl<T> Service for StreamService<T>
where where
T: Service<Request = TokioTcpStream>, T: Service<Request = Io<TcpStream>>,
T::Future: 'static, T::Future: 'static,
T::Error: 'static, T::Error: 'static,
{ {
@ -74,13 +74,12 @@ where
fn call(&mut self, (guard, req): (Option<CounterGuard>, ServerMessage)) -> Self::Future { fn call(&mut self, (guard, req): (Option<CounterGuard>, ServerMessage)) -> Self::Future {
match req { match req {
ServerMessage::Connect(stream) => { ServerMessage::Connect(stream) => {
let stream = let stream = TcpStream::from_std(stream, &Handle::default()).map_err(|e| {
TokioTcpStream::from_std(stream, &Handle::default()).map_err(|e| { error!("Can not convert to an async tcp stream: {}", e);
error!("Can not convert to an async tcp stream: {}", e); });
});
if let Ok(stream) = stream { if let Ok(stream) = stream {
spawn(self.service.call(stream).then(move |res| { spawn(self.service.call(Io::new(stream)).then(move |res| {
drop(guard); drop(guard);
res.map_err(|_| ()).map(|_| ()) res.map_err(|_| ()).map(|_| ())
})); }));
@ -170,7 +169,7 @@ impl InternalServiceFactory for Box<InternalServiceFactory> {
impl<F, T> ServiceFactory for F impl<F, T> ServiceFactory for F
where where
F: Fn() -> T + Send + Clone + 'static, F: Fn() -> T + Send + Clone + 'static,
T: NewService<ServerConfig, Request = TokioTcpStream>, T: NewService<ServerConfig, Request = Io<TcpStream>>,
{ {
type NewService = T; type NewService = T;

View File

@ -1,7 +1,6 @@
use std::io; use std::io;
use std::marker::PhantomData; use std::marker::PhantomData;
use actix_server_config::ServerConfig;
use actix_service::{NewService, Service}; use actix_service::{NewService, Service};
use futures::{future::ok, future::FutureResult, Async, Future, Poll}; use futures::{future::ok, future::FutureResult, Async, Future, Poll};
use native_tls::{self, Error, HandshakeError, TlsAcceptor}; use native_tls::{self, Error, HandshakeError, TlsAcceptor};
@ -9,16 +8,17 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::counter::{Counter, CounterGuard}; use crate::counter::{Counter, CounterGuard};
use crate::ssl::MAX_CONN_COUNTER; use crate::ssl::MAX_CONN_COUNTER;
use crate::{Io, Protocol, ServerConfig};
/// Support `SSL` connections via native-tls package /// Support `SSL` connections via native-tls package
/// ///
/// `tls` feature enables `NativeTlsAcceptor` type /// `tls` feature enables `NativeTlsAcceptor` type
pub struct NativeTlsAcceptor<T> { pub struct NativeTlsAcceptor<T, P> {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
io: PhantomData<T>, io: PhantomData<(T, P)>,
} }
impl<T: AsyncRead + AsyncWrite> NativeTlsAcceptor<T> { impl<T: AsyncRead + AsyncWrite, P> NativeTlsAcceptor<T, P> {
/// Create `NativeTlsAcceptor` instance /// Create `NativeTlsAcceptor` instance
pub fn new(acceptor: TlsAcceptor) -> Self { pub fn new(acceptor: TlsAcceptor) -> Self {
NativeTlsAcceptor { NativeTlsAcceptor {
@ -28,7 +28,7 @@ impl<T: AsyncRead + AsyncWrite> NativeTlsAcceptor<T> {
} }
} }
impl<T: AsyncRead + AsyncWrite> Clone for NativeTlsAcceptor<T> { impl<T: AsyncRead + AsyncWrite, P> Clone for NativeTlsAcceptor<T, P> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
@ -37,11 +37,11 @@ impl<T: AsyncRead + AsyncWrite> Clone for NativeTlsAcceptor<T> {
} }
} }
impl<T: AsyncRead + AsyncWrite> NewService<ServerConfig> for NativeTlsAcceptor<T> { impl<T: AsyncRead + AsyncWrite, P> NewService<ServerConfig> for NativeTlsAcceptor<T, P> {
type Request = T; type Request = Io<T, P>;
type Response = TlsStream<T>; type Response = Io<TlsStream<T>, P>;
type Error = Error; type Error = Error;
type Service = NativeTlsAcceptorService<T>; type Service = NativeTlsAcceptorService<T, P>;
type InitError = (); type InitError = ();
type Future = FutureResult<Self::Service, Self::InitError>; type Future = FutureResult<Self::Service, Self::InitError>;
@ -58,17 +58,17 @@ impl<T: AsyncRead + AsyncWrite> NewService<ServerConfig> for NativeTlsAcceptor<T
} }
} }
pub struct NativeTlsAcceptorService<T> { pub struct NativeTlsAcceptorService<T, P> {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
io: PhantomData<T>, io: PhantomData<(T, P)>,
conns: Counter, conns: Counter,
} }
impl<T: AsyncRead + AsyncWrite> Service for NativeTlsAcceptorService<T> { impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
type Request = T; type Request = Io<T, P>;
type Response = TlsStream<T>; type Response = Io<TlsStream<T>, P>;
type Error = Error; type Error = Error;
type Future = Accept<T>; type Future = Accept<T, P>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
if self.conns.available() { if self.conns.available() {
@ -78,10 +78,12 @@ impl<T: AsyncRead + AsyncWrite> Service for NativeTlsAcceptorService<T> {
} }
} }
fn call(&mut self, req: T) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let (io, params, _) = req.into_parts();
Accept { Accept {
_guard: self.conns.get(), _guard: self.conns.get(),
inner: Some(self.acceptor.accept(req)), inner: Some(self.acceptor.accept(io)),
params: Some(params),
} }
} }
} }
@ -100,21 +102,30 @@ pub struct TlsStream<S> {
/// Future returned from `NativeTlsAcceptor::accept` which will resolve /// Future returned from `NativeTlsAcceptor::accept` which will resolve
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct Accept<S> { pub struct Accept<S, P> {
inner: Option<Result<native_tls::TlsStream<S>, HandshakeError<S>>>, inner: Option<Result<native_tls::TlsStream<S>, HandshakeError<S>>>,
params: Option<P>,
_guard: CounterGuard, _guard: CounterGuard,
} }
impl<Io: AsyncRead + AsyncWrite> Future for Accept<Io> { impl<T: AsyncRead + AsyncWrite, P> Future for Accept<T, P> {
type Item = TlsStream<Io>; type Item = Io<TlsStream<T>, P>;
type Error = Error; type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.inner.take().expect("cannot poll MidHandshake twice") { match self.inner.take().expect("cannot poll MidHandshake twice") {
Ok(stream) => Ok(TlsStream { inner: stream }.into()), Ok(stream) => Ok(Async::Ready(Io::from_parts(
TlsStream { inner: stream },
self.params.take().unwrap(),
Protocol::Unknown,
))),
Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::WouldBlock(s)) => match s.handshake() { Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
Ok(stream) => Ok(TlsStream { inner: stream }.into()), Ok(stream) => Ok(Async::Ready(Io::from_parts(
TlsStream { inner: stream },
self.params.take().unwrap(),
Protocol::Unknown,
))),
Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::WouldBlock(s)) => { Err(HandshakeError::WouldBlock(s)) => {
self.inner = Some(Err(HandshakeError::WouldBlock(s))); self.inner = Some(Err(HandshakeError::WouldBlock(s)));

View File

@ -8,17 +8,17 @@ use tokio_openssl::{AcceptAsync, SslAcceptorExt, SslStream};
use crate::counter::{Counter, CounterGuard}; use crate::counter::{Counter, CounterGuard};
use crate::ssl::MAX_CONN_COUNTER; use crate::ssl::MAX_CONN_COUNTER;
use crate::ServerConfig; use crate::{Io, Protocol, ServerConfig};
/// Support `SSL` connections via openssl package /// Support `SSL` connections via openssl package
/// ///
/// `ssl` feature enables `OpensslAcceptor` type /// `ssl` feature enables `OpensslAcceptor` type
pub struct OpensslAcceptor<T> { pub struct OpensslAcceptor<T, P> {
acceptor: SslAcceptor, acceptor: SslAcceptor,
io: PhantomData<T>, io: PhantomData<(T, P)>,
} }
impl<T> OpensslAcceptor<T> { impl<T, P> OpensslAcceptor<T, P> {
/// Create default `OpensslAcceptor` /// Create default `OpensslAcceptor`
pub fn new(acceptor: SslAcceptor) -> Self { pub fn new(acceptor: SslAcceptor) -> Self {
OpensslAcceptor { OpensslAcceptor {
@ -28,7 +28,7 @@ impl<T> OpensslAcceptor<T> {
} }
} }
impl<T: AsyncRead + AsyncWrite> Clone for OpensslAcceptor<T> { impl<T: AsyncRead + AsyncWrite, P> Clone for OpensslAcceptor<T, P> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
@ -37,11 +37,11 @@ impl<T: AsyncRead + AsyncWrite> Clone for OpensslAcceptor<T> {
} }
} }
impl<T: AsyncRead + AsyncWrite> NewService<ServerConfig> for OpensslAcceptor<T> { impl<T: AsyncRead + AsyncWrite, P> NewService<ServerConfig> for OpensslAcceptor<T, P> {
type Request = T; type Request = Io<T, P>;
type Response = SslStream<T>; type Response = Io<SslStream<T>, P>;
type Error = HandshakeError<T>; type Error = HandshakeError<T>;
type Service = OpensslAcceptorService<T>; type Service = OpensslAcceptorService<T, P>;
type InitError = (); type InitError = ();
type Future = FutureResult<Self::Service, Self::InitError>; type Future = FutureResult<Self::Service, Self::InitError>;
@ -58,17 +58,17 @@ impl<T: AsyncRead + AsyncWrite> NewService<ServerConfig> for OpensslAcceptor<T>
} }
} }
pub struct OpensslAcceptorService<T> { pub struct OpensslAcceptorService<T, P> {
acceptor: SslAcceptor, acceptor: SslAcceptor,
io: PhantomData<T>,
conns: Counter, conns: Counter,
io: PhantomData<(T, P)>,
} }
impl<T: AsyncRead + AsyncWrite> Service for OpensslAcceptorService<T> { impl<T: AsyncRead + AsyncWrite, P> Service for OpensslAcceptorService<T, P> {
type Request = T; type Request = Io<T, P>;
type Response = SslStream<T>; type Response = Io<SslStream<T>, P>;
type Error = HandshakeError<T>; type Error = HandshakeError<T>;
type Future = OpensslAcceptorServiceFut<T>; type Future = OpensslAcceptorServiceFut<T, P>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
if self.conns.available() { if self.conns.available() {
@ -78,27 +78,52 @@ impl<T: AsyncRead + AsyncWrite> Service for OpensslAcceptorService<T> {
} }
} }
fn call(&mut self, req: T) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let (io, params, _) = req.into_parts();
OpensslAcceptorServiceFut { OpensslAcceptorServiceFut {
_guard: self.conns.get(), _guard: self.conns.get(),
fut: SslAcceptorExt::accept_async(&self.acceptor, req), fut: SslAcceptorExt::accept_async(&self.acceptor, io),
params: Some(params),
} }
} }
} }
pub struct OpensslAcceptorServiceFut<T> pub struct OpensslAcceptorServiceFut<T, P>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
{ {
fut: AcceptAsync<T>, fut: AcceptAsync<T>,
params: Option<P>,
_guard: CounterGuard, _guard: CounterGuard,
} }
impl<T: AsyncRead + AsyncWrite> Future for OpensslAcceptorServiceFut<T> { impl<T: AsyncRead + AsyncWrite, P> Future for OpensslAcceptorServiceFut<T, P> {
type Item = SslStream<T>; type Item = Io<SslStream<T>, P>;
type Error = HandshakeError<T>; type Error = HandshakeError<T>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.fut.poll() let io = futures::try_ready!(self.fut.poll());
let proto = if let Some(protos) = io.get_ref().ssl().selected_alpn_protocol() {
const H2: &[u8] = b"\x02h2";
const HTTP10: &[u8] = b"\x08http/1.0";
const HTTP11: &[u8] = b"\x08http/1.1";
if protos.windows(3).any(|window| window == H2) {
Protocol::Http2
} else if protos.windows(9).any(|window| window == HTTP11) {
Protocol::Http11
} else if protos.windows(9).any(|window| window == HTTP10) {
Protocol::Http10
} else {
Protocol::Unknown
}
} else {
Protocol::Unknown
};
Ok(Async::Ready(Io::from_parts(
io,
self.params.take().unwrap(),
proto,
)))
} }
} }

View File

@ -10,17 +10,17 @@ use tokio_rustls::{Accept, TlsAcceptor, TlsStream};
use crate::counter::{Counter, CounterGuard}; use crate::counter::{Counter, CounterGuard};
use crate::ssl::MAX_CONN_COUNTER; use crate::ssl::MAX_CONN_COUNTER;
use crate::ServerConfig as SrvConfig; use crate::{Io, Protocol, ServerConfig as SrvConfig};
/// Support `SSL` connections via rustls package /// Support `SSL` connections via rustls package
/// ///
/// `rust-tls` feature enables `RustlsAcceptor` type /// `rust-tls` feature enables `RustlsAcceptor` type
pub struct RustlsAcceptor<T> { pub struct RustlsAcceptor<T, P> {
config: Arc<ServerConfig>, config: Arc<ServerConfig>,
io: PhantomData<T>, io: PhantomData<(T, P)>,
} }
impl<T: AsyncRead + AsyncWrite> RustlsAcceptor<T> { impl<T: AsyncRead + AsyncWrite, P> RustlsAcceptor<T, P> {
/// Create `RustlsAcceptor` new service /// Create `RustlsAcceptor` new service
pub fn new(config: ServerConfig) -> Self { pub fn new(config: ServerConfig) -> Self {
RustlsAcceptor { RustlsAcceptor {
@ -30,7 +30,7 @@ impl<T: AsyncRead + AsyncWrite> RustlsAcceptor<T> {
} }
} }
impl<T> Clone for RustlsAcceptor<T> { impl<T, P> Clone for RustlsAcceptor<T, P> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
config: self.config.clone(), config: self.config.clone(),
@ -39,11 +39,11 @@ impl<T> Clone for RustlsAcceptor<T> {
} }
} }
impl<T: AsyncRead + AsyncWrite> NewService<SrvConfig> for RustlsAcceptor<T> { impl<T: AsyncRead + AsyncWrite, P> NewService<SrvConfig> for RustlsAcceptor<T, P> {
type Request = T; type Request = Io<T, P>;
type Response = TlsStream<T, ServerSession>; type Response = Io<TlsStream<T, ServerSession>, P>;
type Error = io::Error; type Error = io::Error;
type Service = RustlsAcceptorService<T>; type Service = RustlsAcceptorService<T, P>;
type InitError = (); type InitError = ();
type Future = FutureResult<Self::Service, Self::InitError>; type Future = FutureResult<Self::Service, Self::InitError>;
@ -60,17 +60,17 @@ impl<T: AsyncRead + AsyncWrite> NewService<SrvConfig> for RustlsAcceptor<T> {
} }
} }
pub struct RustlsAcceptorService<T> { pub struct RustlsAcceptorService<T, P> {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
io: PhantomData<T>, io: PhantomData<(T, P)>,
conns: Counter, conns: Counter,
} }
impl<T: AsyncRead + AsyncWrite> Service for RustlsAcceptorService<T> { impl<T: AsyncRead + AsyncWrite, P> Service for RustlsAcceptorService<T, P> {
type Request = T; type Request = Io<T, P>;
type Response = TlsStream<T, ServerSession>; type Response = Io<TlsStream<T, ServerSession>, P>;
type Error = io::Error; type Error = io::Error;
type Future = RustlsAcceptorServiceFut<T>; type Future = RustlsAcceptorServiceFut<T, P>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
if self.conns.available() { if self.conns.available() {
@ -80,27 +80,35 @@ impl<T: AsyncRead + AsyncWrite> Service for RustlsAcceptorService<T> {
} }
} }
fn call(&mut self, req: T) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let (io, params, _) = req.into_parts();
RustlsAcceptorServiceFut { RustlsAcceptorServiceFut {
_guard: self.conns.get(), _guard: self.conns.get(),
fut: self.acceptor.accept(req), fut: self.acceptor.accept(io),
params: Some(params),
} }
} }
} }
pub struct RustlsAcceptorServiceFut<T> pub struct RustlsAcceptorServiceFut<T, P>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
{ {
fut: Accept<T>, fut: Accept<T>,
params: Option<P>,
_guard: CounterGuard, _guard: CounterGuard,
} }
impl<T: AsyncRead + AsyncWrite> Future for RustlsAcceptorServiceFut<T> { impl<T: AsyncRead + AsyncWrite, P> Future for RustlsAcceptorServiceFut<T, P> {
type Item = TlsStream<T, ServerSession>; type Item = Io<TlsStream<T, ServerSession>, P>;
type Error = io::Error; type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.fut.poll() let io = futures::try_ready!(self.fut.poll());
Ok(Async::Ready(Io::from_parts(
io,
self.params.take().unwrap(),
Protocol::Unknown,
)))
} }
} }

View File

@ -1,3 +1,4 @@
use std::sync::mpsc;
use std::{net, thread, time}; use std::{net, thread, time};
use actix_server::{Server, ServerConfig}; use actix_server::{Server, ServerConfig};
@ -68,3 +69,52 @@ fn test_listen() {
thread::sleep(time::Duration::from_millis(500)); thread::sleep(time::Duration::from_millis(500));
assert!(net::TcpStream::connect(addr).is_ok()); assert!(net::TcpStream::connect(addr).is_ok());
} }
#[test]
#[cfg(unix)]
fn test_start() {
let addr = unused_addr();
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let sys = actix_rt::System::new("test");
let srv = Server::build()
.backlog(1)
.bind("test", addr, move || {
fn_cfg_factory(move |cfg: &ServerConfig| {
assert_eq!(cfg.local_addr(), addr);
Ok::<_, ()>((|_| Ok::<_, ()>(())).into_service())
})
})
.unwrap()
.start();
let _ = tx.send((srv, actix_rt::System::current()));
let _ = sys.run();
});
let (srv, sys) = rx.recv().unwrap();
thread::sleep(time::Duration::from_millis(400));
assert!(net::TcpStream::connect(addr).is_ok());
// pause
let _ = srv.pause();
thread::sleep(time::Duration::from_millis(100));
assert!(net::TcpStream::connect_timeout(&addr, time::Duration::from_millis(100)).is_ok());
assert!(net::TcpStream::connect_timeout(&addr, time::Duration::from_millis(100)).is_err());
// resume
let _ = srv.resume();
thread::sleep(time::Duration::from_millis(100));
assert!(net::TcpStream::connect(addr).is_ok());
assert!(net::TcpStream::connect(addr).is_ok());
assert!(net::TcpStream::connect(addr).is_ok());
// stop
let _ = srv.stop(false);
thread::sleep(time::Duration::from_millis(100));
assert!(net::TcpStream::connect(addr).is_err());
let _ = sys.stop();
}