From e2c8f17c2cb03494e13c0ca1586962e4ad189f2c Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 27 Feb 2018 16:08:57 -0800 Subject: [PATCH] drop connection if handler get dropped without consuming payload --- src/client/response.rs | 12 ------------ src/payload.rs | 21 +++++++++++++++++---- src/server/encoding.rs | 6 +++--- src/server/h1.rs | 17 ++++++++++------- src/server/h2.rs | 9 +++++---- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/src/client/response.rs b/src/client/response.rs index 0f997dcda..392c91332 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -71,24 +71,12 @@ impl ClientResponse { self.as_ref().version } - /// Get a mutable reference to the headers. - #[inline] - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.as_mut().headers - } - /// Get the status from the server. #[inline] pub fn status(&self) -> StatusCode { self.as_ref().status } - /// Set the `StatusCode` for this response. - #[inline] - pub fn set_status(&mut self, status: StatusCode) { - self.as_mut().status = status - } - /// Load request cookies. pub fn cookies(&self) -> Result<&Vec>, CookieParseError> { if self.as_ref().cookies.is_none() { diff --git a/src/payload.rs b/src/payload.rs index 4fb80b0bc..3c0f41532 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -8,6 +8,13 @@ use futures::{Async, Poll, Stream}; use error::PayloadError; +#[derive(Debug, PartialEq)] +pub(crate) enum PayloadStatus { + Read, + Pause, + Dropped, +} + /// Buffered stream of bytes chunks /// /// Payload stores chunks in a vector. First chunk can be received with `.readany()` method. @@ -100,7 +107,7 @@ pub(crate) trait PayloadWriter { fn feed_data(&mut self, data: Bytes); /// Need read data - fn need_read(&self) -> bool; + fn need_read(&self) -> PayloadStatus; } /// Sender part of the payload stream @@ -129,11 +136,17 @@ impl PayloadWriter for PayloadSender { } #[inline] - fn need_read(&self) -> bool { + fn need_read(&self) -> PayloadStatus { + // we check need_read only if Payload (other side) is alive, + // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { - shared.borrow().need_read + if shared.borrow().need_read { + PayloadStatus::Read + } else { + PayloadStatus::Pause + } } else { - false + PayloadStatus::Dropped } } } diff --git a/src/server/encoding.rs b/src/server/encoding.rs index 694d63a1d..5e0ac7b0f 100644 --- a/src/server/encoding.rs +++ b/src/server/encoding.rs @@ -18,7 +18,7 @@ use body::{Body, Binary}; use error::PayloadError; use httprequest::HttpInnerMessage; use httpresponse::HttpResponse; -use payload::{PayloadSender, PayloadWriter}; +use payload::{PayloadSender, PayloadWriter, PayloadStatus}; use super::shared::SharedBytes; @@ -120,7 +120,7 @@ impl PayloadWriter for PayloadType { } #[inline] - fn need_read(&self) -> bool { + fn need_read(&self) -> PayloadStatus { match *self { PayloadType::Sender(ref sender) => sender.need_read(), PayloadType::Encoding(ref enc) => enc.need_read(), @@ -352,7 +352,7 @@ impl PayloadWriter for EncodedPayload { } #[inline] - fn need_read(&self) -> bool { + fn need_read(&self) -> PayloadStatus { self.inner.need_read() } } diff --git a/src/server/h1.rs b/src/server/h1.rs index 8fb3a9e97..11c0bb2d4 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -18,7 +18,7 @@ use pipeline::Pipeline; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; use error::{ParseError, PayloadError, ResponseError}; -use payload::{Payload, PayloadWriter}; +use payload::{Payload, PayloadWriter, PayloadStatus}; use super::{utils, Writer}; use super::h1writer::H1Writer; @@ -190,7 +190,7 @@ impl Http1 true }; - let retry = self.reader.need_read(); + let retry = self.reader.need_read() == PayloadStatus::Read; loop { // check in-flight messages @@ -227,7 +227,7 @@ impl Http1 }, // no more IO for this iteration Ok(Async::NotReady) => { - if self.reader.need_read() && !retry { + if self.reader.need_read() == PayloadStatus::Read && !retry { return Ok(Async::Ready(true)); } io = true; @@ -341,6 +341,7 @@ struct PayloadInfo { enum ReaderError { Disconnect, Payload, + PayloadDropped, Error(ParseError), } @@ -352,11 +353,11 @@ impl Reader { } #[inline] - fn need_read(&self) -> bool { + fn need_read(&self) -> PayloadStatus { if let Some(ref info) = self.payload { info.tx.need_read() } else { - true + PayloadStatus::Read } } @@ -392,8 +393,10 @@ impl Reader { settings: &WorkerSettings) -> Poll where T: IoStream { - if !self.need_read() { - return Ok(Async::NotReady) + match self.need_read() { + PayloadStatus::Read => (), + PayloadStatus::Pause => return Ok(Async::NotReady), + PayloadStatus::Dropped => return Err(ReaderError::PayloadDropped), } // read payload diff --git a/src/server/h2.rs b/src/server/h2.rs index 97974c88e..ed75c97f3 100644 --- a/src/server/h2.rs +++ b/src/server/h2.rs @@ -21,7 +21,7 @@ use error::PayloadError; use httpcodes::HTTPNotFound; use httpmessage::HttpMessage; use httprequest::HttpRequest; -use payload::{Payload, PayloadWriter}; +use payload::{Payload, PayloadWriter, PayloadStatus}; use super::h2writer::H2Writer; use super::encoding::PayloadType; @@ -105,7 +105,7 @@ impl Http2 item.poll_payload(); if !item.flags.contains(EntryFlags::EOF) { - let retry = item.payload.need_read(); + let retry = item.payload.need_read() == PayloadStatus::Read; loop { match item.task.poll_io(&mut item.stream) { Ok(Async::Ready(ready)) => { @@ -116,7 +116,8 @@ impl Http2 not_ready = false; }, Ok(Async::NotReady) => { - if item.payload.need_read() && !retry { + if item.payload.need_read() == PayloadStatus::Read && !retry + { continue } }, @@ -307,7 +308,7 @@ impl Entry { fn poll_payload(&mut self) { if !self.flags.contains(EntryFlags::REOF) { - if self.payload.need_read() { + if self.payload.need_read() == PayloadStatus::Read { if let Err(err) = self.recv.release_capacity().release_capacity(32_768) { self.payload.set_error(PayloadError::Http2(err)) }