1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 01:32:57 +01:00

unify requedt/response encoder

This commit is contained in:
Nikolay Kim 2018-11-19 14:57:12 -08:00
parent 1ca6b44bae
commit 3901239128
12 changed files with 448 additions and 405 deletions

View File

@ -49,7 +49,10 @@ where
.and_then(|(item, framed)| {
if let Some(res) = item {
match framed.get_codec().message_type() {
h1::MessageType::None => release_connection(framed),
h1::MessageType::None => {
let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close)
}
_ => {
*res.payload.borrow_mut() = Some(Payload::stream(framed))
}
@ -174,7 +177,9 @@ impl<Io: Connection> Stream for Payload<Io> {
Async::Ready(Some(chunk)) => if let Some(chunk) = chunk {
Ok(Async::Ready(Some(chunk)))
} else {
release_connection(self.framed.take().unwrap());
let framed = self.framed.take().unwrap();
let force_close = framed.get_codec().keepalive();
release_connection(framed, force_close);
Ok(Async::Ready(None))
},
Async::Ready(None) => Ok(Async::Ready(None)),
@ -182,12 +187,12 @@ impl<Io: Connection> Stream for Payload<Io> {
}
}
fn release_connection<T, U>(framed: Framed<T, U>)
fn release_connection<T, U>(framed: Framed<T, U>, force_close: bool)
where
T: Connection,
{
let mut parts = framed.into_parts();
if parts.read_buf.is_empty() && parts.write_buf.is_empty() {
if !force_close && parts.read_buf.is_empty() && parts.write_buf.is_empty() {
parts.io.release()
} else {
parts.io.close()

View File

@ -4,8 +4,8 @@ use std::io::{self, Write};
use bytes::{BufMut, Bytes, BytesMut};
use tokio_codec::{Decoder, Encoder};
use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType};
use super::encoder::RequestEncoder;
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
use super::{decoder, encoder};
use super::{Message, MessageType};
use body::BodyLength;
use client::ClientResponse;
@ -16,13 +16,11 @@ use http::header::{
HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE,
};
use http::{Method, Version};
use message::{Head, MessagePool, RequestHead};
use message::{ConnectionType, Head, MessagePool, RequestHead};
bitflags! {
struct Flags: u8 {
const HEAD = 0b0000_0001;
const UPGRADE = 0b0000_0010;
const KEEPALIVE = 0b0000_0100;
const KEEPALIVE_ENABLED = 0b0000_1000;
const STREAM = 0b0001_0000;
}
@ -42,14 +40,15 @@ pub struct ClientPayloadCodec {
struct ClientCodecInner {
config: ServiceConfig,
decoder: MessageDecoder<ClientResponse>,
decoder: decoder::MessageDecoder<ClientResponse>,
payload: Option<PayloadDecoder>,
version: Version,
ctype: ConnectionType,
// encoder part
flags: Flags,
headers_size: u32,
te: RequestEncoder,
encoder: encoder::MessageEncoder<RequestHead>,
}
impl Default for ClientCodec {
@ -71,25 +70,26 @@ impl ClientCodec {
ClientCodec {
inner: ClientCodecInner {
config,
decoder: MessageDecoder::default(),
decoder: decoder::MessageDecoder::default(),
payload: None,
version: Version::HTTP_11,
ctype: ConnectionType::Close,
flags,
headers_size: 0,
te: RequestEncoder::default(),
encoder: encoder::MessageEncoder::default(),
},
}
}
/// Check if request is upgrade
pub fn upgrade(&self) -> bool {
self.inner.flags.contains(Flags::UPGRADE)
self.inner.ctype == ConnectionType::Upgrade
}
/// Check if last response is keep-alive
pub fn keepalive(&self) -> bool {
self.inner.flags.contains(Flags::KEEPALIVE)
self.inner.ctype == ConnectionType::KeepAlive
}
/// Check last request's message type
@ -103,15 +103,6 @@ impl ClientCodec {
}
}
/// prepare transfer encoding
pub fn prepare_te(&mut self, head: &mut RequestHead, length: BodyLength) {
self.inner.te.update(
head,
self.inner.flags.contains(Flags::HEAD),
self.inner.version,
);
}
/// Convert message codec to a payload codec
pub fn into_payload_codec(self) -> ClientPayloadCodec {
ClientPayloadCodec { inner: self.inner }
@ -119,96 +110,17 @@ impl ClientCodec {
}
impl ClientPayloadCodec {
/// Check if last response is keep-alive
pub fn keepalive(&self) -> bool {
self.inner.ctype == ConnectionType::KeepAlive
}
/// Transform payload codec to a message codec
pub fn into_message_codec(self) -> ClientCodec {
ClientCodec { inner: self.inner }
}
}
fn prn_version(ver: Version) -> &'static str {
match ver {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2.0",
}
}
impl ClientCodecInner {
fn encode_request(
&mut self,
msg: RequestHead,
length: BodyLength,
buffer: &mut BytesMut,
) -> io::Result<()> {
// render message
{
// status line
write!(
Writer(buffer),
"{} {} {}\r\n",
msg.method,
msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
prn_version(msg.version)
).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
// write headers
buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE);
// content length
match length {
BodyLength::Sized(len) => helpers::write_content_length(len, buffer),
BodyLength::Sized64(len) => {
buffer.extend_from_slice(b"content-length: ");
write!(buffer.writer(), "{}", len)?;
buffer.extend_from_slice(b"\r\n");
}
BodyLength::Chunked => {
buffer.extend_from_slice(b"transfer-encoding: chunked\r\n")
}
BodyLength::Empty => buffer.extend_from_slice(b"content-length: 0\r\n"),
BodyLength::None | BodyLength::Stream => (),
}
let mut has_date = false;
for (key, value) in &msg.headers {
match *key {
TRANSFER_ENCODING | CONNECTION | CONTENT_LENGTH => continue,
DATE => has_date = true,
_ => (),
}
buffer.put_slice(key.as_ref());
buffer.put_slice(b": ");
buffer.put_slice(value.as_ref());
buffer.put_slice(b"\r\n");
}
// Connection header
if msg.upgrade() {
self.flags.set(Flags::UPGRADE, msg.upgrade());
buffer.extend_from_slice(b"connection: upgrade\r\n");
} else if msg.keep_alive() {
if self.version < Version::HTTP_11 {
buffer.extend_from_slice(b"connection: keep-alive\r\n");
}
} else if self.version >= Version::HTTP_11 {
buffer.extend_from_slice(b"connection: close\r\n");
}
// Date header
if !has_date {
self.config.set_date(buffer);
} else {
buffer.extend_from_slice(b"\r\n");
}
}
Ok(())
}
}
impl Decoder for ClientCodec {
type Item = ClientResponse;
type Error = ParseError;
@ -217,21 +129,27 @@ impl Decoder for ClientCodec {
debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set");
if let Some((req, payload)) = self.inner.decoder.decode(src)? {
// self.inner
// .flags
// .set(Flags::HEAD, req.head.method == Method::HEAD);
// self.inner.version = req.head.version;
if self.inner.flags.contains(Flags::KEEPALIVE_ENABLED) {
self.inner.flags.set(Flags::KEEPALIVE, req.keep_alive());
if let Some(ctype) = req.head().ctype {
// do not use peer's keep-alive
self.inner.ctype = if ctype == ConnectionType::KeepAlive {
self.inner.ctype
} else {
ctype
};
}
match payload {
PayloadType::None => self.inner.payload = None,
PayloadType::Payload(pl) => self.inner.payload = Some(pl),
PayloadType::Stream(pl) => {
self.inner.payload = Some(pl);
self.inner.flags.insert(Flags::STREAM);
if !self.inner.flags.contains(Flags::HEAD) {
match payload {
PayloadType::None => self.inner.payload = None,
PayloadType::Payload(pl) => self.inner.payload = Some(pl),
PayloadType::Stream(pl) => {
self.inner.payload = Some(pl);
self.inner.flags.insert(Flags::STREAM);
}
}
};
} else {
self.inner.payload = None;
}
Ok(Some(req))
} else {
Ok(None)
@ -270,14 +188,39 @@ impl Encoder for ClientCodec {
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
match item {
Message::Item((msg, btype)) => {
self.inner.encode_request(msg, btype, dst)?;
Message::Item((mut msg, length)) => {
let inner = &mut self.inner;
inner.version = msg.version;
inner.flags.set(Flags::HEAD, msg.method == Method::HEAD);
// connection status
inner.ctype = match msg.connection_type() {
ConnectionType::KeepAlive => {
if inner.flags.contains(Flags::KEEPALIVE_ENABLED) {
ConnectionType::KeepAlive
} else {
ConnectionType::Close
}
}
ConnectionType::Upgrade => ConnectionType::Upgrade,
ConnectionType::Close => ConnectionType::Close,
};
inner.encoder.encode(
dst,
&mut msg,
false,
inner.version,
length,
inner.ctype,
&inner.config,
)?;
}
Message::Chunk(Some(bytes)) => {
self.inner.te.encode(bytes.as_ref(), dst)?;
self.inner.encoder.encode_chunk(bytes.as_ref(), dst)?;
}
Message::Chunk(None) => {
self.inner.te.encode_eof(dst)?;
self.inner.encoder.encode_eof(dst)?;
}
}
Ok(())

View File

@ -5,8 +5,8 @@ use std::io::{self, Write};
use bytes::{BufMut, Bytes, BytesMut};
use tokio_codec::{Decoder, Encoder};
use super::decoder::{MessageDecoder, PayloadDecoder, PayloadItem, PayloadType};
use super::encoder::ResponseEncoder;
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
use super::{decoder, encoder};
use super::{Message, MessageType};
use body::BodyLength;
use config::ServiceConfig;
@ -14,15 +14,13 @@ use error::ParseError;
use helpers;
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{Method, StatusCode, Version};
use message::{Head, ResponseHead};
use message::{ConnectionType, Head, ResponseHead};
use request::Request;
use response::Response;
bitflags! {
struct Flags: u8 {
const HEAD = 0b0000_0001;
const UPGRADE = 0b0000_0010;
const KEEPALIVE = 0b0000_0100;
const KEEPALIVE_ENABLED = 0b0000_1000;
const STREAM = 0b0001_0000;
}
@ -33,14 +31,15 @@ const AVERAGE_HEADER_SIZE: usize = 30;
/// HTTP/1 Codec
pub struct Codec {
config: ServiceConfig,
decoder: MessageDecoder<Request>,
decoder: decoder::MessageDecoder<Request>,
payload: Option<PayloadDecoder>,
version: Version,
ctype: ConnectionType,
// encoder part
flags: Flags,
headers_size: u32,
te: ResponseEncoder,
encoder: encoder::MessageEncoder<Response<()>>,
}
impl Default for Codec {
@ -67,24 +66,25 @@ impl Codec {
};
Codec {
config,
decoder: MessageDecoder::default(),
decoder: decoder::MessageDecoder::default(),
payload: None,
version: Version::HTTP_11,
ctype: ConnectionType::Close,
flags,
headers_size: 0,
te: ResponseEncoder::default(),
encoder: encoder::MessageEncoder::default(),
}
}
/// Check if request is upgrade
pub fn upgrade(&self) -> bool {
self.flags.contains(Flags::UPGRADE)
self.ctype == ConnectionType::Upgrade
}
/// Check if last response is keep-alive
pub fn keepalive(&self) -> bool {
self.flags.contains(Flags::KEEPALIVE)
self.ctype == ConnectionType::KeepAlive
}
/// Check last request's message type
@ -97,130 +97,6 @@ impl Codec {
MessageType::Payload
}
}
/// prepare transfer encoding
fn prepare_te(&mut self, head: &mut ResponseHead, length: BodyLength) {
self.te
.update(head, self.flags.contains(Flags::HEAD), self.version, length);
}
fn encode_response(
&mut self,
msg: &mut ResponseHead,
length: BodyLength,
buffer: &mut BytesMut,
) -> io::Result<()> {
msg.version = self.version;
// Connection upgrade
if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
self.flags.remove(Flags::KEEPALIVE);
msg.headers
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
// keep-alive
else if self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg.keep_alive() {
self.flags.insert(Flags::KEEPALIVE);
if self.version < Version::HTTP_11 {
msg.headers
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
}
} else if self.version >= Version::HTTP_11 {
self.flags.remove(Flags::KEEPALIVE);
msg.headers
.insert(CONNECTION, HeaderValue::from_static("close"));
}
// render message
{
let reason = msg.reason().as_bytes();
buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
// status line
helpers::write_status_line(self.version, msg.status.as_u16(), buffer);
buffer.extend_from_slice(reason);
// content length
match msg.status {
StatusCode::NO_CONTENT
| StatusCode::CONTINUE
| StatusCode::SWITCHING_PROTOCOLS
| StatusCode::PROCESSING => buffer.extend_from_slice(b"\r\n"),
_ => match length {
BodyLength::Chunked => {
buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
}
BodyLength::Empty => {
buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n");
}
BodyLength::Sized(len) => helpers::write_content_length(len, buffer),
BodyLength::Sized64(len) => {
buffer.extend_from_slice(b"\r\ncontent-length: ");
write!(buffer.writer(), "{}", len)?;
buffer.extend_from_slice(b"\r\n");
}
BodyLength::None | BodyLength::Stream => {
buffer.extend_from_slice(b"\r\n")
}
},
}
// write headers
let mut pos = 0;
let mut has_date = false;
let mut remaining = buffer.remaining_mut();
let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) };
for (key, value) in &msg.headers {
match *key {
TRANSFER_ENCODING | CONTENT_LENGTH => continue,
DATE => {
has_date = true;
}
_ => (),
}
let v = value.as_ref();
let k = key.as_str().as_bytes();
let len = k.len() + v.len() + 4;
if len > remaining {
unsafe {
buffer.advance_mut(pos);
}
pos = 0;
buffer.reserve(len);
remaining = buffer.remaining_mut();
unsafe {
buf = &mut *(buffer.bytes_mut() as *mut _);
}
}
buf[pos..pos + k.len()].copy_from_slice(k);
pos += k.len();
buf[pos..pos + 2].copy_from_slice(b": ");
pos += 2;
buf[pos..pos + v.len()].copy_from_slice(v);
pos += v.len();
buf[pos..pos + 2].copy_from_slice(b"\r\n");
pos += 2;
remaining -= len;
}
unsafe {
buffer.advance_mut(pos);
}
// optimized date header, set_date writes \r\n
if !has_date {
self.config.set_date(buffer);
} else {
// msg eof
buffer.extend_from_slice(b"\r\n");
}
self.headers_size = buffer.len() as u32;
}
Ok(())
}
}
impl Decoder for Codec {
@ -240,9 +116,12 @@ impl Decoder for Codec {
} else if let Some((req, payload)) = self.decoder.decode(src)? {
self.flags
.set(Flags::HEAD, req.inner.head.method == Method::HEAD);
self.version = req.inner.head.version;
if self.flags.contains(Flags::KEEPALIVE_ENABLED) {
self.flags.set(Flags::KEEPALIVE, req.keep_alive());
self.version = req.inner().head.version;
self.ctype = req.inner().head.connection_type();
if self.ctype == ConnectionType::KeepAlive
&& !self.flags.contains(Flags::KEEPALIVE_ENABLED)
{
self.ctype = ConnectionType::Close
}
match payload {
PayloadType::None => self.payload = None,
@ -270,14 +149,35 @@ impl Encoder for Codec {
) -> Result<(), Self::Error> {
match item {
Message::Item((mut res, length)) => {
self.prepare_te(res.head_mut(), length);
self.encode_response(res.head_mut(), length, dst)?;
// connection status
self.ctype = if let Some(ct) = res.head().ctype {
if ct == ConnectionType::KeepAlive {
self.ctype
} else {
ct
}
} else {
self.ctype
};
// encode message
let len = dst.len();
self.encoder.encode(
dst,
&mut res,
self.flags.contains(Flags::HEAD),
self.version,
length,
self.ctype,
&self.config,
)?;
self.headers_size = (dst.len() - len) as u32;
}
Message::Chunk(Some(bytes)) => {
self.te.encode(bytes.as_ref(), dst)?;
self.encoder.encode_chunk(bytes.as_ref(), dst)?;
}
Message::Chunk(None) => {
self.te.encode_eof(dst)?;
self.encoder.encode_eof(dst)?;
}
}
Ok(())

View File

@ -10,15 +10,16 @@ use client::ClientResponse;
use error::ParseError;
use http::header::{HeaderName, HeaderValue};
use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version};
use message::{ConnectionType, Head};
use message::ConnectionType;
use request::Request;
const MAX_BUFFER_SIZE: usize = 131_072;
const MAX_HEADERS: usize = 96;
/// Incoming messagd decoder
pub(crate) struct MessageDecoder<T: MessageTypeDecoder>(PhantomData<T>);
pub(crate) struct MessageDecoder<T: MessageType>(PhantomData<T>);
#[derive(Debug)]
/// Incoming request type
pub(crate) enum PayloadType {
None,
@ -26,13 +27,13 @@ pub(crate) enum PayloadType {
Stream(PayloadDecoder),
}
impl<T: MessageTypeDecoder> Default for MessageDecoder<T> {
impl<T: MessageType> Default for MessageDecoder<T> {
fn default() -> Self {
MessageDecoder(PhantomData)
}
}
impl<T: MessageTypeDecoder> Decoder for MessageDecoder<T> {
impl<T: MessageType> Decoder for MessageDecoder<T> {
type Item = (T, PayloadType);
type Error = ParseError;
@ -47,10 +48,8 @@ pub(crate) enum PayloadLength {
None,
}
pub(crate) trait MessageTypeDecoder: Sized {
fn keep_alive(&mut self);
fn force_close(&mut self);
pub(crate) trait MessageType: Sized {
fn set_connection_type(&mut self, ctype: Option<ConnectionType>);
fn headers_mut(&mut self) -> &mut HeaderMap;
@ -59,10 +58,9 @@ pub(crate) trait MessageTypeDecoder: Sized {
fn set_headers(
&mut self,
slice: &Bytes,
version: Version,
raw_headers: &[HeaderIndex],
) -> Result<PayloadLength, ParseError> {
let mut ka = version != Version::HTTP_10;
let mut ka = None;
let mut has_upgrade = false;
let mut chunked = false;
let mut content_length = None;
@ -104,18 +102,18 @@ pub(crate) trait MessageTypeDecoder: Sized {
// connection keep-alive state
header::CONNECTION => {
ka = if let Ok(conn) = value.to_str() {
if version == Version::HTTP_10
&& conn.contains("keep-alive")
{
true
if conn.contains("keep-alive") {
Some(ConnectionType::KeepAlive)
} else if conn.contains("close") {
Some(ConnectionType::Close)
} else if conn.contains("upgrade") {
Some(ConnectionType::Upgrade)
} else {
version == Version::HTTP_11 && !(conn
.contains("close")
|| conn.contains("upgrade"))
None
}
} else {
false
}
None
};
}
header::UPGRADE => {
has_upgrade = true;
@ -136,12 +134,7 @@ pub(crate) trait MessageTypeDecoder: Sized {
}
}
}
if ka {
self.keep_alive();
} else {
self.force_close();
}
self.set_connection_type(ka);
// https://tools.ietf.org/html/rfc7230#section-3.3.3
if chunked {
@ -162,17 +155,9 @@ pub(crate) trait MessageTypeDecoder: Sized {
}
}
impl MessageTypeDecoder for Request {
fn keep_alive(&mut self) {
self.inner_mut()
.head
.set_connection_type(ConnectionType::KeepAlive)
}
fn force_close(&mut self) {
self.inner_mut()
.head
.set_connection_type(ConnectionType::Close)
impl MessageType for Request {
fn set_connection_type(&mut self, ctype: Option<ConnectionType>) {
self.inner_mut().head.ctype = ctype;
}
fn headers_mut(&mut self) -> &mut HeaderMap {
@ -210,8 +195,7 @@ impl MessageTypeDecoder for Request {
let mut msg = Request::new();
// convert headers
let len =
msg.set_headers(&src.split_to(len).freeze(), ver, &headers[..h_len])?;
let len = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?;
// payload decoder
let decoder = match len {
@ -243,13 +227,9 @@ impl MessageTypeDecoder for Request {
}
}
impl MessageTypeDecoder for ClientResponse {
fn keep_alive(&mut self) {
self.head.set_connection_type(ConnectionType::KeepAlive);
}
fn force_close(&mut self) {
self.head.set_connection_type(ConnectionType::Close);
impl MessageType for ClientResponse {
fn set_connection_type(&mut self, ctype: Option<ConnectionType>) {
self.head.ctype = ctype;
}
fn headers_mut(&mut self) -> &mut HeaderMap {
@ -286,8 +266,7 @@ impl MessageTypeDecoder for ClientResponse {
let mut msg = ClientResponse::new();
// convert headers
let len =
msg.set_headers(&src.split_to(len).freeze(), ver, &headers[..h_len])?;
let len = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?;
// message payload
let decoder = if let PayloadLength::Payload(pl) = len {
@ -634,6 +613,7 @@ mod tests {
use super::*;
use error::ParseError;
use httpmessage::HttpMessage;
use message::Head;
impl PayloadType {
fn unwrap(self) -> PayloadDecoder {
@ -886,7 +866,7 @@ mod tests {
let mut buf = BytesMut::from("GET /test HTTP/1.0\r\n\r\n");
let req = parse_ready!(&mut buf);
assert!(!req.keep_alive());
assert_eq!(req.head().connection_type(), ConnectionType::Close);
}
#[test]
@ -894,7 +874,7 @@ mod tests {
let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n");
let req = parse_ready!(&mut buf);
assert!(req.keep_alive());
assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive);
}
#[test]
@ -905,7 +885,7 @@ mod tests {
);
let req = parse_ready!(&mut buf);
assert!(!req.keep_alive());
assert_eq!(req.inner().head.ctype, Some(ConnectionType::Close));
}
#[test]
@ -916,7 +896,7 @@ mod tests {
);
let req = parse_ready!(&mut buf);
assert!(!req.keep_alive());
assert_eq!(req.inner().head.ctype, Some(ConnectionType::Close));
}
#[test]
@ -927,7 +907,7 @@ mod tests {
);
let req = parse_ready!(&mut buf);
assert!(req.keep_alive());
assert_eq!(req.inner().head.ctype, Some(ConnectionType::KeepAlive));
}
#[test]
@ -938,7 +918,7 @@ mod tests {
);
let req = parse_ready!(&mut buf);
assert!(req.keep_alive());
assert_eq!(req.inner().head.ctype, Some(ConnectionType::KeepAlive));
}
#[test]
@ -949,7 +929,7 @@ mod tests {
);
let req = parse_ready!(&mut buf);
assert!(!req.keep_alive());
assert_eq!(req.inner().head.connection_type(), ConnectionType::Close);
}
#[test]
@ -960,7 +940,11 @@ mod tests {
);
let req = parse_ready!(&mut buf);
assert!(req.keep_alive());
assert_eq!(req.inner().head.ctype, None);
assert_eq!(
req.inner().head.connection_type(),
ConnectionType::KeepAlive
);
}
#[test]
@ -973,6 +957,7 @@ mod tests {
let req = parse_ready!(&mut buf);
assert!(req.upgrade());
assert_eq!(req.inner().head.ctype, Some(ConnectionType::Upgrade));
}
#[test]
@ -1070,7 +1055,7 @@ mod tests {
);
let mut reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
assert!(!req.keep_alive());
assert_eq!(req.inner().head.ctype, Some(ConnectionType::Upgrade));
assert!(req.upgrade());
assert!(pl.is_unhandled());
}

View File

@ -1,40 +1,217 @@
#![allow(unused_imports, unused_variables, dead_code)]
use std::fmt::Write as FmtWrite;
use std::io::Write;
use std::marker::PhantomData;
use std::str::FromStr;
use std::{cmp, fmt, io, mem};
use bytes::{Bytes, BytesMut};
use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH};
use http::{StatusCode, Version};
use bytes::{BufMut, Bytes, BytesMut};
use http::header::{
HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
};
use http::{HeaderMap, StatusCode, Version};
use body::BodyLength;
use config::ServiceConfig;
use header::ContentEncoding;
use helpers;
use http::Method;
use message::{RequestHead, ResponseHead};
use message::{ConnectionType, RequestHead, ResponseHead};
use request::Request;
use response::Response;
const AVERAGE_HEADER_SIZE: usize = 30;
#[derive(Debug)]
pub(crate) struct ResponseEncoder {
head: bool,
pub(crate) struct MessageEncoder<T: MessageType> {
pub length: BodyLength,
pub te: TransferEncoding,
_t: PhantomData<T>,
}
impl Default for ResponseEncoder {
impl<T: MessageType> Default for MessageEncoder<T> {
fn default() -> Self {
ResponseEncoder {
head: false,
MessageEncoder {
length: BodyLength::None,
te: TransferEncoding::empty(),
_t: PhantomData,
}
}
}
impl ResponseEncoder {
pub(crate) trait MessageType: Sized {
fn status(&self) -> Option<StatusCode>;
fn connection_type(&self) -> Option<ConnectionType>;
fn headers(&self) -> &HeaderMap;
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;
fn encode_headers(
&mut self,
dst: &mut BytesMut,
version: Version,
mut length: BodyLength,
ctype: ConnectionType,
config: &ServiceConfig,
) -> io::Result<()> {
let mut skip_len = length != BodyLength::Stream;
// Content length
if let Some(status) = self.status() {
match status {
StatusCode::NO_CONTENT
| StatusCode::CONTINUE
| StatusCode::PROCESSING => length = BodyLength::None,
StatusCode::SWITCHING_PROTOCOLS => {
skip_len = true;
length = BodyLength::Stream;
}
_ => (),
}
}
match length {
BodyLength::Chunked => {
dst.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
}
BodyLength::Empty => {
dst.extend_from_slice(b"\r\ncontent-length: 0\r\n");
}
BodyLength::Sized(len) => helpers::write_content_length(len, dst),
BodyLength::Sized64(len) => {
dst.extend_from_slice(b"\r\ncontent-length: ");
write!(dst.writer(), "{}", len)?;
dst.extend_from_slice(b"\r\n");
}
BodyLength::None | BodyLength::Stream => dst.extend_from_slice(b"\r\n"),
}
// Connection
match ctype {
ConnectionType::Upgrade => dst.extend_from_slice(b"connection: upgrade\r\n"),
ConnectionType::KeepAlive if version < Version::HTTP_11 => {
dst.extend_from_slice(b"connection: keep-alive\r\n")
}
ConnectionType::Close if version >= Version::HTTP_11 => {
dst.extend_from_slice(b"connection: close\r\n")
}
_ => (),
}
// write headers
let mut pos = 0;
let mut has_date = false;
let mut remaining = dst.remaining_mut();
let mut buf = unsafe { &mut *(dst.bytes_mut() as *mut [u8]) };
for (key, value) in self.headers() {
match key {
&CONNECTION => continue,
&TRANSFER_ENCODING | &CONTENT_LENGTH if skip_len => continue,
&DATE => {
has_date = true;
}
_ => (),
}
let v = value.as_ref();
let k = key.as_str().as_bytes();
let len = k.len() + v.len() + 4;
if len > remaining {
unsafe {
dst.advance_mut(pos);
}
pos = 0;
dst.reserve(len);
remaining = dst.remaining_mut();
unsafe {
buf = &mut *(dst.bytes_mut() as *mut _);
}
}
buf[pos..pos + k.len()].copy_from_slice(k);
pos += k.len();
buf[pos..pos + 2].copy_from_slice(b": ");
pos += 2;
buf[pos..pos + v.len()].copy_from_slice(v);
pos += v.len();
buf[pos..pos + 2].copy_from_slice(b"\r\n");
pos += 2;
remaining -= len;
}
unsafe {
dst.advance_mut(pos);
}
// optimized date header, set_date writes \r\n
if !has_date {
config.set_date(dst);
} else {
// msg eof
dst.extend_from_slice(b"\r\n");
}
Ok(())
}
}
impl MessageType for Response<()> {
fn status(&self) -> Option<StatusCode> {
Some(self.head().status)
}
fn connection_type(&self) -> Option<ConnectionType> {
self.head().ctype
}
fn headers(&self) -> &HeaderMap {
&self.head().headers
}
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
let head = self.head();
let reason = head.reason().as_bytes();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
// status line
helpers::write_status_line(head.version, head.status.as_u16(), dst);
dst.extend_from_slice(reason);
Ok(())
}
}
impl MessageType for RequestHead {
fn status(&self) -> Option<StatusCode> {
None
}
fn connection_type(&self) -> Option<ConnectionType> {
self.ctype
}
fn headers(&self) -> &HeaderMap {
&self.headers
}
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
write!(
Writer(dst),
"{} {} {}",
self.method,
self.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
match self.version {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2.0",
}
).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
impl<T: MessageType> MessageEncoder<T> {
/// Encode message
pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
self.te.encode(msg, buf)
}
@ -43,59 +220,32 @@ impl ResponseEncoder {
self.te.encode_eof(buf)
}
pub fn update(
pub fn encode(
&mut self,
resp: &mut ResponseHead,
dst: &mut BytesMut,
message: &mut T,
head: bool,
version: Version,
length: BodyLength,
) {
self.head = head;
let transfer = match length {
BodyLength::Empty => TransferEncoding::empty(),
BodyLength::Sized(len) => TransferEncoding::length(len as u64),
BodyLength::Sized64(len) => TransferEncoding::length(len),
BodyLength::Chunked => TransferEncoding::chunked(),
BodyLength::Stream => TransferEncoding::eof(),
BodyLength::None => TransferEncoding::length(0),
};
// check for head response
if !self.head {
self.te = transfer;
ctype: ConnectionType,
config: &ServiceConfig,
) -> io::Result<()> {
// transfer encoding
if !head {
self.te = match length {
BodyLength::Empty => TransferEncoding::empty(),
BodyLength::Sized(len) => TransferEncoding::length(len as u64),
BodyLength::Sized64(len) => TransferEncoding::length(len),
BodyLength::Chunked => TransferEncoding::chunked(),
BodyLength::Stream => TransferEncoding::eof(),
BodyLength::None => TransferEncoding::empty(),
};
} else {
self.te = TransferEncoding::empty();
}
}
}
#[derive(Debug)]
pub(crate) struct RequestEncoder {
head: bool,
pub length: BodyLength,
pub te: TransferEncoding,
}
impl Default for RequestEncoder {
fn default() -> Self {
RequestEncoder {
head: false,
length: BodyLength::None,
te: TransferEncoding::empty(),
}
}
}
impl RequestEncoder {
/// Encode message
pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
self.te.encode(msg, buf)
}
/// Encode eof
pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
self.te.encode_eof(buf)
}
pub fn update(&mut self, resp: &mut RequestHead, head: bool, version: Version) {
self.head = head;
message.encode_status(dst)?;
message.encode_headers(dst, version, length, ctype, config)
}
}
@ -123,7 +273,7 @@ impl TransferEncoding {
#[inline]
pub fn empty() -> TransferEncoding {
TransferEncoding {
kind: TransferEncodingKind::Eof,
kind: TransferEncodingKind::Length(0),
}
}

View File

@ -39,12 +39,13 @@ pub trait Head: Default + 'static {
fn pool() -> &'static MessagePool<Self>;
}
#[derive(Debug)]
pub struct RequestHead {
pub uri: Uri,
pub method: Method,
pub version: Version,
pub headers: HeaderMap,
ctype: Option<ConnectionType>,
pub ctype: Option<ConnectionType>,
}
impl Default for RequestHead {
@ -72,7 +73,7 @@ impl Head for RequestHead {
fn connection_type(&self) -> ConnectionType {
if let Some(ct) = self.ctype {
ct
} else if self.version <= Version::HTTP_11 {
} else if self.version < Version::HTTP_11 {
ConnectionType::Close
} else {
ConnectionType::KeepAlive
@ -84,6 +85,7 @@ impl Head for RequestHead {
}
}
#[derive(Debug)]
pub struct ResponseHead {
pub version: Version,
pub status: StatusCode,
@ -118,7 +120,7 @@ impl Head for ResponseHead {
fn connection_type(&self) -> ConnectionType {
if let Some(ct) = self.ctype {
ct
} else if self.version <= Version::HTTP_11 {
} else if self.version < Version::HTTP_11 {
ConnectionType::Close
} else {
ConnectionType::KeepAlive

View File

@ -8,7 +8,7 @@ use extensions::Extensions;
use httpmessage::HttpMessage;
use payload::Payload;
use message::{Head, Message, MessagePool, RequestHead};
use message::{Message, MessagePool, RequestHead};
/// Request
pub struct Request {
@ -67,6 +67,19 @@ impl Request {
Rc::get_mut(&mut self.inner).expect("Multiple copies exist")
}
#[inline]
/// Http message part of the request
pub fn head(&self) -> &RequestHead {
&self.inner.as_ref().head
}
#[inline]
#[doc(hidden)]
/// Mutable reference to a http message part of the request
pub fn head_mut(&mut self) -> &mut RequestHead {
&mut self.inner_mut().head
}
/// Request's uri.
#[inline]
pub fn uri(&self) -> &Uri {
@ -109,12 +122,6 @@ impl Request {
&mut self.inner_mut().head.headers
}
/// Checks if a connection should be kept alive.
#[inline]
pub fn keep_alive(&self) -> bool {
self.inner().head.keep_alive()
}
/// Request extensions
#[inline]
pub fn extensions(&self) -> Ref<Extensions> {

View File

@ -85,7 +85,14 @@ impl<B: MessageBody> Response<B> {
}
#[inline]
pub(crate) fn head_mut(&mut self) -> &mut ResponseHead {
/// Http message part of the response
pub fn head(&self) -> &ResponseHead {
&self.0.as_ref().head
}
#[inline]
/// Mutable reference to a http message part of the response
pub fn head_mut(&mut self) -> &mut ResponseHead {
&mut self.0.as_mut().head
}
@ -314,7 +321,7 @@ impl ResponseBuilder {
self
}
/// Set a header.
/// Append a header to existing headers.
///
/// ```rust,ignore
/// # extern crate actix_web;
@ -347,6 +354,39 @@ impl ResponseBuilder {
self
}
/// Set a header.
///
/// ```rust,ignore
/// # extern crate actix_web;
/// use actix_web::{http, Request, Response};
///
/// fn index(req: HttpRequest) -> Response {
/// Response::Ok()
/// .set_header("X-TEST", "value")
/// .set_header(http::header::CONTENT_TYPE, "application/json")
/// .finish()
/// }
/// fn main() {}
/// ```
pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.response, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => match value.try_into() {
Ok(value) => {
parts.head.headers.insert(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Set the custom reason for the response.
#[inline]
pub fn reason(&mut self, reason: &'static str) -> &mut Self {
@ -367,11 +407,14 @@ impl ResponseBuilder {
/// Set connection type to Upgrade
#[inline]
pub fn upgrade(&mut self) -> &mut Self {
pub fn upgrade<V>(&mut self, value: V) -> &mut Self
where
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.response, &self.err) {
parts.head.set_connection_type(ConnectionType::Upgrade);
}
self
self.set_header(header::UPGRADE, value)
}
/// Force close connection, even if it is marked as keep-alive
@ -880,8 +923,14 @@ mod tests {
#[test]
fn test_upgrade() {
let resp = Response::build(StatusCode::OK).upgrade().finish();
assert!(resp.upgrade())
let resp = Response::build(StatusCode::OK)
.upgrade("websocket")
.finish();
assert!(resp.upgrade());
assert_eq!(
resp.headers().get(header::UPGRADE).unwrap(),
HeaderValue::from_static("websocket")
);
}
#[test]

View File

@ -443,7 +443,7 @@ where
) -> Result<Framed<impl AsyncRead + AsyncWrite, ws::Codec>, ws::ClientError> {
let url = self.url(path);
self.rt
.block_on(ws::Client::default().call(ws::Connect::new(url)))
.block_on(lazy(|| ws::Client::default().call(ws::Connect::new(url))))
}
/// Connect to a websocket server

View File

@ -183,8 +183,7 @@ pub fn handshake_response(req: &Request) -> ResponseBuilder {
};
Response::build(StatusCode::SWITCHING_PROTOCOLS)
.upgrade()
.header(header::UPGRADE, "websocket")
.upgrade("websocket")
.header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
.take()

View File

@ -89,7 +89,9 @@ fn test_content_length() {
{
for i in 0..4 {
let req = client::ClientRequest::get(srv.url("/")).finish().unwrap();
let req = client::ClientRequest::get(srv.url(&format!("/{}", i)))
.finish()
.unwrap();
let response = srv.send_request(req).unwrap();
assert_eq!(response.headers().get(&header), None);

View File

@ -13,7 +13,7 @@ use actix_net::service::NewServiceExt;
use actix_net::stream::TakeItem;
use actix_web::ws as web_ws;
use bytes::{Bytes, BytesMut};
use futures::future::{ok, Either};
use futures::future::{lazy, ok, Either};
use futures::{Future, Sink, Stream};
use actix_http::{h1, test, ws, ResponseError, SendResponse, ServiceConfig};
@ -81,8 +81,9 @@ fn test_simple() {
{
let url = srv.url("/");
let (reader, mut writer) =
srv.block_on(web_ws::Client::new(url).connect()).unwrap();
let (reader, mut writer) = srv
.block_on(lazy(|| web_ws::Client::new(url).connect()))
.unwrap();
writer.text("text");
let (item, reader) = srv.block_on(reader.into_future()).unwrap();