diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 131811a9e..b7b9db2d7 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -693,18 +693,20 @@ where } } else { // read socket into a buf - if !inner.flags.contains(Flags::READ_DISCONNECT) { - if let Some(true) = - read_available(&mut inner.io, &mut inner.read_buf)? - { - inner.flags.insert(Flags::READ_DISCONNECT); - if let Some(mut payload) = inner.payload.take() { - payload.feed_eof(); - } - } - } + let should_disconnect = if !inner.flags.contains(Flags::READ_DISCONNECT) { + read_available(&mut inner.io, &mut inner.read_buf)? + } else { + None + }; inner.poll_request()?; + if let Some(true) = should_disconnect { + inner.flags.insert(Flags::READ_DISCONNECT); + if let Some(mut payload) = inner.payload.take() { + payload.feed_eof(); + } + }; + loop { if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { inner.write_buf.reserve(HW_BUFFER_SIZE); diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index d0c5e3526..a299f58d7 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -9,6 +9,7 @@ use actix_service::{new_service_cfg, service_fn, NewService}; use bytes::{Bytes, BytesMut}; use futures::future::{self, ok, Future}; use futures::stream::{once, Stream}; +use regex::Regex; use tokio_timer::sleep; use actix_http::body::Body; @@ -215,6 +216,56 @@ fn test_expect_continue_h1() { assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); } +#[test] +fn test_chunked_payload() { + let chunk_sizes = vec![ 32768, 32, 32768 ]; + let total_size: usize = chunk_sizes.iter().sum(); + + let srv = TestServer::new(|| { + HttpService::build() + .h1(|mut request: Request| { + request.take_payload() + .map_err(|e| panic!(format!("Error reading payload: {}", e))) + .fold(0usize, |acc, chunk| { + future::ok::<_, ()>(acc + chunk.len()) + }) + .map(|req_size| { + Response::Ok().body(format!("size={}", req_size)) + }) + }) + }); + + let returned_size = { + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"POST /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); + + for chunk_size in chunk_sizes.iter() { + let mut bytes = Vec::new(); + let random_bytes: Vec = (0..*chunk_size).map(|_| rand::random::()).collect(); + + bytes.extend(format!("{:X}\r\n", chunk_size).as_bytes()); + bytes.extend(&random_bytes[..]); + bytes.extend(b"\r\n"); + let _ = stream.write_all(&bytes); + } + + let _ = stream.write_all(b"0\r\n\r\n"); + stream.shutdown(net::Shutdown::Write).unwrap(); + + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + + let re = Regex::new(r"size=(\d+)").unwrap(); + let size: usize = match re.captures(&data) { + Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(), + None => panic!(format!("Failed to find size in HTTP Response: {}", data)), + }; + size + }; + + assert_eq!(returned_size, total_size); +} + #[test] fn test_slow_request() { let srv = TestServer::new(|| {