diff --git a/src/h1/codec.rs b/src/h1/codec.rs index d0faad43..16965768 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -6,7 +6,7 @@ use tokio_codec::{Decoder, Encoder}; use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder}; use super::encoder::{ResponseEncoder, ResponseLength}; -use body::Body; +use body::{Binary, Body}; use config::ServiceConfig; use error::ParseError; use helpers; @@ -26,12 +26,13 @@ bitflags! { const AVERAGE_HEADER_SIZE: usize = 30; +#[derive(Debug)] /// Http response pub enum OutMessage { /// Http response message Response(Response), /// Payload chunk - Payload(Bytes), + Payload(Option), } /// Incoming http/1 request @@ -151,6 +152,7 @@ impl Codec { buffer.extend_from_slice(reason); // content length + let mut len_is_set = true; match self.te.length { ResponseLength::Chunked => { buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") @@ -167,6 +169,10 @@ impl Codec { buffer.extend_from_slice(b"\r\n"); } ResponseLength::None => buffer.extend_from_slice(b"\r\n"), + ResponseLength::HeaderOrZero => { + len_is_set = false; + buffer.extend_from_slice(b"\r\n") + } } // write headers @@ -179,6 +185,9 @@ impl Codec { TRANSFER_ENCODING => continue, CONTENT_LENGTH => match self.te.length { ResponseLength::None => (), + ResponseLength::HeaderOrZero => { + len_is_set = true; + } _ => continue, }, DATE => { @@ -215,11 +224,13 @@ impl Codec { unsafe { buffer.advance_mut(pos); } + if !len_is_set { + buffer.extend_from_slice(b"content-length: 0\r\n") + } // optimized date header, set_date writes \r\n if !has_date { self.config.set_date(buffer); - buffer.extend_from_slice(b"\r\n"); } else { // msg eof buffer.extend_from_slice(b"\r\n"); @@ -272,8 +283,11 @@ impl Encoder for Codec { OutMessage::Response(res) => { self.encode_response(res, dst)?; } - OutMessage::Payload(bytes) => { - dst.extend_from_slice(&bytes); + OutMessage::Payload(Some(bytes)) => { + self.te.encode(bytes.as_ref(), dst)?; + } + OutMessage::Payload(None) => { + self.te.encode_eof(dst)?; } } Ok(()) diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index c8ce7d65..c2ce1203 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -12,7 +12,7 @@ use tokio_timer::Delay; use error::{ParseError, PayloadError}; use payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; -use body::Body; +use body::{Body, BodyStream}; use config::ServiceConfig; use error::DispatchError; use request::Request; @@ -61,9 +61,8 @@ enum Message { enum State { None, ServiceCall(S::Future), - SendResponse(Option), - SendResponseWithPayload(Option<(OutMessage, Body)>), - Payload(Body), + SendResponse(Option<(OutMessage, Body)>), + SendPayload(Option, Option), } impl State { @@ -99,6 +98,7 @@ where }; let framed = Framed::new(stream, Codec::new(config.clone())); + // keep-alive timer let (ka_expire, ka_timer) = if let Some(delay) = timeout { (delay.deadline(), Some(delay)) } else if let Some(delay) = config.keep_alive_timer() { @@ -174,59 +174,32 @@ where break if let Some(msg) = self.messages.pop_front() { match msg { Message::Item(req) => Some(self.handle_request(req)?), - Message::Error(res) => Some(State::SendResponse(Some( + Message::Error(res) => Some(State::SendResponse(Some(( OutMessage::Response(res), - ))), + Body::Empty, + )))), } } else { None }; }, - State::Payload(ref mut _body) => unimplemented!(), - State::ServiceCall(ref mut fut) => match fut - .poll() - .map_err(DispatchError::Service)? - { - Async::Ready(mut res) => { - self.framed.get_codec_mut().prepare_te(&mut res); - let body = res.replace_body(Body::Empty); - if body.is_empty() { - Some(State::SendResponse(Some(OutMessage::Response(res)))) - } else { - Some(State::SendResponseWithPayload(Some(( + // call inner service + State::ServiceCall(ref mut fut) => { + match fut.poll().map_err(DispatchError::Service)? { + Async::Ready(mut res) => { + self.framed.get_codec_mut().prepare_te(&mut res); + let body = res.replace_body(Body::Empty); + Some(State::SendResponse(Some(( OutMessage::Response(res), body, )))) } - } - Async::NotReady => None, - }, - State::SendResponse(ref mut item) => { - let msg = item.take().expect("SendResponse is empty"); - match self.framed.start_send(msg) { - Ok(AsyncSink::Ready) => { - self.flags.set( - Flags::KEEPALIVE, - self.framed.get_codec().keepalive(), - ); - self.flags.remove(Flags::FLUSHED); - Some(State::None) - } - Ok(AsyncSink::NotReady(msg)) => { - *item = Some(msg); - return Ok(()); - } - Err(err) => { - if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); - } - return Err(DispatchError::Io(err)); - } + Async::NotReady => None, } } - State::SendResponseWithPayload(ref mut item) => { - let (msg, body) = - item.take().expect("SendResponseWithPayload is empty"); + // send respons + State::SendResponse(ref mut item) => { + let (msg, body) = item.take().expect("SendResponse is empty"); match self.framed.start_send(msg) { Ok(AsyncSink::Ready) => { self.flags.set( @@ -234,7 +207,16 @@ where self.framed.get_codec().keepalive(), ); self.flags.remove(Flags::FLUSHED); - Some(State::Payload(body)) + match body { + Body::Empty => Some(State::None), + Body::Binary(bin) => Some(State::SendPayload( + None, + Some(OutMessage::Payload(bin.into())), + )), + Body::Streaming(stream) => { + Some(State::SendPayload(Some(stream), None)) + } + } } Ok(AsyncSink::NotReady(msg)) => { *item = Some((msg, body)); @@ -248,6 +230,48 @@ where } } } + // Send payload + State::SendPayload(ref mut stream, ref mut bin) => { + if let Some(item) = bin.take() { + match self.framed.start_send(item) { + Ok(AsyncSink::Ready) => { + self.flags.remove(Flags::FLUSHED); + } + Ok(AsyncSink::NotReady(item)) => { + *bin = Some(item); + return Ok(()); + } + Err(err) => return Err(DispatchError::Io(err)), + } + } + if let Some(ref mut stream) = stream { + match stream.poll() { + Ok(Async::Ready(Some(item))) => match self + .framed + .start_send(OutMessage::Payload(Some(item.into()))) + { + Ok(AsyncSink::Ready) => { + self.flags.remove(Flags::FLUSHED); + continue; + } + Ok(AsyncSink::NotReady(msg)) => { + *bin = Some(msg); + return Ok(()); + } + Err(err) => return Err(DispatchError::Io(err)), + }, + Ok(Async::Ready(None)) => Some(State::SendPayload( + None, + Some(OutMessage::Payload(None)), + )), + Ok(Async::NotReady) => return Ok(()), + // Err(err) => return Err(DispatchError::Io(err)), + Err(_) => return Err(DispatchError::Unknown), + } + } else { + Some(State::None) + } + } }; match state { @@ -275,19 +299,13 @@ where Async::Ready(mut res) => { self.framed.get_codec_mut().prepare_te(&mut res); let body = res.replace_body(Body::Empty); - if body.is_empty() { - Ok(State::SendResponse(Some(OutMessage::Response(res)))) - } else { - Ok(State::SendResponseWithPayload(Some(( - OutMessage::Response(res), - body, - )))) - } + Ok(State::SendResponse(Some((OutMessage::Response(res), body)))) } Async::NotReady => Ok(State::ServiceCall(task)), } } + /// Process one incoming message fn one_message(&mut self, msg: InMessage) -> Result<(), DispatchError> { self.flags.insert(Flags::STARTED); @@ -408,10 +426,12 @@ where // timeout on first request (slow request) return 408 trace!("Slow request timeout"); self.flags.insert(Flags::STARTED | Flags::DISCONNECTED); - self.state = - State::SendResponse(Some(OutMessage::Response( + self.state = State::SendResponse(Some(( + OutMessage::Response( Response::RequestTimeout().finish(), - ))); + ), + Body::Empty, + ))); } } else if let Some(deadline) = self.config.keep_alive_expire() { timer.reset(deadline) diff --git a/src/h1/encoder.rs b/src/h1/encoder.rs index 1544b240..6e8d44ce 100644 --- a/src/h1/encoder.rs +++ b/src/h1/encoder.rs @@ -17,9 +17,13 @@ use response::Response; #[derive(Debug)] pub(crate) enum ResponseLength { Chunked, + /// Content length is 0 Zero, + /// Check if headers contains length or write 0 + HeaderOrZero, Length(usize), Length64(u64), + /// Do no set content-length None, } @@ -41,6 +45,16 @@ impl Default for ResponseEncoder { } impl ResponseEncoder { + /// Encode message + pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + self.te.encode(msg, buf) + } + + /// Encode eof + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { + self.te.encode_eof(buf) + } + pub fn update(&mut self, resp: &mut Response, head: bool, version: Version) { self.head = head; @@ -63,17 +77,13 @@ impl ResponseEncoder { let transfer = match resp.body() { Body::Empty => { - if !self.head { - self.length = match resp.status() { - StatusCode::NO_CONTENT - | StatusCode::CONTINUE - | StatusCode::SWITCHING_PROTOCOLS - | StatusCode::PROCESSING => ResponseLength::None, - _ => ResponseLength::Zero, - }; - } else { - self.length = ResponseLength::Zero; - } + self.length = match resp.status() { + StatusCode::NO_CONTENT + | StatusCode::CONTINUE + | StatusCode::SWITCHING_PROTOCOLS + | StatusCode::PROCESSING => ResponseLength::None, + _ => ResponseLength::HeaderOrZero, + }; TransferEncoding::empty() } Body::Binary(_) => { @@ -253,16 +263,22 @@ impl TransferEncoding { /// Encode eof. Return `EOF` state of encoder #[inline] - pub fn encode_eof(&mut self, buf: &mut BytesMut) -> bool { + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { match self.kind { - TransferEncodingKind::Eof => true, - TransferEncodingKind::Length(rem) => rem == 0, + TransferEncodingKind::Eof => Ok(()), + TransferEncodingKind::Length(rem) => { + if rem != 0 { + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "")) + } else { + Ok(()) + } + } TransferEncodingKind::Chunked(ref mut eof) => { if !*eof { *eof = true; buf.extend_from_slice(b"0\r\n\r\n"); } - true + Ok(()) } } } diff --git a/tests/test_server.rs b/tests/test_server.rs index 43e3966d..e8617609 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -2,6 +2,7 @@ extern crate actix; extern crate actix_http; extern crate actix_net; extern crate actix_web; +extern crate bytes; extern crate futures; use std::{io::Read, io::Write, net, thread, time}; @@ -9,9 +10,11 @@ use std::{io::Read, io::Write, net, thread, time}; use actix::System; use actix_net::server::Server; use actix_web::{client, test, HttpMessage}; -use futures::future; +use bytes::Bytes; +use futures::future::{self, ok}; +use futures::stream::once; -use actix_http::{h1, KeepAlive, Request, Response}; +use actix_http::{h1, Body, KeepAlive, Request, Response}; #[test] fn test_h1_v2() { @@ -33,7 +36,7 @@ fn test_h1_v2() { let mut sys = System::new("test"); { - let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + let req = client::ClientRequest::get(format!("http://{}/", addr)) .finish() .unwrap(); let response = sys.block_on(req.send()).unwrap(); @@ -68,9 +71,7 @@ fn test_malformed_request() { thread::spawn(move || { Server::new() .bind("test", addr, move || { - h1::H1Service::build() - .client_timeout(100) - .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().finish())) }).unwrap() .run(); }); @@ -117,21 +118,306 @@ fn test_content_length() { let mut sys = System::new("test"); { for i in 0..4 { - let req = - client::ClientRequest::get(format!("http://{}/{}", addr, i).as_str()) - .finish() - .unwrap(); + let req = client::ClientRequest::get(format!("http://{}/{}", addr, i)) + .finish() + .unwrap(); let response = sys.block_on(req.send()).unwrap(); assert_eq!(response.headers().get(&header), None); } for i in 4..6 { - let req = - client::ClientRequest::get(format!("http://{}/{}", addr, i).as_str()) - .finish() - .unwrap(); + let req = client::ClientRequest::get(format!("http://{}/{}", addr, i)) + .finish() + .unwrap(); let response = sys.block_on(req.send()).unwrap(); assert_eq!(response.headers().get(&header), Some(&value)); } } } + +#[test] +fn test_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + let data = data.clone(); + h1::H1Service::new(move |_| { + let mut builder = Response::Ok(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + future::ok::<_, ()>(builder.body(data.clone())) + }) + }) + .unwrap() + .run() + }); + thread::sleep(time::Duration::from_millis(400)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::get(format!("http://{}/", addr)) + .finish() + .unwrap(); + + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[test] +fn test_body() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + }).unwrap() + .run(); + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::get(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_head_empty() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::new(|_| { + ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish()) + }) + }).unwrap() + .run() + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::head(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + { + println!("RESP: {:?}", response); + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_head_binary() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::new(|_| { + ok::<_, ()>( + Response::Ok().content_length(STR.len() as u64).body(STR), + ) + }) + }).unwrap() + .run() + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::head(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_head_binary2() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))) + }).unwrap() + .run() + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::head(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[test] +fn test_body_length() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::new(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .content_length(STR.len() as u64) + .body(Body::Streaming(Box::new(body))), + ) + }) + }).unwrap() + .run() + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::get(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_body_chunked_explicit() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::new(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .chunked() + .body(Body::Streaming(Box::new(body))), + ) + }) + }).unwrap() + .run() + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::get(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_body_chunked_implicit() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + h1::H1Service::new(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>(Response::Ok().body(Body::Streaming(Box::new(body)))) + }) + }).unwrap() + .run() + }); + thread::sleep(time::Duration::from_millis(100)); + + let mut sys = System::new("test"); + let req = client::ClientRequest::get(format!("http://{}/", addr)) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +}