diff --git a/src/client/h2proto.rs b/src/client/h2proto.rs index 617c21b60..c05aeddbe 100644 --- a/src/client/h2proto.rs +++ b/src/client/h2proto.rs @@ -6,10 +6,11 @@ use futures::future::{err, Either}; use futures::{Async, Future, Poll}; use h2::{client::SendRequest, SendStream}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; -use http::{request::Request, HttpTryFrom, Version}; +use http::{request::Request, HttpTryFrom, Method, Version}; use crate::body::{BodyLength, MessageBody}; use crate::message::{Message, RequestHead, ResponseHead}; +use crate::payload::Payload; use super::connection::{ConnectionType, IoConnection}; use super::error::SendRequestError; @@ -28,6 +29,7 @@ where B: MessageBody, { trace!("Sending client request: {:?} {:?}", head, body.length()); + let head_req = head.method == Method::HEAD; let length = body.length(); let eof = match length { BodyLength::None | BodyLength::Empty | BodyLength::Sized(0) => true, @@ -99,18 +101,16 @@ where } } }) - .and_then(|resp| { + .and_then(move |resp| { let (parts, body) = resp.into_parts(); + let payload = if head_req { Payload::None } else { body.into() }; let mut head: Message = Message::new(); head.version = parts.version; head.status = parts.status; head.headers = parts.headers; - Ok(ClientResponse { - head, - payload: body.into(), - }) + Ok(ClientResponse { head, payload }) }) .from_err() } diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index a7eb96c3f..8543aa213 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -228,18 +228,21 @@ where } None => None, }, - State::ServiceCall(mut fut) => { - match fut.poll().map_err(|_| DispatchError::Service)? { - Async::Ready(res) => { - let (res, body) = res.into().replace_body(()); - Some(self.send_response(res, body)?) - } - Async::NotReady => { - self.state = State::ServiceCall(fut); - None - } + State::ServiceCall(mut fut) => match fut.poll() { + Ok(Async::Ready(res)) => { + let (res, body) = res.into().replace_body(()); + Some(self.send_response(res, body)?) } - } + Ok(Async::NotReady) => { + self.state = State::ServiceCall(fut); + None + } + Err(_e) => { + let res: Response = Response::InternalServerError().finish(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + }, State::SendPayload(mut stream) => { loop { if !self.framed.is_write_buf_full() { @@ -289,12 +292,17 @@ where fn handle_request(&mut self, req: Request) -> Result, DispatchError> { let mut task = self.service.call(req); - match task.poll().map_err(|_| DispatchError::Service)? { - Async::Ready(res) => { + match task.poll() { + Ok(Async::Ready(res)) => { let (res, body) = res.into().replace_body(()); self.send_response(res, body) } - Async::NotReady => Ok(State::ServiceCall(task)), + Ok(Async::NotReady) => Ok(State::ServiceCall(task)), + Err(_e) => { + let res: Response = Response::InternalServerError().finish(); + let (res, body) = res.replace_body(()); + self.send_response(res, body.into_body()) + } } } diff --git a/src/h2/dispatcher.rs b/src/h2/dispatcher.rs index a3c731ebb..28465b509 100644 --- a/src/h2/dispatcher.rs +++ b/src/h2/dispatcher.rs @@ -107,9 +107,7 @@ where fn poll(&mut self) -> Poll { loop { match self.connection.poll()? { - Async::Ready(None) => { - self.flags.insert(Flags::DISCONNECTED); - } + Async::Ready(None) => return Ok(Async::Ready(())), Async::Ready(Some((req, res))) => { // update keep-alive expire if self.ka_timer.is_some() { @@ -255,7 +253,7 @@ where } } Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { + Err(_e) => { let res: Response = Response::InternalServerError().finish(); let (res, body) = res.replace_body(()); @@ -304,7 +302,9 @@ where } } else { match body.poll_next() { - Ok(Async::NotReady) => return Ok(Async::NotReady), + Ok(Async::NotReady) => { + return Ok(Async::NotReady); + } Ok(Async::Ready(None)) => { if let Err(e) = stream.send_data(Bytes::new(), true) { warn!("{:?}", e); diff --git a/src/h2/mod.rs b/src/h2/mod.rs index c5972123f..919317e05 100644 --- a/src/h2/mod.rs +++ b/src/h2/mod.rs @@ -40,7 +40,10 @@ impl Stream for Payload { } Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => Err(err.into()), + Err(err) => { + println!("======== {:?}", err); + Err(err.into()) + } } } } diff --git a/test-server/src/lib.rs b/test-server/src/lib.rs index 6b70fbc10..d810d8915 100644 --- a/test-server/src/lib.rs +++ b/test-server/src/lib.rs @@ -167,6 +167,11 @@ impl TestServerRuntime { ClientRequest::get(self.url("/").as_str()) } + /// Create https `GET` request + pub fn sget(&self) -> ClientRequestBuilder { + ClientRequest::get(self.surl("/").as_str()) + } + /// Create `POST` request pub fn post(&self) -> ClientRequestBuilder { ClientRequest::post(self.url("/").as_str()) diff --git a/tests/test_server.rs b/tests/test_server.rs index 3771d35c6..49a943dbb 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -3,15 +3,16 @@ use std::time::Duration; use std::{net, thread}; use actix_http_test::TestServer; -use actix_service::NewService; +use actix_server_config::ServerConfig; +use actix_service::{fn_cfg_factory, NewService}; use bytes::Bytes; use futures::future::{self, ok, Future}; use futures::stream::once; use actix_http::body::Body; use actix_http::{ - body, client, http, Error, HttpMessage as HttpMessage2, HttpService, KeepAlive, - Request, Response, + body, client, error, http, http::header, Error, HttpMessage as HttpMessage2, + HttpService, KeepAlive, Request, Response, }; #[test] @@ -152,7 +153,7 @@ fn test_slow_request() { let srv = TestServer::new(|| { HttpService::build() .client_timeout(100) - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -341,67 +342,66 @@ fn test_content_length() { } } -// TODO: fix -// #[test] -// fn test_h2_content_length() { -// use actix_http::http::{ -// header::{HeaderName, HeaderValue}, -// StatusCode, -// }; -// let openssl = ssl_acceptor().unwrap(); +#[test] +fn test_h2_content_length() { + use actix_http::http::{ + header::{HeaderName, HeaderValue}, + StatusCode, + }; + let openssl = ssl_acceptor().unwrap(); -// let mut srv = TestServer::new(move || { -// openssl -// .clone() -// .map_err(|e| println!("Openssl error: {}", e)) -// .and_then( -// HttpService::build() -// .h2(|req: Request| { -// let indx: usize = req.uri().path()[1..].parse().unwrap(); -// let statuses = [ -// StatusCode::NO_CONTENT, -// StatusCode::CONTINUE, -// StatusCode::SWITCHING_PROTOCOLS, -// StatusCode::PROCESSING, -// StatusCode::OK, -// StatusCode::NOT_FOUND, -// ]; -// future::ok::<_, ()>(Response::new(statuses[indx])) -// }) -// .map_err(|_| ()), -// ) -// }); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + .map_err(|_| ()), + ) + }); -// let header = HeaderName::from_static("content-length"); -// let value = HeaderValue::from_static("0"); + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); -// { -// for i in 0..4 { -// let req = client::ClientRequest::get(srv.surl(&format!("/{}", i))) -// .finish() -// .unwrap(); -// let response = srv.send_request(req).unwrap(); -// assert_eq!(response.headers().get(&header), None); + { + for i in 0..4 { + let req = client::ClientRequest::get(srv.surl(&format!("/{}", i))) + .finish() + .unwrap(); + let response = srv.send_request(req).unwrap(); + assert_eq!(response.headers().get(&header), None); -// let req = client::ClientRequest::head(srv.surl(&format!("/{}", i))) -// .finish() -// .unwrap(); -// let response = srv.send_request(req).unwrap(); -// assert_eq!(response.headers().get(&header), None); -// } + let req = client::ClientRequest::head(srv.surl(&format!("/{}", i))) + .finish() + .unwrap(); + let response = srv.send_request(req).unwrap(); + assert_eq!(response.headers().get(&header), None); + } -// for i in 4..6 { -// let req = client::ClientRequest::get(srv.surl(&format!("/{}", i))) -// .finish() -// .unwrap(); -// let response = srv.send_request(req).unwrap(); -// assert_eq!(response.headers().get(&header), Some(&value)); -// } -// } -// } + for i in 4..6 { + let req = client::ClientRequest::get(srv.surl(&format!("/{}", i))) + .finish() + .unwrap(); + let response = srv.send_request(req).unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} #[test] -fn test_headers() { +fn test_h1_headers() { let data = STR.repeat(10); let data2 = data.clone(); @@ -511,7 +511,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World"; #[test] -fn test_body() { +fn test_h1_body() { let mut srv = TestServer::new(|| { HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().body(STR))) }); @@ -526,7 +526,30 @@ fn test_body() { } #[test] -fn test_head_empty() { +fn test_h2_body2() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let req = srv.sget().finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h1_head_empty() { let mut srv = TestServer::new(|| { HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) }); @@ -549,7 +572,39 @@ fn test_head_empty() { } #[test] -fn test_head_binary() { +fn test_h2_head_empty() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let req = client::ClientRequest::head(srv.surl("/")).finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), http::Version::HTTP_2); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h1_head_binary() { let mut srv = TestServer::new(|| { HttpService::build().h1(|_| { ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) @@ -574,7 +629,42 @@ fn test_head_binary() { } #[test] -fn test_head_binary2() { +fn test_h2_head_binary() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| { + ok::<_, ()>( + Response::Ok().content_length(STR.len() as u64).body(STR), + ) + }) + .map_err(|_| ()), + ) + }); + + let req = client::ClientRequest::head(srv.surl("/")).finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h1_head_binary2() { let mut srv = TestServer::new(|| { HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) }); @@ -593,7 +683,34 @@ fn test_head_binary2() { } #[test] -fn test_body_length() { +fn test_h2_head_binary2() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let req = client::ClientRequest::head(srv.surl("/")).finish().unwrap(); + let response = srv.send_request(req).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[test] +fn test_h1_body_length() { let mut srv = TestServer::new(|| { HttpService::build().h1(|_| { let body = once(Ok(Bytes::from_static(STR.as_ref()))); @@ -614,17 +731,58 @@ fn test_body_length() { } #[test] -fn test_body_chunked_explicit() { +fn test_h2_body_length() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>(Response::Ok().body(Body::from_message( + body::SizedStream::new(STR.len(), body), + ))) + }) + .map_err(|_| ()), + ) + }); + + let req = srv.sget().finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h1_body_chunked_explicit() { let mut srv = TestServer::new(|| { HttpService::build().h1(|_| { let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>(Response::Ok().streaming(body)) + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) }) }); let req = srv.get().finish().unwrap(); let mut response = srv.send_request(req).unwrap(); assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); // read response let bytes = srv.block_on(response.body()).unwrap(); @@ -634,7 +792,41 @@ fn test_body_chunked_explicit() { } #[test] -fn test_body_chunked_implicit() { +fn test_h2_body_chunked_explicit() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| { + let body = + once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .map_err(|_| ()), + ) + }); + + let req = srv.sget().finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h1_body_chunked_implicit() { let mut srv = TestServer::new(|| { HttpService::build().h1(|_| { let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); @@ -645,17 +837,23 @@ fn test_body_chunked_implicit() { let req = srv.get().finish().unwrap(); let mut response = srv.send_request(req).unwrap(); assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); // read response let bytes = srv.block_on(response.body()).unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -use actix_server_config::ServerConfig; -use actix_service::fn_cfg_factory; - #[test] -fn test_response_http_error_handling() { +fn test_h1_response_http_error_handling() { let mut srv = TestServer::new(|| { HttpService::build().h1(fn_cfg_factory(|_: &ServerConfig| { Ok::<_, ()>(|_| { @@ -677,3 +875,76 @@ fn test_response_http_error_handling() { let bytes = srv.block_on(response.body()).unwrap(); assert!(bytes.is_empty()); } + +#[test] +fn test_h2_response_http_error_handling() { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(fn_cfg_factory(|_: &ServerConfig| { + Ok::<_, ()>(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + }) + })) + .map_err(|_| ()), + ) + }); + + let req = srv.sget().finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h1_service_error() { + let mut srv = TestServer::new(|| { + HttpService::build() + .h1(|_| Err::(error::ErrorBadRequest("error"))) + }); + + let req = srv.get().finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h2_service_error() { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| Err::(error::ErrorBadRequest("error"))) + .map_err(|_| ()), + ) + }); + + let req = srv.sget().finish().unwrap(); + let mut response = srv.send_request(req).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(response.body()).unwrap(); + assert!(bytes.is_empty()); +}