From a7fdac1043a0a13985e46a5935c9eebd2834e4f4 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 8 Apr 2019 10:31:29 -0700 Subject: [PATCH] fix expect service registration and tests --- actix-http/src/builder.rs | 25 +++++++++++++-- actix-http/tests/test_server.rs | 56 ++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index 2a8a8360f..6d93c156f 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -87,6 +87,27 @@ where self } + /// Provide service for `EXPECT: 100-Continue` support. + /// + /// Service get called with request that contains `EXPECT` header. + /// Service must return request in case of success, in that case + /// request will be forwarded to main service. + pub fn expect(self, expect: F) -> HttpServiceBuilder + where + F: IntoNewService, + U: NewService, + U::Error: Into, + U::InitError: fmt::Debug, + { + HttpServiceBuilder { + keep_alive: self.keep_alive, + client_timeout: self.client_timeout, + client_disconnect: self.client_disconnect, + expect: expect.into_new_service(), + _t: PhantomData, + } + } + // #[cfg(feature = "ssl")] // /// Configure alpn protocols for SslAcceptorBuilder. // pub fn configure_openssl( @@ -142,7 +163,7 @@ where } /// Finish service configuration and create `HttpService` instance. - pub fn finish(self, service: F) -> HttpService + pub fn finish(self, service: F) -> HttpService where B: MessageBody + 'static, F: IntoNewService, @@ -156,6 +177,6 @@ where self.client_timeout, self.client_disconnect, ); - HttpService::with_config(cfg, service.into_new_service()) + HttpService::with_config(cfg, service.into_new_service()).expect(self.expect) } } diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index da41492f2..e7b539372 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -5,7 +5,7 @@ use std::{net, thread}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_http_test::TestServer; use actix_server_config::ServerConfig; -use actix_service::{fn_cfg_factory, NewService}; +use actix_service::{fn_cfg_factory, fn_service, NewService}; use bytes::{Bytes, BytesMut}; use futures::future::{self, ok, Future}; use futures::stream::{once, Stream}; @@ -153,6 +153,60 @@ fn test_h2_body() -> std::io::Result<()> { Ok(()) } +#[test] +fn test_expect_continue() { + let srv = TestServer::new(|| { + HttpService::build() + .expect(fn_service(|req: Request| { + if req.head().uri.query() == Some("yes=") { + Ok(req) + } else { + Err(error::ErrorPreconditionFailed("error")) + } + })) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); +} + +#[test] +fn test_expect_continue_h1() { + let srv = TestServer::new(|| { + HttpService::build() + .expect(fn_service(|req: Request| { + if req.head().uri.query() == Some("yes=") { + Ok(req) + } else { + Err(error::ErrorPreconditionFailed("error")) + } + })) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); +} + #[test] fn test_slow_request() { let srv = TestServer::new(|| {