diff --git a/Cargo.toml b/Cargo.toml index 12f98ac37..d70a65cf0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -105,6 +105,7 @@ futures = "0.1" futures-cpupool = "0.1" slab = "0.4" tokio = "0.1" +tokio-codec = "0.1" tokio-io = "0.1" tokio-tcp = "0.1" tokio-timer = "0.2" diff --git a/src/lib.rs b/src/lib.rs index 1ed408099..f494c05de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -115,6 +115,7 @@ extern crate parking_lot; extern crate rand; extern crate slab; extern crate tokio; +extern crate tokio_codec; extern crate tokio_current_thread; extern crate tokio_io; extern crate tokio_reactor; diff --git a/src/server/channel.rs b/src/server/channel.rs index cbbe1a95e..b1fef964e 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -9,6 +9,7 @@ use tokio_timer::Delay; use super::error::HttpDispatchError; use super::settings::ServiceConfig; use super::{h1, h2, HttpHandler, IoStream}; +use error::Error; use http::StatusCode; const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; @@ -90,7 +91,7 @@ where H: HttpHandler + 'static, { type Item = (); - type Error = HttpDispatchError; + type Error = HttpDispatchError; fn poll(&mut self) -> Poll { // keep-alive timer @@ -242,7 +243,7 @@ where H: HttpHandler + 'static, { type Item = (); - type Error = HttpDispatchError; + type Error = HttpDispatchError; fn poll(&mut self) -> Poll { if !self.node_reg { diff --git a/src/server/error.rs b/src/server/error.rs index 70f100998..3ae9a107b 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -1,11 +1,12 @@ +use std::fmt::{Debug, Display}; use std::io; use futures::{Async, Poll}; use http2; use super::{helpers, HttpHandlerTask, Writer}; +use error::{Error, ParseError}; use http::{StatusCode, Version}; -use Error; /// Errors produced by `AcceptorError` service. #[derive(Debug)] @@ -20,60 +21,70 @@ pub enum AcceptorError { Timeout, } -#[derive(Fail, Debug)] +#[derive(Debug)] /// A set of errors that can occur during dispatching http requests -pub enum HttpDispatchError { +pub enum HttpDispatchError { /// Application error - #[fail(display = "Application specific error: {}", _0)] - App(Error), + // #[fail(display = "Application specific error: {}", _0)] + App(E), /// An `io::Error` that occurred while trying to read or write to a network /// stream. - #[fail(display = "IO error: {}", _0)] + // #[fail(display = "IO error: {}", _0)] Io(io::Error), + /// Http request parse error. + // #[fail(display = "Parse error: {}", _0)] + Parse(ParseError), + /// The first request did not complete within the specified timeout. - #[fail(display = "The first request did not complete within the specified timeout")] + // #[fail(display = "The first request did not complete within the specified timeout")] SlowRequestTimeout, /// Shutdown timeout - #[fail(display = "Connection shutdown timeout")] + // #[fail(display = "Connection shutdown timeout")] ShutdownTimeout, /// HTTP2 error - #[fail(display = "HTTP2 error: {}", _0)] + // #[fail(display = "HTTP2 error: {}", _0)] Http2(http2::Error), /// Payload is not consumed - #[fail(display = "Task is completed but request's payload is not consumed")] + // #[fail(display = "Task is completed but request's payload is not consumed")] PayloadIsNotConsumed, /// Malformed request - #[fail(display = "Malformed request")] + // #[fail(display = "Malformed request")] MalformedRequest, /// Internal error - #[fail(display = "Internal error")] + // #[fail(display = "Internal error")] InternalError, /// Unknown error - #[fail(display = "Unknown error")] + // #[fail(display = "Unknown error")] Unknown, } -impl From for HttpDispatchError { - fn from(err: Error) -> Self { - HttpDispatchError::App(err) +// impl From for HttpDispatchError { +// fn from(err: E) -> Self { +// HttpDispatchError::App(err) +// } +// } + +impl From for HttpDispatchError { + fn from(err: ParseError) -> Self { + HttpDispatchError::Parse(err) } } -impl From for HttpDispatchError { +impl From for HttpDispatchError { fn from(err: io::Error) -> Self { HttpDispatchError::Io(err) } } -impl From for HttpDispatchError { +impl From for HttpDispatchError { fn from(err: http2::Error) -> Self { HttpDispatchError::Http2(err) } diff --git a/src/server/h1.rs b/src/server/h1.rs index 4fb730f71..e2b4bf45e 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -7,15 +7,16 @@ use futures::{Async, Future, Poll}; use tokio_current_thread::spawn; use tokio_timer::Delay; -use error::{Error, PayloadError}; +use error::{Error, ParseError, PayloadError}; use http::{StatusCode, Version}; use payload::{Payload, PayloadStatus, PayloadWriter}; use super::error::{HttpDispatchError, ServerError}; -use super::h1decoder::{DecoderError, H1Decoder, Message}; +use super::h1decoder::{H1Decoder, Message}; use super::h1writer::H1Writer; use super::handler::{HttpHandler, HttpHandlerTask, HttpHandlerTaskFut}; use super::input::PayloadType; +use super::message::Request; use super::settings::ServiceConfig; use super::{IoStream, Writer}; @@ -44,7 +45,7 @@ pub struct Http1Dispatcher { payload: Option, buf: BytesMut, tasks: VecDeque>, - error: Option, + error: Option>, ka_expire: Instant, ka_timer: Option, } @@ -109,7 +110,7 @@ where Http1Dispatcher { stream: H1Writer::new(stream, settings.clone()), - decoder: H1Decoder::new(), + decoder: H1Decoder::new(settings.request_pool()), payload: None, tasks: VecDeque::new(), error: None, @@ -133,7 +134,7 @@ where let mut disp = Http1Dispatcher { flags: Flags::STARTED | Flags::READ_DISCONNECTED | Flags::FLUSHED, stream: H1Writer::new(stream, settings.clone()), - decoder: H1Decoder::new(), + decoder: H1Decoder::new(settings.request_pool()), payload: None, tasks: VecDeque::new(), error: None, @@ -201,7 +202,7 @@ where } #[inline] - pub fn poll(&mut self) -> Poll<(), HttpDispatchError> { + pub fn poll(&mut self) -> Poll<(), HttpDispatchError> { // check connection keep-alive self.poll_keep_alive()?; @@ -247,7 +248,7 @@ where } /// Flush stream - fn poll_flush(&mut self, shutdown: bool) -> Poll<(), HttpDispatchError> { + fn poll_flush(&mut self, shutdown: bool) -> Poll<(), HttpDispatchError> { if shutdown || self.flags.contains(Flags::STARTED) { match self.stream.poll_completed(shutdown) { Ok(Async::NotReady) => { @@ -277,7 +278,7 @@ where } /// keep-alive timer. returns `true` is keep-alive, otherwise drop - fn poll_keep_alive(&mut self) -> Result<(), HttpDispatchError> { + fn poll_keep_alive(&mut self) -> Result<(), HttpDispatchError> { if let Some(ref mut timer) = self.ka_timer { match timer.poll() { Ok(Async::Ready(_)) => { @@ -336,7 +337,7 @@ where #[inline] /// read data from the stream - pub(self) fn poll_io(&mut self) -> Result { + pub(self) fn poll_io(&mut self) -> Result> { if !self.flags.contains(Flags::POLLED) { self.flags.insert(Flags::POLLED); if !self.buf.is_empty() { @@ -367,7 +368,7 @@ where Ok(updated) } - pub(self) fn poll_handler(&mut self) -> Result<(), HttpDispatchError> { + pub(self) fn poll_handler(&mut self) -> Result<(), HttpDispatchError> { self.poll_io()?; let mut retry = self.can_read(); @@ -409,7 +410,7 @@ where // it is not possible to recover from error // during pipe handling, so just drop connection self.client_disconnected(false); - return Err(err.into()); + return Err(HttpDispatchError::App(err)); } } } @@ -423,7 +424,7 @@ where Ok(Async::NotReady) => false, Ok(Async::Ready(_)) => true, Err(err) => { - self.error = Some(err.into()); + self.error = Some(HttpDispatchError::App(err)); true } }; @@ -462,66 +463,75 @@ where .push_back(Entry::Error(ServerError::err(Version::HTTP_11, status))); } - pub(self) fn parse(&mut self) -> Result { + fn handle_message( + &mut self, mut msg: Request, payload: bool, + ) -> Result<(), HttpDispatchError> { + self.flags.insert(Flags::STARTED); + + if payload { + let (ps, pl) = Payload::new(false); + *msg.inner.payload.borrow_mut() = Some(pl); + self.payload = Some(PayloadType::new(&msg.inner.headers, ps)); + } + + // stream extensions + msg.inner_mut().stream_extensions = self.stream.get_mut().extensions(); + + // set remote addr + msg.inner_mut().addr = self.addr; + + // search handler for request + match self.settings.handler().handle(msg) { + Ok(mut task) => { + if self.tasks.is_empty() { + match task.poll_io(&mut self.stream) { + Ok(Async::Ready(ready)) => { + // override keep-alive state + if self.stream.keepalive() { + self.flags.insert(Flags::KEEPALIVE); + } else { + self.flags.remove(Flags::KEEPALIVE); + } + // prepare stream for next response + self.stream.reset(); + + if !ready { + // task is done with io operations + // but still needs to do more work + spawn(HttpHandlerTaskFut::new(task)); + } + } + Ok(Async::NotReady) => (), + Err(err) => { + error!("Unhandled error: {}", err); + self.client_disconnected(false); + return Err(HttpDispatchError::App(err)); + } + } + } else { + self.tasks.push_back(Entry::Task(task)); + } + } + Err(_) => { + // handler is not found + self.push_response_entry(StatusCode::NOT_FOUND); + } + } + Ok(()) + } + + pub(self) fn parse(&mut self) -> Result> { let mut updated = false; 'outer: loop { - match self.decoder.decode(&mut self.buf, &self.settings) { - Ok(Some(Message::Message { mut msg, payload })) => { + match self.decoder.decode(&mut self.buf) { + Ok(Some(Message::Message(msg))) => { updated = true; - self.flags.insert(Flags::STARTED); - - if payload { - let (ps, pl) = Payload::new(false); - *msg.inner.payload.borrow_mut() = Some(pl); - self.payload = Some(PayloadType::new(&msg.inner.headers, ps)); - } - - // stream extensions - msg.inner_mut().stream_extensions = - self.stream.get_mut().extensions(); - - // set remote addr - msg.inner_mut().addr = self.addr; - - // search handler for request - match self.settings.handler().handle(msg) { - Ok(mut task) => { - if self.tasks.is_empty() { - match task.poll_io(&mut self.stream) { - Ok(Async::Ready(ready)) => { - // override keep-alive state - if self.stream.keepalive() { - self.flags.insert(Flags::KEEPALIVE); - } else { - self.flags.remove(Flags::KEEPALIVE); - } - // prepare stream for next response - self.stream.reset(); - - if !ready { - // task is done with io operations - // but still needs to do more work - spawn(HttpHandlerTaskFut::new(task)); - } - continue 'outer; - } - Ok(Async::NotReady) => (), - Err(err) => { - error!("Unhandled error: {}", err); - self.client_disconnected(false); - return Err(err.into()); - } - } - } - self.tasks.push_back(Entry::Task(task)); - continue 'outer; - } - Err(_) => { - // handler is not found - self.push_response_entry(StatusCode::NOT_FOUND); - } - } + self.handle_message(msg, false)?; + } + Ok(Some(Message::MessageWithPayload(msg))) => { + updated = true; + self.handle_message(msg, true)?; } Ok(Some(Message::Chunk(chunk))) => { updated = true; @@ -556,8 +566,8 @@ where Err(e) => { if let Some(mut payload) = self.payload.take() { let e = match e { - DecoderError::Io(e) => PayloadError::Io(e), - DecoderError::Error(_) => PayloadError::EncodingCorrupted, + ParseError::Io(e) => PayloadError::Io(e), + _ => PayloadError::EncodingCorrupted, }; payload.set_error(e); } @@ -593,6 +603,7 @@ mod tests { use super::*; use application::{App, HttpApplication}; + use error::ParseError; use httpmessage::HttpMessage; use server::h1decoder::Message; use server::handler::IntoHttpHandler; @@ -612,13 +623,14 @@ mod tests { impl Message { fn message(self) -> Request { match self { - Message::Message { msg, payload: _ } => msg, + Message::Message(msg) => msg, + Message::MessageWithPayload(msg) => msg, _ => panic!("error"), } } fn is_payload(&self) -> bool { match *self { - Message::Message { msg: _, payload } => payload, + Message::MessageWithPayload(_) => true, _ => panic!("error"), } } @@ -639,7 +651,7 @@ mod tests { macro_rules! parse_ready { ($e:expr) => {{ let settings = wrk_settings(); - match H1Decoder::new().decode($e, &settings) { + match H1Decoder::new(settings.request_pool()).decode($e) { Ok(Some(msg)) => msg.message(), Ok(_) => unreachable!("Eof during parsing http request"), Err(err) => unreachable!("Error during parsing http request: {:?}", err), @@ -651,10 +663,10 @@ mod tests { ($e:expr) => {{ let settings = wrk_settings(); - match H1Decoder::new().decode($e, &settings) { + match H1Decoder::new(settings.request_pool()).decode($e) { Err(err) => match err { - DecoderError::Error(_) => (), - _ => unreachable!("Parse error expected"), + ParseError::Io(_) => unreachable!("Parse error expected"), + _ => (), }, _ => unreachable!("Error expected"), } @@ -747,8 +759,8 @@ mod tests { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - match reader.decode(&mut buf, &settings) { + let mut reader = H1Decoder::new(settings.request_pool()); + match reader.decode(&mut buf) { Ok(Some(msg)) => { let req = msg.message(); assert_eq!(req.version(), Version::HTTP_11); @@ -764,14 +776,14 @@ mod tests { let mut buf = BytesMut::from("PUT /test HTTP/1"); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - match reader.decode(&mut buf, &settings) { + let mut reader = H1Decoder::new(settings.request_pool()); + match reader.decode(&mut buf) { Ok(None) => (), _ => unreachable!("Error"), } buf.extend(b".1\r\n\r\n"); - match reader.decode(&mut buf, &settings) { + match reader.decode(&mut buf) { Ok(Some(msg)) => { let mut req = msg.message(); assert_eq!(req.version(), Version::HTTP_11); @@ -787,8 +799,8 @@ mod tests { let mut buf = BytesMut::from("POST /test2 HTTP/1.0\r\n\r\n"); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - match reader.decode(&mut buf, &settings) { + let mut reader = H1Decoder::new(settings.request_pool()); + match reader.decode(&mut buf) { Ok(Some(msg)) => { let mut req = msg.message(); assert_eq!(req.version(), Version::HTTP_10); @@ -805,20 +817,15 @@ mod tests { BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - match reader.decode(&mut buf, &settings) { + let mut reader = H1Decoder::new(settings.request_pool()); + match reader.decode(&mut buf) { Ok(Some(msg)) => { let mut req = msg.message(); assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); assert_eq!( - reader - .decode(&mut buf, &settings) - .unwrap() - .unwrap() - .chunk() - .as_ref(), + reader.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), b"body" ); } @@ -832,20 +839,15 @@ mod tests { BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - match reader.decode(&mut buf, &settings) { + let mut reader = H1Decoder::new(settings.request_pool()); + match reader.decode(&mut buf) { Ok(Some(msg)) => { let mut req = msg.message(); assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); assert_eq!( - reader - .decode(&mut buf, &settings) - .unwrap() - .unwrap() - .chunk() - .as_ref(), + reader.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), b"body" ); } @@ -857,11 +859,11 @@ mod tests { fn test_parse_partial_eof() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - assert!(reader.decode(&mut buf, &settings).unwrap().is_none()); + let mut reader = H1Decoder::new(settings.request_pool()); + assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b"\r\n"); - match reader.decode(&mut buf, &settings) { + match reader.decode(&mut buf) { Ok(Some(msg)) => { let req = msg.message(); assert_eq!(req.version(), Version::HTTP_11); @@ -877,17 +879,17 @@ mod tests { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - assert!{ reader.decode(&mut buf, &settings).unwrap().is_none() } + let mut reader = H1Decoder::new(settings.request_pool()); + assert!{ reader.decode(&mut buf).unwrap().is_none() } buf.extend(b"t"); - assert!{ reader.decode(&mut buf, &settings).unwrap().is_none() } + assert!{ reader.decode(&mut buf).unwrap().is_none() } buf.extend(b"es"); - assert!{ reader.decode(&mut buf, &settings).unwrap().is_none() } + assert!{ reader.decode(&mut buf).unwrap().is_none() } buf.extend(b"t: value\r\n\r\n"); - match reader.decode(&mut buf, &settings) { + match reader.decode(&mut buf) { Ok(Some(msg)) => { let req = msg.message(); assert_eq!(req.version(), Version::HTTP_11); @@ -907,8 +909,8 @@ mod tests { Set-Cookie: c2=cookie2\r\n\r\n", ); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let mut reader = H1Decoder::new(settings.request_pool()); + let msg = reader.decode(&mut buf).unwrap().unwrap(); let req = msg.message(); let val: Vec<_> = req @@ -1109,19 +1111,14 @@ mod tests { upgrade: websocket\r\n\r\n\ some raw data", ); - let mut reader = H1Decoder::new(); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let mut reader = H1Decoder::new(settings.request_pool()); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.is_payload()); let req = msg.message(); assert!(!req.keep_alive()); assert!(req.upgrade()); assert_eq!( - reader - .decode(&mut buf, &settings) - .unwrap() - .unwrap() - .chunk() - .as_ref(), + reader.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), b"some raw data" ); } @@ -1169,32 +1166,22 @@ mod tests { transfer-encoding: chunked\r\n\r\n", ); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let mut reader = H1Decoder::new(settings.request_pool()); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.is_payload()); let req = msg.message(); assert!(req.chunked().unwrap()); buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); assert_eq!( - reader - .decode(&mut buf, &settings) - .unwrap() - .unwrap() - .chunk() - .as_ref(), + reader.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), b"data" ); assert_eq!( - reader - .decode(&mut buf, &settings) - .unwrap() - .unwrap() - .chunk() - .as_ref(), + reader.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), b"line" ); - assert!(reader.decode(&mut buf, &settings).unwrap().unwrap().eof()); + assert!(reader.decode(&mut buf).unwrap().unwrap().eof()); } #[test] @@ -1204,8 +1191,8 @@ mod tests { transfer-encoding: chunked\r\n\r\n", ); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let mut reader = H1Decoder::new(settings.request_pool()); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.is_payload()); let req = msg.message(); assert!(req.chunked().unwrap()); @@ -1216,14 +1203,14 @@ mod tests { transfer-encoding: chunked\r\n\r\n" .iter(), ); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(msg.chunk().as_ref(), b"data"); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(msg.chunk().as_ref(), b"line"); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.eof()); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.is_payload()); let req2 = msg.message(); assert!(req2.chunked().unwrap()); @@ -1239,30 +1226,30 @@ mod tests { ); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let mut reader = H1Decoder::new(settings.request_pool()); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.is_payload()); let req = msg.message(); assert!(req.chunked().unwrap()); buf.extend(b"4\r\n1111\r\n"); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(msg.chunk().as_ref(), b"1111"); buf.extend(b"4\r\ndata\r"); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(msg.chunk().as_ref(), b"data"); buf.extend(b"\n4"); - assert!(reader.decode(&mut buf, &settings).unwrap().is_none()); + assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b"\r"); - assert!(reader.decode(&mut buf, &settings).unwrap().is_none()); + assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b"\n"); - assert!(reader.decode(&mut buf, &settings).unwrap().is_none()); + assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b"li"); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(msg.chunk().as_ref(), b"li"); //trailers @@ -1270,12 +1257,12 @@ mod tests { //not_ready!(reader.parse(&mut buf, &mut readbuf)); buf.extend(b"ne\r\n0\r\n"); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(msg.chunk().as_ref(), b"ne"); - assert!(reader.decode(&mut buf, &settings).unwrap().is_none()); + assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b"\r\n"); - assert!(reader.decode(&mut buf, &settings).unwrap().unwrap().eof()); + assert!(reader.decode(&mut buf).unwrap().unwrap().eof()); } #[test] @@ -1286,17 +1273,17 @@ mod tests { ); let settings = wrk_settings(); - let mut reader = H1Decoder::new(); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let mut reader = H1Decoder::new(settings.request_pool()); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.is_payload()); assert!(msg.message().chunked().unwrap()); buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") - let chunk = reader.decode(&mut buf, &settings).unwrap().unwrap().chunk(); + let chunk = reader.decode(&mut buf).unwrap().unwrap().chunk(); assert_eq!(chunk, Bytes::from_static(b"data")); - let chunk = reader.decode(&mut buf, &settings).unwrap().unwrap().chunk(); + let chunk = reader.decode(&mut buf).unwrap().unwrap().chunk(); assert_eq!(chunk, Bytes::from_static(b"line")); - let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); + let msg = reader.decode(&mut buf).unwrap().unwrap(); assert!(msg.eof()); } } diff --git a/src/server/h1codec.rs b/src/server/h1codec.rs new file mode 100644 index 000000000..ea56110d3 --- /dev/null +++ b/src/server/h1codec.rs @@ -0,0 +1,251 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::io::{self, Write}; + +use bytes::{BufMut, Bytes, BytesMut}; +use tokio_codec::{Decoder, Encoder}; + +use super::h1decoder::{H1Decoder, Message}; +use super::helpers; +use super::message::RequestPool; +use super::output::{ResponseInfo, ResponseLength}; +use body::Body; +use error::ParseError; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use http::Version; +use httpresponse::HttpResponse; + +pub(crate) enum OutMessage { + Response(HttpResponse), + Payload(Bytes), +} + +pub(crate) struct H1Codec { + decoder: H1Decoder, + encoder: H1Writer, +} + +impl H1Codec { + pub fn new(pool: &'static RequestPool) -> Self { + H1Codec { + decoder: H1Decoder::new(pool), + encoder: H1Writer::new(), + } + } +} + +impl Decoder for H1Codec { + type Item = Message; + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + self.decoder.decode(src) + } +} + +impl Encoder for H1Codec { + type Item = OutMessage; + type Error = io::Error; + + fn encode( + &mut self, item: Self::Item, dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + match item { + OutMessage::Response(res) => { + self.encoder.encode(res, dst)?; + } + OutMessage::Payload(bytes) => { + dst.extend_from_slice(&bytes); + } + } + Ok(()) + } +} + +bitflags! { + struct Flags: u8 { + const STARTED = 0b0000_0001; + const UPGRADE = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const DISCONNECTED = 0b0000_1000; + } +} + +const AVERAGE_HEADER_SIZE: usize = 30; + +struct H1Writer { + flags: Flags, + written: u64, + headers_size: u32, +} + +impl H1Writer { + fn new() -> H1Writer { + H1Writer { + flags: Flags::empty(), + written: 0, + headers_size: 0, + } + } + + fn written(&self) -> u64 { + self.written + } + + pub fn reset(&mut self) { + self.written = 0; + self.flags = Flags::KEEPALIVE; + } + + pub fn upgrade(&self) -> bool { + self.flags.contains(Flags::UPGRADE) + } + + pub fn keepalive(&self) -> bool { + self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) + } + + fn encode( + &mut self, mut msg: HttpResponse, buffer: &mut BytesMut, + ) -> io::Result<()> { + // prepare task + let info = ResponseInfo::new(false); // req.inner.method == Method::HEAD); + + //if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { + //self.flags = Flags::STARTED | Flags::KEEPALIVE; + //} else { + self.flags = Flags::STARTED; + //} + + // Connection upgrade + let version = msg.version().unwrap_or_else(|| Version::HTTP_11); //req.inner.version); + if msg.upgrade() { + self.flags.insert(Flags::UPGRADE); + msg.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("upgrade")); + } + // keep-alive + else if self.flags.contains(Flags::KEEPALIVE) { + if version < Version::HTTP_11 { + msg.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + } + } else if version >= Version::HTTP_11 { + msg.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("close")); + } + let body = msg.replace_body(Body::Empty); + + // render message + { + let reason = msg.reason().as_bytes(); + if let Body::Binary(ref bytes) = body { + buffer.reserve( + 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + + bytes.len() + + reason.len(), + ); + } else { + buffer.reserve( + 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len(), + ); + } + + // status line + helpers::write_status_line(version, msg.status().as_u16(), buffer); + buffer.extend_from_slice(reason); + + // content length + match info.length { + ResponseLength::Chunked => { + buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") + } + ResponseLength::Zero => { + buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n") + } + ResponseLength::Length(len) => { + helpers::write_content_length(len, buffer) + } + ResponseLength::Length64(len) => { + buffer.extend_from_slice(b"\r\ncontent-length: "); + write!(buffer.writer(), "{}", len)?; + buffer.extend_from_slice(b"\r\n"); + } + ResponseLength::None => buffer.extend_from_slice(b"\r\n"), + } + if let Some(ce) = info.content_encoding { + buffer.extend_from_slice(b"content-encoding: "); + buffer.extend_from_slice(ce.as_ref()); + buffer.extend_from_slice(b"\r\n"); + } + + // write headers + let mut pos = 0; + let mut has_date = false; + let mut remaining = buffer.remaining_mut(); + let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) }; + for (key, value) in msg.headers() { + match *key { + TRANSFER_ENCODING => continue, + CONTENT_LENGTH => match info.length { + ResponseLength::None => (), + _ => continue, + }, + DATE => { + has_date = true; + } + _ => (), + } + + let v = value.as_ref(); + let k = key.as_str().as_bytes(); + let len = k.len() + v.len() + 4; + if len > remaining { + unsafe { + buffer.advance_mut(pos); + } + pos = 0; + buffer.reserve(len); + remaining = buffer.remaining_mut(); + unsafe { + buf = &mut *(buffer.bytes_mut() as *mut _); + } + } + + buf[pos..pos + k.len()].copy_from_slice(k); + pos += k.len(); + buf[pos..pos + 2].copy_from_slice(b": "); + pos += 2; + buf[pos..pos + v.len()].copy_from_slice(v); + pos += v.len(); + buf[pos..pos + 2].copy_from_slice(b"\r\n"); + pos += 2; + remaining -= len; + } + unsafe { + buffer.advance_mut(pos); + } + + // optimized date header, set_date writes \r\n + if !has_date { + // self.settings.set_date(&mut buffer, true); + buffer.extend_from_slice(b"\r\n"); + } else { + // msg eof + buffer.extend_from_slice(b"\r\n"); + } + self.headers_size = buffer.len() as u32; + } + + if let Body::Binary(bytes) = body { + self.written = bytes.len() as u64; + // buffer.write(bytes.as_ref())?; + buffer.extend_from_slice(bytes.as_ref()); + } else { + // capacity, makes sense only for streaming or actor + // self.buffer_capacity = msg.write_buffer_capacity(); + + msg.replace_body(body); + } + Ok(()) + } +} diff --git a/src/server/h1decoder.rs b/src/server/h1decoder.rs index 434dc42df..c6f0974aa 100644 --- a/src/server/h1decoder.rs +++ b/src/server/h1decoder.rs @@ -4,8 +4,7 @@ use bytes::{Bytes, BytesMut}; use futures::{Async, Poll}; use httparse; -use super::message::{MessageFlags, Request}; -use super::settings::ServiceConfig; +use super::message::{MessageFlags, Request, RequestPool}; use error::ParseError; use http::header::{HeaderName, HeaderValue}; use http::{header, HttpTryFrom, Method, Uri, Version}; @@ -16,35 +15,26 @@ const MAX_HEADERS: usize = 96; pub(crate) struct H1Decoder { decoder: Option, + pool: &'static RequestPool, } #[derive(Debug)] -pub(crate) enum Message { - Message { msg: Request, payload: bool }, +pub enum Message { + Message(Request), + MessageWithPayload(Request), Chunk(Bytes), Eof, } -#[derive(Debug)] -pub(crate) enum DecoderError { - Io(io::Error), - Error(ParseError), -} - -impl From for DecoderError { - fn from(err: io::Error) -> DecoderError { - DecoderError::Io(err) - } -} - impl H1Decoder { - pub fn new() -> H1Decoder { - H1Decoder { decoder: None } + pub fn new(pool: &'static RequestPool) -> H1Decoder { + H1Decoder { + pool, + decoder: None, + } } - pub fn decode( - &mut self, src: &mut BytesMut, settings: &ServiceConfig, - ) -> Result, DecoderError> { + pub fn decode(&mut self, src: &mut BytesMut) -> Result, ParseError> { // read payload if self.decoder.is_some() { match self.decoder.as_mut().unwrap().decode(src)? { @@ -57,21 +47,19 @@ impl H1Decoder { } } - match self - .parse_message(src, settings) - .map_err(DecoderError::Error)? - { + match self.parse_message(src)? { Async::Ready((msg, decoder)) => { self.decoder = decoder; - Ok(Some(Message::Message { - msg, - payload: self.decoder.is_some(), - })) + if self.decoder.is_some() { + Ok(Some(Message::MessageWithPayload(msg))) + } else { + Ok(Some(Message::Message(msg))) + } } Async::NotReady => { if src.len() >= MAX_BUFFER_SIZE { error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - Err(DecoderError::Error(ParseError::TooLarge)) + Err(ParseError::TooLarge) } else { Ok(None) } @@ -79,8 +67,8 @@ impl H1Decoder { } } - fn parse_message( - &self, buf: &mut BytesMut, settings: &ServiceConfig, + fn parse_message( + &self, buf: &mut BytesMut, ) -> Poll<(Request, Option), ParseError> { // Parse http message let mut has_upgrade = false; @@ -119,7 +107,7 @@ impl H1Decoder { let slice = buf.split_to(len).freeze(); // convert headers - let mut msg = settings.get_request(); + let mut msg = RequestPool::get(self.pool); { let inner = msg.inner_mut(); inner diff --git a/src/server/h1disp.rs b/src/server/h1disp.rs new file mode 100644 index 000000000..b1c2c8a21 --- /dev/null +++ b/src/server/h1disp.rs @@ -0,0 +1,425 @@ +// #![allow(unused_imports, unused_variables, dead_code)] +use std::collections::VecDeque; +use std::fmt::{Debug, Display}; +use std::net::SocketAddr; +// use std::time::{Duration, Instant}; + +use actix_net::service::Service; + +use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; +use tokio_codec::Framed; +// use tokio_current_thread::spawn; +use tokio_io::AsyncWrite; +// use tokio_timer::Delay; + +use error::{ParseError, PayloadError}; +use payload::{Payload, PayloadStatus, PayloadWriter}; + +use body::Body; +use httpresponse::HttpResponse; + +use super::error::HttpDispatchError; +use super::h1codec::{H1Codec, OutMessage}; +use super::h1decoder::Message; +use super::input::PayloadType; +use super::message::{Request, RequestPool}; +use super::IoStream; + +const MAX_PIPELINED_MESSAGES: usize = 16; + +bitflags! { + pub struct Flags: u8 { + const STARTED = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const SHUTDOWN = 0b0000_1000; + const READ_DISCONNECTED = 0b0001_0000; + const WRITE_DISCONNECTED = 0b0010_0000; + const POLLED = 0b0100_0000; + const FLUSHED = 0b1000_0000; + } +} + +/// Dispatcher for HTTP/1.1 protocol +pub struct Http1Dispatcher +where + S::Error: Debug + Display, +{ + service: S, + flags: Flags, + addr: Option, + framed: Framed, + error: Option>, + + state: State, + payload: Option, + messages: VecDeque, +} + +enum State { + None, + Response(S::Future), + SendResponse(Option), + SendResponseWithPayload(Option<(OutMessage, Body)>), + Payload(Body), +} + +impl State { + fn is_empty(&self) -> bool { + if let State::None = self { + true + } else { + false + } + } +} + +impl Http1Dispatcher +where + T: IoStream, + S: Service, + S::Error: Debug + Display, +{ + pub fn new(stream: T, pool: &'static RequestPool, service: S) -> Self { + let addr = stream.peer_addr(); + let flags = Flags::FLUSHED; + let codec = H1Codec::new(pool); + let framed = Framed::new(stream, codec); + + Http1Dispatcher { + payload: None, + state: State::None, + error: None, + messages: VecDeque::new(), + service, + flags, + addr, + framed, + } + } + + #[inline] + fn can_read(&self) -> bool { + if self.flags.contains(Flags::READ_DISCONNECTED) { + return false; + } + + if let Some(ref info) = self.payload { + info.need_read() == PayloadStatus::Read + } else { + true + } + } + + // if checked is set to true, delay disconnect until all tasks have finished. + fn client_disconnected(&mut self, checked: bool) { + self.flags.insert(Flags::READ_DISCONNECTED); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete); + } + + // if !checked || self.tasks.is_empty() { + // self.flags + // .insert(Flags::WRITE_DISCONNECTED | Flags::FLUSHED); + + // // notify tasks + // for mut task in self.tasks.drain(..) { + // task.disconnected(); + // match task.poll_completed() { + // Ok(Async::NotReady) => { + // // spawn not completed task, it does not require access to io + // // at this point + // spawn(HttpHandlerTaskFut::new(task.into_task())); + // } + // Ok(Async::Ready(_)) => (), + // Err(err) => { + // error!("Unhandled application error: {}", err); + // } + // } + // } + // } + } + + /// Flush stream + fn poll_flush(&mut self) -> Poll<(), HttpDispatchError> { + if self.flags.contains(Flags::STARTED) && !self.flags.contains(Flags::FLUSHED) { + match self.framed.poll_complete() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(err) => { + debug!("Error sending data: {}", err); + self.client_disconnected(false); + Err(err.into()) + } + Ok(Async::Ready(_)) => { + // if payload is not consumed we can not use connection + if self.payload.is_some() && self.state.is_empty() { + return Err(HttpDispatchError::PayloadIsNotConsumed); + } + self.flags.insert(Flags::FLUSHED); + Ok(Async::Ready(())) + } + } + } else { + Ok(Async::Ready(())) + } + } + + pub(self) fn poll_handler(&mut self) -> Result<(), HttpDispatchError> { + self.poll_io()?; + let mut retry = self.can_read(); + + // process + loop { + let state = match self.state { + State::None => loop { + break if let Some(msg) = self.messages.pop_front() { + let mut task = self.service.call(msg); + match task.poll() { + Ok(Async::Ready(res)) => { + if res.body().is_streaming() { + unimplemented!() + } else { + Some(Ok(State::SendResponse(Some( + OutMessage::Response(res), + )))) + } + } + Ok(Async::NotReady) => Some(Ok(State::Response(task))), + Err(err) => Some(Err(HttpDispatchError::App(err))), + } + } else { + None + }; + }, + State::Payload(ref mut body) => unimplemented!(), + State::Response(ref mut fut) => { + match fut.poll() { + Ok(Async::Ready(res)) => { + if res.body().is_streaming() { + unimplemented!() + } else { + Some(Ok(State::SendResponse(Some( + OutMessage::Response(res), + )))) + } + } + Ok(Async::NotReady) => None, + Err(err) => { + // it is not possible to recover from error + // during pipe handling, so just drop connection + Some(Err(HttpDispatchError::App(err))) + } + } + } + State::SendResponse(ref mut item) => { + let msg = item.take().expect("SendResponse is empty"); + match self.framed.start_send(msg) { + Ok(AsyncSink::Ready) => { + self.flags.remove(Flags::FLUSHED); + Some(Ok(State::None)) + } + Ok(AsyncSink::NotReady(msg)) => { + *item = Some(msg); + return Ok(()); + } + Err(err) => Some(Err(HttpDispatchError::Io(err))), + } + } + State::SendResponseWithPayload(ref mut item) => { + let (msg, body) = item.take().expect("SendResponse is empty"); + match self.framed.start_send(msg) { + Ok(AsyncSink::Ready) => { + self.flags.remove(Flags::FLUSHED); + Some(Ok(State::Payload(body))) + } + Ok(AsyncSink::NotReady(msg)) => { + *item = Some((msg, body)); + return Ok(()); + } + Err(err) => Some(Err(HttpDispatchError::Io(err))), + } + } + }; + + match state { + Some(Ok(state)) => self.state = state, + Some(Err(err)) => { + // error!("Unhandled error1: {}", err); + self.client_disconnected(false); + return Err(err); + } + None => { + // if read-backpressure is enabled and we consumed some data. + // we may read more dataand retry + if !retry && self.can_read() && self.poll_io()? { + retry = self.can_read(); + continue; + } + break; + } + } + } + + Ok(()) + } + + fn one_message(&mut self, msg: Message) -> Result<(), HttpDispatchError> { + self.flags.insert(Flags::STARTED); + + match msg { + Message::Message(mut msg) => { + // set remote addr + msg.inner_mut().addr = self.addr; + + // handle request early + if self.state.is_empty() { + let mut task = self.service.call(msg); + match task.poll() { + Ok(Async::Ready(res)) => { + if res.body().is_streaming() { + unimplemented!() + } else { + self.state = + State::SendResponse(Some(OutMessage::Response(res))); + } + } + Ok(Async::NotReady) => self.state = State::Response(task), + Err(err) => { + error!("Unhandled application error: {}", err); + self.client_disconnected(false); + return Err(HttpDispatchError::App(err)); + } + } + } else { + self.messages.push_back(msg); + } + } + Message::MessageWithPayload(mut msg) => { + // set remote addr + msg.inner_mut().addr = self.addr; + + // payload + let (ps, pl) = Payload::new(false); + *msg.inner.payload.borrow_mut() = Some(pl); + self.payload = Some(PayloadType::new(&msg.inner.headers, ps)); + + self.messages.push_back(msg); + } + Message::Chunk(chunk) => { + if let Some(ref mut payload) = self.payload { + payload.feed_data(chunk); + } else { + error!("Internal server error: unexpected payload chunk"); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + // self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); + self.error = Some(HttpDispatchError::InternalError); + } + } + Message::Eof => { + if let Some(mut payload) = self.payload.take() { + payload.feed_eof(); + } else { + error!("Internal server error: unexpected eof"); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + // self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); + self.error = Some(HttpDispatchError::InternalError); + } + } + } + + Ok(()) + } + + pub(self) fn poll_io(&mut self) -> Result> { + let mut updated = false; + + if self.messages.len() < MAX_PIPELINED_MESSAGES { + 'outer: loop { + match self.framed.poll() { + Ok(Async::Ready(Some(msg))) => { + updated = true; + self.one_message(msg)?; + } + Ok(Async::Ready(None)) => { + if self.flags.contains(Flags::READ_DISCONNECTED) { + self.client_disconnected(true); + } + break; + } + Ok(Async::NotReady) => break, + Err(e) => { + if let Some(mut payload) = self.payload.take() { + let e = match e { + ParseError::Io(e) => PayloadError::Io(e), + _ => PayloadError::EncodingCorrupted, + }; + payload.set_error(e); + } + + // Malformed requests should be responded with 400 + // self.push_response_entry(StatusCode::BAD_REQUEST); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + self.error = Some(HttpDispatchError::MalformedRequest); + break; + } + } + } + } + + Ok(updated) + } +} + +impl Future for Http1Dispatcher +where + T: IoStream, + S: Service, + S::Error: Debug + Display, +{ + type Item = (); + type Error = HttpDispatchError; + + #[inline] + fn poll(&mut self) -> Poll<(), Self::Error> { + // shutdown + if self.flags.contains(Flags::SHUTDOWN) { + if self.flags.contains(Flags::WRITE_DISCONNECTED) { + return Ok(Async::Ready(())); + } + try_ready!(self.poll_flush()); + return Ok(AsyncWrite::shutdown(self.framed.get_mut())?); + } + + // process incoming requests + if !self.flags.contains(Flags::WRITE_DISCONNECTED) { + self.poll_handler()?; + + // flush stream + self.poll_flush()?; + + // deal with keep-alive and stream eof (client-side write shutdown) + if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { + // handle stream eof + if self + .flags + .intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) + { + return Ok(Async::Ready(())); + } + // no keep-alive + if self.flags.contains(Flags::STARTED) + && (!self.flags.contains(Flags::KEEPALIVE_ENABLED) + || !self.flags.contains(Flags::KEEPALIVE)) + { + self.flags.insert(Flags::SHUTDOWN); + return self.poll(); + } + } + Ok(Async::NotReady) + } else if let Some(err) = self.error.take() { + Err(err) + } else { + Ok(Async::Ready(())) + } + } +} diff --git a/src/server/h2.rs b/src/server/h2.rs index 2fe2fa073..27ae47851 100644 --- a/src/server/h2.rs +++ b/src/server/h2.rs @@ -86,7 +86,7 @@ where &self.settings } - pub fn poll(&mut self) -> Poll<(), HttpDispatchError> { + pub fn poll(&mut self) -> Poll<(), HttpDispatchError> { // server if let State::Connection(ref mut conn) = self.state { // keep-alive timer diff --git a/src/server/message.rs b/src/server/message.rs index 9c4bc1ec4..74ec5f17c 100644 --- a/src/server/message.rs +++ b/src/server/message.rs @@ -241,10 +241,7 @@ impl fmt::Debug for Request { } } -pub(crate) struct RequestPool( - RefCell>>, - RefCell, -); +pub struct RequestPool(RefCell>>, RefCell); thread_local!(static POOL: &'static RequestPool = RequestPool::create()); @@ -257,7 +254,7 @@ impl RequestPool { Box::leak(Box::new(pool)) } - pub fn pool(settings: ServerSettings) -> &'static RequestPool { + pub(crate) fn pool(settings: ServerSettings) -> &'static RequestPool { POOL.with(|p| { *p.1.borrow_mut() = settings; *p @@ -275,7 +272,7 @@ impl RequestPool { #[inline] /// Release request instance - pub fn release(&self, msg: Rc) { + pub(crate) fn release(&self, msg: Rc) { let v = &mut self.0.borrow_mut(); if v.len() < 128 { v.push_front(msg); diff --git a/src/server/mod.rs b/src/server/mod.rs index 3277dba5a..ce3f2fbf6 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -122,7 +122,10 @@ pub(crate) mod builder; mod channel; mod error; pub(crate) mod h1; -pub(crate) mod h1decoder; +#[doc(hidden)] +pub mod h1codec; +#[doc(hidden)] +pub mod h1decoder; mod h1writer; mod h2; mod h2writer; @@ -145,6 +148,9 @@ pub use self::ssl::*; pub use self::error::{AcceptorError, HttpDispatchError}; pub use self::settings::ServerSettings; +#[doc(hidden)] +pub mod h1disp; + #[doc(hidden)] pub use self::acceptor::AcceptorTimeout; diff --git a/src/server/output.rs b/src/server/output.rs index 70c24facc..1da7e9025 100644 --- a/src/server/output.rs +++ b/src/server/output.rs @@ -10,7 +10,7 @@ use bytes::BytesMut; use flate2::write::{GzEncoder, ZlibEncoder}; #[cfg(feature = "flate2")] use flate2::Compression; -use http::header::{ACCEPT_ENCODING, CONTENT_LENGTH}; +use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH}; use http::{StatusCode, Version}; use super::message::InnerRequest; @@ -18,6 +18,12 @@ use body::{Binary, Body}; use header::ContentEncoding; use httpresponse::HttpResponse; +// #[derive(Debug)] +// pub(crate) struct RequestInfo { +// pub version: Version, +// pub accept_encoding: Option, +// } + #[derive(Debug)] pub(crate) enum ResponseLength { Chunked, diff --git a/src/server/service.rs b/src/server/service.rs index e3402e305..a55c33f72 100644 --- a/src/server/service.rs +++ b/src/server/service.rs @@ -10,6 +10,7 @@ use super::error::HttpDispatchError; use super::handler::HttpHandler; use super::settings::ServiceConfig; use super::IoStream; +use error::Error; /// `NewService` implementation for HTTP1/HTTP2 transports pub struct HttpService @@ -42,7 +43,7 @@ where { type Request = Io; type Response = (); - type Error = HttpDispatchError; + type Error = HttpDispatchError; type InitError = (); type Service = HttpServiceHandler; type Future = FutureResult; @@ -81,7 +82,7 @@ where { type Request = Io; type Response = (); - type Error = HttpDispatchError; + type Error = HttpDispatchError; type Future = HttpChannel; fn poll_ready(&mut self) -> Poll<(), Self::Error> { @@ -124,7 +125,7 @@ where { type Request = Io; type Response = (); - type Error = HttpDispatchError; + type Error = HttpDispatchError; type InitError = (); type Service = H1ServiceHandler; type Future = FutureResult; @@ -164,7 +165,7 @@ where { type Request = Io; type Response = (); - type Error = HttpDispatchError; + type Error = HttpDispatchError; type Future = H1Channel; fn poll_ready(&mut self) -> Poll<(), Self::Error> { diff --git a/src/server/settings.rs b/src/server/settings.rs index 9b27ed5e5..9df1e457f 100644 --- a/src/server/settings.rs +++ b/src/server/settings.rs @@ -215,6 +215,11 @@ impl ServiceConfig { RequestPool::get(self.0.messages) } + #[doc(hidden)] + pub fn request_pool(&self) -> &'static RequestPool { + self.0.messages + } + fn update_date(&self) { // Unsafe: WorkerSetting is !Sync and !Send unsafe { (*self.0.date.get()).0 = false }; diff --git a/tests/test_h1v2.rs b/tests/test_h1v2.rs new file mode 100644 index 000000000..7e8e9a42c --- /dev/null +++ b/tests/test_h1v2.rs @@ -0,0 +1,58 @@ +extern crate actix; +extern crate actix_net; +extern crate actix_web; +extern crate futures; + +use std::thread; + +use actix::System; +use actix_net::server::Server; +use actix_net::service::{IntoNewService, IntoService}; +use futures::future; + +use actix_web::server::h1disp::Http1Dispatcher; +use actix_web::server::KeepAlive; +use actix_web::server::ServiceConfig; +use actix_web::{client, test, App, Error, HttpRequest, HttpResponse}; + +#[test] +fn test_h1_v2() { + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + let app = App::new() + .route("/", http::Method::GET, |_: HttpRequest| "OK") + .finish(); + let settings = ServiceConfig::build(app) + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_shutdown(1000) + .server_hostname("localhost") + .server_address(addr) + .finish(); + + (move |io| { + let pool = settings.request_pool(); + Http1Dispatcher::new( + io, + pool, + (|req| { + println!("REQ: {:?}", req); + future::ok::<_, Error>(HttpResponse::Ok().finish()) + }).into_service(), + ) + }).into_new_service() + }).unwrap() + .run(); + }); + + let mut sys = System::new("test"); + { + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + } +}