diff --git a/Cargo.toml b/Cargo.toml index 80d24595..2eef1c50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,9 +91,11 @@ native-tls = { version="0.2", optional = true } # openssl openssl = { version="0.10", optional = true } -#rustls +# rustls rustls = { version = "^0.14", optional = true } +backtrace="*" + [dev-dependencies] actix-web = "0.7" env_logger = "0.5" diff --git a/src/body.rs b/src/body.rs index a44bc603..3449448e 100644 --- a/src/body.rs +++ b/src/body.rs @@ -9,7 +9,7 @@ use error::{Error, PayloadError}; /// Type represent streaming payload pub type PayloadStream = Box>; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Copy, Clone)] /// Different type of body pub enum BodyLength { None, @@ -76,10 +76,11 @@ impl MessageBody for Body { Body::None => Ok(Async::Ready(None)), Body::Empty => Ok(Async::Ready(None)), Body::Bytes(ref mut bin) => { - if bin.len() == 0 { + let len = bin.len(); + if len == 0 { Ok(Async::Ready(None)) } else { - Ok(Async::Ready(Some(bin.slice_to(bin.len())))) + Ok(Async::Ready(Some(bin.split_to(len)))) } } Body::Message(ref mut body) => body.poll_next(), diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index 8be860ae..63551cff 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -33,7 +33,7 @@ where .from_err() // create Framed and send reqest .map(|io| Framed::new(io, h1::ClientCodec::default())) - .and_then(|framed| framed.send((head, len).into()).from_err()) + .and_then(move |framed| framed.send((head, len).into()).from_err()) // send request body .and_then(move |framed| match body.length() { BodyLength::None | BodyLength::Empty | BodyLength::Sized(0) => { diff --git a/src/client/request.rs b/src/client/request.rs index dd418a6f..cc65b9db 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -16,7 +16,7 @@ use http::{ uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method, Uri, Version, }; -use message::RequestHead; +use message::{Head, RequestHead}; use super::response::ClientResponse; use super::{pipeline, Connect, Connection, ConnectorError, SendRequestError}; @@ -365,8 +365,21 @@ impl ClientRequestBuilder { where V: IntoHeaderValue, { + { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.set_upgrade(); + } + } self.set_header(header::UPGRADE, value) - .set_header(header::CONNECTION, "upgrade") + } + + /// Close connection + #[inline] + pub fn close(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.force_close(); + } + self } /// Set request's content type diff --git a/src/client/response.rs b/src/client/response.rs index 41c18562..dc7b13c1 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -8,7 +8,7 @@ use http::{HeaderMap, StatusCode, Version}; use body::PayloadStream; use error::PayloadError; use httpmessage::HttpMessage; -use message::{MessageFlags, ResponseHead}; +use message::{Head, ResponseHead}; use super::pipeline::Payload; @@ -81,7 +81,7 @@ impl ClientResponse { /// Checks if a connection should be kept alive. #[inline] pub fn keep_alive(&self) -> bool { - self.head().flags.contains(MessageFlags::KEEPALIVE) + self.head().keep_alive() } } diff --git a/src/h1/client.rs b/src/h1/client.rs index 2cb2fb2e..e2d1eefe 100644 --- a/src/h1/client.rs +++ b/src/h1/client.rs @@ -16,7 +16,7 @@ use http::header::{ HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, }; use http::{Method, Version}; -use message::{MessagePool, RequestHead}; +use message::{Head, MessagePool, RequestHead}; bitflags! { struct Flags: u8 { @@ -135,7 +135,7 @@ fn prn_version(ver: Version) -> &'static str { } impl ClientCodecInner { - fn encode_response( + fn encode_request( &mut self, msg: RequestHead, length: BodyLength, @@ -146,7 +146,7 @@ impl ClientCodecInner { // status line write!( Writer(buffer), - "{} {} {}", + "{} {} {}\r\n", msg.method, msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), prn_version(msg.version) @@ -156,38 +156,26 @@ impl ClientCodecInner { buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE); // content length - let mut len_is_set = true; match length { BodyLength::Sized(len) => helpers::write_content_length(len, buffer), BodyLength::Sized64(len) => { - buffer.extend_from_slice(b"\r\ncontent-length: "); + buffer.extend_from_slice(b"content-length: "); write!(buffer.writer(), "{}", len)?; buffer.extend_from_slice(b"\r\n"); } BodyLength::Chunked => { - buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") - } - BodyLength::Empty => { - len_is_set = false; - buffer.extend_from_slice(b"\r\n") - } - BodyLength::None | BodyLength::Stream => { - buffer.extend_from_slice(b"\r\n") + buffer.extend_from_slice(b"transfer-encoding: chunked\r\n") } + BodyLength::Empty => buffer.extend_from_slice(b"content-length: 0\r\n"), + BodyLength::None | BodyLength::Stream => (), } let mut has_date = false; for (key, value) in &msg.headers { match *key { - TRANSFER_ENCODING => continue, - CONTENT_LENGTH => match length { - BodyLength::None => (), - BodyLength::Empty => len_is_set = true, - _ => continue, - }, + TRANSFER_ENCODING | CONNECTION | CONTENT_LENGTH => continue, DATE => has_date = true, - UPGRADE => self.flags.insert(Flags::UPGRADE), _ => (), } @@ -197,12 +185,19 @@ impl ClientCodecInner { buffer.put_slice(b"\r\n"); } - // set content length - if !len_is_set { - buffer.extend_from_slice(b"content-length: 0\r\n") + // Connection header + if msg.upgrade() { + self.flags.set(Flags::UPGRADE, msg.upgrade()); + buffer.extend_from_slice(b"connection: upgrade\r\n"); + } else if msg.keep_alive() { + if self.version < Version::HTTP_11 { + buffer.extend_from_slice(b"connection: keep-alive\r\n"); + } + } else if self.version >= Version::HTTP_11 { + buffer.extend_from_slice(b"connection: close\r\n"); } - // set date header + // Date header if !has_date { self.config.set_date(buffer); } else { @@ -276,7 +271,7 @@ impl Encoder for ClientCodec { ) -> Result<(), Self::Error> { match item { Message::Item((msg, btype)) => { - self.inner.encode_response(msg, btype, dst)?; + self.inner.encode_request(msg, btype, dst)?; } Message::Chunk(Some(bytes)) => { self.inner.te.encode(bytes.as_ref(), dst)?; diff --git a/src/h1/codec.rs b/src/h1/codec.rs index b4a62a50..117c8cde 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -13,8 +13,8 @@ use config::ServiceConfig; use error::ParseError; use helpers; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; -use http::{Method, Version}; -use message::ResponseHead; +use http::{Method, StatusCode, Version}; +use message::{Head, ResponseHead}; use request::Request; use response::Response; @@ -99,69 +99,71 @@ impl Codec { } /// prepare transfer encoding - pub fn prepare_te(&mut self, head: &mut ResponseHead, length: &mut BodyLength) { + fn prepare_te(&mut self, head: &mut ResponseHead, length: BodyLength) { self.te .update(head, self.flags.contains(Flags::HEAD), self.version, length); } fn encode_response( &mut self, - mut msg: Response<()>, + msg: &mut ResponseHead, + length: BodyLength, buffer: &mut BytesMut, ) -> io::Result<()> { - let ka = self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg - .keep_alive() - .unwrap_or_else(|| self.flags.contains(Flags::KEEPALIVE)); + msg.version = self.version; // Connection upgrade if msg.upgrade() { self.flags.insert(Flags::UPGRADE); self.flags.remove(Flags::KEEPALIVE); - msg.headers_mut() + msg.headers .insert(CONNECTION, HeaderValue::from_static("upgrade")); } // keep-alive - else if ka { + else if self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg.keep_alive() { self.flags.insert(Flags::KEEPALIVE); if self.version < Version::HTTP_11 { - msg.headers_mut() + msg.headers .insert(CONNECTION, HeaderValue::from_static("keep-alive")); } } else if self.version >= Version::HTTP_11 { self.flags.remove(Flags::KEEPALIVE); - msg.headers_mut() + msg.headers .insert(CONNECTION, HeaderValue::from_static("close")); } // render message { let reason = msg.reason().as_bytes(); - buffer - .reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len()); + buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); // status line - helpers::write_status_line(self.version, msg.status().as_u16(), buffer); + helpers::write_status_line(self.version, msg.status.as_u16(), buffer); buffer.extend_from_slice(reason); // content length - let mut len_is_set = true; - match self.te.length { - BodyLength::Chunked => { - buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") - } - BodyLength::Empty => { - len_is_set = false; - buffer.extend_from_slice(b"\r\n") - } - BodyLength::Sized(len) => helpers::write_content_length(len, buffer), - BodyLength::Sized64(len) => { - buffer.extend_from_slice(b"\r\ncontent-length: "); - write!(buffer.writer(), "{}", len)?; - buffer.extend_from_slice(b"\r\n"); - } - BodyLength::None | BodyLength::Stream => { - buffer.extend_from_slice(b"\r\n") - } + match msg.status { + StatusCode::NO_CONTENT + | StatusCode::CONTINUE + | StatusCode::SWITCHING_PROTOCOLS + | StatusCode::PROCESSING => buffer.extend_from_slice(b"\r\n"), + _ => match length { + BodyLength::Chunked => { + buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") + } + BodyLength::Empty => { + buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n"); + } + BodyLength::Sized(len) => helpers::write_content_length(len, buffer), + BodyLength::Sized64(len) => { + buffer.extend_from_slice(b"\r\ncontent-length: "); + write!(buffer.writer(), "{}", len)?; + buffer.extend_from_slice(b"\r\n"); + } + BodyLength::None | BodyLength::Stream => { + buffer.extend_from_slice(b"\r\n") + } + }, } // write headers @@ -169,16 +171,9 @@ impl Codec { let mut has_date = false; let mut remaining = buffer.remaining_mut(); let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) }; - for (key, value) in msg.headers() { + for (key, value) in &msg.headers { match *key { - TRANSFER_ENCODING => continue, - CONTENT_LENGTH => match self.te.length { - BodyLength::None => (), - BodyLength::Empty => { - len_is_set = true; - } - _ => continue, - }, + TRANSFER_ENCODING | CONTENT_LENGTH => continue, DATE => { has_date = true; } @@ -213,9 +208,6 @@ impl Codec { unsafe { buffer.advance_mut(pos); } - if !len_is_set { - buffer.extend_from_slice(b"content-length: 0\r\n") - } // optimized date header, set_date writes \r\n if !has_date { @@ -268,7 +260,7 @@ impl Decoder for Codec { } impl Encoder for Codec { - type Item = Message>; + type Item = Message<(Response<()>, BodyLength)>; type Error = io::Error; fn encode( @@ -277,8 +269,9 @@ impl Encoder for Codec { dst: &mut BytesMut, ) -> Result<(), Self::Error> { match item { - Message::Item(res) => { - self.encode_response(res, dst)?; + Message::Item((mut res, length)) => { + self.prepare_te(res.head_mut(), length); + self.encode_response(res.head_mut(), length, dst)?; } Message::Chunk(Some(bytes)) => { self.te.encode(bytes.as_ref(), dst)?; diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index 26154ef1..12008a77 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -10,7 +10,7 @@ use client::ClientResponse; use error::ParseError; use http::header::{HeaderName, HeaderValue}; use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version}; -use message::MessageFlags; +use message::Head; use request::Request; const MAX_BUFFER_SIZE: usize = 131_072; @@ -50,6 +50,8 @@ pub(crate) enum PayloadLength { pub(crate) trait MessageTypeDecoder: Sized { fn keep_alive(&mut self); + fn force_close(&mut self); + fn headers_mut(&mut self) -> &mut HeaderMap; fn decode(src: &mut BytesMut) -> Result, ParseError>; @@ -137,6 +139,8 @@ pub(crate) trait MessageTypeDecoder: Sized { if ka { self.keep_alive(); + } else { + self.force_close(); } // https://tools.ietf.org/html/rfc7230#section-3.3.3 @@ -160,7 +164,11 @@ pub(crate) trait MessageTypeDecoder: Sized { impl MessageTypeDecoder for Request { fn keep_alive(&mut self) { - self.inner_mut().flags.set(MessageFlags::KEEPALIVE); + self.inner_mut().head.set_keep_alive() + } + + fn force_close(&mut self) { + self.inner_mut().head.force_close() } fn headers_mut(&mut self) -> &mut HeaderMap { @@ -234,7 +242,11 @@ impl MessageTypeDecoder for Request { impl MessageTypeDecoder for ClientResponse { fn keep_alive(&mut self) { - self.head.flags.insert(MessageFlags::KEEPALIVE); + self.head.set_keep_alive(); + } + + fn force_close(&mut self) { + self.head.force_close(); } fn headers_mut(&mut self) -> &mut HeaderMap { diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index ff9d54e6..bf0abb04 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -30,7 +30,6 @@ bitflags! { const KEEPALIVE_ENABLED = 0b0000_0010; const KEEPALIVE = 0b0000_0100; const POLLED = 0b0000_1000; - const FLUSHED = 0b0001_0000; const SHUTDOWN = 0b0010_0000; const DISCONNECTED = 0b0100_0000; } @@ -105,9 +104,9 @@ where ) -> Self { let keepalive = config.keep_alive_enabled(); let flags = if keepalive { - Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED + Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED } else { - Flags::FLUSHED + Flags::empty() }; let framed = Framed::new(stream, Codec::new(config.clone())); @@ -167,7 +166,7 @@ where /// Flush stream fn poll_flush(&mut self) -> Poll<(), DispatchError> { - if !self.flags.contains(Flags::FLUSHED) { + if !self.framed.is_write_buf_empty() { match self.framed.poll_complete() { Ok(Async::NotReady) => Ok(Async::NotReady), Err(err) => { @@ -179,7 +178,6 @@ where if self.payload.is_some() && self.state.is_empty() { return Err(DispatchError::PayloadIsNotConsumed); } - self.flags.insert(Flags::FLUSHED); Ok(Async::Ready(())) } } @@ -194,7 +192,7 @@ where body: B1, ) -> Result, DispatchError> { self.framed - .force_send(Message::Item(message)) + .force_send(Message::Item((message, body.length()))) .map_err(|err| { if let Some(mut payload) = self.payload.take() { payload.set_error(PayloadError::Incomplete(None)); @@ -204,7 +202,6 @@ where self.flags .set(Flags::KEEPALIVE, self.framed.get_codec().keepalive()); - self.flags.remove(Flags::FLUSHED); match body.length() { BodyLength::None | BodyLength::Empty => Ok(State::None), _ => Ok(State::SendPayload(body)), @@ -228,10 +225,7 @@ where State::ServiceCall(mut fut) => { match fut.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { - let (mut res, body) = res.replace_body(()); - self.framed - .get_codec_mut() - .prepare_te(res.head_mut(), &mut body.length()); + let (res, body) = res.replace_body(()); Some(self.send_response(res, body)?) } Async::NotReady => { @@ -248,13 +242,11 @@ where .map_err(|_| DispatchError::Unknown)? { Async::Ready(Some(item)) => { - self.flags.remove(Flags::FLUSHED); self.framed .force_send(Message::Chunk(Some(item)))?; continue; } Async::Ready(None) => { - self.flags.remove(Flags::FLUSHED); self.framed.force_send(Message::Chunk(None))?; } Async::NotReady => { @@ -296,10 +288,7 @@ where let mut task = self.service.call(req); match task.poll().map_err(DispatchError::Service)? { Async::Ready(res) => { - let (mut res, body) = res.replace_body(()); - self.framed - .get_codec_mut() - .prepare_te(res.head_mut(), &mut body.length()); + let (res, body) = res.replace_body(()); self.send_response(res, body) } Async::NotReady => Ok(State::ServiceCall(task)), @@ -408,7 +397,7 @@ where /// keep-alive timer fn poll_keepalive(&mut self) -> Result<(), DispatchError> { - if self.ka_timer.is_some() { + if self.ka_timer.is_none() { return Ok(()); } match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { @@ -421,7 +410,7 @@ where return Err(DispatchError::DisconnectTimeout); } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { // check for any outstanding response processing - if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { + if self.state.is_empty() && self.framed.is_write_buf_empty() { if self.flags.contains(Flags::STARTED) { trace!("Keep-alive timeout, close connection"); self.flags.insert(Flags::SHUTDOWN); @@ -490,12 +479,14 @@ where inner.poll_response()?; inner.poll_flush()?; + if inner.flags.contains(Flags::DISCONNECTED) { + return Ok(Async::Ready(H1ServiceResult::Disconnected)); + } + // keep-alive and stream errors - if inner.state.is_empty() && inner.flags.contains(Flags::FLUSHED) { + if inner.state.is_empty() && inner.framed.is_write_buf_empty() { if let Some(err) = inner.error.take() { return Err(err); - } else if inner.flags.contains(Flags::DISCONNECTED) { - return Ok(Async::Ready(H1ServiceResult::Disconnected)); } // unhandled request (upgrade or connect) else if inner.unhandled.is_some() { diff --git a/src/h1/encoder.rs b/src/h1/encoder.rs index fd52d731..1664af16 100644 --- a/src/h1/encoder.rs +++ b/src/h1/encoder.rs @@ -48,22 +48,13 @@ impl ResponseEncoder { resp: &mut ResponseHead, head: bool, version: Version, - length: &mut BodyLength, + length: BodyLength, ) { self.head = head; let transfer = match length { - BodyLength::Empty => { - match resp.status { - StatusCode::NO_CONTENT - | StatusCode::CONTINUE - | StatusCode::SWITCHING_PROTOCOLS - | StatusCode::PROCESSING => *length = BodyLength::None, - _ => (), - } - TransferEncoding::empty() - } - BodyLength::Sized(len) => TransferEncoding::length(*len as u64), - BodyLength::Sized64(len) => TransferEncoding::length(*len), + BodyLength::Empty => TransferEncoding::empty(), + BodyLength::Sized(len) => TransferEncoding::length(len as u64), + BodyLength::Sized64(len) => TransferEncoding::length(len), BodyLength::Chunked => TransferEncoding::chunked(), BodyLength::Stream => TransferEncoding::eof(), BodyLength::None => TransferEncoding::length(0), diff --git a/src/lib.rs b/src/lib.rs index 615f0401..57ff5df7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,6 +109,8 @@ extern crate serde_derive; #[cfg(feature = "ssl")] extern crate openssl; +extern crate backtrace; + pub mod body; pub mod client; mod config; @@ -173,5 +175,4 @@ pub mod http { pub use header::*; } pub use header::ContentEncoding; - pub use response::ConnectionType; } diff --git a/src/message.rs b/src/message.rs index 3dcb203d..e7ce22fc 100644 --- a/src/message.rs +++ b/src/message.rs @@ -12,12 +12,41 @@ use uri::Url; pub trait Head: Default + 'static { fn clear(&mut self); + fn flags(&self) -> MessageFlags; + + fn flags_mut(&mut self) -> &mut MessageFlags; + fn pool() -> &'static MessagePool; + + /// Set upgrade + fn set_upgrade(&mut self) { + *self.flags_mut() = MessageFlags::UPGRADE; + } + + /// Check if request is upgrade request + fn upgrade(&self) -> bool { + self.flags().contains(MessageFlags::UPGRADE) + } + + /// Set keep-alive + fn set_keep_alive(&mut self) { + *self.flags_mut() = MessageFlags::KEEP_ALIVE; + } + + /// Check if request is keep-alive + fn keep_alive(&self) -> bool; + + /// Set force-close connection + fn force_close(&mut self) { + *self.flags_mut() = MessageFlags::FORCE_CLOSE; + } } bitflags! { - pub(crate) struct MessageFlags: u8 { - const KEEPALIVE = 0b0000_0001; + pub struct MessageFlags: u8 { + const KEEP_ALIVE = 0b0000_0001; + const FORCE_CLOSE = 0b0000_0010; + const UPGRADE = 0b0000_0100; } } @@ -47,6 +76,25 @@ impl Head for RequestHead { self.flags = MessageFlags::empty(); } + fn flags(&self) -> MessageFlags { + self.flags + } + + fn flags_mut(&mut self) -> &mut MessageFlags { + &mut self.flags + } + + /// Check if request is keep-alive + fn keep_alive(&self) -> bool { + if self.flags().contains(MessageFlags::FORCE_CLOSE) { + false + } else if self.flags().contains(MessageFlags::KEEP_ALIVE) { + true + } else { + self.version <= Version::HTTP_11 + } + } + fn pool() -> &'static MessagePool { REQUEST_POOL.with(|p| *p) } @@ -79,11 +127,44 @@ impl Head for ResponseHead { self.flags = MessageFlags::empty(); } + fn flags(&self) -> MessageFlags { + self.flags + } + + fn flags_mut(&mut self) -> &mut MessageFlags { + &mut self.flags + } + + /// Check if response is keep-alive + fn keep_alive(&self) -> bool { + if self.flags().contains(MessageFlags::FORCE_CLOSE) { + false + } else if self.flags().contains(MessageFlags::KEEP_ALIVE) { + true + } else { + self.version <= Version::HTTP_11 + } + } + fn pool() -> &'static MessagePool { RESPONSE_POOL.with(|p| *p) } } +impl ResponseHead { + /// Get custom reason for the response + #[inline] + pub fn reason(&self) -> &str { + if let Some(reason) = self.reason { + reason + } else { + self.status + .canonical_reason() + .unwrap_or("") + } + } +} + pub struct Message { pub head: T, pub url: Url, diff --git a/src/request.rs b/src/request.rs index 1e191047..1ee47edb 100644 --- a/src/request.rs +++ b/src/request.rs @@ -8,7 +8,7 @@ use extensions::Extensions; use httpmessage::HttpMessage; use payload::Payload; -use message::{Message, MessageFlags, MessagePool, RequestHead}; +use message::{Head, Message, MessagePool, RequestHead}; /// Request pub struct Request { @@ -116,7 +116,7 @@ impl Request { /// Checks if a connection should be kept alive. #[inline] pub fn keep_alive(&self) -> bool { - self.inner().flags.get().contains(MessageFlags::KEEPALIVE) + self.inner().head.keep_alive() } /// Request extensions diff --git a/src/response.rs b/src/response.rs index a6f7c81c..542d4963 100644 --- a/src/response.rs +++ b/src/response.rs @@ -20,17 +20,6 @@ use message::{Head, MessageFlags, ResponseHead}; /// max write buffer size 64k pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536; -/// Represents various types of connection -#[derive(Copy, Clone, PartialEq, Debug)] -pub enum ConnectionType { - /// Close connection after response - Close, - /// Keep connection alive after response - KeepAlive, - /// Connection is upgraded to different type - Upgrade, -} - /// An HTTP Response pub struct Response(Box, B); @@ -124,27 +113,6 @@ impl Response { &mut self.get_mut().head.status } - /// Get custom reason for the response - #[inline] - pub fn reason(&self) -> &str { - if let Some(reason) = self.get_ref().head.reason { - reason - } else { - self.get_ref() - .head - .status - .canonical_reason() - .unwrap_or("") - } - } - - /// Set the custom reason for the response - #[inline] - pub fn set_reason(&mut self, reason: &'static str) -> &mut Self { - self.get_mut().head.reason = Some(reason); - self - } - /// Get the headers from the response #[inline] pub fn headers(&self) -> &HeaderMap { @@ -207,28 +175,15 @@ impl Response { count } - /// Set connection type - pub fn set_connection_type(&mut self, conn: ConnectionType) -> &mut Self { - self.get_mut().connection_type = Some(conn); - self - } - /// Connection upgrade status #[inline] pub fn upgrade(&self) -> bool { - self.get_ref().connection_type == Some(ConnectionType::Upgrade) + self.get_ref().head.upgrade() } /// Keep-alive status for this connection - pub fn keep_alive(&self) -> Option { - if let Some(ct) = self.get_ref().connection_type { - match ct { - ConnectionType::KeepAlive => Some(true), - ConnectionType::Close | ConnectionType::Upgrade => Some(false), - } - } else { - None - } + pub fn keep_alive(&self) -> bool { + self.get_ref().head.keep_alive() } /// Get body os this response @@ -275,19 +230,20 @@ impl Response { } } -impl fmt::Debug for Response { +impl fmt::Debug for Response { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = writeln!( f, "\nResponse {:?} {}{}", self.get_ref().head.version, self.get_ref().head.status, - self.get_ref().head.reason.unwrap_or("") + self.get_ref().head.reason.unwrap_or(""), ); let _ = writeln!(f, " headers:"); for (key, val) in self.get_ref().head.headers.iter() { let _ = writeln!(f, " {:?}: {:?}", key, val); } + let _ = writeln!(f, " body: {:?}", self.body().length()); res } } @@ -400,27 +356,31 @@ impl ResponseBuilder { self } - /// Set connection type + /// Set connection type to KeepAlive #[inline] - #[doc(hidden)] - pub fn connection_type(&mut self, conn: ConnectionType) -> &mut Self { + pub fn keep_alive(&mut self) -> &mut Self { if let Some(parts) = parts(&mut self.response, &self.err) { - parts.connection_type = Some(conn); + parts.head.set_keep_alive(); } self } /// Set connection type to Upgrade #[inline] - #[doc(hidden)] pub fn upgrade(&mut self) -> &mut Self { - self.connection_type(ConnectionType::Upgrade) + if let Some(parts) = parts(&mut self.response, &self.err) { + parts.head.set_upgrade(); + } + self } /// Force close connection, even if it is marked as keep-alive #[inline] pub fn force_close(&mut self) -> &mut Self { - self.connection_type(ConnectionType::Close) + if let Some(parts) = parts(&mut self.response, &self.err) { + parts.head.force_close(); + } + self } /// Set response content type @@ -719,8 +679,6 @@ impl From for Response { struct InnerResponse { head: ResponseHead, - connection_type: Option, - write_capacity: usize, response_size: u64, error: Option, pool: &'static ResponsePool, @@ -728,7 +686,6 @@ struct InnerResponse { pub(crate) struct ResponseParts { head: ResponseHead, - connection_type: Option, error: Option, } @@ -744,9 +701,7 @@ impl InnerResponse { flags: MessageFlags::empty(), }, pool, - connection_type: None, response_size: 0, - write_capacity: MAX_WRITE_BUFFER_SIZE, error: None, } } @@ -755,7 +710,6 @@ impl InnerResponse { fn into_parts(self) -> ResponseParts { ResponseParts { head: self.head, - connection_type: self.connection_type, error: self.error, } } @@ -763,9 +717,7 @@ impl InnerResponse { fn from_parts(parts: ResponseParts) -> InnerResponse { InnerResponse { head: parts.head, - connection_type: parts.connection_type, response_size: 0, - write_capacity: MAX_WRITE_BUFFER_SIZE, error: parts.error, pool: ResponsePool::pool(), } @@ -838,10 +790,8 @@ impl ResponsePool { let mut p = inner.pool.0.borrow_mut(); if p.len() < 128 { inner.head.clear(); - inner.connection_type = None; inner.response_size = 0; inner.error = None; - inner.write_capacity = MAX_WRITE_BUFFER_SIZE; p.push_front(inner); } } @@ -937,7 +887,7 @@ mod tests { #[test] fn test_force_close() { let resp = Response::build(StatusCode::OK).force_close().finish(); - assert!(!resp.keep_alive().unwrap()) + assert!(!resp.keep_alive()) } #[test] diff --git a/src/service.rs b/src/service.rs index 51a16b48..f934305c 100644 --- a/src/service.rs +++ b/src/service.rs @@ -3,10 +3,10 @@ use std::marker::PhantomData; use actix_net::codec::Framed; use actix_net::service::{NewService, Service}; use futures::future::{ok, Either, FutureResult}; -use futures::{Async, AsyncSink, Future, Poll, Sink}; +use futures::{Async, Future, Poll, Sink}; use tokio_io::{AsyncRead, AsyncWrite}; -use body::MessageBody; +use body::{BodyLength, MessageBody}; use error::{Error, ResponseError}; use h1::{Codec, Message}; use response::Response; @@ -15,7 +15,7 @@ pub struct SendError(PhantomData<(T, R, E)>); impl Default for SendError where - T: AsyncWrite, + T: AsyncRead + AsyncWrite, E: ResponseError, { fn default() -> Self { @@ -25,7 +25,7 @@ where impl NewService for SendError where - T: AsyncWrite, + T: AsyncRead + AsyncWrite, E: ResponseError, { type Request = Result)>; @@ -42,7 +42,7 @@ where impl Service for SendError where - T: AsyncWrite, + T: AsyncRead + AsyncWrite, E: ResponseError, { type Request = Result)>; @@ -62,7 +62,7 @@ where let (res, _body) = res.replace_body(()); Either::B(SendErrorFut { framed: Some(framed), - res: Some(res.into()), + res: Some((res, BodyLength::Empty).into()), err: Some(e), _t: PhantomData, }) @@ -72,7 +72,7 @@ where } pub struct SendErrorFut { - res: Option>>, + res: Option, BodyLength)>>, framed: Option>, err: Option, _t: PhantomData, @@ -81,22 +81,15 @@ pub struct SendErrorFut { impl Future for SendErrorFut where E: ResponseError, - T: AsyncWrite, + T: AsyncRead + AsyncWrite, { type Item = R; type Error = (E, Framed); fn poll(&mut self) -> Poll { if let Some(res) = self.res.take() { - match self.framed.as_mut().unwrap().start_send(res) { - Ok(AsyncSink::Ready) => (), - Ok(AsyncSink::NotReady(res)) => { - self.res = Some(res); - return Ok(Async::NotReady); - } - Err(_) => { - return Err((self.err.take().unwrap(), self.framed.take().unwrap())) - } + if let Err(_) = self.framed.as_mut().unwrap().force_send(res) { + return Err((self.err.take().unwrap(), self.framed.take().unwrap())); } } match self.framed.as_mut().unwrap().poll_complete() { @@ -123,20 +116,15 @@ where B: MessageBody, { pub fn send( - mut framed: Framed, + framed: Framed, res: Response, ) -> impl Future, Error = Error> { // extract body from response - let (mut res, body) = res.replace_body(()); - - // init codec - framed - .get_codec_mut() - .prepare_te(&mut res.head_mut(), &mut body.length()); + let (res, body) = res.replace_body(()); // write response SendResponseFut { - res: Some(Message::Item(res)), + res: Some(Message::Item((res, body.length()))), body: Some(body), framed: Some(framed), } @@ -174,13 +162,10 @@ where Ok(Async::Ready(())) } - fn call(&mut self, (res, mut framed): Self::Request) -> Self::Future { - let (mut res, body) = res.replace_body(()); - framed - .get_codec_mut() - .prepare_te(res.head_mut(), &mut body.length()); + fn call(&mut self, (res, framed): Self::Request) -> Self::Future { + let (res, body) = res.replace_body(()); SendResponseFut { - res: Some(Message::Item(res)), + res: Some(Message::Item((res, body.length()))), body: Some(body), framed: Some(framed), } @@ -188,7 +173,7 @@ where } pub struct SendResponseFut { - res: Option>>, + res: Option, BodyLength)>>, body: Option, framed: Option>, } diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 1fbbd03d..5c86d8c4 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -8,7 +8,7 @@ use std::io; use error::ResponseError; use http::{header, Method, StatusCode}; use request::Request; -use response::{ConnectionType, Response, ResponseBuilder}; +use response::{Response, ResponseBuilder}; mod client; mod codec; @@ -183,7 +183,7 @@ pub fn handshake_response(req: &Request) -> ResponseBuilder { }; Response::build(StatusCode::SWITCHING_PROTOCOLS) - .connection_type(ConnectionType::Upgrade) + .upgrade() .header(header::UPGRADE, "websocket") .header(header::TRANSFER_ENCODING, "chunked") .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) diff --git a/tests/test_server.rs b/tests/test_server.rs index c499cf63..9d16e92e 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -12,10 +12,13 @@ use actix_net::server::Server; use actix_net::service::NewServiceExt; use actix_web::{client, test, HttpMessage}; use bytes::Bytes; -use futures::future::{self, ok}; +use futures::future::{self, lazy, ok}; use futures::stream::once; -use actix_http::{body, h1, http, Body, Error, KeepAlive, Request, Response}; +use actix_http::{ + body, client as client2, h1, http, Body, Error, HttpMessage as HttpMessage2, + KeepAlive, Request, Response, +}; #[test] fn test_h1_v2() { @@ -181,14 +184,19 @@ fn test_headers() { .unwrap() .run() }); - thread::sleep(time::Duration::from_millis(400)); + thread::sleep(time::Duration::from_millis(200)); let mut sys = System::new("test"); - let req = client::ClientRequest::get(format!("http://{}/", addr)) + let mut connector = sys + .block_on(lazy(|| { + Ok::<_, ()>(client2::Connector::default().service()) + })).unwrap(); + + let req = client2::ClientRequest::get(format!("http://{}/", addr)) .finish() .unwrap(); - let response = sys.block_on(req.send()).unwrap(); + let response = sys.block_on(req.send(&mut connector)).unwrap(); assert!(response.status().is_success()); // read response @@ -249,9 +257,7 @@ fn test_head_empty() { thread::spawn(move || { Server::new() .bind("test", addr, move || { - h1::H1Service::new(|_| { - ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish()) - }).map(|_| ()) + h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))).map(|_| ()) }).unwrap() .run() });