From 03912d208928db99d49ac6a0604c9fc284e36fbc Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 19 Feb 2018 22:48:27 -0800 Subject: [PATCH] support client request's async body --- src/client/pipeline.rs | 173 +++++++++++++++++++++++++++++++++++++++-- src/client/writer.rs | 23 ++++-- src/error.rs | 6 +- src/pipeline.rs | 3 - src/server/encoding.rs | 2 +- src/ws/client.rs | 12 +-- tests/test_server.rs | 24 ++++++ tests/test_ws.rs | 2 +- 8 files changed, 217 insertions(+), 28 deletions(-) diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index 273a679ac..e7d10efaa 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -1,10 +1,15 @@ use std::{io, mem}; use bytes::{Bytes, BytesMut}; use futures::{Async, Future, Poll}; +use futures::unsync::oneshot; use actix::prelude::*; +use error::Error; +use body::{Body, BodyStream}; +use context::{Frame, ActorHttpContext}; use error::PayloadError; +use server::WriterState; use server::shared::SharedBytes; use super::{ClientRequest, ClientResponse}; use super::{Connect, Connection, ClientConnector, ClientConnectorError}; @@ -39,7 +44,7 @@ enum State { } /// `SendRequest` is a `Future` which represents asynchronous request sending process. -#[must_use = "SendRequest do nothing unless polled"] +#[must_use = "SendRequest does nothing unless polled"] pub struct SendRequest { req: ClientRequest, state: State, @@ -79,13 +84,25 @@ impl Future for SendRequest { }, Ok(Async::Ready(result)) => match result { Ok(stream) => { + let mut writer = HttpClientWriter::new(SharedBytes::default()); + writer.start(&mut self.req)?; + + let body = match self.req.replace_body(Body::Empty) { + Body::Streaming(stream) => IoBody::Payload(stream), + Body::Actor(ctx) => IoBody::Actor(ctx), + _ => IoBody::Done, + }; + let mut pl = Box::new(Pipeline { + body: body, conn: stream, - writer: HttpClientWriter::new(SharedBytes::default()), + writer: writer, parser: HttpResponseParser::default(), parser_buf: BytesMut::new(), + disconnected: false, + running: RunningState::Running, + drain: None, }); - pl.writer.start(&mut self.req)?; self.state = State::Send(pl); }, Err(err) => return Err(SendRequestError::Connector(err)), @@ -94,7 +111,10 @@ impl Future for SendRequest { return Err(SendRequestError::Connector(ClientConnectorError::Disconnected)) }, State::Send(mut pl) => { - pl.poll_write()?; + pl.poll_write() + .map_err(|e| io::Error::new( + io::ErrorKind::Other, format!("{}", e).as_str()))?; + match pl.parse() { Ok(Async::Ready(mut resp)) => { resp.set_pipeline(pl); @@ -115,10 +135,42 @@ impl Future for SendRequest { pub(crate) struct Pipeline { + body: IoBody, conn: Connection, writer: HttpClientWriter, parser: HttpResponseParser, parser_buf: BytesMut, + disconnected: bool, + running: RunningState, + drain: Option>, +} + +enum IoBody { + Payload(BodyStream), + Actor(Box), + Done, +} + +#[derive(PartialEq)] +enum RunningState { + Running, + Paused, + Done, +} + +impl RunningState { + #[inline] + fn pause(&mut self) { + if *self != RunningState::Done { + *self = RunningState::Paused + } + } + #[inline] + fn resume(&mut self) { + if *self != RunningState::Done { + *self = RunningState::Running + } + } } impl Pipeline { @@ -130,12 +182,117 @@ impl Pipeline { #[inline] pub fn poll(&mut self) -> Poll, PayloadError> { - self.poll_write()?; - self.parser.parse_payload(&mut self.conn, &mut self.parser_buf) + self.poll_write() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e).as_str()))?; + Ok(self.parser.parse_payload(&mut self.conn, &mut self.parser_buf)?) } #[inline] - pub fn poll_write(&mut self) -> Poll<(), io::Error> { - self.writer.poll_completed(&mut self.conn, false) + pub fn poll_write(&mut self) -> Poll<(), Error> { + if self.running == RunningState::Done { + return Ok(Async::Ready(())) + } + + let mut done = false; + + if self.drain.is_none() && self.running != RunningState::Paused { + 'outter: loop { + let result = match mem::replace(&mut self.body, IoBody::Done) { + IoBody::Payload(mut body) => { + match body.poll()? { + Async::Ready(None) => { + self.writer.write_eof()?; + self.disconnected = true; + break + }, + Async::Ready(Some(chunk)) => { + self.body = IoBody::Payload(body); + self.writer.write(chunk.into())? + } + Async::NotReady => { + done = true; + self.body = IoBody::Payload(body); + break + }, + } + }, + IoBody::Actor(mut ctx) => { + if self.disconnected { + ctx.disconnected(); + } + match ctx.poll()? { + Async::Ready(Some(vec)) => { + if vec.is_empty() { + self.body = IoBody::Actor(ctx); + break + } + let mut res = None; + for frame in vec { + match frame { + Frame::Chunk(None) => { + // info.context = Some(ctx); + self.writer.write_eof()?; + break 'outter + }, + Frame::Chunk(Some(chunk)) => + res = Some(self.writer.write(chunk)?), + Frame::Drain(fut) => self.drain = Some(fut), + } + } + self.body = IoBody::Actor(ctx); + if self.drain.is_some() { + self.running.resume(); + break + } + res.unwrap() + }, + Async::Ready(None) => { + done = true; + break + } + Async::NotReady => { + done = true; + self.body = IoBody::Actor(ctx); + break + } + } + }, + IoBody::Done => { + done = true; + break + } + }; + + match result { + WriterState::Pause => { + self.running.pause(); + break + } + WriterState::Done => { + self.running.resume() + }, + } + } + } + + // flush io but only if we need to + match self.writer.poll_completed(&mut self.conn, false) { + Ok(Async::Ready(_)) => { + self.running.resume(); + + // resolve drain futures + if let Some(tx) = self.drain.take() { + let _ = tx.send(()); + } + // restart io processing + if !done { + self.poll_write() + } else { + Ok(Async::NotReady) + } + }, + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(err) => Err(err.into()), + } } } diff --git a/src/client/writer.rs b/src/client/writer.rs index 432baea92..ed63e166a 100644 --- a/src/client/writer.rs +++ b/src/client/writer.rs @@ -111,6 +111,10 @@ impl HttpClientWriter { buffer.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE); } + if msg.upgrade() { + self.flags.insert(Flags::UPGRADE); + } + // status line let _ = write!(buffer, "{} {} {:?}\r\n", msg.method(), msg.uri().path(), msg.version()); @@ -145,10 +149,14 @@ impl HttpClientWriter { Ok(()) } - pub fn write(&mut self, payload: &Binary) -> io::Result { + pub fn write(&mut self, payload: Binary) -> io::Result { self.written += payload.len() as u64; if !self.flags.contains(Flags::DISCONNECTED) { - self.buffer.extend_from_slice(payload.as_ref()) + if self.flags.contains(Flags::UPGRADE) { + self.buffer.extend(payload); + } else { + self.encoder.write(payload)?; + } } if self.buffer.len() > self.high { @@ -158,11 +166,14 @@ impl HttpClientWriter { } } - pub fn write_eof(&mut self) -> io::Result { - if self.buffer.len() > self.high { - Ok(WriterState::Pause) + pub fn write_eof(&mut self) -> io::Result<()> { + self.encoder.write_eof()?; + + if !self.encoder.is_eof() { + Err(io::Error::new(io::ErrorKind::Other, + "Last payload item, but eof is not reached")) } else { - Ok(WriterState::Done) + Ok(()) } } diff --git a/src/error.rs b/src/error.rs index da6745fd5..513c0f4d0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -227,9 +227,9 @@ pub enum PayloadError { /// A payload length is unknown. #[fail(display="A payload length is unknown.")] UnknownLength, - /// Parse error + /// Io error #[fail(display="{}", _0)] - ParseError(#[cause] IoError), + Io(#[cause] IoError), /// Http2 error #[fail(display="{}", _0)] Http2(#[cause] Http2Error), @@ -237,7 +237,7 @@ pub enum PayloadError { impl From for PayloadError { fn from(err: IoError) -> PayloadError { - PayloadError::ParseError(err) + PayloadError::Io(err) } } diff --git a/src/pipeline.rs b/src/pipeline.rs index 83714e09a..2408bf93b 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -494,7 +494,6 @@ impl ProcessResponse { IOState::Payload(mut body) => { match body.poll() { Ok(Async::Ready(None)) => { - self.iostate = IOState::Done; if let Err(err) = io.write_eof() { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init(info, self.resp)) @@ -536,7 +535,6 @@ impl ProcessResponse { match frame { Frame::Chunk(None) => { info.context = Some(ctx); - self.iostate = IOState::Done; if let Err(err) = io.write_eof() { info.error = Some(err.into()); return Ok( @@ -566,7 +564,6 @@ impl ProcessResponse { res.unwrap() }, Ok(Async::Ready(None)) => { - self.iostate = IOState::Done; break } Ok(Async::NotReady) => { diff --git a/src/server/encoding.rs b/src/server/encoding.rs index 2a503bfbc..964754ab0 100644 --- a/src/server/encoding.rs +++ b/src/server/encoding.rs @@ -262,7 +262,7 @@ impl PayloadWriter for EncodedPayload { self.error = true; self.decoder = Decoder::Identity; if let Some(err) = err { - self.set_error(PayloadError::ParseError(err)); + self.set_error(PayloadError::Io(err)); } else { self.set_error(PayloadError::Incomplete); } diff --git a/src/ws/client.rs b/src/ws/client.rs index 98e5b35b9..1d34b864b 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -445,7 +445,7 @@ impl WsClientWriter { /// Write payload #[inline] - fn write(&mut self, data: &Binary) { + fn write(&mut self, data: Binary) { if !self.as_mut().closed { let _ = self.as_mut().writer.write(data); } else { @@ -456,30 +456,30 @@ impl WsClientWriter { /// Send text frame #[inline] pub fn text>(&mut self, text: T) { - self.write(&Frame::message(text.into(), OpCode::Text, true, true)); + self.write(Frame::message(text.into(), OpCode::Text, true, true)); } /// Send binary frame #[inline] pub fn binary>(&mut self, data: B) { - self.write(&Frame::message(data, OpCode::Binary, true, true)); + self.write(Frame::message(data, OpCode::Binary, true, true)); } /// Send ping frame #[inline] pub fn ping(&mut self, message: &str) { - self.write(&Frame::message(Vec::from(message), OpCode::Ping, true, true)); + self.write(Frame::message(Vec::from(message), OpCode::Ping, true, true)); } /// Send pong frame #[inline] pub fn pong(&mut self, message: &str) { - self.write(&Frame::message(Vec::from(message), OpCode::Pong, true, true)); + self.write(Frame::message(Vec::from(message), OpCode::Pong, true, true)); } /// Send close frame #[inline] pub fn close(&mut self, code: CloseCode, reason: &str) { - self.write(&Frame::close(code, reason, true)); + self.write(Frame::close(code, reason, true)); } } diff --git a/tests/test_server.rs b/tests/test_server.rs index cbe289203..2cbeba8fc 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -545,6 +545,30 @@ fn test_h2() { // assert_eq!(_res.unwrap(), Bytes::from_static(STR.as_ref())); } +#[test] +fn test_client_streaming_explicit() { + let mut srv = test::TestServer::new( + |app| app.handler( + |req: HttpRequest| req.body() + .map_err(Error::from) + .and_then(|body| { + Ok(httpcodes::HTTPOk.build() + .chunked() + .content_encoding(headers::ContentEncoding::Identity) + .body(body)?)}) + .responder())); + + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + + let request = srv.get().body(Body::Streaming(Box::new(body))).unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + #[test] fn test_application() { let mut srv = test::TestServer::with_factory( diff --git a/tests/test_ws.rs b/tests/test_ws.rs index cb5a7426c..ac7119914 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -49,5 +49,5 @@ fn test_simple() { writer.close(ws::CloseCode::Normal, ""); let (item, _) = srv.execute(reader.into_future()).unwrap(); - assert!(item.is_none()) + assert!(item.is_none()); }