From f0f67072aece8f7cacb6be8fcf24d147ecfe1ee7 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 3 Sep 2018 21:35:11 -0700 Subject: [PATCH] Read client response until eof if connection header set to close #464 --- CHANGES.md | 4 ++++ src/client/parser.rs | 37 ++++++++++++++++++++++++++++++------- tests/test_client.rs | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 8 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index dd6cdcd20..b48c743c8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -15,6 +15,10 @@ * Handling scoped paths without leading slashes #460 +### Changed + +* Read client response until eof if connection header set to close #464 + ## [0.7.4] - 2018-08-23 diff --git a/src/client/parser.rs b/src/client/parser.rs index 0ee4598de..11252fa52 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -20,6 +20,7 @@ const MAX_HEADERS: usize = 96; #[derive(Default)] pub struct HttpResponseParser { decoder: Option, + eof: bool, // indicate that we read payload until stream eof } #[derive(Debug, Fail)] @@ -44,8 +45,14 @@ impl HttpResponseParser { match HttpResponseParser::parse_message(buf) .map_err(HttpResponseParserError::Error)? { - Async::Ready((msg, decoder)) => { - self.decoder = decoder; + Async::Ready((msg, info)) => { + if let Some((decoder, eof)) = info { + self.eof = eof; + self.decoder = Some(decoder); + } else { + self.eof = false; + self.decoder = None; + } return Ok(Async::Ready(msg)); } Async::NotReady => { @@ -97,7 +104,12 @@ impl HttpResponseParser { return Ok(Async::NotReady); } if stream_finished { - return Err(PayloadError::Incomplete); + // read untile eof? + if self.eof { + return Ok(Async::Ready(None)); + } else { + return Err(PayloadError::Incomplete); + } } } Err(err) => return Err(err.into()), @@ -110,7 +122,7 @@ impl HttpResponseParser { fn parse_message( buf: &mut BytesMut, - ) -> Poll<(ClientResponse, Option), ParseError> { + ) -> Poll<(ClientResponse, Option<(EncodingDecoder, bool)>), ParseError> { // Unsafe: we read only this data only after httparse parses headers into. // performance bump for pipeline benchmarks. let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; @@ -156,12 +168,12 @@ impl HttpResponseParser { } let decoder = if status == StatusCode::SWITCHING_PROTOCOLS { - Some(EncodingDecoder::eof()) + Some((EncodingDecoder::eof(), true)) } else if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { // Content-Length if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::() { - Some(EncodingDecoder::length(len)) + Some((EncodingDecoder::length(len), false)) } else { debug!("illegal Content-Length: {:?}", len); return Err(ParseError::Header); @@ -172,7 +184,18 @@ impl HttpResponseParser { } } else if chunked(&hdrs)? { // Chunked encoding - Some(EncodingDecoder::chunked()) + Some((EncodingDecoder::chunked(), false)) + } else if let Some(value) = hdrs.get(header::CONNECTION) { + let close = if let Ok(s) = value.to_str() { + s == "close" + } else { + false + }; + if close { + Some((EncodingDecoder::eof(), true)) + } else { + None + } } else { None }; diff --git a/tests/test_client.rs b/tests/test_client.rs index d4a2ce1f3..8707114fa 100644 --- a/tests/test_client.rs +++ b/tests/test_client.rs @@ -8,7 +8,8 @@ extern crate rand; #[cfg(all(unix, feature = "uds"))] extern crate tokio_uds; -use std::io::Read; +use std::io::{Read, Write}; +use std::{net, thread}; use bytes::Bytes; use flate2::read::GzDecoder; @@ -470,3 +471,34 @@ fn test_default_headers() { "\"" ))); } + +#[test] +fn client_read_until_eof() { + let addr = test::TestServer::unused_addr(); + + thread::spawn(move || { + let lst = net::TcpListener::bind(addr).unwrap(); + + for stream in lst.incoming() { + let mut stream = stream.unwrap(); + let mut b = [0; 1000]; + let _ = stream.read(&mut b).unwrap(); + let _ = stream + .write_all(b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\nwelcome!"); + } + }); + + let mut sys = actix::System::new("test"); + + // client request + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + .finish() + .unwrap(); + println!("TEST: {:?}", req); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(b"welcome!")); +}