From 00b7dc7887540249560f0f195f2e59df51e10f61 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 18 Mar 2019 09:44:48 -0700 Subject: [PATCH] handle socket shutdown for h1 connections --- src/client/h2proto.rs | 6 +-- src/config.rs | 10 ++-- src/h1/dispatcher.rs | 112 ++++++++++++++++++++++------------------- src/h2/dispatcher.rs | 2 +- src/payload.rs | 2 +- src/service/service.rs | 2 +- src/ws/client/mod.rs | 24 ++++----- 7 files changed, 81 insertions(+), 77 deletions(-) diff --git a/src/client/h2proto.rs b/src/client/h2proto.rs index c05aeddbe..bf2d3e1b2 100644 --- a/src/client/h2proto.rs +++ b/src/client/h2proto.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use futures::future::{err, Either}; use futures::{Async, Future, Poll}; use h2::{client::SendRequest, SendStream}; -use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::{request::Request, HttpTryFrom, Method, Version}; use crate::body::{BodyLength, MessageBody}; @@ -45,7 +45,7 @@ where *req.version_mut() = Version::HTTP_2; let mut skip_len = true; - let mut has_date = false; + // let mut has_date = false; // Content length let _ = match length { @@ -72,7 +72,7 @@ where match *key { CONNECTION | TRANSFER_ENCODING => continue, // http2 specific CONTENT_LENGTH if skip_len => continue, - DATE => has_date = true, + // DATE => has_date = true, _ => (), } req.headers_mut().append(key, value.clone()); diff --git a/src/config.rs b/src/config.rs index 3c7df2feb..f7d7f5f94 100644 --- a/src/config.rs +++ b/src/config.rs @@ -85,7 +85,7 @@ impl ServiceConfig { ka_enabled, client_timeout, client_disconnect, - timer: DateService::with(Duration::from_millis(500)), + timer: DateService::new(), })) } @@ -204,14 +204,12 @@ impl fmt::Write for Date { struct DateService(Rc); struct DateServiceInner { - interval: Duration, current: UnsafeCell>, } impl DateServiceInner { - fn new(interval: Duration) -> Self { + fn new() -> Self { DateServiceInner { - interval, current: UnsafeCell::new(None), } } @@ -232,8 +230,8 @@ impl DateServiceInner { } impl DateService { - fn with(resolution: Duration) -> Self { - DateService(Rc::new(DateServiceInner::new(resolution))) + fn new() -> Self { + DateService(Rc::new(DateServiceInner::new())) } fn check_date(&self) { diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 82813a526..8e21e9b09 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -7,7 +7,7 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::Service; use actix_utils::cloneable::CloneableService; use bitflags::bitflags; -use futures::{try_ready, Async, Future, Poll, Sink, Stream}; +use futures::{Async, Future, Poll, Sink, Stream}; use log::{debug, error, trace}; use tokio_timer::Delay; @@ -32,6 +32,7 @@ bitflags! { const POLLED = 0b0000_1000; const SHUTDOWN = 0b0010_0000; const DISCONNECTED = 0b0100_0000; + const DROPPING = 0b1000_0000; } } @@ -56,7 +57,6 @@ where state: State, payload: Option, messages: VecDeque, - unhandled: Option, ka_expire: Instant, ka_timer: Option, @@ -131,7 +131,6 @@ where state: State::None, error: None, messages: VecDeque::new(), - unhandled: None, service, flags, config, @@ -411,8 +410,19 @@ where /// keep-alive timer fn poll_keepalive(&mut self) -> Result<(), DispatchError> { if self.ka_timer.is_none() { - return Ok(()); + // shutdown timeout + if self.flags.contains(Flags::SHUTDOWN) { + if let Some(interval) = self.config.client_disconnect_timer() { + self.ka_timer = Some(Delay::new(interval)); + } else { + self.flags.insert(Flags::DISCONNECTED); + return Ok(()); + } + } else { + return Ok(()); + } } + match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { error!("Timer error {:?}", e); DispatchError::Unknown @@ -436,6 +446,8 @@ where let _ = timer.poll(); } } else { + // no shutdown timeout, drop socket + self.flags.insert(Flags::DISCONNECTED); return Ok(()); } } else { @@ -483,61 +495,55 @@ where #[inline] fn poll(&mut self) -> Poll { - let shutdown = if let Some(ref mut inner) = self.inner { - if inner.flags.contains(Flags::SHUTDOWN) { - inner.poll_keepalive()?; - try_ready!(inner.poll_flush()); - true + let inner = self.inner.as_mut().unwrap(); + + if inner.flags.contains(Flags::SHUTDOWN) { + inner.poll_keepalive()?; + if inner.flags.contains(Flags::DISCONNECTED) { + Ok(Async::Ready(())) } else { - inner.poll_keepalive()?; - inner.poll_request()?; - loop { - inner.poll_response()?; - if let Async::Ready(false) = inner.poll_flush()? { - break; - } - } - - if inner.flags.contains(Flags::DISCONNECTED) { - return Ok(Async::Ready(())); - } - - // keep-alive and stream errors - if inner.state.is_empty() && inner.framed.is_write_buf_empty() { - if let Some(err) = inner.error.take() { - return Err(err); - } - // unhandled request (upgrade or connect) - else if inner.unhandled.is_some() { - false - } - // disconnect if keep-alive is not enabled - else if inner.flags.contains(Flags::STARTED) - && !inner.flags.intersects(Flags::KEEPALIVE) - { - true - } - // disconnect if shutdown - else if inner.flags.contains(Flags::SHUTDOWN) { - true - } else { - return Ok(Async::NotReady); - } - } else { - return Ok(Async::NotReady); + // try_ready!(inner.poll_flush()); + match inner.framed.get_mut().shutdown()? { + Async::Ready(_) => Ok(Async::Ready(())), + Async::NotReady => Ok(Async::NotReady), } } } else { - unreachable!() - }; + inner.poll_keepalive()?; + inner.poll_request()?; + loop { + inner.poll_response()?; + if let Async::Ready(false) = inner.poll_flush()? { + break; + } + } - let mut inner = self.inner.take().unwrap(); + if inner.flags.contains(Flags::DISCONNECTED) { + return Ok(Async::Ready(())); + } - // TODO: shutdown - Ok(Async::Ready(())) - //Ok(Async::Ready(HttpServiceResult::Shutdown( - // inner.framed.into_inner(), - //))) + // keep-alive and stream errors + if inner.state.is_empty() && inner.framed.is_write_buf_empty() { + if let Some(err) = inner.error.take() { + return Err(err); + } + // disconnect if keep-alive is not enabled + else if inner.flags.contains(Flags::STARTED) + && !inner.flags.intersects(Flags::KEEPALIVE) + { + inner.flags.insert(Flags::SHUTDOWN); + self.poll() + } + // disconnect if shutdown + else if inner.flags.contains(Flags::SHUTDOWN) { + self.poll() + } else { + return Ok(Async::NotReady); + } + } else { + return Ok(Async::NotReady); + } + } } } diff --git a/src/h2/dispatcher.rs b/src/h2/dispatcher.rs index cbba34c0d..ea63dc2bc 100644 --- a/src/h2/dispatcher.rs +++ b/src/h2/dispatcher.rs @@ -56,7 +56,7 @@ where config: ServiceConfig, timeout: Option, ) -> Self { - let keepalive = config.keep_alive_enabled(); + // let keepalive = config.keep_alive_enabled(); // let flags = if keepalive { // Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED // } else { diff --git a/src/payload.rs b/src/payload.rs index 8f96bab95..91e6b5c95 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -41,7 +41,7 @@ impl From for Payload { impl Payload { /// Takes current payload and replaces it with `None` value - fn take(&mut self) -> Payload { + pub fn take(&mut self) -> Payload { std::mem::replace(self, Payload::None) } } diff --git a/src/service/service.rs b/src/service/service.rs index 3ddf55739..0bc1634d8 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -169,7 +169,7 @@ where } fn call(&mut self, req: Self::Request) -> Self::Future { - let (io, params, proto) = req.into_parts(); + let (io, _, proto) = req.into_parts(); match proto { Protocol::Http2 => { let io = Io { diff --git a/src/ws/client/mod.rs b/src/ws/client/mod.rs index 0dbf081c6..a5c221967 100644 --- a/src/ws/client/mod.rs +++ b/src/ws/client/mod.rs @@ -25,19 +25,19 @@ impl Protocol { } } - fn is_http(self) -> bool { - match self { - Protocol::Https | Protocol::Http => true, - _ => false, - } - } + // fn is_http(self) -> bool { + // match self { + // Protocol::Https | Protocol::Http => true, + // _ => false, + // } + // } - fn is_secure(self) -> bool { - match self { - Protocol::Https | Protocol::Wss => true, - _ => false, - } - } + // fn is_secure(self) -> bool { + // match self { + // Protocol::Https | Protocol::Wss => true, + // _ => false, + // } + // } fn port(self) -> u16 { match self {