From 625469f0f421fef4ee7266fd459c6efed31cbdb0 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 16 Nov 2018 19:28:07 -0800 Subject: [PATCH] refactor decoder --- src/client/pipeline.rs | 2 +- src/client/request.rs | 2 +- src/client/response.rs | 64 ++--- src/h1/client.rs | 21 +- src/h1/codec.rs | 13 +- src/h1/decoder.rs | 551 ++++++++++++++++++++--------------------- src/h1/encoder.rs | 3 +- src/h1/mod.rs | 1 - src/lib.rs | 1 + src/message.rs | 163 ++++++++++++ src/request.rs | 148 ++--------- src/test.rs | 2 +- src/uri.rs | 17 +- 13 files changed, 493 insertions(+), 495 deletions(-) create mode 100644 src/message.rs diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index 26bd1c62..17ba93e7 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -13,7 +13,7 @@ use super::{Connect, Connection}; use body::{BodyType, MessageBody, PayloadStream}; use error::PayloadError; use h1; -use request::RequestHead; +use message::RequestHead; pub(crate) fn send_request( head: RequestHead, diff --git a/src/client/request.rs b/src/client/request.rs index c4c7f2f6..d3d1544c 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -16,7 +16,7 @@ use http::{ uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method, Uri, Version, }; -use request::RequestHead; +use message::RequestHead; use super::response::ClientResponse; use super::{pipeline, Connect, Connection, ConnectorError, SendRequestError}; diff --git a/src/client/response.rs b/src/client/response.rs index 56e13fa4..797c88df 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -1,6 +1,5 @@ -use std::cell::{Cell, Ref, RefCell, RefMut}; +use std::cell::RefCell; use std::fmt; -use std::rc::Rc; use bytes::Bytes; use futures::{Async, Poll, Stream}; @@ -8,16 +7,14 @@ use http::{HeaderMap, StatusCode, Version}; use body::PayloadStream; use error::PayloadError; -use extensions::Extensions; use httpmessage::HttpMessage; -use request::{Message, MessageFlags, MessagePool, RequestHead}; -use uri::Url; +use message::{MessageFlags, ResponseHead}; use super::pipeline::Payload; /// Client Response pub struct ClientResponse { - pub(crate) inner: Rc, + pub(crate) head: ResponseHead, pub(crate) payload: RefCell>, } @@ -25,7 +22,7 @@ impl HttpMessage for ClientResponse { type Stream = PayloadStream; fn headers(&self) -> &HeaderMap { - &self.inner.head.headers + &self.head.headers } #[inline] @@ -41,75 +38,50 @@ impl HttpMessage for ClientResponse { impl ClientResponse { /// Create new Request instance pub fn new() -> ClientResponse { - ClientResponse::with_pool(MessagePool::pool()) - } - - /// Create new Request instance with pool - pub(crate) fn with_pool(pool: &'static MessagePool) -> ClientResponse { ClientResponse { - inner: Rc::new(Message { - pool, - head: RequestHead::default(), - status: StatusCode::OK, - url: Url::default(), - flags: Cell::new(MessageFlags::empty()), - payload: RefCell::new(None), - extensions: RefCell::new(Extensions::new()), - }), + head: ResponseHead::default(), payload: RefCell::new(None), } } #[inline] - pub(crate) fn inner(&self) -> &Message { - self.inner.as_ref() + pub(crate) fn head(&self) -> &ResponseHead { + &self.head } #[inline] - pub(crate) fn inner_mut(&mut self) -> &mut Message { - Rc::get_mut(&mut self.inner).expect("Multiple copies exist") + pub(crate) fn head_mut(&mut self) -> &mut ResponseHead { + &mut self.head } /// Read the Request Version. #[inline] pub fn version(&self) -> Version { - self.inner().head.version + self.head().version.clone().unwrap() } /// Get the status from the server. #[inline] pub fn status(&self) -> StatusCode { - self.inner().status + self.head().status } #[inline] /// Returns Request's headers. pub fn headers(&self) -> &HeaderMap { - &self.inner().head.headers + &self.head().headers } #[inline] /// Returns mutable Request's headers. pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.inner_mut().head.headers + &mut self.head_mut().headers } /// Checks if a connection should be kept alive. #[inline] pub fn keep_alive(&self) -> bool { - self.inner().flags.get().contains(MessageFlags::KEEPALIVE) - } - - /// Request extensions - #[inline] - pub fn extensions(&self) -> Ref { - self.inner().extensions.borrow() - } - - /// Mutable reference to a the request's extensions - #[inline] - pub fn extensions_mut(&self) -> RefMut { - self.inner().extensions.borrow_mut() + self.head().flags.contains(MessageFlags::KEEPALIVE) } } @@ -126,14 +98,6 @@ impl Stream for ClientResponse { } } -impl Drop for ClientResponse { - fn drop(&mut self) { - if Rc::strong_count(&self.inner) == 1 { - self.inner.pool.release(self.inner.clone()); - } - } -} - impl fmt::Debug for ClientResponse { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; diff --git a/src/h1/client.rs b/src/h1/client.rs index 81a6f568..8d4051d3 100644 --- a/src/h1/client.rs +++ b/src/h1/client.rs @@ -4,7 +4,7 @@ use std::io::{self, Write}; use bytes::{BufMut, Bytes, BytesMut}; use tokio_codec::{Decoder, Encoder}; -use super::decoder::{PayloadDecoder, PayloadItem, PayloadType, ResponseDecoder}; +use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType}; use super::encoder::{RequestEncoder, ResponseLength}; use super::{Message, MessageType}; use body::{Binary, Body, BodyType}; @@ -16,7 +16,7 @@ use http::header::{ HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, }; use http::{Method, Version}; -use request::{MessagePool, RequestHead}; +use message::{MessagePool, RequestHead}; bitflags! { struct Flags: u8 { @@ -42,7 +42,7 @@ pub struct ClientPayloadCodec { struct ClientCodecInner { config: ServiceConfig, - decoder: ResponseDecoder, + decoder: MessageDecoder, payload: Option, version: Version, @@ -63,11 +63,6 @@ impl ClientCodec { /// /// `keepalive_enabled` how response `connection` header get generated. pub fn new(config: ServiceConfig) -> Self { - ClientCodec::with_pool(MessagePool::pool(), config) - } - - /// Create HTTP/1 codec with request's pool - pub(crate) fn with_pool(pool: &'static MessagePool, config: ServiceConfig) -> Self { let flags = if config.keep_alive_enabled() { Flags::KEEPALIVE_ENABLED } else { @@ -76,7 +71,7 @@ impl ClientCodec { ClientCodec { inner: ClientCodecInner { config, - decoder: ResponseDecoder::with_pool(pool), + decoder: MessageDecoder::default(), payload: None, version: Version::HTTP_11, @@ -185,10 +180,10 @@ impl Decoder for ClientCodec { debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set"); if let Some((req, payload)) = self.inner.decoder.decode(src)? { - self.inner - .flags - .set(Flags::HEAD, req.inner.head.method == Method::HEAD); - self.inner.version = req.inner.head.version; + // self.inner + // .flags + // .set(Flags::HEAD, req.head.method == Method::HEAD); + // self.inner.version = req.head.version; if self.inner.flags.contains(Flags::KEEPALIVE_ENABLED) { self.inner.flags.set(Flags::KEEPALIVE, req.keep_alive()); } diff --git a/src/h1/codec.rs b/src/h1/codec.rs index 57e84898..c1c6091d 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -5,7 +5,7 @@ use std::io::{self, Write}; use bytes::{BufMut, Bytes, BytesMut}; use tokio_codec::{Decoder, Encoder}; -use super::decoder::{PayloadDecoder, PayloadItem, PayloadType, RequestDecoder}; +use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType}; use super::encoder::{ResponseEncoder, ResponseLength}; use super::{Message, MessageType}; use body::{Binary, Body}; @@ -14,7 +14,7 @@ use error::ParseError; use helpers; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::{Method, Version}; -use request::{MessagePool, Request}; +use request::Request; use response::Response; bitflags! { @@ -32,7 +32,7 @@ const AVERAGE_HEADER_SIZE: usize = 30; /// HTTP/1 Codec pub struct Codec { config: ServiceConfig, - decoder: RequestDecoder, + decoder: MessageDecoder, payload: Option, version: Version, @@ -59,11 +59,6 @@ impl Codec { /// /// `keepalive_enabled` how response `connection` header get generated. pub fn new(config: ServiceConfig) -> Self { - Codec::with_pool(MessagePool::pool(), config) - } - - /// Create HTTP/1 codec with request's pool - pub(crate) fn with_pool(pool: &'static MessagePool, config: ServiceConfig) -> Self { let flags = if config.keep_alive_enabled() { Flags::KEEPALIVE_ENABLED } else { @@ -71,7 +66,7 @@ impl Codec { }; Codec { config, - decoder: RequestDecoder::with_pool(pool), + decoder: MessageDecoder::default(), payload: None, version: Version::HTTP_11, diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index 75bd0f7c..fe2aa707 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -1,3 +1,4 @@ +use std::marker::PhantomData; use std::{io, mem}; use bytes::{Bytes, BytesMut}; @@ -8,327 +9,305 @@ use tokio_codec::Decoder; use client::ClientResponse; use error::ParseError; use http::header::{HeaderName, HeaderValue}; -use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version}; -use request::{MessageFlags, MessagePool, Request}; -use uri::Url; +use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version}; +use message::MessageFlags; +use request::Request; const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; -/// Client request decoder -pub struct RequestDecoder(&'static MessagePool); - -/// Server response decoder -pub struct ResponseDecoder(&'static MessagePool); +/// Incoming messagd decoder +pub(crate) struct MessageDecoder(PhantomData); /// Incoming request type -pub enum PayloadType { +pub(crate) enum PayloadType { None, Payload(PayloadDecoder), Stream(PayloadDecoder), } -impl RequestDecoder { - pub(crate) fn with_pool(pool: &'static MessagePool) -> RequestDecoder { - RequestDecoder(pool) +impl Default for MessageDecoder { + fn default() -> Self { + MessageDecoder(PhantomData) } } -impl Default for RequestDecoder { - fn default() -> RequestDecoder { - RequestDecoder::with_pool(MessagePool::pool()) - } -} - -impl Decoder for RequestDecoder { - type Item = (Request, PayloadType); +impl Decoder for MessageDecoder { + type Item = (T, PayloadType); type Error = ParseError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - // Parse http message + T::decode(src) + } +} + +pub(crate) enum PayloadLength { + None, + Chunked, + Upgrade, + Length(u64), +} + +pub(crate) trait MessageTypeDecoder: Sized { + fn keep_alive(&mut self); + + fn headers_mut(&mut self) -> &mut HeaderMap; + + fn decode(src: &mut BytesMut) -> Result, ParseError>; + + fn process_headers( + &mut self, + slice: &Bytes, + version: Version, + raw_headers: &[HeaderIndex], + ) -> Result { + let mut ka = version != Version::HTTP_10; let mut has_upgrade = false; let mut chunked = false; let mut content_length = None; - let msg = { - // Unsafe: we read only this data only after httparse parses headers into. - // performance bump for pipeline benchmarks. - let mut headers: [HeaderIndex; MAX_HEADERS] = + { + let headers = self.headers_mut(); + + for idx in raw_headers.iter() { + if let Ok(name) = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]) + { + // Unsafe: httparse check header value for valid utf-8 + let value = unsafe { + HeaderValue::from_shared_unchecked( + slice.slice(idx.value.0, idx.value.1), + ) + }; + match name { + header::CONTENT_LENGTH => { + if let Ok(s) = value.to_str() { + if let Ok(len) = s.parse::() { + content_length = Some(len); + } else { + debug!("illegal Content-Length: {:?}", s); + return Err(ParseError::Header); + } + } else { + debug!("illegal Content-Length: {:?}", value); + return Err(ParseError::Header); + } + } + // transfer-encoding + header::TRANSFER_ENCODING => { + if let Ok(s) = value.to_str() { + chunked = s.to_lowercase().contains("chunked"); + } else { + return Err(ParseError::Header); + } + } + // connection keep-alive state + header::CONNECTION => { + ka = if let Ok(conn) = value.to_str() { + if version == Version::HTTP_10 + && conn.contains("keep-alive") + { + true + } else { + version == Version::HTTP_11 && !(conn + .contains("close") + || conn.contains("upgrade")) + } + } else { + false + } + } + header::UPGRADE => { + has_upgrade = true; + // check content-length, some clients (dart) + // sends "content-length: 0" with websocket upgrade + if let Ok(val) = value.to_str() { + if val == "websocket" { + content_length = None; + } + } + } + _ => (), + } + + headers.append(name, value); + } else { + return Err(ParseError::Header); + } + } + } + + if ka { + self.keep_alive(); + } + + if chunked { + Ok(PayloadLength::Chunked) + } else if let Some(len) = content_length { + Ok(PayloadLength::Length(len)) + } else if has_upgrade { + Ok(PayloadLength::Upgrade) + } else { + Ok(PayloadLength::None) + } + } +} + +impl MessageTypeDecoder for Request { + fn keep_alive(&mut self) { + self.inner_mut().flags.set(MessageFlags::KEEPALIVE); + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + self.headers_mut() + } + + fn decode(src: &mut BytesMut) -> Result, ParseError> { + // Unsafe: we read only this data only after httparse parses headers into. + // performance bump for pipeline benchmarks. + let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; + + let (len, method, uri, version, headers_len) = { + let mut parsed: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; - let (len, method, path, version, headers_len) = { - let mut parsed: [httparse::Header; MAX_HEADERS] = - unsafe { mem::uninitialized() }; - - let mut req = httparse::Request::new(&mut parsed); - match req.parse(src)? { - httparse::Status::Complete(len) => { - let method = Method::from_bytes(req.method.unwrap().as_bytes()) - .map_err(|_| ParseError::Method)?; - let path = Url::new(Uri::try_from(req.path.unwrap())?); - let version = if req.version.unwrap() == 1 { - Version::HTTP_11 - } else { - Version::HTTP_10 - }; - HeaderIndex::record(src, req.headers, &mut headers); - - (len, method, path, version, req.headers.len()) - } - httparse::Status::Partial => return Ok(None), - } - }; - - let slice = src.split_to(len).freeze(); - - // convert headers - let mut msg = MessagePool::get_request(self.0); - { - let inner = msg.inner_mut(); - inner - .flags - .get_mut() - .set(MessageFlags::KEEPALIVE, version != Version::HTTP_10); - - for idx in headers[..headers_len].iter() { - if let Ok(name) = - HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]) - { - // Unsafe: httparse check header value for valid utf-8 - let value = unsafe { - HeaderValue::from_shared_unchecked( - slice.slice(idx.value.0, idx.value.1), - ) - }; - match name { - header::CONTENT_LENGTH => { - if let Ok(s) = value.to_str() { - if let Ok(len) = s.parse::() { - content_length = Some(len); - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header); - } - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header); - } - } - // transfer-encoding - header::TRANSFER_ENCODING => { - if let Ok(s) = value.to_str() { - chunked = s.to_lowercase().contains("chunked"); - } else { - return Err(ParseError::Header); - } - } - // connection keep-alive state - header::CONNECTION => { - let ka = if let Ok(conn) = value.to_str() { - if version == Version::HTTP_10 - && conn.contains("keep-alive") - { - true - } else { - version == Version::HTTP_11 && !(conn - .contains("close") - || conn.contains("upgrade")) - } - } else { - false - }; - inner.flags.get_mut().set(MessageFlags::KEEPALIVE, ka); - } - header::UPGRADE => { - has_upgrade = true; - // check content-length, some clients (dart) - // sends "content-length: 0" with websocket upgrade - if let Ok(val) = value.to_str() { - if val == "websocket" { - content_length = None; - } - } - } - _ => (), - } - - inner.head.headers.append(name, value); + let mut req = httparse::Request::new(&mut parsed); + match req.parse(src)? { + httparse::Status::Complete(len) => { + let method = Method::from_bytes(req.method.unwrap().as_bytes()) + .map_err(|_| ParseError::Method)?; + let uri = Uri::try_from(req.path.unwrap())?; + let version = if req.version.unwrap() == 1 { + Version::HTTP_11 } else { - return Err(ParseError::Header); - } - } + Version::HTTP_10 + }; + HeaderIndex::record(src, req.headers, &mut headers); - inner.url = path; - inner.head.method = method; - inner.head.version = version; + (len, method, uri, version, req.headers.len()) + } + httparse::Status::Partial => return Ok(None), } - msg }; + // convert headers + let mut msg = Request::new(); + + let len = msg.process_headers( + &src.split_to(len).freeze(), + version, + &headers[..headers_len], + )?; + // https://tools.ietf.org/html/rfc7230#section-3.3.3 - let decoder = if chunked { - // Chunked encoding - PayloadType::Payload(PayloadDecoder::chunked()) - } else if let Some(len) = content_length { - // Content-Length - PayloadType::Payload(PayloadDecoder::length(len)) - } else if has_upgrade || msg.inner.head.method == Method::CONNECT { - // upgrade(websocket) or connect - PayloadType::Stream(PayloadDecoder::eof()) - } else if src.len() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - return Err(ParseError::TooLarge); - } else { - PayloadType::None + let decoder = match len { + PayloadLength::Chunked => { + // Chunked encoding + PayloadType::Payload(PayloadDecoder::chunked()) + } + PayloadLength::Length(len) => { + // Content-Length + PayloadType::Payload(PayloadDecoder::length(len)) + } + PayloadLength::Upgrade => { + // upgrade(websocket) or connect + PayloadType::Stream(PayloadDecoder::eof()) + } + PayloadLength::None => { + if method == Method::CONNECT { + // upgrade(websocket) or connect + PayloadType::Stream(PayloadDecoder::eof()) + } else if src.len() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + PayloadType::None + } + } }; + { + let inner = msg.inner_mut(); + inner.url.update(&uri); + inner.head.uri = uri; + inner.head.method = method; + inner.head.version = version; + } + Ok(Some((msg, decoder))) } } -impl ResponseDecoder { - pub(crate) fn with_pool(pool: &'static MessagePool) -> ResponseDecoder { - ResponseDecoder(pool) +impl MessageTypeDecoder for ClientResponse { + fn keep_alive(&mut self) { + self.head.flags.insert(MessageFlags::KEEPALIVE); } -} -impl Default for ResponseDecoder { - fn default() -> ResponseDecoder { - ResponseDecoder::with_pool(MessagePool::pool()) + fn headers_mut(&mut self) -> &mut HeaderMap { + self.headers_mut() } -} -impl Decoder for ResponseDecoder { - type Item = (ClientResponse, PayloadType); - type Error = ParseError; + fn decode(src: &mut BytesMut) -> Result, ParseError> { + // Unsafe: we read only this data only after httparse parses headers into. + // performance bump for pipeline benchmarks. + let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - // Parse http message - let mut chunked = false; - let mut content_length = None; - - let msg = { - // Unsafe: we read only this data only after httparse parses headers into. - // performance bump for pipeline benchmarks. - let mut headers: [HeaderIndex; MAX_HEADERS] = + let (len, version, status, headers_len) = { + let mut parsed: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; - let (len, version, status, headers_len) = { - let mut parsed: [httparse::Header; MAX_HEADERS] = - unsafe { mem::uninitialized() }; - - let mut res = httparse::Response::new(&mut parsed); - match res.parse(src)? { - httparse::Status::Complete(len) => { - let version = if res.version.unwrap() == 1 { - Version::HTTP_11 - } else { - Version::HTTP_10 - }; - let status = StatusCode::from_u16(res.code.unwrap()) - .map_err(|_| ParseError::Status)?; - HeaderIndex::record(src, res.headers, &mut headers); - - (len, version, status, res.headers.len()) - } - httparse::Status::Partial => return Ok(None), - } - }; - - let slice = src.split_to(len).freeze(); - - // convert headers - let mut msg = MessagePool::get_response(self.0); - { - let inner = msg.inner_mut(); - inner - .flags - .get_mut() - .set(MessageFlags::KEEPALIVE, version != Version::HTTP_10); - - for idx in headers[..headers_len].iter() { - if let Ok(name) = - HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]) - { - // Unsafe: httparse check header value for valid utf-8 - let value = unsafe { - HeaderValue::from_shared_unchecked( - slice.slice(idx.value.0, idx.value.1), - ) - }; - match name { - header::CONTENT_LENGTH => { - if let Ok(s) = value.to_str() { - if let Ok(len) = s.parse::() { - content_length = Some(len); - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header); - } - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header); - } - } - // transfer-encoding - header::TRANSFER_ENCODING => { - if let Ok(s) = value.to_str() { - chunked = s.to_lowercase().contains("chunked"); - } else { - return Err(ParseError::Header); - } - } - // connection keep-alive state - header::CONNECTION => { - let ka = if let Ok(conn) = value.to_str() { - if version == Version::HTTP_10 - && conn.contains("keep-alive") - { - true - } else { - version == Version::HTTP_11 && !(conn - .contains("close") - || conn.contains("upgrade")) - } - } else { - false - }; - inner.flags.get_mut().set(MessageFlags::KEEPALIVE, ka); - } - _ => (), - } - - inner.head.headers.append(name, value); + let mut res = httparse::Response::new(&mut parsed); + match res.parse(src)? { + httparse::Status::Complete(len) => { + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 } else { - return Err(ParseError::Header); - } - } + Version::HTTP_10 + }; + let status = StatusCode::from_u16(res.code.unwrap()) + .map_err(|_| ParseError::Status)?; + HeaderIndex::record(src, res.headers, &mut headers); - inner.status = status; - inner.head.version = version; + (len, version, status, res.headers.len()) + } + httparse::Status::Partial => return Ok(None), } - msg }; + let mut msg = ClientResponse::new(); + + // convert headers + let len = msg.process_headers( + &src.split_to(len).freeze(), + version, + &headers[..headers_len], + )?; + // https://tools.ietf.org/html/rfc7230#section-3.3.3 - let decoder = if chunked { - // Chunked encoding - PayloadType::Payload(PayloadDecoder::chunked()) - } else if let Some(len) = content_length { - // Content-Length - PayloadType::Payload(PayloadDecoder::length(len)) - } else if msg.inner.status == StatusCode::SWITCHING_PROTOCOLS - || msg.inner.head.method == Method::CONNECT - { - // switching protocol or connect - PayloadType::Stream(PayloadDecoder::eof()) - } else if src.len() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - return Err(ParseError::TooLarge); - } else { - PayloadType::None + let decoder = match len { + PayloadLength::Chunked => { + // Chunked encoding + PayloadType::Payload(PayloadDecoder::chunked()) + } + PayloadLength::Length(len) => { + // Content-Length + PayloadType::Payload(PayloadDecoder::length(len)) + } + _ => { + if status == StatusCode::SWITCHING_PROTOCOLS { + // switching protocol or connect + PayloadType::Stream(PayloadDecoder::eof()) + } else if src.len() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + PayloadType::None + } + } }; + msg.head.status = status; + msg.head.version = Some(version); + Ok(Some((msg, decoder))) } } @@ -690,7 +669,7 @@ mod tests { macro_rules! parse_ready { ($e:expr) => {{ - match RequestDecoder::default().decode($e) { + match MessageDecoder::::default().decode($e) { Ok(Some((msg, _))) => msg, Ok(_) => unreachable!("Eof during parsing http request"), Err(err) => unreachable!("Error during parsing http request: {:?}", err), @@ -700,7 +679,7 @@ mod tests { macro_rules! expect_parse_err { ($e:expr) => {{ - match RequestDecoder::default().decode($e) { + match MessageDecoder::::default().decode($e) { Err(err) => match err { ParseError::Io(_) => unreachable!("Parse error expected"), _ => (), @@ -779,7 +758,7 @@ mod tests { fn test_parse() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); match reader.decode(&mut buf) { Ok(Some((req, _))) => { assert_eq!(req.version(), Version::HTTP_11); @@ -794,7 +773,7 @@ mod tests { fn test_parse_partial() { let mut buf = BytesMut::from("PUT /test HTTP/1"); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b".1\r\n\r\n"); @@ -808,7 +787,7 @@ mod tests { fn test_parse_post() { let mut buf = BytesMut::from("POST /test2 HTTP/1.0\r\n\r\n"); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(req.version(), Version::HTTP_10); assert_eq!(*req.method(), Method::POST); @@ -820,7 +799,7 @@ mod tests { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); let mut pl = pl.unwrap(); assert_eq!(req.version(), Version::HTTP_11); @@ -837,7 +816,7 @@ mod tests { let mut buf = BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); let mut pl = pl.unwrap(); assert_eq!(req.version(), Version::HTTP_11); @@ -852,7 +831,7 @@ mod tests { #[test] fn test_parse_partial_eof() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b"\r\n"); @@ -866,7 +845,7 @@ mod tests { fn test_headers_split_field() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); assert!{ reader.decode(&mut buf).unwrap().is_none() } buf.extend(b"t"); @@ -890,7 +869,7 @@ mod tests { Set-Cookie: c1=cookie1\r\n\ Set-Cookie: c2=cookie2\r\n\r\n", ); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); let val: Vec<_> = req @@ -1090,7 +1069,7 @@ mod tests { upgrade: websocket\r\n\r\n\ some raw data", ); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); assert!(!req.keep_alive()); assert!(req.upgrade()); @@ -1139,7 +1118,7 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); let mut pl = pl.unwrap(); assert!(req.chunked().unwrap()); @@ -1162,7 +1141,7 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); let mut pl = pl.unwrap(); assert!(req.chunked().unwrap()); @@ -1193,7 +1172,7 @@ mod tests { transfer-encoding: chunked\r\n\r\n", ); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); let mut pl = pl.unwrap(); assert!(req.chunked().unwrap()); @@ -1238,7 +1217,7 @@ mod tests { transfer-encoding: chunked\r\n\r\n"[..], ); - let mut reader = RequestDecoder::default(); + let mut reader = MessageDecoder::::default(); let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); let mut pl = pl.unwrap(); assert!(msg.chunked().unwrap()); diff --git a/src/h1/encoder.rs b/src/h1/encoder.rs index 7527eeea..de45351d 100644 --- a/src/h1/encoder.rs +++ b/src/h1/encoder.rs @@ -11,7 +11,8 @@ use http::{StatusCode, Version}; use body::{Binary, Body}; use header::ContentEncoding; use http::Method; -use request::{Request, RequestHead}; +use message::RequestHead; +use request::Request; use response::Response; #[derive(Debug)] diff --git a/src/h1/mod.rs b/src/h1/mod.rs index 81838700..395d6619 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -11,7 +11,6 @@ mod service; pub use self::client::{ClientCodec, ClientPayloadCodec}; pub use self::codec::Codec; -pub use self::decoder::{PayloadDecoder, RequestDecoder}; pub use self::dispatcher::Dispatcher; pub use self::service::{H1Service, H1ServiceHandler, OneRequest}; diff --git a/src/lib.rs b/src/lib.rs index 32369d16..f64876e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,6 +117,7 @@ mod header; mod httpcodes; mod httpmessage; mod json; +mod message; mod payload; mod request; mod response; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 00000000..b38c2f80 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,163 @@ +use std::cell::{Cell, RefCell}; +use std::collections::VecDeque; +use std::rc::Rc; + +use http::{HeaderMap, Method, StatusCode, Uri, Version}; + +use extensions::Extensions; +use payload::Payload; +use uri::Url; + +#[doc(hidden)] +pub trait Head: Default + 'static { + fn clear(&mut self); + + fn pool() -> &'static MessagePool; +} + +bitflags! { + pub(crate) struct MessageFlags: u8 { + const KEEPALIVE = 0b0000_0001; + } +} + +pub struct RequestHead { + pub uri: Uri, + pub method: Method, + pub version: Version, + pub headers: HeaderMap, + pub(crate) flags: MessageFlags, +} + +impl Default for RequestHead { + fn default() -> RequestHead { + RequestHead { + uri: Uri::default(), + method: Method::default(), + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + flags: MessageFlags::empty(), + } + } +} + +impl Head for RequestHead { + fn clear(&mut self) { + self.headers.clear(); + self.flags = MessageFlags::empty(); + } + + fn pool() -> &'static MessagePool { + REQUEST_POOL.with(|p| *p) + } +} + +pub struct ResponseHead { + pub version: Option, + pub status: StatusCode, + pub headers: HeaderMap, + pub reason: Option<&'static str>, + pub(crate) flags: MessageFlags, +} + +impl Default for ResponseHead { + fn default() -> ResponseHead { + ResponseHead { + version: None, + status: StatusCode::OK, + headers: HeaderMap::with_capacity(16), + reason: None, + flags: MessageFlags::empty(), + } + } +} + +impl Head for ResponseHead { + fn clear(&mut self) { + self.headers.clear(); + self.flags = MessageFlags::empty(); + } + + fn pool() -> &'static MessagePool { + RESPONSE_POOL.with(|p| *p) + } +} + +pub struct Message { + pub head: T, + pub url: Url, + pub status: StatusCode, + pub extensions: RefCell, + pub payload: RefCell>, + pub(crate) pool: &'static MessagePool, + pub(crate) flags: Cell, +} + +impl Message { + #[inline] + /// Reset request instance + pub fn reset(&mut self) { + self.head.clear(); + self.extensions.borrow_mut().clear(); + self.flags.set(MessageFlags::empty()); + *self.payload.borrow_mut() = None; + } +} + +impl Default for Message { + fn default() -> Self { + Message { + pool: T::pool(), + url: Url::default(), + head: T::default(), + status: StatusCode::OK, + flags: Cell::new(MessageFlags::empty()), + payload: RefCell::new(None), + extensions: RefCell::new(Extensions::new()), + } + } +} + +#[doc(hidden)] +/// Request's objects pool +pub struct MessagePool(RefCell>>>); + +thread_local!(static REQUEST_POOL: &'static MessagePool = MessagePool::::create()); +thread_local!(static RESPONSE_POOL: &'static MessagePool = MessagePool::::create()); + +impl MessagePool { + /// Get default request's pool + pub fn pool() -> &'static MessagePool { + REQUEST_POOL.with(|p| *p) + } + + /// Get Request object + #[inline] + pub fn get_message() -> Rc> { + REQUEST_POOL.with(|pool| { + if let Some(mut msg) = pool.0.borrow_mut().pop_front() { + if let Some(r) = Rc::get_mut(&mut msg) { + r.reset(); + } + return msg; + } + Rc::new(Message::default()) + }) + } +} + +impl MessagePool { + fn create() -> &'static MessagePool { + let pool = MessagePool(RefCell::new(VecDeque::with_capacity(128))); + Box::leak(Box::new(pool)) + } + + #[inline] + /// Release request instance + pub(crate) fn release(&self, msg: Rc>) { + let v = &mut self.0.borrow_mut(); + if v.len() < 128 { + v.push_front(msg); + } + } +} diff --git a/src/request.rs b/src/request.rs index edad4d50..1e191047 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,65 +1,18 @@ -use std::cell::{Cell, Ref, RefCell, RefMut}; -use std::collections::VecDeque; +use std::cell::{Ref, RefMut}; use std::fmt; use std::rc::Rc; -use http::{header, HeaderMap, Method, StatusCode, Uri, Version}; +use http::{header, HeaderMap, Method, Uri, Version}; -use client::ClientResponse; use extensions::Extensions; use httpmessage::HttpMessage; use payload::Payload; -use uri::Url; -bitflags! { - pub(crate) struct MessageFlags: u8 { - const KEEPALIVE = 0b0000_0001; - const CONN_INFO = 0b0000_0010; - } -} +use message::{Message, MessageFlags, MessagePool, RequestHead}; /// Request pub struct Request { - pub(crate) inner: Rc, -} - -pub struct RequestHead { - pub uri: Uri, - pub method: Method, - pub version: Version, - pub headers: HeaderMap, -} - -impl Default for RequestHead { - fn default() -> RequestHead { - RequestHead { - uri: Uri::default(), - method: Method::default(), - version: Version::HTTP_11, - headers: HeaderMap::with_capacity(16), - } - } -} - -pub struct Message { - pub head: RequestHead, - pub url: Url, - pub status: StatusCode, - pub extensions: RefCell, - pub payload: RefCell>, - pub(crate) pool: &'static MessagePool, - pub(crate) flags: Cell, -} - -impl Message { - #[inline] - /// Reset request instance - pub fn reset(&mut self) { - self.head.clear(); - self.extensions.borrow_mut().clear(); - self.flags.set(MessageFlags::empty()); - *self.payload.borrow_mut() = None; - } + pub(crate) inner: Rc>, } impl HttpMessage for Request { @@ -82,33 +35,35 @@ impl HttpMessage for Request { impl Request { /// Create new Request instance pub fn new() -> Request { - Request::with_pool(MessagePool::pool()) - } - - /// Create new Request instance with pool - pub(crate) fn with_pool(pool: &'static MessagePool) -> Request { Request { - inner: Rc::new(Message { - pool, - url: Url::default(), - head: RequestHead::default(), - status: StatusCode::OK, - flags: Cell::new(MessageFlags::empty()), - payload: RefCell::new(None), - extensions: RefCell::new(Extensions::new()), - }), + inner: MessagePool::get_message(), } } + // /// Create new Request instance with pool + // pub(crate) fn with_pool(pool: &'static MessagePool) -> Request { + // Request { + // inner: Rc::new(Message { + // pool, + // url: Url::default(), + // head: RequestHead::default(), + // status: StatusCode::OK, + // flags: Cell::new(MessageFlags::empty()), + // payload: RefCell::new(None), + // extensions: RefCell::new(Extensions::new()), + // }), + // } + // } + #[inline] #[doc(hidden)] - pub fn inner(&self) -> &Message { + pub fn inner(&self) -> &Message { self.inner.as_ref() } #[inline] #[doc(hidden)] - pub fn inner_mut(&mut self) -> &mut Message { + pub fn inner_mut(&mut self) -> &mut Message { Rc::get_mut(&mut self.inner).expect("Multiple copies exist") } @@ -139,7 +94,11 @@ impl Request { /// The target path of this Request. #[inline] pub fn path(&self) -> &str { - self.inner().url.path() + if let Some(path) = self.inner().url.path() { + path + } else { + self.inner().head.uri.path() + } } #[inline] @@ -219,56 +178,3 @@ impl fmt::Debug for Request { Ok(()) } } - -/// Request's objects pool -pub(crate) struct MessagePool(RefCell>>); - -thread_local!(static POOL: &'static MessagePool = MessagePool::create()); - -impl MessagePool { - fn create() -> &'static MessagePool { - let pool = MessagePool(RefCell::new(VecDeque::with_capacity(128))); - Box::leak(Box::new(pool)) - } - - /// Get default request's pool - pub fn pool() -> &'static MessagePool { - POOL.with(|p| *p) - } - - /// Get Request object - #[inline] - pub fn get_request(pool: &'static MessagePool) -> Request { - if let Some(mut msg) = pool.0.borrow_mut().pop_front() { - if let Some(r) = Rc::get_mut(&mut msg) { - r.reset(); - } - return Request { inner: msg }; - } - Request::with_pool(pool) - } - - /// Get Client Response object - #[inline] - pub fn get_response(pool: &'static MessagePool) -> ClientResponse { - if let Some(mut msg) = pool.0.borrow_mut().pop_front() { - if let Some(r) = Rc::get_mut(&mut msg) { - r.reset(); - } - return ClientResponse { - inner: msg, - payload: RefCell::new(None), - }; - } - ClientResponse::with_pool(pool) - } - - #[inline] - /// Release request instance - pub(crate) fn release(&self, msg: Rc) { - let v = &mut self.0.borrow_mut(); - if v.len() < 128 { - v.push_front(msg); - } - } -} diff --git a/src/test.rs b/src/test.rs index b0b8dc83..43974934 100644 --- a/src/test.rs +++ b/src/test.rs @@ -392,7 +392,7 @@ impl TestRequest { { let inner = req.inner_mut(); inner.head.method = method; - inner.url = InnerUrl::new(uri); + inner.url = InnerUrl::new(&uri); inner.head.version = version; inner.head.headers = headers; *inner.payload.borrow_mut() = payload; diff --git a/src/uri.rs b/src/uri.rs index 6edd220c..89f6d3b1 100644 --- a/src/uri.rs +++ b/src/uri.rs @@ -37,27 +37,22 @@ lazy_static! { #[derive(Default, Clone, Debug)] pub struct Url { - uri: Uri, path: Option>, } impl Url { - pub fn new(uri: Uri) -> Url { + pub fn new(uri: &Uri) -> Url { let path = DEFAULT_QUOTER.requote(uri.path().as_bytes()); - Url { uri, path } + Url { path } } - pub fn uri(&self) -> &Uri { - &self.uri + pub(crate) fn update(&mut self, uri: &Uri) { + self.path = DEFAULT_QUOTER.requote(uri.path().as_bytes()); } - pub fn path(&self) -> &str { - if let Some(ref s) = self.path { - s - } else { - self.uri.path() - } + pub fn path(&self) -> Option<&str> { + self.path.as_ref().map(|s| s.as_str()) } }