diff --git a/Cargo.toml b/Cargo.toml index a0eb17312..9d55b85c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,8 @@ fail = ["failure"] [dependencies] #actix-service = "0.3.2" -actix-codec = "0.1.0" +actix-codec = "0.1.1" + #actix-connector = "0.3.0" #actix-utils = "0.3.1" diff --git a/src/error.rs b/src/error.rs index a3037b985..696162f86 100644 --- a/src/error.rs +++ b/src/error.rs @@ -328,12 +328,11 @@ impl ResponseError for cookie::ParseError { } } -#[derive(Debug, Display)] +#[derive(Debug, Display, From)] /// A set of errors that can occur during dispatching http requests -pub enum DispatchError { +pub enum DispatchError { /// Service error - #[display(fmt = "Service specific error: {:?}", _0)] - Service(E), + Service, /// An `io::Error` that occurred while trying to read or write to a network /// stream. @@ -373,24 +372,6 @@ pub enum DispatchError { Unknown, } -impl From for DispatchError { - fn from(err: ParseError) -> Self { - DispatchError::Parse(err) - } -} - -impl From for DispatchError { - fn from(err: io::Error) -> Self { - DispatchError::Io(err) - } -} - -impl From for DispatchError { - fn from(err: h2::Error) -> Self { - DispatchError::H2(err) - } -} - /// A set of error that can occure during parsing content type #[derive(PartialEq, Debug, Display)] pub enum ContentTypeError { diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 7024ce3af..42ab33e79 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -20,7 +20,7 @@ use crate::response::Response; use super::codec::Codec; use super::payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; -use super::{H1ServiceResult, Message, MessageType}; +use super::{Message, MessageType}; const MAX_PIPELINED_MESSAGES: usize = 16; @@ -50,7 +50,7 @@ where service: CloneableService, flags: Flags, framed: Framed, - error: Option>, + error: Option, config: ServiceConfig, state: State, @@ -93,12 +93,17 @@ where { /// Create http/1 dispatcher. pub fn new(stream: T, config: ServiceConfig, service: CloneableService) -> Self { - Dispatcher::with_timeout(stream, config, None, service) + Dispatcher::with_timeout( + Framed::new(stream, Codec::new(config.clone())), + config, + None, + service, + ) } /// Create http/1 dispatcher with slow request timeout. pub fn with_timeout( - stream: T, + framed: Framed, config: ServiceConfig, timeout: Option, service: CloneableService, @@ -109,7 +114,6 @@ where } else { Flags::empty() }; - let framed = Framed::new(stream, Codec::new(config.clone())); // keep-alive timer let (ka_expire, ka_timer) = if let Some(delay) = timeout { @@ -167,7 +171,7 @@ where } /// Flush stream - fn poll_flush(&mut self) -> Poll> { + fn poll_flush(&mut self) -> Poll { if !self.framed.is_write_buf_empty() { match self.framed.poll_complete() { Ok(Async::NotReady) => Ok(Async::NotReady), @@ -192,7 +196,7 @@ where &mut self, message: Response<()>, body: ResponseBody, - ) -> Result, DispatchError> { + ) -> Result, DispatchError> { self.framed .force_send(Message::Item((message, body.length()))) .map_err(|err| { @@ -210,7 +214,7 @@ where } } - fn poll_response(&mut self) -> Result<(), DispatchError> { + fn poll_response(&mut self) -> Result<(), DispatchError> { let mut retry = self.can_read(); loop { let state = match mem::replace(&mut self.state, State::None) { @@ -225,7 +229,7 @@ where None => None, }, State::ServiceCall(mut fut) => { - match fut.poll().map_err(DispatchError::Service)? { + match fut.poll().map_err(|_| DispatchError::Service)? { Async::Ready(res) => { let (res, body) = res.into().replace_body(()); Some(self.send_response(res, body)?) @@ -283,12 +287,9 @@ where Ok(()) } - fn handle_request( - &mut self, - req: Request, - ) -> Result, DispatchError> { + fn handle_request(&mut self, req: Request) -> Result, DispatchError> { let mut task = self.service.call(req); - match task.poll().map_err(DispatchError::Service)? { + match task.poll().map_err(|_| DispatchError::Service)? { Async::Ready(res) => { let (res, body) = res.into().replace_body(()); self.send_response(res, body) @@ -298,7 +299,7 @@ where } /// Process one incoming requests - pub(self) fn poll_request(&mut self) -> Result> { + pub(self) fn poll_request(&mut self) -> Result { // limit a mount of non processed requests if self.messages.len() >= MAX_PIPELINED_MESSAGES { return Ok(false); @@ -400,7 +401,7 @@ where } /// keep-alive timer - fn poll_keepalive(&mut self) -> Result<(), DispatchError> { + fn poll_keepalive(&mut self) -> Result<(), DispatchError> { if self.ka_timer.is_none() { return Ok(()); } @@ -469,8 +470,8 @@ where S::Response: Into>, B: MessageBody, { - type Item = H1ServiceResult; - type Error = DispatchError; + type Item = (); + type Error = DispatchError; #[inline] fn poll(&mut self) -> Poll { @@ -490,7 +491,7 @@ where } if inner.flags.contains(Flags::DISCONNECTED) { - return Ok(Async::Ready(H1ServiceResult::Disconnected)); + return Ok(Async::Ready(())); } // keep-alive and stream errors @@ -523,14 +524,12 @@ where }; let mut inner = self.inner.take().unwrap(); - if shutdown { - Ok(Async::Ready(H1ServiceResult::Shutdown( - inner.framed.into_inner(), - ))) - } else { - let req = inner.unhandled.take().unwrap(); - Ok(Async::Ready(H1ServiceResult::Unhandled(req, inner.framed))) - } + + // TODO: shutdown + Ok(Async::Ready(())) + //Ok(Async::Ready(HttpServiceResult::Shutdown( + // inner.framed.into_inner(), + //))) } } diff --git a/src/h1/mod.rs b/src/h1/mod.rs index 9054c2665..e3d63c521 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -1,7 +1,4 @@ //! HTTP/1 implementation -use std::fmt; - -use actix_codec::Framed; use bytes::Bytes; mod client; @@ -18,29 +15,6 @@ pub use self::dispatcher::Dispatcher; pub use self::payload::{Payload, PayloadBuffer}; pub use self::service::{H1Service, H1ServiceHandler, OneRequest}; -use crate::request::Request; - -/// H1 service response type -pub enum H1ServiceResult { - Disconnected, - Shutdown(T), - Unhandled(Request, Framed), -} - -impl fmt::Debug for H1ServiceResult { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - H1ServiceResult::Disconnected => write!(f, "H1ServiceResult::Disconnected"), - H1ServiceResult::Shutdown(ref v) => { - write!(f, "H1ServiceResult::Shutdown({:?})", v) - } - H1ServiceResult::Unhandled(ref req, _) => { - write!(f, "H1ServiceResult::Unhandled({:?})", req) - } - } - } -} - #[derive(Debug)] /// Codec message pub enum Message { @@ -67,6 +41,7 @@ pub enum MessageType { #[cfg(test)] mod tests { use super::*; + use crate::request::Request; impl Message { pub fn message(self) -> Request { diff --git a/src/h1/service.rs b/src/h1/service.rs index 1c4f1ae3e..229229e69 100644 --- a/src/h1/service.rs +++ b/src/h1/service.rs @@ -17,7 +17,7 @@ use crate::response::Response; use super::codec::Codec; use super::dispatcher::Dispatcher; -use super::{H1ServiceResult, Message}; +use super::Message; /// `NewService` implementation for HTTP1 transport pub struct H1Service { @@ -72,8 +72,8 @@ where S::Service: 'static, B: MessageBody, { - type Response = H1ServiceResult; - type Error = DispatchError; + type Response = (); + type Error = DispatchError; type InitError = S::InitError; type Service = H1ServiceHandler; type Future = H1ServiceResponse; @@ -275,12 +275,15 @@ where S::Response: Into>, B: MessageBody, { - type Response = H1ServiceResult; - type Error = DispatchError; + type Response = (); + type Error = DispatchError; type Future = Dispatcher; fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.srv.poll_ready().map_err(DispatchError::Service) + self.srv.poll_ready().map_err(|e| { + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service + }) } fn call(&mut self, req: T) -> Self::Future { diff --git a/src/h2/dispatcher.rs b/src/h2/dispatcher.rs index 7f5409c68..be813bd57 100644 --- a/src/h2/dispatcher.rs +++ b/src/h2/dispatcher.rs @@ -26,8 +26,6 @@ use crate::payload::Payload; use crate::request::Request; use crate::response::Response; -use super::H2ServiceResult; - const CHUNK_SIZE: usize = 16_384; bitflags! { @@ -40,7 +38,7 @@ bitflags! { /// Dispatcher for HTTP/2 protocol pub struct Dispatcher< T: AsyncRead + AsyncWrite, - S: Service> + 'static, + S: Service + 'static, B: MessageBody, > { flags: Flags, @@ -55,8 +53,8 @@ pub struct Dispatcher< impl Dispatcher where T: AsyncRead + AsyncWrite, - S: Service> + 'static, - S::Error: Into + fmt::Debug, + S: Service + 'static, + S::Error: fmt::Debug, S::Response: Into>, B: MessageBody + 'static, { @@ -97,13 +95,13 @@ where impl Future for Dispatcher where T: AsyncRead + AsyncWrite, - S: Service> + 'static, - S::Error: Into + fmt::Debug, + S: Service + 'static, + S::Error: fmt::Debug, S::Response: Into>, B: MessageBody + 'static, { type Item = (); - type Error = DispatchError<()>; + type Error = DispatchError; #[inline] fn poll(&mut self) -> Poll { @@ -143,21 +141,21 @@ where } } -struct ServiceResponse>, B> { +struct ServiceResponse, B> { state: ServiceResponseState, config: ServiceConfig, buffer: Option, } -enum ServiceResponseState>, B> { +enum ServiceResponseState, B> { ServiceCall(S::Future, Option>), SendPayload(SendStream, ResponseBody), } impl ServiceResponse where - S: Service> + 'static, - S::Error: Into + fmt::Debug, + S: Service + 'static, + S::Error: fmt::Debug, S::Response: Into>, B: MessageBody + 'static, { @@ -224,8 +222,8 @@ where impl Future for ServiceResponse where - S: Service> + 'static, - S::Error: Into + fmt::Debug, + S: Service + 'static, + S::Error: fmt::Debug, S::Response: Into>, B: MessageBody + 'static, { @@ -258,7 +256,7 @@ where } Ok(Async::NotReady) => Ok(Async::NotReady), Err(e) => { - let res: Response = e.into().into(); + let res: Response = Response::InternalServerError().finish(); let (res, body) = res.replace_body(()); let mut send = send.take().unwrap(); diff --git a/src/h2/mod.rs b/src/h2/mod.rs index 55e057607..c5972123f 100644 --- a/src/h2/mod.rs +++ b/src/h2/mod.rs @@ -9,26 +9,10 @@ use h2::RecvStream; mod dispatcher; mod service; +pub use self::dispatcher::Dispatcher; pub use self::service::H2Service; use crate::error::PayloadError; -/// H1 service response type -pub enum H2ServiceResult { - Disconnected, - Shutdown(T), -} - -impl fmt::Debug for H2ServiceResult { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - H2ServiceResult::Disconnected => write!(f, "H2ServiceResult::Disconnected"), - H2ServiceResult::Shutdown(ref v) => { - write!(f, "H2ServiceResult::Shutdown({:?})", v) - } - } - } -} - /// H2 receive stream pub struct Payload { pl: RecvStream, diff --git a/src/h2/service.rs b/src/h2/service.rs index e225e9fcb..a47f507b5 100644 --- a/src/h2/service.rs +++ b/src/h2/service.rs @@ -20,7 +20,6 @@ use crate::request::Request; use crate::response::Response; use super::dispatcher::Dispatcher; -use super::H2ServiceResult; /// `NewService` implementation for HTTP2 transport pub struct H2Service { @@ -31,14 +30,14 @@ pub struct H2Service { impl H2Service where - S: NewService>, + S: NewService, S::Service: 'static, - S::Error: Into + Debug + 'static, + S::Error: Debug + 'static, S::Response: Into>, B: MessageBody + 'static, { /// Create new `HttpService` instance. - pub fn new>>(service: F) -> Self { + pub fn new>(service: F) -> Self { let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); H2Service { @@ -57,14 +56,14 @@ where impl NewService for H2Service where T: AsyncRead + AsyncWrite, - S: NewService>, + S: NewService, S::Service: 'static, - S::Error: Into + Debug, + S::Error: Debug, S::Response: Into>, B: MessageBody + 'static, { type Response = (); - type Error = DispatchError<()>; + type Error = DispatchError; type InitError = S::InitError; type Service = H2ServiceHandler; type Future = H2ServiceResponse; @@ -94,9 +93,9 @@ pub struct H2ServiceBuilder { impl H2ServiceBuilder where - S: NewService>, + S: NewService, S::Service: 'static, - S::Error: Into + Debug + 'static, + S::Error: Debug + 'static, { /// Create instance of `H2ServiceBuilder` pub fn new() -> H2ServiceBuilder { @@ -189,30 +188,11 @@ where self } - // #[cfg(feature = "ssl")] - // /// Configure alpn protocols for SslAcceptorBuilder. - // pub fn configure_openssl( - // builder: &mut openssl::ssl::SslAcceptorBuilder, - // ) -> io::Result<()> { - // let protos: &[u8] = b"\x02h2"; - // builder.set_alpn_select_callback(|_, protos| { - // const H2: &[u8] = b"\x02h2"; - // if protos.windows(3).any(|window| window == H2) { - // Ok(b"h2") - // } else { - // Err(openssl::ssl::AlpnError::NOACK) - // } - // }); - // builder.set_alpn_protos(&protos)?; - - // Ok(()) - // } - /// Finish service configuration and create `H1Service` instance. pub fn finish(self, service: F) -> H2Service where B: MessageBody, - F: IntoNewService>, + F: IntoNewService, { let cfg = ServiceConfig::new( self.keep_alive, @@ -228,7 +208,7 @@ where } #[doc(hidden)] -pub struct H2ServiceResponse>, B> { +pub struct H2ServiceResponse, B> { fut: ::Future, cfg: Option, _t: PhantomData<(T, B)>, @@ -237,10 +217,10 @@ pub struct H2ServiceResponse>, B> { impl Future for H2ServiceResponse where T: AsyncRead + AsyncWrite, - S: NewService>, + S: NewService, S::Service: 'static, S::Response: Into>, - S::Error: Into + Debug, + S::Error: Debug, B: MessageBody + 'static, { type Item = H2ServiceHandler; @@ -264,8 +244,8 @@ pub struct H2ServiceHandler { impl H2ServiceHandler where - S: Service> + 'static, - S::Error: Into + Debug, + S: Service + 'static, + S::Error: Debug, S::Response: Into>, B: MessageBody + 'static, { @@ -281,19 +261,19 @@ where impl Service for H2ServiceHandler where T: AsyncRead + AsyncWrite, - S: Service> + 'static, - S::Error: Into + Debug, + S: Service + 'static, + S::Error: Debug, S::Response: Into>, B: MessageBody + 'static, { type Response = (); - type Error = DispatchError<()>; + type Error = DispatchError; type Future = H2ServiceHandlerResponse; fn poll_ready(&mut self) -> Poll<(), Self::Error> { self.srv.poll_ready().map_err(|e| { error!("Service readiness error: {:?}", e); - DispatchError::Service(()) + DispatchError::Service }) } @@ -308,11 +288,7 @@ where } } -enum State< - T: AsyncRead + AsyncWrite, - S: Service> + 'static, - B: MessageBody, -> { +enum State + 'static, B: MessageBody> { Incoming(Dispatcher), Handshake( Option>, @@ -324,8 +300,8 @@ enum State< pub struct H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite, - S: Service> + 'static, - S::Error: Into + Debug, + S: Service + 'static, + S::Error: Debug, S::Response: Into>, B: MessageBody + 'static, { @@ -335,13 +311,13 @@ where impl Future for H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite, - S: Service> + 'static, - S::Error: Into + Debug, + S: Service + 'static, + S::Error: Debug, S::Response: Into>, B: MessageBody, { type Item = (); - type Error = DispatchError<()>; + type Error = DispatchError; fn poll(&mut self) -> Poll { match self.state { diff --git a/src/lib.rs b/src/lib.rs index 8750b24ca..74a46fd17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,6 +97,7 @@ pub use self::message::{Head, Message, RequestHead, ResponseHead}; pub use self::payload::{Payload, PayloadStream}; pub use self::request::Request; pub use self::response::Response; +pub use self::service::HttpService; pub use self::service::{SendError, SendResponse}; pub mod dev { diff --git a/src/service/mod.rs b/src/service/mod.rs index 83a40bd12..25e95bf60 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,3 +1,5 @@ mod senderror; +mod service; pub use self::senderror::{SendError, SendResponse}; +pub use self::service::HttpService; diff --git a/src/service/service.rs b/src/service/service.rs new file mode 100644 index 000000000..5f6ee8190 --- /dev/null +++ b/src/service/service.rs @@ -0,0 +1,446 @@ +use std::fmt::Debug; +use std::marker::PhantomData; +use std::{fmt, io, net}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed, FramedParts}; +use actix_service::{IntoNewService, NewService, Service}; +use actix_utils::cloneable::CloneableService; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use h2::server::{self, Handshake}; +use log::error; + +use crate::body::MessageBody; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::DispatchError; +use crate::request::Request; +use crate::response::Response; + +use crate::{h1, h2::Dispatcher}; + +/// `NewService` HTTP1.1/HTTP2 transport implementation +pub struct HttpService { + srv: S, + cfg: ServiceConfig, + _t: PhantomData<(T, B)>, +} + +impl HttpService +where + S: NewService, + S::Service: 'static, + S::Error: Debug + 'static, + S::Response: Into>, + B: MessageBody + 'static, +{ + /// Create new `HttpService` instance. + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + + HttpService { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } + + /// Create builder for `HttpService` instance. + pub fn build() -> HttpServiceBuilder { + HttpServiceBuilder::new() + } +} + +impl NewService for HttpService +where + T: AsyncRead + AsyncWrite + 'static, + S: NewService, + S::Service: 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Response = (); + type Error = DispatchError; + type InitError = S::InitError; + type Service = HttpServiceHandler; + type Future = HttpServiceResponse; + + fn new_service(&self, _: &()) -> Self::Future { + HttpServiceResponse { + fut: self.srv.new_service(&()).into_future(), + cfg: Some(self.cfg.clone()), + _t: PhantomData, + } + } +} + +/// A http service factory builder +/// +/// This type can be used to construct an instance of `ServiceConfig` through a +/// builder-like pattern. +pub struct HttpServiceBuilder { + keep_alive: KeepAlive, + client_timeout: u64, + client_disconnect: u64, + host: String, + addr: net::SocketAddr, + secure: bool, + _t: PhantomData<(T, S)>, +} + +impl HttpServiceBuilder +where + S: NewService, + S::Service: 'static, + S::Error: Debug + 'static, +{ + /// Create instance of `HttpServiceBuilder` type + pub fn new() -> HttpServiceBuilder { + HttpServiceBuilder { + keep_alive: KeepAlive::Timeout(5), + client_timeout: 5000, + client_disconnect: 0, + secure: false, + host: "localhost".to_owned(), + addr: "127.0.0.1:8080".parse().unwrap(), + _t: PhantomData, + } + } + + /// Enable secure flag for current server. + /// This flags also enables `client disconnect timeout`. + /// + /// By default this flag is set to false. + pub fn secure(mut self) -> Self { + self.secure = true; + if self.client_disconnect == 0 { + self.client_disconnect = 3000; + } + self + } + + /// Set server keep-alive setting. + /// + /// By default keep alive is set to a 5 seconds. + pub fn keep_alive>(mut self, val: U) -> Self { + self.keep_alive = val.into(); + self + } + + /// Set server client timeout in milliseconds for first request. + /// + /// Defines a timeout for reading client request header. If a client does not transmit + /// the entire set headers within this time, the request is terminated with + /// the 408 (Request Time-out) error. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_timeout(mut self, val: u64) -> Self { + self.client_timeout = val; + self + } + + /// Set server connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the request get dropped. This timeout affects secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 3000 milliseconds. + pub fn client_disconnect(mut self, val: u64) -> Self { + self.client_disconnect = val; + self + } + + /// Set server host name. + /// + /// Host name is used by application router aa a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + /// + /// By default host name is set to a "localhost" value. + pub fn server_hostname(mut self, val: &str) -> Self { + self.host = val.to_owned(); + self + } + + /// Set server ip address. + /// + /// Host name is used by application router aa a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + /// + /// By default server address is set to a "127.0.0.1:8080" + pub fn server_address(mut self, addr: U) -> Self { + match addr.to_socket_addrs() { + Err(err) => error!("Can not convert to SocketAddr: {}", err), + Ok(mut addrs) => { + if let Some(addr) = addrs.next() { + self.addr = addr; + } + } + } + self + } + + // #[cfg(feature = "ssl")] + // /// Configure alpn protocols for SslAcceptorBuilder. + // pub fn configure_openssl( + // builder: &mut openssl::ssl::SslAcceptorBuilder, + // ) -> io::Result<()> { + // let protos: &[u8] = b"\x02h2"; + // builder.set_alpn_select_callback(|_, protos| { + // const H2: &[u8] = b"\x02h2"; + // if protos.windows(3).any(|window| window == H2) { + // Ok(b"h2") + // } else { + // Err(openssl::ssl::AlpnError::NOACK) + // } + // }); + // builder.set_alpn_protos(&protos)?; + + // Ok(()) + // } + + /// Finish service configuration and create `HttpService` instance. + pub fn finish(self, service: F) -> HttpService + where + B: MessageBody, + F: IntoNewService, + { + let cfg = ServiceConfig::new( + self.keep_alive, + self.client_timeout, + self.client_disconnect, + ); + HttpService { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +pub struct HttpServiceResponse, B> { + fut: ::Future, + cfg: Option, + _t: PhantomData<(T, B)>, +} + +impl Future for HttpServiceResponse +where + T: AsyncRead + AsyncWrite, + S: NewService, + S::Service: 'static, + S::Response: Into>, + S::Error: Debug, + B: MessageBody + 'static, +{ + type Item = HttpServiceHandler; + type Error = S::InitError; + + fn poll(&mut self) -> Poll { + let service = try_ready!(self.fut.poll()); + Ok(Async::Ready(HttpServiceHandler::new( + self.cfg.take().unwrap(), + service, + ))) + } +} + +/// `Service` implementation for http transport +pub struct HttpServiceHandler { + srv: CloneableService, + cfg: ServiceConfig, + _t: PhantomData<(T, B)>, +} + +impl HttpServiceHandler +where + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + fn new(cfg: ServiceConfig, srv: S) -> HttpServiceHandler { + HttpServiceHandler { + cfg, + srv: CloneableService::new(srv), + _t: PhantomData, + } + } +} + +impl Service for HttpServiceHandler +where + T: AsyncRead + AsyncWrite + 'static, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Response = (); + type Error = DispatchError; + type Future = HttpServiceHandlerResponse; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.srv.poll_ready().map_err(|e| { + error!("Service readiness error: {:?}", e); + DispatchError::Service + }) + } + + fn call(&mut self, req: T) -> Self::Future { + HttpServiceHandlerResponse { + state: State::Unknown(Some(( + req, + BytesMut::with_capacity(14), + self.cfg.clone(), + self.srv.clone(), + ))), + } + } +} + +enum State + 'static, B: MessageBody> +where + S::Error: fmt::Debug, + T: AsyncRead + AsyncWrite + 'static, +{ + H1(h1::Dispatcher), + H2(Dispatcher, S, B>), + Unknown(Option<(T, BytesMut, ServiceConfig, CloneableService)>), + Handshake(Option<(Handshake, Bytes>, ServiceConfig, CloneableService)>), +} + +pub struct HttpServiceHandlerResponse +where + T: AsyncRead + AsyncWrite + 'static, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + state: State, +} + +const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; + +impl Future for HttpServiceHandlerResponse +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + type Item = (); + type Error = DispatchError; + + fn poll(&mut self) -> Poll { + match self.state { + State::H1(ref mut disp) => disp.poll(), + State::H2(ref mut disp) => disp.poll(), + State::Unknown(ref mut data) => { + if let Some(ref mut item) = data { + loop { + unsafe { + let b = item.1.bytes_mut(); + let n = { try_ready!(item.0.poll_read(b)) }; + item.1.advance_mut(n); + if item.1.len() >= HTTP2_PREFACE.len() { + break; + } + } + } + } else { + panic!() + } + let (io, buf, cfg, srv) = data.take().unwrap(); + if buf[..14] == HTTP2_PREFACE[..] { + let io = Io { + inner: io, + unread: Some(buf), + }; + self.state = + State::Handshake(Some((server::handshake(io), cfg, srv))); + } else { + let framed = Framed::from_parts(FramedParts::with_read_buf( + io, + h1::Codec::new(cfg.clone()), + buf, + )); + self.state = + State::H1(h1::Dispatcher::with_timeout(framed, cfg, None, srv)) + } + self.poll() + } + State::Handshake(ref mut data) => { + let conn = if let Some(ref mut item) = data { + match item.0.poll() { + Ok(Async::Ready(conn)) => conn, + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => { + trace!("H2 handshake error: {}", err); + return Err(err.into()); + } + } + } else { + panic!() + }; + let (_, cfg, srv) = data.take().unwrap(); + self.state = State::H2(Dispatcher::new(srv, conn, cfg, None)); + self.poll() + } + } + } +} + +/// Wrapper for `AsyncRead + AsyncWrite` types +struct Io { + unread: Option, + inner: T, +} + +impl io::Read for Io { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if let Some(mut bytes) = self.unread.take() { + let size = std::cmp::min(buf.len(), bytes.len()); + buf[..size].copy_from_slice(&bytes[..size]); + if bytes.len() > size { + bytes.split_to(size); + self.unread = Some(bytes); + } + Ok(size) + } else { + self.inner.read(buf) + } + } +} + +impl io::Write for Io { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.write(buf) + } + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() + } +} + +impl AsyncRead for Io { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } +} + +impl AsyncWrite for Io { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.inner.shutdown() + } + fn write_buf(&mut self, buf: &mut B) -> Poll { + self.inner.write_buf(buf) + } +} diff --git a/test-server/src/lib.rs b/test-server/src/lib.rs index 3afee682f..695a477d2 100644 --- a/test-server/src/lib.rs +++ b/test-server/src/lib.rs @@ -73,7 +73,7 @@ impl TestServer { .start(); tx.send((System::current(), local_addr)).unwrap(); - sys.run(); + sys.run() }); let (system, addr) = rx.recv().unwrap(); diff --git a/tests/test_server.rs b/tests/test_server.rs index fd848b82b..4cffdd096 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -10,8 +10,8 @@ use futures::stream::once; use actix_http::body::Body; use actix_http::{ - body, client, h1, h2, http, Error, HttpMessage as HttpMessage2, KeepAlive, Request, - Response, + body, client, h1, h2, http, Error, HttpMessage as HttpMessage2, HttpService, + KeepAlive, Request, Response, }; #[test] @@ -31,6 +31,26 @@ fn test_h1() { assert!(response.status().is_success()); } +#[test] +fn test_h1_2() { + let mut srv = TestServer::new(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_disconnect(1000) + .server_hostname("localhost") + .finish(|req: Request| { + assert_eq!(req.version(), http::Version::HTTP_11); + future::ok::<_, ()>(Response::Ok().finish()) + }) + .map(|_| ()) + }); + + let req = client::ClientRequest::get(srv.url("/")).finish().unwrap(); + let response = srv.send_request(req).unwrap(); + assert!(response.status().is_success()); +} + #[cfg(feature = "ssl")] fn ssl_acceptor() -> std::io::Result> { use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; @@ -71,7 +91,30 @@ fn test_h2() -> std::io::Result<()> { let req = client::ClientRequest::get(srv.surl("/")).finish().unwrap(); let response = srv.send_request(req).unwrap(); - println!("RES: {:?}", response); + assert!(response.status().is_success()); + Ok(()) +} + +#[cfg(feature = "ssl")] +#[test] +fn test_h2_1() -> std::io::Result<()> { + let openssl = ssl_acceptor()?; + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .finish(|req: Request| { + assert_eq!(req.version(), http::Version::HTTP_2); + future::ok::<_, Error>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let req = client::ClientRequest::get(srv.surl("/")).finish().unwrap(); + let response = srv.send_request(req).unwrap(); assert!(response.status().is_success()); Ok(()) } @@ -79,9 +122,6 @@ fn test_h2() -> std::io::Result<()> { #[cfg(feature = "ssl")] #[test] fn test_h2_body() -> std::io::Result<()> { - // std::env::set_var("RUST_LOG", "actix_http=trace"); - // env_logger::init(); - let data = "HELLOWORLD".to_owned().repeat(64 * 1024); let openssl = ssl_acceptor()?; let mut srv = TestServer::new(move || {