diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index 63551cffa..e1d8421e9 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -49,7 +49,10 @@ where .and_then(|(item, framed)| { if let Some(res) = item { match framed.get_codec().message_type() { - h1::MessageType::None => release_connection(framed), + h1::MessageType::None => { + let force_close = !framed.get_codec().keepalive(); + release_connection(framed, force_close) + } _ => { *res.payload.borrow_mut() = Some(Payload::stream(framed)) } @@ -174,7 +177,9 @@ impl Stream for Payload { Async::Ready(Some(chunk)) => if let Some(chunk) = chunk { Ok(Async::Ready(Some(chunk))) } else { - release_connection(self.framed.take().unwrap()); + let framed = self.framed.take().unwrap(); + let force_close = framed.get_codec().keepalive(); + release_connection(framed, force_close); Ok(Async::Ready(None)) }, Async::Ready(None) => Ok(Async::Ready(None)), @@ -182,12 +187,12 @@ impl Stream for Payload { } } -fn release_connection(framed: Framed) +fn release_connection(framed: Framed, force_close: bool) where T: Connection, { let mut parts = framed.into_parts(); - if parts.read_buf.is_empty() && parts.write_buf.is_empty() { + if !force_close && parts.read_buf.is_empty() && parts.write_buf.is_empty() { parts.io.release() } else { parts.io.close() diff --git a/src/h1/client.rs b/src/h1/client.rs index e2d1eefe6..7704ba97a 100644 --- a/src/h1/client.rs +++ b/src/h1/client.rs @@ -4,8 +4,8 @@ use std::io::{self, Write}; use bytes::{BufMut, Bytes, BytesMut}; use tokio_codec::{Decoder, Encoder}; -use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType}; -use super::encoder::RequestEncoder; +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; +use super::{decoder, encoder}; use super::{Message, MessageType}; use body::BodyLength; use client::ClientResponse; @@ -16,13 +16,11 @@ use http::header::{ HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, }; use http::{Method, Version}; -use message::{Head, MessagePool, RequestHead}; +use message::{ConnectionType, Head, MessagePool, RequestHead}; bitflags! { struct Flags: u8 { const HEAD = 0b0000_0001; - const UPGRADE = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; const KEEPALIVE_ENABLED = 0b0000_1000; const STREAM = 0b0001_0000; } @@ -42,14 +40,15 @@ pub struct ClientPayloadCodec { struct ClientCodecInner { config: ServiceConfig, - decoder: MessageDecoder, + decoder: decoder::MessageDecoder, payload: Option, version: Version, + ctype: ConnectionType, // encoder part flags: Flags, headers_size: u32, - te: RequestEncoder, + encoder: encoder::MessageEncoder, } impl Default for ClientCodec { @@ -71,25 +70,26 @@ impl ClientCodec { ClientCodec { inner: ClientCodecInner { config, - decoder: MessageDecoder::default(), + decoder: decoder::MessageDecoder::default(), payload: None, version: Version::HTTP_11, + ctype: ConnectionType::Close, flags, headers_size: 0, - te: RequestEncoder::default(), + encoder: encoder::MessageEncoder::default(), }, } } /// Check if request is upgrade pub fn upgrade(&self) -> bool { - self.inner.flags.contains(Flags::UPGRADE) + self.inner.ctype == ConnectionType::Upgrade } /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { - self.inner.flags.contains(Flags::KEEPALIVE) + self.inner.ctype == ConnectionType::KeepAlive } /// Check last request's message type @@ -103,15 +103,6 @@ impl ClientCodec { } } - /// prepare transfer encoding - pub fn prepare_te(&mut self, head: &mut RequestHead, length: BodyLength) { - self.inner.te.update( - head, - self.inner.flags.contains(Flags::HEAD), - self.inner.version, - ); - } - /// Convert message codec to a payload codec pub fn into_payload_codec(self) -> ClientPayloadCodec { ClientPayloadCodec { inner: self.inner } @@ -119,96 +110,17 @@ impl ClientCodec { } impl ClientPayloadCodec { + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.inner.ctype == ConnectionType::KeepAlive + } + /// Transform payload codec to a message codec pub fn into_message_codec(self) -> ClientCodec { ClientCodec { inner: self.inner } } } -fn prn_version(ver: Version) -> &'static str { - match ver { - Version::HTTP_09 => "HTTP/0.9", - Version::HTTP_10 => "HTTP/1.0", - Version::HTTP_11 => "HTTP/1.1", - Version::HTTP_2 => "HTTP/2.0", - } -} - -impl ClientCodecInner { - fn encode_request( - &mut self, - msg: RequestHead, - length: BodyLength, - buffer: &mut BytesMut, - ) -> io::Result<()> { - // render message - { - // status line - write!( - Writer(buffer), - "{} {} {}\r\n", - msg.method, - msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), - prn_version(msg.version) - ).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - - // write headers - buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE); - - // content length - match length { - BodyLength::Sized(len) => helpers::write_content_length(len, buffer), - BodyLength::Sized64(len) => { - buffer.extend_from_slice(b"content-length: "); - write!(buffer.writer(), "{}", len)?; - buffer.extend_from_slice(b"\r\n"); - } - BodyLength::Chunked => { - buffer.extend_from_slice(b"transfer-encoding: chunked\r\n") - } - BodyLength::Empty => buffer.extend_from_slice(b"content-length: 0\r\n"), - BodyLength::None | BodyLength::Stream => (), - } - - let mut has_date = false; - - for (key, value) in &msg.headers { - match *key { - TRANSFER_ENCODING | CONNECTION | CONTENT_LENGTH => continue, - DATE => has_date = true, - _ => (), - } - - buffer.put_slice(key.as_ref()); - buffer.put_slice(b": "); - buffer.put_slice(value.as_ref()); - buffer.put_slice(b"\r\n"); - } - - // Connection header - if msg.upgrade() { - self.flags.set(Flags::UPGRADE, msg.upgrade()); - buffer.extend_from_slice(b"connection: upgrade\r\n"); - } else if msg.keep_alive() { - if self.version < Version::HTTP_11 { - buffer.extend_from_slice(b"connection: keep-alive\r\n"); - } - } else if self.version >= Version::HTTP_11 { - buffer.extend_from_slice(b"connection: close\r\n"); - } - - // Date header - if !has_date { - self.config.set_date(buffer); - } else { - buffer.extend_from_slice(b"\r\n"); - } - } - - Ok(()) - } -} - impl Decoder for ClientCodec { type Item = ClientResponse; type Error = ParseError; @@ -217,21 +129,27 @@ impl Decoder for ClientCodec { debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set"); if let Some((req, payload)) = self.inner.decoder.decode(src)? { - // self.inner - // .flags - // .set(Flags::HEAD, req.head.method == Method::HEAD); - // self.inner.version = req.head.version; - if self.inner.flags.contains(Flags::KEEPALIVE_ENABLED) { - self.inner.flags.set(Flags::KEEPALIVE, req.keep_alive()); + if let Some(ctype) = req.head().ctype { + // do not use peer's keep-alive + self.inner.ctype = if ctype == ConnectionType::KeepAlive { + self.inner.ctype + } else { + ctype + }; } - match payload { - PayloadType::None => self.inner.payload = None, - PayloadType::Payload(pl) => self.inner.payload = Some(pl), - PayloadType::Stream(pl) => { - self.inner.payload = Some(pl); - self.inner.flags.insert(Flags::STREAM); + + if !self.inner.flags.contains(Flags::HEAD) { + match payload { + PayloadType::None => self.inner.payload = None, + PayloadType::Payload(pl) => self.inner.payload = Some(pl), + PayloadType::Stream(pl) => { + self.inner.payload = Some(pl); + self.inner.flags.insert(Flags::STREAM); + } } - }; + } else { + self.inner.payload = None; + } Ok(Some(req)) } else { Ok(None) @@ -270,14 +188,39 @@ impl Encoder for ClientCodec { dst: &mut BytesMut, ) -> Result<(), Self::Error> { match item { - Message::Item((msg, btype)) => { - self.inner.encode_request(msg, btype, dst)?; + Message::Item((mut msg, length)) => { + let inner = &mut self.inner; + inner.version = msg.version; + inner.flags.set(Flags::HEAD, msg.method == Method::HEAD); + + // connection status + inner.ctype = match msg.connection_type() { + ConnectionType::KeepAlive => { + if inner.flags.contains(Flags::KEEPALIVE_ENABLED) { + ConnectionType::KeepAlive + } else { + ConnectionType::Close + } + } + ConnectionType::Upgrade => ConnectionType::Upgrade, + ConnectionType::Close => ConnectionType::Close, + }; + + inner.encoder.encode( + dst, + &mut msg, + false, + inner.version, + length, + inner.ctype, + &inner.config, + )?; } Message::Chunk(Some(bytes)) => { - self.inner.te.encode(bytes.as_ref(), dst)?; + self.inner.encoder.encode_chunk(bytes.as_ref(), dst)?; } Message::Chunk(None) => { - self.inner.te.encode_eof(dst)?; + self.inner.encoder.encode_eof(dst)?; } } Ok(()) diff --git a/src/h1/codec.rs b/src/h1/codec.rs index 117c8cde1..f9f455e5d 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -5,8 +5,8 @@ use std::io::{self, Write}; use bytes::{BufMut, Bytes, BytesMut}; use tokio_codec::{Decoder, Encoder}; -use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType}; -use super::encoder::ResponseEncoder; +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; +use super::{decoder, encoder}; use super::{Message, MessageType}; use body::BodyLength; use config::ServiceConfig; @@ -14,15 +14,13 @@ use error::ParseError; use helpers; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::{Method, StatusCode, Version}; -use message::{Head, ResponseHead}; +use message::{ConnectionType, Head, ResponseHead}; use request::Request; use response::Response; bitflags! { struct Flags: u8 { const HEAD = 0b0000_0001; - const UPGRADE = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; const KEEPALIVE_ENABLED = 0b0000_1000; const STREAM = 0b0001_0000; } @@ -33,14 +31,15 @@ const AVERAGE_HEADER_SIZE: usize = 30; /// HTTP/1 Codec pub struct Codec { config: ServiceConfig, - decoder: MessageDecoder, + decoder: decoder::MessageDecoder, payload: Option, version: Version, + ctype: ConnectionType, // encoder part flags: Flags, headers_size: u32, - te: ResponseEncoder, + encoder: encoder::MessageEncoder>, } impl Default for Codec { @@ -67,24 +66,25 @@ impl Codec { }; Codec { config, - decoder: MessageDecoder::default(), + decoder: decoder::MessageDecoder::default(), payload: None, version: Version::HTTP_11, + ctype: ConnectionType::Close, flags, headers_size: 0, - te: ResponseEncoder::default(), + encoder: encoder::MessageEncoder::default(), } } /// Check if request is upgrade pub fn upgrade(&self) -> bool { - self.flags.contains(Flags::UPGRADE) + self.ctype == ConnectionType::Upgrade } /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { - self.flags.contains(Flags::KEEPALIVE) + self.ctype == ConnectionType::KeepAlive } /// Check last request's message type @@ -97,130 +97,6 @@ impl Codec { MessageType::Payload } } - - /// prepare transfer encoding - fn prepare_te(&mut self, head: &mut ResponseHead, length: BodyLength) { - self.te - .update(head, self.flags.contains(Flags::HEAD), self.version, length); - } - - fn encode_response( - &mut self, - msg: &mut ResponseHead, - length: BodyLength, - buffer: &mut BytesMut, - ) -> io::Result<()> { - msg.version = self.version; - - // Connection upgrade - if msg.upgrade() { - self.flags.insert(Flags::UPGRADE); - self.flags.remove(Flags::KEEPALIVE); - msg.headers - .insert(CONNECTION, HeaderValue::from_static("upgrade")); - } - // keep-alive - else if self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg.keep_alive() { - self.flags.insert(Flags::KEEPALIVE); - if self.version < Version::HTTP_11 { - msg.headers - .insert(CONNECTION, HeaderValue::from_static("keep-alive")); - } - } else if self.version >= Version::HTTP_11 { - self.flags.remove(Flags::KEEPALIVE); - msg.headers - .insert(CONNECTION, HeaderValue::from_static("close")); - } - - // render message - { - let reason = msg.reason().as_bytes(); - buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); - - // status line - helpers::write_status_line(self.version, msg.status.as_u16(), buffer); - buffer.extend_from_slice(reason); - - // content length - match msg.status { - StatusCode::NO_CONTENT - | StatusCode::CONTINUE - | StatusCode::SWITCHING_PROTOCOLS - | StatusCode::PROCESSING => buffer.extend_from_slice(b"\r\n"), - _ => match length { - BodyLength::Chunked => { - buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") - } - BodyLength::Empty => { - buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n"); - } - BodyLength::Sized(len) => helpers::write_content_length(len, buffer), - BodyLength::Sized64(len) => { - buffer.extend_from_slice(b"\r\ncontent-length: "); - write!(buffer.writer(), "{}", len)?; - buffer.extend_from_slice(b"\r\n"); - } - BodyLength::None | BodyLength::Stream => { - buffer.extend_from_slice(b"\r\n") - } - }, - } - - // write headers - let mut pos = 0; - let mut has_date = false; - let mut remaining = buffer.remaining_mut(); - let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) }; - for (key, value) in &msg.headers { - match *key { - TRANSFER_ENCODING | CONTENT_LENGTH => continue, - DATE => { - has_date = true; - } - _ => (), - } - - let v = value.as_ref(); - let k = key.as_str().as_bytes(); - let len = k.len() + v.len() + 4; - if len > remaining { - unsafe { - buffer.advance_mut(pos); - } - pos = 0; - buffer.reserve(len); - remaining = buffer.remaining_mut(); - unsafe { - buf = &mut *(buffer.bytes_mut() as *mut _); - } - } - - buf[pos..pos + k.len()].copy_from_slice(k); - pos += k.len(); - buf[pos..pos + 2].copy_from_slice(b": "); - pos += 2; - buf[pos..pos + v.len()].copy_from_slice(v); - pos += v.len(); - buf[pos..pos + 2].copy_from_slice(b"\r\n"); - pos += 2; - remaining -= len; - } - unsafe { - buffer.advance_mut(pos); - } - - // optimized date header, set_date writes \r\n - if !has_date { - self.config.set_date(buffer); - } else { - // msg eof - buffer.extend_from_slice(b"\r\n"); - } - self.headers_size = buffer.len() as u32; - } - - Ok(()) - } } impl Decoder for Codec { @@ -240,9 +116,12 @@ impl Decoder for Codec { } else if let Some((req, payload)) = self.decoder.decode(src)? { self.flags .set(Flags::HEAD, req.inner.head.method == Method::HEAD); - self.version = req.inner.head.version; - if self.flags.contains(Flags::KEEPALIVE_ENABLED) { - self.flags.set(Flags::KEEPALIVE, req.keep_alive()); + self.version = req.inner().head.version; + self.ctype = req.inner().head.connection_type(); + if self.ctype == ConnectionType::KeepAlive + && !self.flags.contains(Flags::KEEPALIVE_ENABLED) + { + self.ctype = ConnectionType::Close } match payload { PayloadType::None => self.payload = None, @@ -270,14 +149,35 @@ impl Encoder for Codec { ) -> Result<(), Self::Error> { match item { Message::Item((mut res, length)) => { - self.prepare_te(res.head_mut(), length); - self.encode_response(res.head_mut(), length, dst)?; + // connection status + self.ctype = if let Some(ct) = res.head().ctype { + if ct == ConnectionType::KeepAlive { + self.ctype + } else { + ct + } + } else { + self.ctype + }; + + // encode message + let len = dst.len(); + self.encoder.encode( + dst, + &mut res, + self.flags.contains(Flags::HEAD), + self.version, + length, + self.ctype, + &self.config, + )?; + self.headers_size = (dst.len() - len) as u32; } Message::Chunk(Some(bytes)) => { - self.te.encode(bytes.as_ref(), dst)?; + self.encoder.encode_chunk(bytes.as_ref(), dst)?; } Message::Chunk(None) => { - self.te.encode_eof(dst)?; + self.encoder.encode_eof(dst)?; } } Ok(()) diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index de0df367f..a081a5cf3 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -10,15 +10,16 @@ use client::ClientResponse; use error::ParseError; use http::header::{HeaderName, HeaderValue}; use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version}; -use message::{ConnectionType, Head}; +use message::ConnectionType; use request::Request; const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; /// Incoming messagd decoder -pub(crate) struct MessageDecoder(PhantomData); +pub(crate) struct MessageDecoder(PhantomData); +#[derive(Debug)] /// Incoming request type pub(crate) enum PayloadType { None, @@ -26,13 +27,13 @@ pub(crate) enum PayloadType { Stream(PayloadDecoder), } -impl Default for MessageDecoder { +impl Default for MessageDecoder { fn default() -> Self { MessageDecoder(PhantomData) } } -impl Decoder for MessageDecoder { +impl Decoder for MessageDecoder { type Item = (T, PayloadType); type Error = ParseError; @@ -47,10 +48,8 @@ pub(crate) enum PayloadLength { None, } -pub(crate) trait MessageTypeDecoder: Sized { - fn keep_alive(&mut self); - - fn force_close(&mut self); +pub(crate) trait MessageType: Sized { + fn set_connection_type(&mut self, ctype: Option); fn headers_mut(&mut self) -> &mut HeaderMap; @@ -59,10 +58,9 @@ pub(crate) trait MessageTypeDecoder: Sized { fn set_headers( &mut self, slice: &Bytes, - version: Version, raw_headers: &[HeaderIndex], ) -> Result { - let mut ka = version != Version::HTTP_10; + let mut ka = None; let mut has_upgrade = false; let mut chunked = false; let mut content_length = None; @@ -104,18 +102,18 @@ pub(crate) trait MessageTypeDecoder: Sized { // connection keep-alive state header::CONNECTION => { ka = if let Ok(conn) = value.to_str() { - if version == Version::HTTP_10 - && conn.contains("keep-alive") - { - true + if conn.contains("keep-alive") { + Some(ConnectionType::KeepAlive) + } else if conn.contains("close") { + Some(ConnectionType::Close) + } else if conn.contains("upgrade") { + Some(ConnectionType::Upgrade) } else { - version == Version::HTTP_11 && !(conn - .contains("close") - || conn.contains("upgrade")) + None } } else { - false - } + None + }; } header::UPGRADE => { has_upgrade = true; @@ -136,12 +134,7 @@ pub(crate) trait MessageTypeDecoder: Sized { } } } - - if ka { - self.keep_alive(); - } else { - self.force_close(); - } + self.set_connection_type(ka); // https://tools.ietf.org/html/rfc7230#section-3.3.3 if chunked { @@ -162,17 +155,9 @@ pub(crate) trait MessageTypeDecoder: Sized { } } -impl MessageTypeDecoder for Request { - fn keep_alive(&mut self) { - self.inner_mut() - .head - .set_connection_type(ConnectionType::KeepAlive) - } - - fn force_close(&mut self) { - self.inner_mut() - .head - .set_connection_type(ConnectionType::Close) +impl MessageType for Request { + fn set_connection_type(&mut self, ctype: Option) { + self.inner_mut().head.ctype = ctype; } fn headers_mut(&mut self) -> &mut HeaderMap { @@ -210,8 +195,7 @@ impl MessageTypeDecoder for Request { let mut msg = Request::new(); // convert headers - let len = - msg.set_headers(&src.split_to(len).freeze(), ver, &headers[..h_len])?; + let len = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?; // payload decoder let decoder = match len { @@ -243,13 +227,9 @@ impl MessageTypeDecoder for Request { } } -impl MessageTypeDecoder for ClientResponse { - fn keep_alive(&mut self) { - self.head.set_connection_type(ConnectionType::KeepAlive); - } - - fn force_close(&mut self) { - self.head.set_connection_type(ConnectionType::Close); +impl MessageType for ClientResponse { + fn set_connection_type(&mut self, ctype: Option) { + self.head.ctype = ctype; } fn headers_mut(&mut self) -> &mut HeaderMap { @@ -286,8 +266,7 @@ impl MessageTypeDecoder for ClientResponse { let mut msg = ClientResponse::new(); // convert headers - let len = - msg.set_headers(&src.split_to(len).freeze(), ver, &headers[..h_len])?; + let len = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?; // message payload let decoder = if let PayloadLength::Payload(pl) = len { @@ -634,6 +613,7 @@ mod tests { use super::*; use error::ParseError; use httpmessage::HttpMessage; + use message::Head; impl PayloadType { fn unwrap(self) -> PayloadDecoder { @@ -886,7 +866,7 @@ mod tests { let mut buf = BytesMut::from("GET /test HTTP/1.0\r\n\r\n"); let req = parse_ready!(&mut buf); - assert!(!req.keep_alive()); + assert_eq!(req.head().connection_type(), ConnectionType::Close); } #[test] @@ -894,7 +874,7 @@ mod tests { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); let req = parse_ready!(&mut buf); - assert!(req.keep_alive()); + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); } #[test] @@ -905,7 +885,7 @@ mod tests { ); let req = parse_ready!(&mut buf); - assert!(!req.keep_alive()); + assert_eq!(req.inner().head.ctype, Some(ConnectionType::Close)); } #[test] @@ -916,7 +896,7 @@ mod tests { ); let req = parse_ready!(&mut buf); - assert!(!req.keep_alive()); + assert_eq!(req.inner().head.ctype, Some(ConnectionType::Close)); } #[test] @@ -927,7 +907,7 @@ mod tests { ); let req = parse_ready!(&mut buf); - assert!(req.keep_alive()); + assert_eq!(req.inner().head.ctype, Some(ConnectionType::KeepAlive)); } #[test] @@ -938,7 +918,7 @@ mod tests { ); let req = parse_ready!(&mut buf); - assert!(req.keep_alive()); + assert_eq!(req.inner().head.ctype, Some(ConnectionType::KeepAlive)); } #[test] @@ -949,7 +929,7 @@ mod tests { ); let req = parse_ready!(&mut buf); - assert!(!req.keep_alive()); + assert_eq!(req.inner().head.connection_type(), ConnectionType::Close); } #[test] @@ -960,7 +940,11 @@ mod tests { ); let req = parse_ready!(&mut buf); - assert!(req.keep_alive()); + assert_eq!(req.inner().head.ctype, None); + assert_eq!( + req.inner().head.connection_type(), + ConnectionType::KeepAlive + ); } #[test] @@ -973,6 +957,7 @@ mod tests { let req = parse_ready!(&mut buf); assert!(req.upgrade()); + assert_eq!(req.inner().head.ctype, Some(ConnectionType::Upgrade)); } #[test] @@ -1070,7 +1055,7 @@ mod tests { ); let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - assert!(!req.keep_alive()); + assert_eq!(req.inner().head.ctype, Some(ConnectionType::Upgrade)); assert!(req.upgrade()); assert!(pl.is_unhandled()); } diff --git a/src/h1/encoder.rs b/src/h1/encoder.rs index 1664af162..f9b4aab35 100644 --- a/src/h1/encoder.rs +++ b/src/h1/encoder.rs @@ -1,40 +1,217 @@ #![allow(unused_imports, unused_variables, dead_code)] use std::fmt::Write as FmtWrite; use std::io::Write; +use std::marker::PhantomData; use std::str::FromStr; use std::{cmp, fmt, io, mem}; -use bytes::{Bytes, BytesMut}; -use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH}; -use http::{StatusCode, Version}; +use bytes::{BufMut, Bytes, BytesMut}; +use http::header::{ + HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, +}; +use http::{HeaderMap, StatusCode, Version}; use body::BodyLength; +use config::ServiceConfig; use header::ContentEncoding; +use helpers; use http::Method; -use message::{RequestHead, ResponseHead}; +use message::{ConnectionType, RequestHead, ResponseHead}; use request::Request; use response::Response; +const AVERAGE_HEADER_SIZE: usize = 30; + #[derive(Debug)] -pub(crate) struct ResponseEncoder { - head: bool, +pub(crate) struct MessageEncoder { pub length: BodyLength, pub te: TransferEncoding, + _t: PhantomData, } -impl Default for ResponseEncoder { +impl Default for MessageEncoder { fn default() -> Self { - ResponseEncoder { - head: false, + MessageEncoder { length: BodyLength::None, te: TransferEncoding::empty(), + _t: PhantomData, } } } -impl ResponseEncoder { +pub(crate) trait MessageType: Sized { + fn status(&self) -> Option; + + fn connection_type(&self) -> Option; + + fn headers(&self) -> &HeaderMap; + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>; + + fn encode_headers( + &mut self, + dst: &mut BytesMut, + version: Version, + mut length: BodyLength, + ctype: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + let mut skip_len = length != BodyLength::Stream; + + // Content length + if let Some(status) = self.status() { + match status { + StatusCode::NO_CONTENT + | StatusCode::CONTINUE + | StatusCode::PROCESSING => length = BodyLength::None, + StatusCode::SWITCHING_PROTOCOLS => { + skip_len = true; + length = BodyLength::Stream; + } + _ => (), + } + } + match length { + BodyLength::Chunked => { + dst.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") + } + BodyLength::Empty => { + dst.extend_from_slice(b"\r\ncontent-length: 0\r\n"); + } + BodyLength::Sized(len) => helpers::write_content_length(len, dst), + BodyLength::Sized64(len) => { + dst.extend_from_slice(b"\r\ncontent-length: "); + write!(dst.writer(), "{}", len)?; + dst.extend_from_slice(b"\r\n"); + } + BodyLength::None | BodyLength::Stream => dst.extend_from_slice(b"\r\n"), + } + + // Connection + match ctype { + ConnectionType::Upgrade => dst.extend_from_slice(b"connection: upgrade\r\n"), + ConnectionType::KeepAlive if version < Version::HTTP_11 => { + dst.extend_from_slice(b"connection: keep-alive\r\n") + } + ConnectionType::Close if version >= Version::HTTP_11 => { + dst.extend_from_slice(b"connection: close\r\n") + } + _ => (), + } + + // write headers + let mut pos = 0; + let mut has_date = false; + let mut remaining = dst.remaining_mut(); + let mut buf = unsafe { &mut *(dst.bytes_mut() as *mut [u8]) }; + for (key, value) in self.headers() { + match key { + &CONNECTION => continue, + &TRANSFER_ENCODING | &CONTENT_LENGTH if skip_len => continue, + &DATE => { + has_date = true; + } + _ => (), + } + + let v = value.as_ref(); + let k = key.as_str().as_bytes(); + let len = k.len() + v.len() + 4; + if len > remaining { + unsafe { + dst.advance_mut(pos); + } + pos = 0; + dst.reserve(len); + remaining = dst.remaining_mut(); + unsafe { + buf = &mut *(dst.bytes_mut() as *mut _); + } + } + + buf[pos..pos + k.len()].copy_from_slice(k); + pos += k.len(); + buf[pos..pos + 2].copy_from_slice(b": "); + pos += 2; + buf[pos..pos + v.len()].copy_from_slice(v); + pos += v.len(); + buf[pos..pos + 2].copy_from_slice(b"\r\n"); + pos += 2; + remaining -= len; + } + unsafe { + dst.advance_mut(pos); + } + + // optimized date header, set_date writes \r\n + if !has_date { + config.set_date(dst); + } else { + // msg eof + dst.extend_from_slice(b"\r\n"); + } + + Ok(()) + } +} + +impl MessageType for Response<()> { + fn status(&self) -> Option { + Some(self.head().status) + } + + fn connection_type(&self) -> Option { + self.head().ctype + } + + fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + let head = self.head(); + let reason = head.reason().as_bytes(); + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); + + // status line + helpers::write_status_line(head.version, head.status.as_u16(), dst); + dst.extend_from_slice(reason); + Ok(()) + } +} + +impl MessageType for RequestHead { + fn status(&self) -> Option { + None + } + + fn connection_type(&self) -> Option { + self.ctype + } + + fn headers(&self) -> &HeaderMap { + &self.headers + } + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + write!( + Writer(dst), + "{} {} {}", + self.method, + self.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), + match self.version { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2.0", + } + ).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } +} + +impl MessageEncoder { /// Encode message - pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { self.te.encode(msg, buf) } @@ -43,59 +220,32 @@ impl ResponseEncoder { self.te.encode_eof(buf) } - pub fn update( + pub fn encode( &mut self, - resp: &mut ResponseHead, + dst: &mut BytesMut, + message: &mut T, head: bool, version: Version, length: BodyLength, - ) { - self.head = head; - let transfer = match length { - BodyLength::Empty => TransferEncoding::empty(), - BodyLength::Sized(len) => TransferEncoding::length(len as u64), - BodyLength::Sized64(len) => TransferEncoding::length(len), - BodyLength::Chunked => TransferEncoding::chunked(), - BodyLength::Stream => TransferEncoding::eof(), - BodyLength::None => TransferEncoding::length(0), - }; - // check for head response - if !self.head { - self.te = transfer; + ctype: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + // transfer encoding + if !head { + self.te = match length { + BodyLength::Empty => TransferEncoding::empty(), + BodyLength::Sized(len) => TransferEncoding::length(len as u64), + BodyLength::Sized64(len) => TransferEncoding::length(len), + BodyLength::Chunked => TransferEncoding::chunked(), + BodyLength::Stream => TransferEncoding::eof(), + BodyLength::None => TransferEncoding::empty(), + }; + } else { + self.te = TransferEncoding::empty(); } - } -} -#[derive(Debug)] -pub(crate) struct RequestEncoder { - head: bool, - pub length: BodyLength, - pub te: TransferEncoding, -} - -impl Default for RequestEncoder { - fn default() -> Self { - RequestEncoder { - head: false, - length: BodyLength::None, - te: TransferEncoding::empty(), - } - } -} - -impl RequestEncoder { - /// Encode message - pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { - self.te.encode(msg, buf) - } - - /// Encode eof - pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { - self.te.encode_eof(buf) - } - - pub fn update(&mut self, resp: &mut RequestHead, head: bool, version: Version) { - self.head = head; + message.encode_status(dst)?; + message.encode_headers(dst, version, length, ctype, config) } } @@ -123,7 +273,7 @@ impl TransferEncoding { #[inline] pub fn empty() -> TransferEncoding { TransferEncoding { - kind: TransferEncodingKind::Eof, + kind: TransferEncodingKind::Length(0), } } diff --git a/src/message.rs b/src/message.rs index e1059fa48..d3d52e2f9 100644 --- a/src/message.rs +++ b/src/message.rs @@ -39,12 +39,13 @@ pub trait Head: Default + 'static { fn pool() -> &'static MessagePool; } +#[derive(Debug)] pub struct RequestHead { pub uri: Uri, pub method: Method, pub version: Version, pub headers: HeaderMap, - ctype: Option, + pub ctype: Option, } impl Default for RequestHead { @@ -72,7 +73,7 @@ impl Head for RequestHead { fn connection_type(&self) -> ConnectionType { if let Some(ct) = self.ctype { ct - } else if self.version <= Version::HTTP_11 { + } else if self.version < Version::HTTP_11 { ConnectionType::Close } else { ConnectionType::KeepAlive @@ -84,6 +85,7 @@ impl Head for RequestHead { } } +#[derive(Debug)] pub struct ResponseHead { pub version: Version, pub status: StatusCode, @@ -118,7 +120,7 @@ impl Head for ResponseHead { fn connection_type(&self) -> ConnectionType { if let Some(ct) = self.ctype { ct - } else if self.version <= Version::HTTP_11 { + } else if self.version < Version::HTTP_11 { ConnectionType::Close } else { ConnectionType::KeepAlive diff --git a/src/request.rs b/src/request.rs index d529c09f9..248555e60 100644 --- a/src/request.rs +++ b/src/request.rs @@ -8,7 +8,7 @@ use extensions::Extensions; use httpmessage::HttpMessage; use payload::Payload; -use message::{Head, Message, MessagePool, RequestHead}; +use message::{Message, MessagePool, RequestHead}; /// Request pub struct Request { @@ -67,6 +67,19 @@ impl Request { Rc::get_mut(&mut self.inner).expect("Multiple copies exist") } + #[inline] + /// Http message part of the request + pub fn head(&self) -> &RequestHead { + &self.inner.as_ref().head + } + + #[inline] + #[doc(hidden)] + /// Mutable reference to a http message part of the request + pub fn head_mut(&mut self) -> &mut RequestHead { + &mut self.inner_mut().head + } + /// Request's uri. #[inline] pub fn uri(&self) -> &Uri { @@ -109,12 +122,6 @@ impl Request { &mut self.inner_mut().head.headers } - /// Checks if a connection should be kept alive. - #[inline] - pub fn keep_alive(&self) -> bool { - self.inner().head.keep_alive() - } - /// Request extensions #[inline] pub fn extensions(&self) -> Ref { diff --git a/src/response.rs b/src/response.rs index f2ed72925..bc730718d 100644 --- a/src/response.rs +++ b/src/response.rs @@ -85,7 +85,14 @@ impl Response { } #[inline] - pub(crate) fn head_mut(&mut self) -> &mut ResponseHead { + /// Http message part of the response + pub fn head(&self) -> &ResponseHead { + &self.0.as_ref().head + } + + #[inline] + /// Mutable reference to a http message part of the response + pub fn head_mut(&mut self) -> &mut ResponseHead { &mut self.0.as_mut().head } @@ -314,7 +321,7 @@ impl ResponseBuilder { self } - /// Set a header. + /// Append a header to existing headers. /// /// ```rust,ignore /// # extern crate actix_web; @@ -347,6 +354,39 @@ impl ResponseBuilder { self } + /// Set a header. + /// + /// ```rust,ignore + /// # extern crate actix_web; + /// use actix_web::{http, Request, Response}; + /// + /// fn index(req: HttpRequest) -> Response { + /// Response::Ok() + /// .set_header("X-TEST", "value") + /// .set_header(http::header::CONTENT_TYPE, "application/json") + /// .finish() + /// } + /// fn main() {} + /// ``` + pub fn set_header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.response, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + parts.head.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + /// Set the custom reason for the response. #[inline] pub fn reason(&mut self, reason: &'static str) -> &mut Self { @@ -367,11 +407,14 @@ impl ResponseBuilder { /// Set connection type to Upgrade #[inline] - pub fn upgrade(&mut self) -> &mut Self { + pub fn upgrade(&mut self, value: V) -> &mut Self + where + V: IntoHeaderValue, + { if let Some(parts) = parts(&mut self.response, &self.err) { parts.head.set_connection_type(ConnectionType::Upgrade); } - self + self.set_header(header::UPGRADE, value) } /// Force close connection, even if it is marked as keep-alive @@ -880,8 +923,14 @@ mod tests { #[test] fn test_upgrade() { - let resp = Response::build(StatusCode::OK).upgrade().finish(); - assert!(resp.upgrade()) + let resp = Response::build(StatusCode::OK) + .upgrade("websocket") + .finish(); + assert!(resp.upgrade()); + assert_eq!( + resp.headers().get(header::UPGRADE).unwrap(), + HeaderValue::from_static("websocket") + ); } #[test] diff --git a/src/test.rs b/src/test.rs index 9e14129b6..8f90246b4 100644 --- a/src/test.rs +++ b/src/test.rs @@ -443,7 +443,7 @@ where ) -> Result, ws::ClientError> { let url = self.url(path); self.rt - .block_on(ws::Client::default().call(ws::Connect::new(url))) + .block_on(lazy(|| ws::Client::default().call(ws::Connect::new(url)))) } /// Connect to a websocket server diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 5c86d8c49..f1c91714a 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -183,8 +183,7 @@ pub fn handshake_response(req: &Request) -> ResponseBuilder { }; Response::build(StatusCode::SWITCHING_PROTOCOLS) - .upgrade() - .header(header::UPGRADE, "websocket") + .upgrade("websocket") .header(header::TRANSFER_ENCODING, "chunked") .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) .take() diff --git a/tests/test_server.rs b/tests/test_server.rs index a01af4f0d..c6e03e285 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -89,7 +89,9 @@ fn test_content_length() { { for i in 0..4 { - let req = client::ClientRequest::get(srv.url("/")).finish().unwrap(); + let req = client::ClientRequest::get(srv.url(&format!("/{}", i))) + .finish() + .unwrap(); let response = srv.send_request(req).unwrap(); assert_eq!(response.headers().get(&header), None); diff --git a/tests/test_ws.rs b/tests/test_ws.rs index 21f635129..22ce3ca29 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -13,7 +13,7 @@ use actix_net::service::NewServiceExt; use actix_net::stream::TakeItem; use actix_web::ws as web_ws; use bytes::{Bytes, BytesMut}; -use futures::future::{ok, Either}; +use futures::future::{lazy, ok, Either}; use futures::{Future, Sink, Stream}; use actix_http::{h1, test, ws, ResponseError, SendResponse, ServiceConfig}; @@ -81,8 +81,9 @@ fn test_simple() { { let url = srv.url("/"); - let (reader, mut writer) = - srv.block_on(web_ws::Client::new(url).connect()).unwrap(); + let (reader, mut writer) = srv + .block_on(lazy(|| web_ws::Client::new(url).connect())) + .unwrap(); writer.text("text"); let (item, reader) = srv.block_on(reader.into_future()).unwrap();