From e482b887412a0c77c6662fd7e5355cf3ad8144db Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 11 Jan 2018 21:48:36 -0800 Subject: [PATCH] refactor http protocol selection procedure --- examples/basics/Cargo.toml | 2 +- src/server/channel.rs | 99 +++++++++++++++------ src/server/h1.rs | 172 +++++++++---------------------------- src/server/h1writer.rs | 4 - src/server/mod.rs | 1 + src/server/utils.rs | 30 +++++++ 6 files changed, 142 insertions(+), 166 deletions(-) create mode 100644 src/server/utils.rs diff --git a/examples/basics/Cargo.toml b/examples/basics/Cargo.toml index 88b5f61e0..44b745392 100644 --- a/examples/basics/Cargo.toml +++ b/examples/basics/Cargo.toml @@ -8,4 +8,4 @@ workspace = "../.." futures = "*" env_logger = "0.4" actix = "0.4" -actix-web = { git = "https://github.com/actix/actix-web" } +actix-web = { path = "../../" } diff --git a/src/server/channel.rs b/src/server/channel.rs index 6ea14c45d..da4c613e2 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -2,29 +2,43 @@ use std::{ptr, mem, time, io}; use std::rc::Rc; use std::net::{SocketAddr, Shutdown}; -use bytes::{Bytes, Buf, BufMut}; +use bytes::{Bytes, BytesMut, Buf, BufMut}; use futures::{Future, Poll, Async}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_core::net::TcpStream; -use super::{h1, h2, HttpHandler, IoStream}; +use super::{h1, h2, utils, HttpHandler, IoStream}; use super::settings::WorkerSettings; +const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; + + enum HttpProtocol { H1(h1::Http1), H2(h2::Http2), + Unknown(Rc>, Option, T, BytesMut), +} +impl HttpProtocol { + fn is_unknown(&self) -> bool { + match *self { + HttpProtocol::Unknown(_, _, _, _) => true, + _ => false + } + } +} + +enum ProtocolKind { + Http1, + Http2, } #[doc(hidden)] -pub struct HttpChannel - where T: IoStream, H: HttpHandler + 'static -{ +pub struct HttpChannel where T: IoStream, H: HttpHandler + 'static { proto: Option>, node: Option>>, } -impl HttpChannel - where T: IoStream, H: HttpHandler + 'static +impl HttpChannel where T: IoStream, H: HttpHandler + 'static { pub(crate) fn new(settings: Rc>, io: T, peer: Option, http2: bool) -> HttpChannel @@ -38,7 +52,8 @@ impl HttpChannel } else { HttpChannel { node: None, - proto: Some(HttpProtocol::H1(h1::Http1::new(settings, io, peer))) } + proto: Some(HttpProtocol::Unknown( + settings, peer, io, BytesMut::with_capacity(4096))) } } } @@ -52,7 +67,7 @@ impl HttpChannel Some(HttpProtocol::H2(ref mut h2)) => { h2.shutdown() } - _ => unreachable!(), + _ => (), } } } @@ -64,28 +79,25 @@ impl Future for HttpChannel type Error = (); fn poll(&mut self) -> Poll { - if self.node.is_none() { + if !self.proto.as_ref().map(|p| p.is_unknown()).unwrap_or(false) && self.node.is_none() { self.node = Some(Node::new(self)); match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => { - h1.settings().head().insert(self.node.as_ref().unwrap()); - } - Some(HttpProtocol::H2(ref mut h2)) => { - h2.settings().head().insert(self.node.as_ref().unwrap()); - } - _ => unreachable!(), + Some(HttpProtocol::H1(ref mut h1)) => + h1.settings().head().insert(self.node.as_ref().unwrap()), + Some(HttpProtocol::H2(ref mut h2)) => + h2.settings().head().insert(self.node.as_ref().unwrap()), + _ => (), } } - match self.proto { + let kind = match self.proto { Some(HttpProtocol::H1(ref mut h1)) => { match h1.poll() { - Ok(Async::Ready(h1::Http1Result::Done)) => { + Ok(Async::Ready(())) => { h1.settings().remove_channel(); self.node.as_ref().unwrap().remove(); return Ok(Async::Ready(())) } - Ok(Async::Ready(h1::Http1Result::Switch)) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(_) => { @@ -94,7 +106,7 @@ impl Future for HttpChannel return Err(()) } } - } + }, Some(HttpProtocol::H2(ref mut h2)) => { let result = h2.poll(); match result { @@ -105,18 +117,49 @@ impl Future for HttpChannel _ => (), } return result - } + }, + Some(HttpProtocol::Unknown(_, _, ref mut io, ref mut buf)) => { + match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => { + debug!("Ignored premature client disconnection"); + return Err(()) + }, + Err(err) => { + debug!("Ignored premature client disconnection {}", err); + return Err(()) + } + _ => (), + } + + if buf.len() >= 14 { + if buf[..14] == HTTP2_PREFACE[..] { + ProtocolKind::Http2 + } else { + ProtocolKind::Http1 + } + } else { + return Ok(Async::NotReady); + } + }, None => unreachable!(), - } + }; // upgrade to h2 let proto = self.proto.take().unwrap(); match proto { - HttpProtocol::H1(h1) => { - let (h, io, addr, buf) = h1.into_inner(); - self.proto = Some( - HttpProtocol::H2(h2::Http2::new(h, io, addr, buf))); - self.poll() + HttpProtocol::Unknown(settings, addr, io, buf) => { + match kind { + ProtocolKind::Http1 => { + self.proto = Some( + HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf))); + self.poll() + }, + ProtocolKind::Http2 => { + self.proto = Some( + HttpProtocol::H2(h2::Http2::new(settings, io, addr, buf.freeze()))); + self.poll() + }, + } } _ => unreachable!() } diff --git a/src/server/h1.rs b/src/server/h1.rs index 0da6f1fc4..67ec26372 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -8,7 +8,7 @@ use actix::Arbiter; use httparse; use http::{Uri, Method, Version, HttpTryFrom, HeaderMap}; use http::header::{self, HeaderName, HeaderValue}; -use bytes::{Bytes, BytesMut, BufMut}; +use bytes::{Bytes, BytesMut}; use futures::{Future, Poll, Async}; use tokio_core::reactor::Timeout; @@ -18,24 +18,20 @@ use httprequest::HttpRequest; use error::{ParseError, PayloadError, ResponseError}; use payload::{Payload, PayloadWriter, DEFAULT_BUFFER_SIZE}; -use super::Writer; +use super::{utils, Writer}; use super::h1writer::H1Writer; use super::encoding::PayloadType; use super::settings::WorkerSettings; use super::{HttpHandler, HttpHandlerTask, IoStream}; -const LW_BUFFER_SIZE: usize = 4096; -const HW_BUFFER_SIZE: usize = 16_384; const MAX_BUFFER_SIZE: usize = 131_072; -const MAX_HEADERS: usize = 100; +const MAX_HEADERS: usize = 96; const MAX_PIPELINED_MESSAGES: usize = 16; -const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; bitflags! { struct Flags: u8 { const ERROR = 0b0000_0010; const KEEPALIVE = 0b0000_0100; - const H2 = 0b0000_1000; } } @@ -47,17 +43,6 @@ bitflags! { } } -pub(crate) enum Http1Result { - Done, - Switch, -} - -#[derive(Debug)] -enum Item { - Http1(HttpRequest), - Http2, -} - pub(crate) struct Http1 { flags: Flags, settings: Rc>, @@ -77,14 +62,16 @@ struct Entry { impl Http1 where T: IoStream, H: HttpHandler + 'static { - pub fn new(h: Rc>, stream: T, addr: Option) -> Self { + pub fn new(h: Rc>, stream: T, addr: Option, buf: BytesMut) + -> Self + { let bytes = h.get_shared_bytes(); Http1{ flags: Flags::KEEPALIVE, settings: h, addr: addr, stream: H1Writer::new(stream, bytes), reader: Reader::new(), - read_buf: BytesMut::new(), + read_buf: buf, tasks: VecDeque::new(), keepalive_timer: None } } @@ -93,10 +80,6 @@ impl Http1 self.settings.as_ref() } - pub fn into_inner(self) -> (Rc>, T, Option, Bytes) { - (self.settings, self.stream.into_inner(), self.addr, self.read_buf.freeze()) - } - pub(crate) fn io(&mut self) -> &mut T { self.stream.get_mut() } @@ -115,13 +98,13 @@ impl Http1 // TODO: refacrtor #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))] - pub fn poll(&mut self) -> Poll { + pub fn poll(&mut self) -> Poll<(), ()> { // keep-alive timer if self.keepalive_timer.is_some() { match self.keepalive_timer.as_mut().unwrap().poll() { Ok(Async::Ready(_)) => { trace!("Keep-alive timeout, close connection"); - return Ok(Async::Ready(Http1Result::Done)) + return Ok(Async::Ready(())) } Ok(Async::NotReady) => (), Err(_) => unreachable!(), @@ -209,27 +192,18 @@ impl Http1 // no keep-alive if !self.flags.contains(Flags::KEEPALIVE) && self.tasks.is_empty() { - let h2 = self.flags.contains(Flags::H2); - // check stream state - if !self.poll_completed(!h2)? { + if !self.poll_completed(true)? { return Ok(Async::NotReady) } - - if h2 { - return Ok(Async::Ready(Http1Result::Switch)) - } else { - return Ok(Async::Ready(Http1Result::Done)) - } + return Ok(Async::Ready(())) } // read incoming data - while !self.flags.contains(Flags::ERROR) && !self.flags.contains(Flags::H2) && - self.tasks.len() < MAX_PIPELINED_MESSAGES - { + while !self.flags.contains(Flags::ERROR) && self.tasks.len() < MAX_PIPELINED_MESSAGES { match self.reader.parse(self.stream.get_mut(), &mut self.read_buf, &self.settings) { - Ok(Async::Ready(Item::Http1(mut req))) => { + Ok(Async::Ready(mut req)) => { not_ready = false; // set remote addr @@ -254,9 +228,6 @@ impl Http1 Entry {pipe: pipe.unwrap_or_else(|| Pipeline::error(HTTPNotFound)), flags: EntryFlags::empty()}); } - Ok(Async::Ready(Item::Http2)) => { - self.flags.insert(Flags::H2); - } Err(ReaderError::Disconnect) => { not_ready = false; self.flags.insert(Flags::ERROR); @@ -309,7 +280,7 @@ impl Http1 return Ok(Async::NotReady) } // keep-alive disable, drop connection - return Ok(Async::Ready(Http1Result::Done)) + return Ok(Async::Ready(())) } } else if !self.poll_completed(false)? || self.flags.contains(Flags::KEEPALIVE) @@ -318,7 +289,7 @@ impl Http1 // if keep-alive unset, rely on operating system return Ok(Async::NotReady) } else { - return Ok(Async::Ready(Http1Result::Done)) + return Ok(Async::Ready(())) } } break @@ -328,17 +299,12 @@ impl Http1 // check for parse error if self.tasks.is_empty() { - let h2 = self.flags.contains(Flags::H2); - // check stream state - if !self.poll_completed(!h2)? { + if !self.poll_completed(true)? { return Ok(Async::NotReady) } - if h2 { - return Ok(Async::Ready(Http1Result::Switch)) - } if self.flags.contains(Flags::ERROR) || self.keepalive_timer.is_none() { - return Ok(Async::Ready(Http1Result::Done)) + return Ok(Async::Ready(())) } } @@ -351,7 +317,6 @@ impl Http1 } struct Reader { - h1: bool, payload: Option, } @@ -373,22 +338,14 @@ enum ReaderError { Error(ParseError), } -enum Message { - Http1(HttpRequest, Option), - Http2, - NotReady, -} - impl Reader { pub fn new() -> Reader { Reader { - h1: false, payload: None, } } - fn decode(&mut self, buf: &mut BytesMut) -> std::result::Result - { + fn decode(&mut self, buf: &mut BytesMut) -> std::result::Result { if let Some(ref mut payload) = self.payload { if payload.tx.capacity() > DEFAULT_BUFFER_SIZE { return Ok(Decoding::Paused) @@ -416,12 +373,12 @@ impl Reader { pub fn parse(&mut self, io: &mut T, buf: &mut BytesMut, - settings: &WorkerSettings) -> Poll + settings: &WorkerSettings) -> Poll where T: IoStream { // read payload if self.payload.is_some() { - match self.read_from_io(io, buf) { + match utils::read_from_io(io, buf) { Ok(Async::Ready(0)) => { if let Some(ref mut payload) = self.payload { payload.tx.set_error(PayloadError::Incomplete); @@ -446,7 +403,7 @@ impl Reader { // if buf is empty parse_message will always return NotReady, let's avoid that let read = if buf.is_empty() { - match self.read_from_io(io, buf) { + match utils::read_from_io(io, buf) { Ok(Async::Ready(0)) => { // debug!("Ignored premature client disconnection"); return Err(ReaderError::Disconnect); @@ -464,7 +421,7 @@ impl Reader { loop { match Reader::parse_message(buf, settings).map_err(ReaderError::Error)? { - Message::Http1(msg, decoder) => { + Async::Ready((msg, decoder)) => { // process payload if let Some(payload) = decoder { self.payload = Some(payload); @@ -473,22 +430,15 @@ impl Reader { Decoding::Ready => self.payload = None, } } - self.h1 = true; - return Ok(Async::Ready(Item::Http1(msg))); + return Ok(Async::Ready(msg)); }, - Message::Http2 => { - if self.h1 { - return Err(ReaderError::Error(ParseError::Version)) - } - return Ok(Async::Ready(Item::Http2)); - }, - Message::NotReady => { + Async::NotReady => { if buf.capacity() >= MAX_BUFFER_SIZE { error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(ReaderError::Error(ParseError::TooLarge)); } if read { - match self.read_from_io(io, buf) { + match utils::read_from_io(io, buf) { Ok(Async::Ready(0)) => { debug!("Ignored premature client disconnection"); return Err(ReaderError::Disconnect); @@ -507,39 +457,8 @@ impl Reader { } } - fn read_from_io(&mut self, io: &mut T, buf: &mut BytesMut) - -> Poll - { - unsafe { - if buf.remaining_mut() < LW_BUFFER_SIZE { - buf.reserve(HW_BUFFER_SIZE); - } - match io.read(buf.bytes_mut()) { - Ok(n) => { - buf.advance_mut(n); - Ok(Async::Ready(n)) - }, - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - Ok(Async::NotReady) - } else { - Err(e) - } - } - } - } - } - fn parse_message(buf: &mut BytesMut, settings: &WorkerSettings) - -> Result - { - if buf.is_empty() { - return Ok(Message::NotReady); - } - if buf.len() >= 14 && buf[..14] == HTTP2_PREFACE[..] { - return Ok(Message::Http2) - } - + -> Poll<(HttpRequest, Option), ParseError> { // Parse http message let msg = { let bytes_ptr = buf.as_ref().as_ptr() as usize; @@ -565,7 +484,7 @@ impl Reader { }; (len, method, path, version, req.headers.len()) } - httparse::Status::Partial => return Ok(Message::NotReady), + httparse::Status::Partial => return Ok(Async::NotReady), } }; @@ -625,9 +544,9 @@ impl Reader { decoder: decoder, }; msg.get_mut().payload = Some(payload); - Ok(Message::Http1(HttpRequest::from_message(msg), Some(info))) + Ok(Async::Ready((HttpRequest::from_message(msg), Some(info)))) } else { - Ok(Message::Http1(HttpRequest::from_message(msg), None)) + Ok(Async::Ready((HttpRequest::from_message(msg), None))) } } } @@ -977,7 +896,7 @@ mod tests { ($e:expr) => ({ let settings = WorkerSettings::::new(Vec::new(), None); match Reader::new().parse($e, &mut BytesMut::new(), &settings) { - Ok(Async::Ready(Item::Http1(req))) => req, + Ok(Async::Ready(req)) => req, Ok(_) => panic!("Eof during parsing http request"), Err(err) => panic!("Error during parsing http request: {:?}", err), } @@ -987,7 +906,7 @@ mod tests { macro_rules! reader_parse_ready { ($e:expr) => ( match $e { - Ok(Async::Ready(Item::Http1(req))) => req, + Ok(Async::Ready(req)) => req, Ok(_) => panic!("Eof during parsing http request"), Err(err) => panic!("Error during parsing http request: {:?}", err), } @@ -1019,7 +938,7 @@ mod tests { let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(req))) => { + Ok(Async::Ready(req)) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -1042,7 +961,7 @@ mod tests { buf.feed_data(".1\r\n\r\n"); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(req))) => { + Ok(Async::Ready(req)) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::PUT); assert_eq!(req.path(), "/test"); @@ -1059,7 +978,7 @@ mod tests { let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(req))) => { + Ok(Async::Ready(req)) => { assert_eq!(req.version(), Version::HTTP_10); assert_eq!(*req.method(), Method::POST); assert_eq!(req.path(), "/test2"); @@ -1076,7 +995,7 @@ mod tests { let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(mut req))) => { + Ok(Async::Ready(mut req)) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -1095,7 +1014,7 @@ mod tests { let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(mut req))) => { + Ok(Async::Ready(mut req)) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -1116,7 +1035,7 @@ mod tests { buf.feed_data("\r\n"); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(req))) => { + Ok(Async::Ready(req)) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -1142,7 +1061,7 @@ mod tests { buf.feed_data("t: value\r\n\r\n"); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(req))) => { + Ok(Async::Ready(req)) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -1163,7 +1082,7 @@ mod tests { let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http1(req))) => { + Ok(Async::Ready(req)) => { let val: Vec<_> = req.headers().get_all("Set-Cookie") .iter().map(|v| v.to_str().unwrap().to_owned()).collect(); assert_eq!(val[0], "c1=cookie1"); @@ -1512,17 +1431,4 @@ mod tests { Err(err) => panic!("{:?}", err), } }*/ - - #[test] - fn test_http2_prefix() { - let mut buf = Buffer::new("PRI * HTTP/2.0\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); - - let mut reader = Reader::new(); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(Item::Http2)) => (), - Ok(_) | Err(_) => panic!("Error during parsing http request"), - } - } } diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index 04c304d95..2a2601c5d 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -55,10 +55,6 @@ impl H1Writer { self.flags = Flags::empty(); } - pub fn into_inner(self) -> T { - self.stream - } - pub fn disconnected(&mut self) { self.encoder.get_mut().take(); } diff --git a/src/server/mod.rs b/src/server/mod.rs index 6f4b9ebe5..a62a04ba7 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -14,6 +14,7 @@ mod h2; mod h1writer; mod h2writer; mod settings; +mod utils; pub use self::srv::HttpServer; pub use self::settings::ServerSettings; diff --git a/src/server/utils.rs b/src/server/utils.rs new file mode 100644 index 000000000..79e0a11c5 --- /dev/null +++ b/src/server/utils.rs @@ -0,0 +1,30 @@ +use std::io; +use bytes::{BytesMut, BufMut}; +use futures::{Async, Poll}; + +use super::IoStream; + +const LW_BUFFER_SIZE: usize = 4096; +const HW_BUFFER_SIZE: usize = 16_384; + + +pub fn read_from_io(io: &mut T, buf: &mut BytesMut) -> Poll { + unsafe { + if buf.remaining_mut() < LW_BUFFER_SIZE { + buf.reserve(HW_BUFFER_SIZE); + } + match io.read(buf.bytes_mut()) { + Ok(n) => { + buf.advance_mut(n); + Ok(Async::Ready(n)) + }, + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + Ok(Async::NotReady) + } else { + Err(e) + } + } + } + } +}