From 91c44a1cf1493a8ffd8a26f3f3c3013d55301d41 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 20 Jan 2018 16:12:38 -0800 Subject: [PATCH] Fix HEAD requests handling --- CHANGES.md | 2 ++ src/server/encoding.rs | 44 +++++++++++++++++-------- src/server/h1.rs | 74 ++++++++++++++++++++---------------------- src/server/h1writer.rs | 8 +++-- tests/test_server.rs | 60 ++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+), 53 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 695ac1509..019e6bc50 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,8 @@ ## 0.3.2 (2018-01-xx) +* Fix HEAD requests handling + * Can't have multiple Applications on a single server with different state #49 diff --git a/src/server/encoding.rs b/src/server/encoding.rs index e5b75c482..e374cf07d 100644 --- a/src/server/encoding.rs +++ b/src/server/encoding.rs @@ -3,7 +3,7 @@ use std::io::{Read, Write}; use std::fmt::Write as FmtWrite; use std::str::FromStr; -use http::Version; +use http::{Version, Method, HttpTryFrom}; use http::header::{HeaderMap, HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; @@ -378,10 +378,12 @@ impl PayloadEncoder { ContentEncoding::Identity }; - let transfer = match body { + let mut transfer = match body { Body::Empty => { - resp.headers_mut().remove(CONTENT_LENGTH); - TransferEncoding::eof(buf) + if req.method != Method::HEAD { + resp.headers_mut().remove(CONTENT_LENGTH); + } + TransferEncoding::length(0, buf) }, Body::Binary(ref mut bytes) => { if encoding.is_compression() { @@ -404,7 +406,14 @@ impl PayloadEncoder { *bytes = Binary::from(tmp.take()); encoding = ContentEncoding::Identity; } - resp.headers_mut().remove(CONTENT_LENGTH); + if req.method == Method::HEAD { + let mut b = BytesMut::new(); + let _ = write!(b, "{}", bytes.len()); + resp.headers_mut().insert( + CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); + } else { + resp.headers_mut().remove(CONTENT_LENGTH); + } TransferEncoding::eof(buf) } Body::Streaming(_) | Body::Actor(_) => { @@ -425,7 +434,12 @@ impl PayloadEncoder { } } }; - resp.replace_body(body); + // + if req.method == Method::HEAD { + transfer.kind = TransferEncodingKind::Length(0); + } else { + resp.replace_body(body); + } PayloadEncoder( match encoding { @@ -714,14 +728,18 @@ impl TransferEncoding { Ok(*eof) }, TransferEncodingKind::Length(ref mut remaining) => { - if msg.is_empty() { - return Ok(*remaining == 0) - } - let max = cmp::min(*remaining, msg.len() as u64); - self.buffer.extend(msg.take().split_to(max as usize).into()); + if *remaining > 0 { + if msg.is_empty() { + return Ok(*remaining == 0) + } + let len = cmp::min(*remaining, msg.len() as u64); + self.buffer.extend(msg.take().split_to(len as usize).into()); - *remaining -= max as u64; - Ok(*remaining == 0) + *remaining -= len as u64; + Ok(*remaining == 0) + } else { + Ok(true) + } }, } } diff --git a/src/server/h1.rs b/src/server/h1.rs index 67ec26372..0171ac568 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -100,8 +100,8 @@ impl Http1 #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))] pub fn poll(&mut self) -> Poll<(), ()> { // keep-alive timer - if self.keepalive_timer.is_some() { - match self.keepalive_timer.as_mut().unwrap().poll() { + if let Some(ref mut timer) = self.keepalive_timer { + match timer.poll() { Ok(Async::Ready(_)) => { trace!("Keep-alive timeout, close connection"); return Ok(Async::Ready(())) @@ -146,10 +146,8 @@ impl Http1 item.flags.insert(EntryFlags::FINISHED); } }, - Ok(Async::NotReady) => { - // no more IO for this iteration - io = true; - }, + // no more IO for this iteration + Ok(Async::NotReady) => io = true, Err(err) => { // it is not possible to recover from error // during pipe handling, so just drop connection @@ -227,38 +225,7 @@ impl Http1 self.tasks.push_back( Entry {pipe: pipe.unwrap_or_else(|| Pipeline::error(HTTPNotFound)), flags: EntryFlags::empty()}); - } - Err(ReaderError::Disconnect) => { - not_ready = false; - self.flags.insert(Flags::ERROR); - self.stream.disconnected(); - for entry in &mut self.tasks { - entry.pipe.disconnected() - } }, - Err(err) => { - // notify all tasks - not_ready = false; - self.stream.disconnected(); - for entry in &mut self.tasks { - entry.pipe.disconnected() - } - - // kill keepalive - self.flags.remove(Flags::KEEPALIVE); - self.keepalive_timer.take(); - - // on parse error, stop reading stream but tasks need to be completed - self.flags.insert(Flags::ERROR); - - if self.tasks.is_empty() { - if let ReaderError::Error(err) = err { - self.tasks.push_back( - Entry {pipe: Pipeline::error(err.error_response()), - flags: EntryFlags::empty()}); - } - } - } Ok(Async::NotReady) => { // start keep-alive timer, this also is slow request timeout if self.tasks.is_empty() { @@ -293,7 +260,38 @@ impl Http1 } } break - } + }, + Err(ReaderError::Disconnect) => { + not_ready = false; + self.flags.insert(Flags::ERROR); + self.stream.disconnected(); + for entry in &mut self.tasks { + entry.pipe.disconnected() + } + }, + Err(err) => { + // notify all tasks + not_ready = false; + self.stream.disconnected(); + for entry in &mut self.tasks { + entry.pipe.disconnected() + } + + // kill keepalive + self.flags.remove(Flags::KEEPALIVE); + self.keepalive_timer.take(); + + // on parse error, stop reading stream but tasks need to be completed + self.flags.insert(Flags::ERROR); + + if self.tasks.is_empty() { + if let ReaderError::Error(err) = err { + self.tasks.push_back( + Entry {pipe: Pipeline::error(err.error_response()), + flags: EntryFlags::empty()}); + } + } + }, } } diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index e1212980e..e423f8758 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -2,7 +2,7 @@ use std::io; use bytes::BufMut; use futures::{Async, Poll}; use tokio_io::AsyncWrite; -use http::Version; +use http::{Method, Version}; use http::header::{HeaderValue, CONNECTION, DATE}; use helpers; @@ -132,7 +132,11 @@ impl Writer for H1Writer { match body { Body::Empty => - buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n"), + if req.method != Method::HEAD { + buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n"); + } else { + buffer.extend_from_slice(b"\r\n"); + }, Body::Binary(ref bytes) => helpers::write_content_length(bytes.len(), &mut buffer), _ => diff --git a/tests/test_server.rs b/tests/test_server.rs index bb6a6baef..3a8321c83 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -152,6 +152,66 @@ fn test_body_br_streaming() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } +#[test] +fn test_head_empty() { + let srv = test::TestServer::new( + |app| app.handler(|_| { + httpcodes::HTTPOk.build() + .content_length(STR.len() as u64).finish()})); + + let client = reqwest::Client::new(); + let mut res = client.head(&srv.url("/")).send().unwrap(); + assert!(res.status().is_success()); + let mut bytes = BytesMut::with_capacity(2048).writer(); + let len = res.headers() + .get::().map(|ct_len| **ct_len).unwrap(); + assert_eq!(len, STR.len() as u64); + let _ = res.copy_to(&mut bytes); + let bytes = bytes.into_inner(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_head_binary() { + let srv = test::TestServer::new( + |app| app.handler(|_| { + httpcodes::HTTPOk.build() + .content_encoding(headers::ContentEncoding::Identity) + .content_length(100).body(STR)})); + + let client = reqwest::Client::new(); + let mut res = client.head(&srv.url("/")).send().unwrap(); + assert!(res.status().is_success()); + let mut bytes = BytesMut::with_capacity(2048).writer(); + let len = res.headers() + .get::().map(|ct_len| **ct_len).unwrap(); + assert_eq!(len, STR.len() as u64); + let _ = res.copy_to(&mut bytes); + let bytes = bytes.into_inner(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_head_binary2() { + let srv = test::TestServer::new( + |app| app.handler(|_| { + httpcodes::HTTPOk.build() + .content_encoding(headers::ContentEncoding::Identity) + .body(STR) + })); + + let client = reqwest::Client::new(); + let mut res = client.head(&srv.url("/")).send().unwrap(); + assert!(res.status().is_success()); + let mut bytes = BytesMut::with_capacity(2048).writer(); + let len = res.headers() + .get::().map(|ct_len| **ct_len).unwrap(); + assert_eq!(len, STR.len() as u64); + let _ = res.copy_to(&mut bytes); + let bytes = bytes.into_inner(); + assert!(bytes.is_empty()); +} + #[test] fn test_body_length() { let srv = test::TestServer::new(