1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-23 15:24:36 +01:00

refactor encoder/decoder impl

This commit is contained in:
Nikolay Kim 2018-11-18 17:52:56 -08:00
parent 8fea1367c7
commit adad203314
17 changed files with 255 additions and 234 deletions

View File

@ -91,9 +91,11 @@ native-tls = { version="0.2", optional = true }
# openssl
openssl = { version="0.10", optional = true }
#rustls
# rustls
rustls = { version = "^0.14", optional = true }
backtrace="*"
[dev-dependencies]
actix-web = "0.7"
env_logger = "0.5"

View File

@ -9,7 +9,7 @@ use error::{Error, PayloadError};
/// Type represent streaming payload
pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>;
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Copy, Clone)]
/// Different type of body
pub enum BodyLength {
None,
@ -76,10 +76,11 @@ impl MessageBody for Body {
Body::None => Ok(Async::Ready(None)),
Body::Empty => Ok(Async::Ready(None)),
Body::Bytes(ref mut bin) => {
if bin.len() == 0 {
let len = bin.len();
if len == 0 {
Ok(Async::Ready(None))
} else {
Ok(Async::Ready(Some(bin.slice_to(bin.len()))))
Ok(Async::Ready(Some(bin.split_to(len))))
}
}
Body::Message(ref mut body) => body.poll_next(),

View File

@ -33,7 +33,7 @@ where
.from_err()
// create Framed and send reqest
.map(|io| Framed::new(io, h1::ClientCodec::default()))
.and_then(|framed| framed.send((head, len).into()).from_err())
.and_then(move |framed| framed.send((head, len).into()).from_err())
// send request body
.and_then(move |framed| match body.length() {
BodyLength::None | BodyLength::Empty | BodyLength::Sized(0) => {

View File

@ -16,7 +16,7 @@ use http::{
uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method,
Uri, Version,
};
use message::RequestHead;
use message::{Head, RequestHead};
use super::response::ClientResponse;
use super::{pipeline, Connect, Connection, ConnectorError, SendRequestError};
@ -365,8 +365,21 @@ impl ClientRequestBuilder {
where
V: IntoHeaderValue,
{
{
if let Some(parts) = parts(&mut self.head, &self.err) {
parts.set_upgrade();
}
}
self.set_header(header::UPGRADE, value)
.set_header(header::CONNECTION, "upgrade")
}
/// Close connection
#[inline]
pub fn close(&mut self) -> &mut Self {
if let Some(parts) = parts(&mut self.head, &self.err) {
parts.force_close();
}
self
}
/// Set request's content type

View File

@ -8,7 +8,7 @@ use http::{HeaderMap, StatusCode, Version};
use body::PayloadStream;
use error::PayloadError;
use httpmessage::HttpMessage;
use message::{MessageFlags, ResponseHead};
use message::{Head, ResponseHead};
use super::pipeline::Payload;
@ -81,7 +81,7 @@ impl ClientResponse {
/// Checks if a connection should be kept alive.
#[inline]
pub fn keep_alive(&self) -> bool {
self.head().flags.contains(MessageFlags::KEEPALIVE)
self.head().keep_alive()
}
}

View File

@ -16,7 +16,7 @@ use http::header::{
HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE,
};
use http::{Method, Version};
use message::{MessagePool, RequestHead};
use message::{Head, MessagePool, RequestHead};
bitflags! {
struct Flags: u8 {
@ -135,7 +135,7 @@ fn prn_version(ver: Version) -> &'static str {
}
impl ClientCodecInner {
fn encode_response(
fn encode_request(
&mut self,
msg: RequestHead,
length: BodyLength,
@ -146,7 +146,7 @@ impl ClientCodecInner {
// status line
write!(
Writer(buffer),
"{} {} {}",
"{} {} {}\r\n",
msg.method,
msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
prn_version(msg.version)
@ -156,38 +156,26 @@ impl ClientCodecInner {
buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE);
// content length
let mut len_is_set = true;
match length {
BodyLength::Sized(len) => helpers::write_content_length(len, buffer),
BodyLength::Sized64(len) => {
buffer.extend_from_slice(b"\r\ncontent-length: ");
buffer.extend_from_slice(b"content-length: ");
write!(buffer.writer(), "{}", len)?;
buffer.extend_from_slice(b"\r\n");
}
BodyLength::Chunked => {
buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
}
BodyLength::Empty => {
len_is_set = false;
buffer.extend_from_slice(b"\r\n")
}
BodyLength::None | BodyLength::Stream => {
buffer.extend_from_slice(b"\r\n")
buffer.extend_from_slice(b"transfer-encoding: chunked\r\n")
}
BodyLength::Empty => buffer.extend_from_slice(b"content-length: 0\r\n"),
BodyLength::None | BodyLength::Stream => (),
}
let mut has_date = false;
for (key, value) in &msg.headers {
match *key {
TRANSFER_ENCODING => continue,
CONTENT_LENGTH => match length {
BodyLength::None => (),
BodyLength::Empty => len_is_set = true,
_ => continue,
},
TRANSFER_ENCODING | CONNECTION | CONTENT_LENGTH => continue,
DATE => has_date = true,
UPGRADE => self.flags.insert(Flags::UPGRADE),
_ => (),
}
@ -197,12 +185,19 @@ impl ClientCodecInner {
buffer.put_slice(b"\r\n");
}
// set content length
if !len_is_set {
buffer.extend_from_slice(b"content-length: 0\r\n")
// Connection header
if msg.upgrade() {
self.flags.set(Flags::UPGRADE, msg.upgrade());
buffer.extend_from_slice(b"connection: upgrade\r\n");
} else if msg.keep_alive() {
if self.version < Version::HTTP_11 {
buffer.extend_from_slice(b"connection: keep-alive\r\n");
}
} else if self.version >= Version::HTTP_11 {
buffer.extend_from_slice(b"connection: close\r\n");
}
// set date header
// Date header
if !has_date {
self.config.set_date(buffer);
} else {
@ -276,7 +271,7 @@ impl Encoder for ClientCodec {
) -> Result<(), Self::Error> {
match item {
Message::Item((msg, btype)) => {
self.inner.encode_response(msg, btype, dst)?;
self.inner.encode_request(msg, btype, dst)?;
}
Message::Chunk(Some(bytes)) => {
self.inner.te.encode(bytes.as_ref(), dst)?;

View File

@ -13,8 +13,8 @@ use config::ServiceConfig;
use error::ParseError;
use helpers;
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{Method, Version};
use message::ResponseHead;
use http::{Method, StatusCode, Version};
use message::{Head, ResponseHead};
use request::Request;
use response::Response;
@ -99,69 +99,71 @@ impl Codec {
}
/// prepare transfer encoding
pub fn prepare_te(&mut self, head: &mut ResponseHead, length: &mut BodyLength) {
fn prepare_te(&mut self, head: &mut ResponseHead, length: BodyLength) {
self.te
.update(head, self.flags.contains(Flags::HEAD), self.version, length);
}
fn encode_response(
&mut self,
mut msg: Response<()>,
msg: &mut ResponseHead,
length: BodyLength,
buffer: &mut BytesMut,
) -> io::Result<()> {
let ka = self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg
.keep_alive()
.unwrap_or_else(|| self.flags.contains(Flags::KEEPALIVE));
msg.version = self.version;
// Connection upgrade
if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
self.flags.remove(Flags::KEEPALIVE);
msg.headers_mut()
msg.headers
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
// keep-alive
else if ka {
else if self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg.keep_alive() {
self.flags.insert(Flags::KEEPALIVE);
if self.version < Version::HTTP_11 {
msg.headers_mut()
msg.headers
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
}
} else if self.version >= Version::HTTP_11 {
self.flags.remove(Flags::KEEPALIVE);
msg.headers_mut()
msg.headers
.insert(CONNECTION, HeaderValue::from_static("close"));
}
// render message
{
let reason = msg.reason().as_bytes();
buffer
.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len());
buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
// status line
helpers::write_status_line(self.version, msg.status().as_u16(), buffer);
helpers::write_status_line(self.version, msg.status.as_u16(), buffer);
buffer.extend_from_slice(reason);
// content length
let mut len_is_set = true;
match self.te.length {
BodyLength::Chunked => {
buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
}
BodyLength::Empty => {
len_is_set = false;
buffer.extend_from_slice(b"\r\n")
}
BodyLength::Sized(len) => helpers::write_content_length(len, buffer),
BodyLength::Sized64(len) => {
buffer.extend_from_slice(b"\r\ncontent-length: ");
write!(buffer.writer(), "{}", len)?;
buffer.extend_from_slice(b"\r\n");
}
BodyLength::None | BodyLength::Stream => {
buffer.extend_from_slice(b"\r\n")
}
match msg.status {
StatusCode::NO_CONTENT
| StatusCode::CONTINUE
| StatusCode::SWITCHING_PROTOCOLS
| StatusCode::PROCESSING => buffer.extend_from_slice(b"\r\n"),
_ => match length {
BodyLength::Chunked => {
buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
}
BodyLength::Empty => {
buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n");
}
BodyLength::Sized(len) => helpers::write_content_length(len, buffer),
BodyLength::Sized64(len) => {
buffer.extend_from_slice(b"\r\ncontent-length: ");
write!(buffer.writer(), "{}", len)?;
buffer.extend_from_slice(b"\r\n");
}
BodyLength::None | BodyLength::Stream => {
buffer.extend_from_slice(b"\r\n")
}
},
}
// write headers
@ -169,16 +171,9 @@ impl Codec {
let mut has_date = false;
let mut remaining = buffer.remaining_mut();
let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) };
for (key, value) in msg.headers() {
for (key, value) in &msg.headers {
match *key {
TRANSFER_ENCODING => continue,
CONTENT_LENGTH => match self.te.length {
BodyLength::None => (),
BodyLength::Empty => {
len_is_set = true;
}
_ => continue,
},
TRANSFER_ENCODING | CONTENT_LENGTH => continue,
DATE => {
has_date = true;
}
@ -213,9 +208,6 @@ impl Codec {
unsafe {
buffer.advance_mut(pos);
}
if !len_is_set {
buffer.extend_from_slice(b"content-length: 0\r\n")
}
// optimized date header, set_date writes \r\n
if !has_date {
@ -268,7 +260,7 @@ impl Decoder for Codec {
}
impl Encoder for Codec {
type Item = Message<Response<()>>;
type Item = Message<(Response<()>, BodyLength)>;
type Error = io::Error;
fn encode(
@ -277,8 +269,9 @@ impl Encoder for Codec {
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
match item {
Message::Item(res) => {
self.encode_response(res, dst)?;
Message::Item((mut res, length)) => {
self.prepare_te(res.head_mut(), length);
self.encode_response(res.head_mut(), length, dst)?;
}
Message::Chunk(Some(bytes)) => {
self.te.encode(bytes.as_ref(), dst)?;

View File

@ -10,7 +10,7 @@ use client::ClientResponse;
use error::ParseError;
use http::header::{HeaderName, HeaderValue};
use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version};
use message::MessageFlags;
use message::Head;
use request::Request;
const MAX_BUFFER_SIZE: usize = 131_072;
@ -50,6 +50,8 @@ pub(crate) enum PayloadLength {
pub(crate) trait MessageTypeDecoder: Sized {
fn keep_alive(&mut self);
fn force_close(&mut self);
fn headers_mut(&mut self) -> &mut HeaderMap;
fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError>;
@ -137,6 +139,8 @@ pub(crate) trait MessageTypeDecoder: Sized {
if ka {
self.keep_alive();
} else {
self.force_close();
}
// https://tools.ietf.org/html/rfc7230#section-3.3.3
@ -160,7 +164,11 @@ pub(crate) trait MessageTypeDecoder: Sized {
impl MessageTypeDecoder for Request {
fn keep_alive(&mut self) {
self.inner_mut().flags.set(MessageFlags::KEEPALIVE);
self.inner_mut().head.set_keep_alive()
}
fn force_close(&mut self) {
self.inner_mut().head.force_close()
}
fn headers_mut(&mut self) -> &mut HeaderMap {
@ -234,7 +242,11 @@ impl MessageTypeDecoder for Request {
impl MessageTypeDecoder for ClientResponse {
fn keep_alive(&mut self) {
self.head.flags.insert(MessageFlags::KEEPALIVE);
self.head.set_keep_alive();
}
fn force_close(&mut self) {
self.head.force_close();
}
fn headers_mut(&mut self) -> &mut HeaderMap {

View File

@ -30,7 +30,6 @@ bitflags! {
const KEEPALIVE_ENABLED = 0b0000_0010;
const KEEPALIVE = 0b0000_0100;
const POLLED = 0b0000_1000;
const FLUSHED = 0b0001_0000;
const SHUTDOWN = 0b0010_0000;
const DISCONNECTED = 0b0100_0000;
}
@ -105,9 +104,9 @@ where
) -> Self {
let keepalive = config.keep_alive_enabled();
let flags = if keepalive {
Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED
Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED
} else {
Flags::FLUSHED
Flags::empty()
};
let framed = Framed::new(stream, Codec::new(config.clone()));
@ -167,7 +166,7 @@ where
/// Flush stream
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
if !self.flags.contains(Flags::FLUSHED) {
if !self.framed.is_write_buf_empty() {
match self.framed.poll_complete() {
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
@ -179,7 +178,6 @@ where
if self.payload.is_some() && self.state.is_empty() {
return Err(DispatchError::PayloadIsNotConsumed);
}
self.flags.insert(Flags::FLUSHED);
Ok(Async::Ready(()))
}
}
@ -194,7 +192,7 @@ where
body: B1,
) -> Result<State<S, B1>, DispatchError<S::Error>> {
self.framed
.force_send(Message::Item(message))
.force_send(Message::Item((message, body.length())))
.map_err(|err| {
if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete(None));
@ -204,7 +202,6 @@ where
self.flags
.set(Flags::KEEPALIVE, self.framed.get_codec().keepalive());
self.flags.remove(Flags::FLUSHED);
match body.length() {
BodyLength::None | BodyLength::Empty => Ok(State::None),
_ => Ok(State::SendPayload(body)),
@ -228,10 +225,7 @@ where
State::ServiceCall(mut fut) => {
match fut.poll().map_err(DispatchError::Service)? {
Async::Ready(mut res) => {
let (mut res, body) = res.replace_body(());
self.framed
.get_codec_mut()
.prepare_te(res.head_mut(), &mut body.length());
let (res, body) = res.replace_body(());
Some(self.send_response(res, body)?)
}
Async::NotReady => {
@ -248,13 +242,11 @@ where
.map_err(|_| DispatchError::Unknown)?
{
Async::Ready(Some(item)) => {
self.flags.remove(Flags::FLUSHED);
self.framed
.force_send(Message::Chunk(Some(item)))?;
continue;
}
Async::Ready(None) => {
self.flags.remove(Flags::FLUSHED);
self.framed.force_send(Message::Chunk(None))?;
}
Async::NotReady => {
@ -296,10 +288,7 @@ where
let mut task = self.service.call(req);
match task.poll().map_err(DispatchError::Service)? {
Async::Ready(res) => {
let (mut res, body) = res.replace_body(());
self.framed
.get_codec_mut()
.prepare_te(res.head_mut(), &mut body.length());
let (res, body) = res.replace_body(());
self.send_response(res, body)
}
Async::NotReady => Ok(State::ServiceCall(task)),
@ -408,7 +397,7 @@ where
/// keep-alive timer
fn poll_keepalive(&mut self) -> Result<(), DispatchError<S::Error>> {
if self.ka_timer.is_some() {
if self.ka_timer.is_none() {
return Ok(());
}
match self.ka_timer.as_mut().unwrap().poll().map_err(|e| {
@ -421,7 +410,7 @@ where
return Err(DispatchError::DisconnectTimeout);
} else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire {
// check for any outstanding response processing
if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) {
if self.state.is_empty() && self.framed.is_write_buf_empty() {
if self.flags.contains(Flags::STARTED) {
trace!("Keep-alive timeout, close connection");
self.flags.insert(Flags::SHUTDOWN);
@ -490,12 +479,14 @@ where
inner.poll_response()?;
inner.poll_flush()?;
if inner.flags.contains(Flags::DISCONNECTED) {
return Ok(Async::Ready(H1ServiceResult::Disconnected));
}
// keep-alive and stream errors
if inner.state.is_empty() && inner.flags.contains(Flags::FLUSHED) {
if inner.state.is_empty() && inner.framed.is_write_buf_empty() {
if let Some(err) = inner.error.take() {
return Err(err);
} else if inner.flags.contains(Flags::DISCONNECTED) {
return Ok(Async::Ready(H1ServiceResult::Disconnected));
}
// unhandled request (upgrade or connect)
else if inner.unhandled.is_some() {

View File

@ -48,22 +48,13 @@ impl ResponseEncoder {
resp: &mut ResponseHead,
head: bool,
version: Version,
length: &mut BodyLength,
length: BodyLength,
) {
self.head = head;
let transfer = match length {
BodyLength::Empty => {
match resp.status {
StatusCode::NO_CONTENT
| StatusCode::CONTINUE
| StatusCode::SWITCHING_PROTOCOLS
| StatusCode::PROCESSING => *length = BodyLength::None,
_ => (),
}
TransferEncoding::empty()
}
BodyLength::Sized(len) => TransferEncoding::length(*len as u64),
BodyLength::Sized64(len) => TransferEncoding::length(*len),
BodyLength::Empty => TransferEncoding::empty(),
BodyLength::Sized(len) => TransferEncoding::length(len as u64),
BodyLength::Sized64(len) => TransferEncoding::length(len),
BodyLength::Chunked => TransferEncoding::chunked(),
BodyLength::Stream => TransferEncoding::eof(),
BodyLength::None => TransferEncoding::length(0),

View File

@ -109,6 +109,8 @@ extern crate serde_derive;
#[cfg(feature = "ssl")]
extern crate openssl;
extern crate backtrace;
pub mod body;
pub mod client;
mod config;
@ -173,5 +175,4 @@ pub mod http {
pub use header::*;
}
pub use header::ContentEncoding;
pub use response::ConnectionType;
}

View File

@ -12,12 +12,41 @@ use uri::Url;
pub trait Head: Default + 'static {
fn clear(&mut self);
fn flags(&self) -> MessageFlags;
fn flags_mut(&mut self) -> &mut MessageFlags;
fn pool() -> &'static MessagePool<Self>;
/// Set upgrade
fn set_upgrade(&mut self) {
*self.flags_mut() = MessageFlags::UPGRADE;
}
/// Check if request is upgrade request
fn upgrade(&self) -> bool {
self.flags().contains(MessageFlags::UPGRADE)
}
/// Set keep-alive
fn set_keep_alive(&mut self) {
*self.flags_mut() = MessageFlags::KEEP_ALIVE;
}
/// Check if request is keep-alive
fn keep_alive(&self) -> bool;
/// Set force-close connection
fn force_close(&mut self) {
*self.flags_mut() = MessageFlags::FORCE_CLOSE;
}
}
bitflags! {
pub(crate) struct MessageFlags: u8 {
const KEEPALIVE = 0b0000_0001;
pub struct MessageFlags: u8 {
const KEEP_ALIVE = 0b0000_0001;
const FORCE_CLOSE = 0b0000_0010;
const UPGRADE = 0b0000_0100;
}
}
@ -47,6 +76,25 @@ impl Head for RequestHead {
self.flags = MessageFlags::empty();
}
fn flags(&self) -> MessageFlags {
self.flags
}
fn flags_mut(&mut self) -> &mut MessageFlags {
&mut self.flags
}
/// Check if request is keep-alive
fn keep_alive(&self) -> bool {
if self.flags().contains(MessageFlags::FORCE_CLOSE) {
false
} else if self.flags().contains(MessageFlags::KEEP_ALIVE) {
true
} else {
self.version <= Version::HTTP_11
}
}
fn pool() -> &'static MessagePool<Self> {
REQUEST_POOL.with(|p| *p)
}
@ -79,11 +127,44 @@ impl Head for ResponseHead {
self.flags = MessageFlags::empty();
}
fn flags(&self) -> MessageFlags {
self.flags
}
fn flags_mut(&mut self) -> &mut MessageFlags {
&mut self.flags
}
/// Check if response is keep-alive
fn keep_alive(&self) -> bool {
if self.flags().contains(MessageFlags::FORCE_CLOSE) {
false
} else if self.flags().contains(MessageFlags::KEEP_ALIVE) {
true
} else {
self.version <= Version::HTTP_11
}
}
fn pool() -> &'static MessagePool<Self> {
RESPONSE_POOL.with(|p| *p)
}
}
impl ResponseHead {
/// Get custom reason for the response
#[inline]
pub fn reason(&self) -> &str {
if let Some(reason) = self.reason {
reason
} else {
self.status
.canonical_reason()
.unwrap_or("<unknown status code>")
}
}
}
pub struct Message<T: Head> {
pub head: T,
pub url: Url,

View File

@ -8,7 +8,7 @@ use extensions::Extensions;
use httpmessage::HttpMessage;
use payload::Payload;
use message::{Message, MessageFlags, MessagePool, RequestHead};
use message::{Head, Message, MessagePool, RequestHead};
/// Request
pub struct Request {
@ -116,7 +116,7 @@ impl Request {
/// Checks if a connection should be kept alive.
#[inline]
pub fn keep_alive(&self) -> bool {
self.inner().flags.get().contains(MessageFlags::KEEPALIVE)
self.inner().head.keep_alive()
}
/// Request extensions

View File

@ -20,17 +20,6 @@ use message::{Head, MessageFlags, ResponseHead};
/// max write buffer size 64k
pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536;
/// Represents various types of connection
#[derive(Copy, Clone, PartialEq, Debug)]
pub enum ConnectionType {
/// Close connection after response
Close,
/// Keep connection alive after response
KeepAlive,
/// Connection is upgraded to different type
Upgrade,
}
/// An HTTP Response
pub struct Response<B: MessageBody = Body>(Box<InnerResponse>, B);
@ -124,27 +113,6 @@ impl<B: MessageBody> Response<B> {
&mut self.get_mut().head.status
}
/// Get custom reason for the response
#[inline]
pub fn reason(&self) -> &str {
if let Some(reason) = self.get_ref().head.reason {
reason
} else {
self.get_ref()
.head
.status
.canonical_reason()
.unwrap_or("<unknown status code>")
}
}
/// Set the custom reason for the response
#[inline]
pub fn set_reason(&mut self, reason: &'static str) -> &mut Self {
self.get_mut().head.reason = Some(reason);
self
}
/// Get the headers from the response
#[inline]
pub fn headers(&self) -> &HeaderMap {
@ -207,28 +175,15 @@ impl<B: MessageBody> Response<B> {
count
}
/// Set connection type
pub fn set_connection_type(&mut self, conn: ConnectionType) -> &mut Self {
self.get_mut().connection_type = Some(conn);
self
}
/// Connection upgrade status
#[inline]
pub fn upgrade(&self) -> bool {
self.get_ref().connection_type == Some(ConnectionType::Upgrade)
self.get_ref().head.upgrade()
}
/// Keep-alive status for this connection
pub fn keep_alive(&self) -> Option<bool> {
if let Some(ct) = self.get_ref().connection_type {
match ct {
ConnectionType::KeepAlive => Some(true),
ConnectionType::Close | ConnectionType::Upgrade => Some(false),
}
} else {
None
}
pub fn keep_alive(&self) -> bool {
self.get_ref().head.keep_alive()
}
/// Get body os this response
@ -275,19 +230,20 @@ impl<B: MessageBody> Response<B> {
}
}
impl fmt::Debug for Response {
impl<B: MessageBody> fmt::Debug for Response<B> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!(
f,
"\nResponse {:?} {}{}",
self.get_ref().head.version,
self.get_ref().head.status,
self.get_ref().head.reason.unwrap_or("")
self.get_ref().head.reason.unwrap_or(""),
);
let _ = writeln!(f, " headers:");
for (key, val) in self.get_ref().head.headers.iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val);
}
let _ = writeln!(f, " body: {:?}", self.body().length());
res
}
}
@ -400,27 +356,31 @@ impl ResponseBuilder {
self
}
/// Set connection type
/// Set connection type to KeepAlive
#[inline]
#[doc(hidden)]
pub fn connection_type(&mut self, conn: ConnectionType) -> &mut Self {
pub fn keep_alive(&mut self) -> &mut Self {
if let Some(parts) = parts(&mut self.response, &self.err) {
parts.connection_type = Some(conn);
parts.head.set_keep_alive();
}
self
}
/// Set connection type to Upgrade
#[inline]
#[doc(hidden)]
pub fn upgrade(&mut self) -> &mut Self {
self.connection_type(ConnectionType::Upgrade)
if let Some(parts) = parts(&mut self.response, &self.err) {
parts.head.set_upgrade();
}
self
}
/// Force close connection, even if it is marked as keep-alive
#[inline]
pub fn force_close(&mut self) -> &mut Self {
self.connection_type(ConnectionType::Close)
if let Some(parts) = parts(&mut self.response, &self.err) {
parts.head.force_close();
}
self
}
/// Set response content type
@ -719,8 +679,6 @@ impl From<BytesMut> for Response {
struct InnerResponse {
head: ResponseHead,
connection_type: Option<ConnectionType>,
write_capacity: usize,
response_size: u64,
error: Option<Error>,
pool: &'static ResponsePool,
@ -728,7 +686,6 @@ struct InnerResponse {
pub(crate) struct ResponseParts {
head: ResponseHead,
connection_type: Option<ConnectionType>,
error: Option<Error>,
}
@ -744,9 +701,7 @@ impl InnerResponse {
flags: MessageFlags::empty(),
},
pool,
connection_type: None,
response_size: 0,
write_capacity: MAX_WRITE_BUFFER_SIZE,
error: None,
}
}
@ -755,7 +710,6 @@ impl InnerResponse {
fn into_parts(self) -> ResponseParts {
ResponseParts {
head: self.head,
connection_type: self.connection_type,
error: self.error,
}
}
@ -763,9 +717,7 @@ impl InnerResponse {
fn from_parts(parts: ResponseParts) -> InnerResponse {
InnerResponse {
head: parts.head,
connection_type: parts.connection_type,
response_size: 0,
write_capacity: MAX_WRITE_BUFFER_SIZE,
error: parts.error,
pool: ResponsePool::pool(),
}
@ -838,10 +790,8 @@ impl ResponsePool {
let mut p = inner.pool.0.borrow_mut();
if p.len() < 128 {
inner.head.clear();
inner.connection_type = None;
inner.response_size = 0;
inner.error = None;
inner.write_capacity = MAX_WRITE_BUFFER_SIZE;
p.push_front(inner);
}
}
@ -937,7 +887,7 @@ mod tests {
#[test]
fn test_force_close() {
let resp = Response::build(StatusCode::OK).force_close().finish();
assert!(!resp.keep_alive().unwrap())
assert!(!resp.keep_alive())
}
#[test]

View File

@ -3,10 +3,10 @@ use std::marker::PhantomData;
use actix_net::codec::Framed;
use actix_net::service::{NewService, Service};
use futures::future::{ok, Either, FutureResult};
use futures::{Async, AsyncSink, Future, Poll, Sink};
use futures::{Async, Future, Poll, Sink};
use tokio_io::{AsyncRead, AsyncWrite};
use body::MessageBody;
use body::{BodyLength, MessageBody};
use error::{Error, ResponseError};
use h1::{Codec, Message};
use response::Response;
@ -15,7 +15,7 @@ pub struct SendError<T, R, E>(PhantomData<(T, R, E)>);
impl<T, R, E> Default for SendError<T, R, E>
where
T: AsyncWrite,
T: AsyncRead + AsyncWrite,
E: ResponseError,
{
fn default() -> Self {
@ -25,7 +25,7 @@ where
impl<T, R, E> NewService for SendError<T, R, E>
where
T: AsyncWrite,
T: AsyncRead + AsyncWrite,
E: ResponseError,
{
type Request = Result<R, (E, Framed<T, Codec>)>;
@ -42,7 +42,7 @@ where
impl<T, R, E> Service for SendError<T, R, E>
where
T: AsyncWrite,
T: AsyncRead + AsyncWrite,
E: ResponseError,
{
type Request = Result<R, (E, Framed<T, Codec>)>;
@ -62,7 +62,7 @@ where
let (res, _body) = res.replace_body(());
Either::B(SendErrorFut {
framed: Some(framed),
res: Some(res.into()),
res: Some((res, BodyLength::Empty).into()),
err: Some(e),
_t: PhantomData,
})
@ -72,7 +72,7 @@ where
}
pub struct SendErrorFut<T, R, E> {
res: Option<Message<Response<()>>>,
res: Option<Message<(Response<()>, BodyLength)>>,
framed: Option<Framed<T, Codec>>,
err: Option<E>,
_t: PhantomData<R>,
@ -81,22 +81,15 @@ pub struct SendErrorFut<T, R, E> {
impl<T, R, E> Future for SendErrorFut<T, R, E>
where
E: ResponseError,
T: AsyncWrite,
T: AsyncRead + AsyncWrite,
{
type Item = R;
type Error = (E, Framed<T, Codec>);
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(res) = self.res.take() {
match self.framed.as_mut().unwrap().start_send(res) {
Ok(AsyncSink::Ready) => (),
Ok(AsyncSink::NotReady(res)) => {
self.res = Some(res);
return Ok(Async::NotReady);
}
Err(_) => {
return Err((self.err.take().unwrap(), self.framed.take().unwrap()))
}
if let Err(_) = self.framed.as_mut().unwrap().force_send(res) {
return Err((self.err.take().unwrap(), self.framed.take().unwrap()));
}
}
match self.framed.as_mut().unwrap().poll_complete() {
@ -123,20 +116,15 @@ where
B: MessageBody,
{
pub fn send(
mut framed: Framed<T, Codec>,
framed: Framed<T, Codec>,
res: Response<B>,
) -> impl Future<Item = Framed<T, Codec>, Error = Error> {
// extract body from response
let (mut res, body) = res.replace_body(());
// init codec
framed
.get_codec_mut()
.prepare_te(&mut res.head_mut(), &mut body.length());
let (res, body) = res.replace_body(());
// write response
SendResponseFut {
res: Some(Message::Item(res)),
res: Some(Message::Item((res, body.length()))),
body: Some(body),
framed: Some(framed),
}
@ -174,13 +162,10 @@ where
Ok(Async::Ready(()))
}
fn call(&mut self, (res, mut framed): Self::Request) -> Self::Future {
let (mut res, body) = res.replace_body(());
framed
.get_codec_mut()
.prepare_te(res.head_mut(), &mut body.length());
fn call(&mut self, (res, framed): Self::Request) -> Self::Future {
let (res, body) = res.replace_body(());
SendResponseFut {
res: Some(Message::Item(res)),
res: Some(Message::Item((res, body.length()))),
body: Some(body),
framed: Some(framed),
}
@ -188,7 +173,7 @@ where
}
pub struct SendResponseFut<T, B> {
res: Option<Message<Response<()>>>,
res: Option<Message<(Response<()>, BodyLength)>>,
body: Option<B>,
framed: Option<Framed<T, Codec>>,
}

View File

@ -8,7 +8,7 @@ use std::io;
use error::ResponseError;
use http::{header, Method, StatusCode};
use request::Request;
use response::{ConnectionType, Response, ResponseBuilder};
use response::{Response, ResponseBuilder};
mod client;
mod codec;
@ -183,7 +183,7 @@ pub fn handshake_response(req: &Request) -> ResponseBuilder {
};
Response::build(StatusCode::SWITCHING_PROTOCOLS)
.connection_type(ConnectionType::Upgrade)
.upgrade()
.header(header::UPGRADE, "websocket")
.header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())

View File

@ -12,10 +12,13 @@ use actix_net::server::Server;
use actix_net::service::NewServiceExt;
use actix_web::{client, test, HttpMessage};
use bytes::Bytes;
use futures::future::{self, ok};
use futures::future::{self, lazy, ok};
use futures::stream::once;
use actix_http::{body, h1, http, Body, Error, KeepAlive, Request, Response};
use actix_http::{
body, client as client2, h1, http, Body, Error, HttpMessage as HttpMessage2,
KeepAlive, Request, Response,
};
#[test]
fn test_h1_v2() {
@ -181,14 +184,19 @@ fn test_headers() {
.unwrap()
.run()
});
thread::sleep(time::Duration::from_millis(400));
thread::sleep(time::Duration::from_millis(200));
let mut sys = System::new("test");
let req = client::ClientRequest::get(format!("http://{}/", addr))
let mut connector = sys
.block_on(lazy(|| {
Ok::<_, ()>(client2::Connector::default().service())
})).unwrap();
let req = client2::ClientRequest::get(format!("http://{}/", addr))
.finish()
.unwrap();
let response = sys.block_on(req.send()).unwrap();
let response = sys.block_on(req.send(&mut connector)).unwrap();
assert!(response.status().is_success());
// read response
@ -249,9 +257,7 @@ fn test_head_empty() {
thread::spawn(move || {
Server::new()
.bind("test", addr, move || {
h1::H1Service::new(|_| {
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish())
}).map(|_| ())
h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))).map(|_| ())
}).unwrap()
.run()
});