From 6297fe0d4117bd93a4e215b36fee91e7bcb8d750 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 14 Nov 2018 09:38:16 -0800 Subject: [PATCH] refactor client response payload handling --- Cargo.toml | 3 +- src/body.rs | 7 ++- src/client/pipeline.rs | 51 +++++++++++++------- src/client/response.rs | 45 +++++++++-------- src/error.rs | 8 +++- src/h1/client.rs | 106 ++++++++++++++++++++++++++++++----------- src/h1/dispatcher.rs | 7 ++- src/h1/mod.rs | 2 +- src/payload.rs | 8 ++-- src/request.rs | 2 +- tests/test_client.rs | 10 ++-- 11 files changed, 166 insertions(+), 83 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e586f34ed..074e53631 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,8 @@ rust-tls = ["rustls", "actix-net/rust-tls"] [dependencies] actix = "0.7.5" #actix-net = "0.2.0" -actix-net = { git="https://github.com/actix/actix-net.git" } +#actix-net = { git="https://github.com/actix/actix-net.git" } +actix-net = { path="../actix-net" } base64 = "0.9" bitflags = "1.0" diff --git a/src/body.rs b/src/body.rs index e001273c4..1165909e8 100644 --- a/src/body.rs +++ b/src/body.rs @@ -4,10 +4,13 @@ use std::{fmt, mem}; use bytes::{Bytes, BytesMut}; use futures::{Async, Poll, Stream}; -use error::Error; +use error::{Error, PayloadError}; /// Type represent streaming body -pub type BodyStream = Box>; +pub type BodyStream = Box>; + +/// Type represent streaming payload +pub type PayloadStream = Box>; /// Different type of bory pub enum BodyType { diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index fff900131..4081b6354 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -11,8 +11,8 @@ use super::error::{ConnectorError, SendRequestError}; use super::request::RequestHead; use super::response::ClientResponse; use super::{Connect, Connection}; -use body::{BodyStream, BodyType, MessageBody}; -use error::Error; +use body::{BodyType, MessageBody, PayloadStream}; +use error::PayloadError; use h1; pub fn send_request( @@ -44,7 +44,7 @@ where let mut res = item.into_item().unwrap(); match framed.get_codec().message_type() { h1::MessageType::None => release_connection(framed), - _ => res.payload = Some(Payload::stream(framed)), + _ => *res.payload.borrow_mut() = Some(Payload::stream(framed)), } ok(res) } else { @@ -129,41 +129,56 @@ where } } -struct Payload { - framed: Option, h1::ClientCodec>>, +struct EmptyPayload; + +impl Stream for EmptyPayload { + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll, Self::Error> { + Ok(Async::Ready(None)) + } +} + +pub(crate) struct Payload { + framed: Option, h1::ClientPayloadCodec>>, +} + +impl Payload<()> { + pub fn empty() -> PayloadStream { + Box::new(EmptyPayload) + } } impl Payload { - fn stream(framed: Framed, h1::ClientCodec>) -> BodyStream { + fn stream(framed: Framed, h1::ClientCodec>) -> PayloadStream { Box::new(Payload { - framed: Some(framed), + framed: Some(framed.map_codec(|codec| codec.into_payload_codec())), }) } } impl Stream for Payload { type Item = Bytes; - type Error = Error; + type Error = PayloadError; - fn poll(&mut self) -> Poll, Error> { + fn poll(&mut self) -> Poll, Self::Error> { match self.framed.as_mut().unwrap().poll()? { Async::NotReady => Ok(Async::NotReady), - Async::Ready(Some(chunk)) => match chunk { - h1::Message::Chunk(Some(chunk)) => Ok(Async::Ready(Some(chunk))), - h1::Message::Chunk(None) => { - release_connection(self.framed.take().unwrap()); - Ok(Async::Ready(None)) - } - h1::Message::Item(_) => unreachable!(), + Async::Ready(Some(chunk)) => if let Some(chunk) = chunk { + Ok(Async::Ready(Some(chunk))) + } else { + release_connection(self.framed.take().unwrap()); + Ok(Async::Ready(None)) }, Async::Ready(None) => Ok(Async::Ready(None)), } } } -fn release_connection(framed: Framed, h1::ClientCodec>) +fn release_connection(framed: Framed, U>) where - Io: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + 'static, { let parts = framed.into_parts(); if parts.read_buf.is_empty() && parts.write_buf.is_empty() { diff --git a/src/client/response.rs b/src/client/response.rs index 0d5a87a0d..e8e63e4f7 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -6,34 +6,37 @@ use bytes::Bytes; use futures::{Async, Poll, Stream}; use http::{HeaderMap, Method, StatusCode, Version}; -use body::BodyStream; -use error::Error; +use body::PayloadStream; +use error::PayloadError; use extensions::Extensions; +use httpmessage::HttpMessage; use request::{Message, MessageFlags, MessagePool}; use uri::Url; +use super::pipeline::Payload; + /// Client Response pub struct ClientResponse { pub(crate) inner: Rc, - pub(crate) payload: Option, + pub(crate) payload: RefCell>, } -// impl HttpMessage for ClientResponse { -// type Stream = Payload; +impl HttpMessage for ClientResponse { + type Stream = PayloadStream; -// fn headers(&self) -> &HeaderMap { -// &self.inner.headers -// } + fn headers(&self) -> &HeaderMap { + &self.inner.headers + } -// #[inline] -// fn payload(&self) -> Payload { -// if let Some(payload) = self.inner.payload.borrow_mut().take() { -// payload -// } else { -// Payload::empty() -// } -// } -// } + #[inline] + fn payload(&self) -> Self::Stream { + if let Some(payload) = self.payload.borrow_mut().take() { + payload + } else { + Payload::empty() + } + } +} impl ClientResponse { /// Create new Request instance @@ -55,7 +58,7 @@ impl ClientResponse { payload: RefCell::new(None), extensions: RefCell::new(Extensions::new()), }), - payload: None, + payload: RefCell::new(None), } } @@ -114,10 +117,10 @@ impl ClientResponse { impl Stream for ClientResponse { type Item = Bytes; - type Error = Error; + type Error = PayloadError; - fn poll(&mut self) -> Poll, Error> { - if let Some(ref mut payload) = self.payload { + fn poll(&mut self) -> Poll, Self::Error> { + if let Some(ref mut payload) = &mut *self.payload.borrow_mut() { payload.poll() } else { Ok(Async::Ready(None)) diff --git a/src/error.rs b/src/error.rs index 3064cda49..956ec4eb4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -339,7 +339,7 @@ impl From for ParseError { pub enum PayloadError { /// A payload reached EOF, but is not complete. #[fail(display = "A payload reached EOF, but is not complete.")] - Incomplete, + Incomplete(Option), /// Content encoding stream corruption #[fail(display = "Can not decode content-encoding.")] EncodingCorrupted, @@ -351,6 +351,12 @@ pub enum PayloadError { UnknownLength, } +impl From for PayloadError { + fn from(err: io::Error) -> Self { + PayloadError::Incomplete(Some(err)) + } +} + /// `PayloadError` returns two possible results: /// /// - `Overflow` returns `PayloadTooLarge` diff --git a/src/h1/client.rs b/src/h1/client.rs index 9ace98e0e..eb7bdc233 100644 --- a/src/h1/client.rs +++ b/src/h1/client.rs @@ -10,7 +10,7 @@ use super::{Message, MessageType}; use body::{Binary, Body, BodyType}; use client::{ClientResponse, RequestHead}; use config::ServiceConfig; -use error::ParseError; +use error::{ParseError, PayloadError}; use helpers; use http::header::{ HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, @@ -32,6 +32,15 @@ const AVERAGE_HEADER_SIZE: usize = 30; /// HTTP/1 Codec pub struct ClientCodec { + inner: ClientCodecInner, +} + +/// HTTP/1 Payload Codec +pub struct ClientPayloadCodec { + inner: ClientCodecInner, +} + +struct ClientCodecInner { config: ServiceConfig, decoder: ResponseDecoder, payload: Option, @@ -65,32 +74,34 @@ impl ClientCodec { Flags::empty() }; ClientCodec { - config, - decoder: ResponseDecoder::with_pool(pool), - payload: None, - version: Version::HTTP_11, + inner: ClientCodecInner { + config, + decoder: ResponseDecoder::with_pool(pool), + payload: None, + version: Version::HTTP_11, - flags, - headers_size: 0, - te: RequestEncoder::default(), + flags, + headers_size: 0, + te: RequestEncoder::default(), + }, } } /// Check if request is upgrade pub fn upgrade(&self) -> bool { - self.flags.contains(Flags::UPGRADE) + self.inner.flags.contains(Flags::UPGRADE) } /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { - self.flags.contains(Flags::KEEPALIVE) + self.inner.flags.contains(Flags::KEEPALIVE) } /// Check last request's message type pub fn message_type(&self) -> MessageType { - if self.flags.contains(Flags::STREAM) { + if self.inner.flags.contains(Flags::STREAM) { MessageType::Stream - } else if self.payload.is_none() { + } else if self.inner.payload.is_none() { MessageType::None } else { MessageType::Payload @@ -99,10 +110,27 @@ impl ClientCodec { /// prepare transfer encoding pub fn prepare_te(&mut self, head: &mut RequestHead, btype: BodyType) { - self.te - .update(head, self.flags.contains(Flags::HEAD), self.version); + self.inner.te.update( + head, + self.inner.flags.contains(Flags::HEAD), + self.inner.version, + ); } + /// Convert message codec to a payload codec + pub fn into_payload_codec(self) -> ClientPayloadCodec { + ClientPayloadCodec { inner: self.inner } + } +} + +impl ClientPayloadCodec { + /// Transform payload codec to a message codec + pub fn into_message_codec(self) -> ClientCodec { + ClientCodec { inner: self.inner } + } +} + +impl ClientCodecInner { fn encode_response( &mut self, msg: RequestHead, @@ -154,25 +182,26 @@ impl Decoder for ClientCodec { type Error = ParseError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if self.payload.is_some() { - Ok(match self.payload.as_mut().unwrap().decode(src)? { + if self.inner.payload.is_some() { + Ok(match self.inner.payload.as_mut().unwrap().decode(src)? { Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))), Some(PayloadItem::Eof) => Some(Message::Chunk(None)), None => None, }) - } else if let Some((req, payload)) = self.decoder.decode(src)? { - self.flags + } else if let Some((req, payload)) = self.inner.decoder.decode(src)? { + self.inner + .flags .set(Flags::HEAD, req.inner.method == Method::HEAD); - self.version = req.inner.version; - if self.flags.contains(Flags::KEEPALIVE_ENABLED) { - self.flags.set(Flags::KEEPALIVE, req.keep_alive()); + self.inner.version = req.inner.version; + if self.inner.flags.contains(Flags::KEEPALIVE_ENABLED) { + self.inner.flags.set(Flags::KEEPALIVE, req.keep_alive()); } match payload { - PayloadType::None => self.payload = None, - PayloadType::Payload(pl) => self.payload = Some(pl), + PayloadType::None => self.inner.payload = None, + PayloadType::Payload(pl) => self.inner.payload = Some(pl), PayloadType::Stream(pl) => { - self.payload = Some(pl); - self.flags.insert(Flags::STREAM); + self.inner.payload = Some(pl); + self.inner.flags.insert(Flags::STREAM); } }; Ok(Some(Message::Item(req))) @@ -182,6 +211,27 @@ impl Decoder for ClientCodec { } } +impl Decoder for ClientPayloadCodec { + type Item = Option; + type Error = PayloadError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + assert!( + self.inner.payload.is_some(), + "Payload decoder is not specified" + ); + + Ok(match self.inner.payload.as_mut().unwrap().decode(src)? { + Some(PayloadItem::Chunk(chunk)) => Some(Some(chunk)), + Some(PayloadItem::Eof) => { + self.inner.payload.take(); + Some(None) + } + None => None, + }) + } +} + impl Encoder for ClientCodec { type Item = Message<(RequestHead, BodyType)>; type Error = io::Error; @@ -193,13 +243,13 @@ impl Encoder for ClientCodec { ) -> Result<(), Self::Error> { match item { Message::Item((msg, btype)) => { - self.encode_response(msg, btype, dst)?; + self.inner.encode_response(msg, btype, dst)?; } Message::Chunk(Some(bytes)) => { - self.te.encode(bytes.as_ref(), dst)?; + self.inner.te.encode(bytes.as_ref(), dst)?; } Message::Chunk(None) => { - self.te.encode_eof(dst)?; + self.inner.te.encode_eof(dst)?; } } Ok(()) diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index b23af9648..71d6435c1 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -143,7 +143,7 @@ where fn client_disconnected(&mut self) { self.flags.insert(Flags::DISCONNECTED); if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); + payload.set_error(PayloadError::Incomplete(None)); } } @@ -228,7 +228,7 @@ where } Err(err) => { if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); + payload.set_error(PayloadError::Incomplete(None)); } return Err(DispatchError::Io(err)); } @@ -236,7 +236,10 @@ where } // Send payload State::SendPayload(ref mut stream, ref mut bin) => { + println!("SEND payload"); if let Some(item) = bin.take() { + let mut framed = self.framed.as_mut().unwrap(); + if framed.is_ match self.framed.as_mut().unwrap().start_send(item) { Ok(AsyncSink::Ready) => { self.flags.remove(Flags::FLUSHED); diff --git a/src/h1/mod.rs b/src/h1/mod.rs index e7e0759b9..21261e99c 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -9,7 +9,7 @@ mod dispatcher; mod encoder; mod service; -pub use self::client::ClientCodec; +pub use self::client::{ClientCodec, ClientPayloadCodec}; pub use self::codec::Codec; pub use self::decoder::{PayloadDecoder, RequestDecoder}; pub use self::dispatcher::Dispatcher; diff --git a/src/payload.rs b/src/payload.rs index 54539c408..b05924969 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -527,7 +527,7 @@ mod tests { #[test] fn test_error() { - let err = PayloadError::Incomplete; + let err = PayloadError::Incomplete(None); assert_eq!( format!("{}", err), "A payload reached EOF, but is not complete." @@ -584,7 +584,7 @@ mod tests { assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); - sender.set_error(PayloadError::Incomplete); + sender.set_error(PayloadError::Incomplete(None)); payload.readany().err().unwrap(); let res: Result<(), ()> = Ok(()); result(res) @@ -644,7 +644,7 @@ mod tests { ); assert_eq!(payload.len, 4); - sender.set_error(PayloadError::Incomplete); + sender.set_error(PayloadError::Incomplete(None)); payload.read_exact(10).err().unwrap(); let res: Result<(), ()> = Ok(()); @@ -677,7 +677,7 @@ mod tests { ); assert_eq!(payload.len, 0); - sender.set_error(PayloadError::Incomplete); + sender.set_error(PayloadError::Incomplete(None)); payload.read_until(b"b").err().unwrap(); let res: Result<(), ()> = Ok(()); diff --git a/src/request.rs b/src/request.rs index 593080fa6..fb5cb183b 100644 --- a/src/request.rs +++ b/src/request.rs @@ -242,7 +242,7 @@ impl MessagePool { } return ClientResponse { inner: msg, - payload: None, + payload: RefCell::new(None), }; } ClientResponse::with_pool(pool) diff --git a/tests/test_client.rs b/tests/test_client.rs index 6741a2c61..40920d1b8 100644 --- a/tests/test_client.rs +++ b/tests/test_client.rs @@ -9,8 +9,10 @@ use std::{thread, time}; use actix::System; use actix_net::server::Server; use actix_net::service::NewServiceExt; +use bytes::Bytes; use futures::future::{self, lazy, ok}; +use actix_http::HttpMessage; use actix_http::{client, h1, test, Request, Response}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ @@ -73,8 +75,8 @@ fn test_h1_v2() { assert!(response.status().is_success()); // read response - // let bytes = srv.execute(response.body()).unwrap(); - // assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); let request = client::ClientRequest::post(format!("http://{}/", addr)) .finish() @@ -83,8 +85,8 @@ fn test_h1_v2() { assert!(response.status().is_success()); // read response - // let bytes = srv.execute(response.body()).unwrap(); - // assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } #[test]