diff --git a/src/client/connector.rs b/src/client/connector.rs index 88d6dfd6b..88be77f9b 100644 --- a/src/client/connector.rs +++ b/src/client/connector.rs @@ -1283,6 +1283,11 @@ impl IoStream for Connection { fn set_linger(&mut self, dur: Option) -> io::Result<()> { IoStream::set_linger(&mut *self.stream, dur) } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + IoStream::set_keepalive(&mut *self.stream, dur) + } } impl io::Read for Connection { diff --git a/src/server/channel.rs b/src/server/channel.rs index 3cea291fd..d8cad9707 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -13,7 +13,7 @@ use super::{h1, h2, HttpHandler, IoStream}; const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; enum HttpProtocol { - H1(h1::Http1), + H1(h1::Http1Dispatcher), H2(h2::Http2), Unknown(WorkerSettings, Option, T, BytesMut), } @@ -167,7 +167,7 @@ where if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() { match kind { ProtocolKind::Http1 => { - self.proto = Some(HttpProtocol::H1(h1::Http1::new( + self.proto = Some(HttpProtocol::H1(h1::Http1Dispatcher::new( settings, io, addr, @@ -311,6 +311,10 @@ where fn set_linger(&mut self, _: Option) -> io::Result<()> { Ok(()) } + #[inline] + fn set_keepalive(&mut self, _: Option) -> io::Result<()> { + Ok(()) + } } impl io::Read for WrapperStream diff --git a/src/server/h1.rs b/src/server/h1.rs index 433a916b0..6875972ee 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -24,15 +24,18 @@ const MAX_PIPELINED_MESSAGES: usize = 16; bitflags! { pub struct Flags: u8 { const STARTED = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_0010; const KEEPALIVE = 0b0000_0100; const SHUTDOWN = 0b0000_1000; const READ_DISCONNECTED = 0b0001_0000; const WRITE_DISCONNECTED = 0b0010_0000; const POLLED = 0b0100_0000; + } } -pub(crate) struct Http1 { +/// Dispatcher for HTTP/1.1 protocol +pub struct Http1Dispatcher { flags: Flags, settings: WorkerSettings, addr: Option, @@ -42,7 +45,6 @@ pub(crate) struct Http1 { buf: BytesMut, tasks: VecDeque>, error: Option, - ka_enabled: bool, ka_expire: Instant, ka_timer: Option, } @@ -79,7 +81,7 @@ impl Entry { } } -impl Http1 +impl Http1Dispatcher where T: IoStream, H: HttpHandler + 'static, @@ -88,7 +90,6 @@ where settings: WorkerSettings, stream: T, addr: Option, buf: BytesMut, is_eof: bool, keepalive_timer: Option, ) -> Self { - let ka_enabled = settings.keep_alive_enabled(); let (ka_expire, ka_timer) = if let Some(delay) = keepalive_timer { (delay.deadline(), Some(delay)) } else if let Some(delay) = settings.keep_alive_timer() { @@ -97,12 +98,16 @@ where (settings.now(), None) }; - Http1 { - flags: if is_eof { - Flags::READ_DISCONNECTED - } else { - Flags::KEEPALIVE - }, + let mut flags = if is_eof { + Flags::READ_DISCONNECTED + } else if settings.keep_alive_enabled() { + Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED + } else { + Flags::empty() + }; + + Http1Dispatcher { + flags, stream: H1Writer::new(stream, settings.clone()), decoder: H1Decoder::new(), payload: None, @@ -113,7 +118,6 @@ where settings, ka_timer, ka_expire, - ka_enabled, } } @@ -212,7 +216,7 @@ where } // no keep-alive if self.flags.contains(Flags::STARTED) - && (!self.ka_enabled + && (!self.flags.contains(Flags::KEEPALIVE_ENABLED) || !self.flags.contains(Flags::KEEPALIVE)) { self.flags.insert(Flags::SHUTDOWN); @@ -280,7 +284,7 @@ where #[inline] /// read data from stream - pub fn poll_io(&mut self) -> Result<(), HttpDispatchError> { + pub(self) fn poll_io(&mut self) -> Result<(), HttpDispatchError> { if !self.flags.contains(Flags::POLLED) { self.parse()?; self.flags.insert(Flags::POLLED); @@ -308,7 +312,7 @@ where Ok(()) } - pub fn poll_handler(&mut self) -> Poll { + pub(self) fn poll_handler(&mut self) -> Poll { let retry = self.can_read(); // process first pipelined response, only one task can do io operation in http/1 @@ -419,7 +423,7 @@ where .push_back(Entry::Error(ServerError::err(Version::HTTP_11, status))); } - pub fn parse(&mut self) -> Result<(), HttpDispatchError> { + pub(self) fn parse(&mut self) -> Result<(), HttpDispatchError> { let mut updated = false; 'outer: loop { @@ -686,7 +690,8 @@ mod tests { let readbuf = BytesMut::new(); let settings = wrk_settings(); - let mut h1 = Http1::new(settings.clone(), buf, None, readbuf, false, None); + let mut h1 = + Http1Dispatcher::new(settings.clone(), buf, None, readbuf, false, None); assert!(h1.poll_io().is_ok()); assert!(h1.poll_io().is_ok()); assert!(h1.flags.contains(Flags::READ_DISCONNECTED)); diff --git a/src/server/mod.rs b/src/server/mod.rs index b72410516..456b46183 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -143,9 +143,11 @@ pub use self::message::Request; pub use self::ssl::*; pub use self::error::{AcceptorError, HttpDispatchError}; -pub use self::service::HttpService; pub use self::settings::{ServerSettings, WorkerSettings, WorkerSettingsBuilder}; +#[doc(hidden)] +pub use self::service::{HttpService, StreamConfiguration}; + #[doc(hidden)] pub use self::helpers::write_content_length; @@ -268,6 +270,8 @@ pub trait IoStream: AsyncRead + AsyncWrite + 'static { fn set_linger(&mut self, dur: Option) -> io::Result<()>; + fn set_keepalive(&mut self, dur: Option) -> io::Result<()>; + fn read_available(&mut self, buf: &mut BytesMut) -> Poll<(bool, bool), io::Error> { let mut read_some = false; loop { @@ -324,6 +328,11 @@ impl IoStream for ::tokio_uds::UnixStream { fn set_linger(&mut self, _dur: Option) -> io::Result<()> { Ok(()) } + + #[inline] + fn set_keepalive(&mut self, _nodelay: bool) -> io::Result<()> { + Ok(()) + } } impl IoStream for TcpStream { @@ -341,4 +350,9 @@ impl IoStream for TcpStream { fn set_linger(&mut self, dur: Option) -> io::Result<()> { TcpStream::set_linger(self, dur) } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + TcpStream::set_keepalive(self, dur) + } } diff --git a/src/server/service.rs b/src/server/service.rs index 2988bc661..89a58af75 100644 --- a/src/server/service.rs +++ b/src/server/service.rs @@ -1,4 +1,5 @@ use std::marker::PhantomData; +use std::time::Duration; use actix_net::service::{NewService, Service}; use futures::future::{ok, FutureResult}; @@ -10,6 +11,7 @@ use super::handler::HttpHandler; use super::settings::WorkerSettings; use super::IoStream; +/// `NewService` implementation for HTTP1/HTTP2 transports pub struct HttpService where H: HttpHandler, @@ -56,7 +58,6 @@ where Io: IoStream, { settings: WorkerSettings, - // tcp_ka: Option, _t: PhantomData, } @@ -66,12 +67,6 @@ where Io: IoStream, { fn new(settings: WorkerSettings) -> HttpServiceHandler { - // let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive { - // Some(Duration::new(val as u64, 0)) - // } else { - // None - // }; - HttpServiceHandler { settings, _t: PhantomData, @@ -94,7 +89,89 @@ where } fn call(&mut self, mut req: Self::Request) -> Self::Future { - let _ = req.set_nodelay(true); HttpChannel::new(self.settings.clone(), req, None) } } + +/// `NewService` implementation for stream configuration service +pub struct StreamConfiguration { + no_delay: Option, + tcp_ka: Option>, + _t: PhantomData<(T, E)>, +} + +impl StreamConfiguration { + /// Create new `StreamConfigurationService` instance. + pub fn new() -> Self { + Self { + no_delay: None, + tcp_ka: None, + _t: PhantomData, + } + } + + /// Sets the value of the `TCP_NODELAY` option on this socket. + pub fn nodelay(mut self, nodelay: bool) -> Self { + self.no_delay = Some(nodelay); + self + } + + /// Sets whether keepalive messages are enabled to be sent on this socket. + pub fn tcp_keepalive(mut self, keepalive: Option) -> Self { + self.tcp_ka = Some(keepalive); + self + } +} + +impl NewService for StreamConfiguration { + type Request = T; + type Response = T; + type Error = E; + type InitError = (); + type Service = StreamConfigurationService; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(StreamConfigurationService { + no_delay: self.no_delay.clone(), + tcp_ka: self.tcp_ka.clone(), + _t: PhantomData, + }) + } +} + +/// Stream configuration service +pub struct StreamConfigurationService { + no_delay: Option, + tcp_ka: Option>, + _t: PhantomData<(T, E)>, +} + +impl Service for StreamConfigurationService +where + T: IoStream, +{ + type Request = T; + type Response = T; + type Error = E; + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, mut req: Self::Request) -> Self::Future { + if let Some(no_delay) = self.no_delay { + if req.set_nodelay(no_delay).is_err() { + error!("Can not set socket no-delay option"); + } + } + if let Some(keepalive) = self.tcp_ka { + if req.set_keepalive(keepalive).is_err() { + error!("Can not set socket keep-alive option"); + } + } + + ok(req) + } +} diff --git a/src/server/ssl/nativetls.rs b/src/server/ssl/nativetls.rs index d59948c79..e56b4521b 100644 --- a/src/server/ssl/nativetls.rs +++ b/src/server/ssl/nativetls.rs @@ -21,4 +21,9 @@ impl IoStream for TlsStream { fn set_linger(&mut self, dur: Option) -> io::Result<()> { self.get_mut().get_mut().set_linger(dur) } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_keepalive(dur) + } } diff --git a/src/server/ssl/openssl.rs b/src/server/ssl/openssl.rs index 590dc0bbb..99ca40e03 100644 --- a/src/server/ssl/openssl.rs +++ b/src/server/ssl/openssl.rs @@ -74,4 +74,9 @@ impl IoStream for SslStream { fn set_linger(&mut self, dur: Option) -> io::Result<()> { self.get_mut().get_mut().set_linger(dur) } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_keepalive(dur) + } } diff --git a/src/server/ssl/rustls.rs b/src/server/ssl/rustls.rs index c74b62ea4..df78d1dc6 100644 --- a/src/server/ssl/rustls.rs +++ b/src/server/ssl/rustls.rs @@ -51,6 +51,11 @@ impl IoStream for TlsStream { fn set_linger(&mut self, dur: Option) -> io::Result<()> { self.get_mut().0.set_linger(dur) } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_keepalive(dur) + } } impl IoStream for TlsStream { @@ -69,4 +74,9 @@ impl IoStream for TlsStream { fn set_linger(&mut self, dur: Option) -> io::Result<()> { self.get_mut().0.set_linger(dur) } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_keepalive(dur) + } } diff --git a/tests/test_server.rs b/tests/test_server.rs index f8fabef6d..a74cb809a 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1016,7 +1016,10 @@ fn test_server_cookies() { #[test] fn test_custom_pipeline() { use actix::System; - use actix_web::server::{HttpService, KeepAlive, WorkerSettings}; + use actix_net::service::NewServiceExt; + use actix_web::server::{ + HttpService, KeepAlive, StreamConfiguration, WorkerSettings, + }; let addr = test::TestServer::unused_addr(); @@ -1034,7 +1037,9 @@ fn test_custom_pipeline() { .server_address(addr) .finish(); - HttpService::new(settings) + StreamConfiguration::new() + .nodelay(true) + .and_then(HttpService::new(settings)) }).unwrap() .run(); }); diff --git a/tests/test_ws.rs b/tests/test_ws.rs index f67314e8a..3baa48eb7 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -7,7 +7,7 @@ extern crate rand; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use std::{thread, time}; +use std::thread; use bytes::Bytes; use futures::Stream;