From cb70d5ec3d0fd74e48384791ad55ebb42e64b67a Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 19 Feb 2018 03:11:11 -0800 Subject: [PATCH] refactor http client --- src/client/mod.rs | 8 +- src/client/parser.rs | 120 +++++--------------- src/client/pipeline.rs | 126 +++++++++++++++++++++ src/client/request.rs | 47 +++++++- src/client/response.rs | 249 +++++++++++++++++++---------------------- src/client/writer.rs | 157 +++++++++++++++++++++++++- src/multipart.rs | 25 ++--- src/server/encoding.rs | 4 +- src/ws/client.rs | 3 +- 9 files changed, 483 insertions(+), 256 deletions(-) create mode 100644 src/client/pipeline.rs diff --git a/src/client/mod.rs b/src/client/mod.rs index 66244e415..4240aa3c4 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -2,10 +2,12 @@ mod connector; mod parser; mod request; mod response; +mod pipeline; mod writer; -pub(crate) use self::writer::HttpClientWriter; +pub use self::pipeline::{SendRequest, SendRequestError}; pub use self::request::{ClientRequest, ClientRequestBuilder}; -pub use self::response::{ClientResponse, JsonResponse}; -pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError}; +pub use self::response::{ClientResponse, JsonResponse, UrlEncoded}; pub use self::connector::{Connect, Connection, ClientConnector, ClientConnectorError}; +pub(crate) use self::writer::HttpClientWriter; +pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError}; diff --git a/src/client/parser.rs b/src/client/parser.rs index 32995ecaa..67f65374c 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -2,15 +2,13 @@ use std::mem; use httparse; use http::{Version, HttpTryFrom, HeaderMap, StatusCode}; use http::header::{self, HeaderName, HeaderValue}; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::{Poll, Async}; use error::{ParseError, PayloadError}; -use payload::{Payload, PayloadWriter, DEFAULT_BUFFER_SIZE}; use server::{utils, IoStream}; use server::h1::{Decoder, chunked}; -use server::encoding::PayloadType; use super::ClientResponse; use super::response::ClientMessage; @@ -20,18 +18,7 @@ const MAX_HEADERS: usize = 96; #[derive(Default)] pub struct HttpResponseParser { - payload: Option, -} - -enum Decoding { - Paused, - Ready, - NotReady, -} - -struct PayloadInfo { - tx: PayloadType, - decoder: Decoder, + decoder: Option, } #[derive(Debug)] @@ -47,31 +34,6 @@ impl HttpResponseParser { -> Poll where T: IoStream { - // read payload - if self.payload.is_some() { - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - if let Some(ref mut payload) = self.payload { - payload.tx.set_error(PayloadError::Incomplete); - } - // http channel should not deal with payload errors - return Err(HttpResponseParserError::Payload) - }, - Err(err) => { - if let Some(ref mut payload) = self.payload { - payload.tx.set_error(err.into()); - } - // http channel should not deal with payload errors - return Err(HttpResponseParserError::Payload) - } - _ => (), - } - match self.decode(buf)? { - Decoding::Ready => self.payload = None, - Decoding::Paused | Decoding::NotReady => return Ok(Async::NotReady), - } - } - // if buf is empty parse_message will always return NotReady, let's avoid that let read = if buf.is_empty() { match utils::read_from_io(io, buf) { @@ -93,32 +55,19 @@ impl HttpResponseParser { loop { match HttpResponseParser::parse_message(buf).map_err(HttpResponseParserError::Error)? { Async::Ready((msg, decoder)) => { - // process payload - if let Some(payload) = decoder { - self.payload = Some(payload); - match self.decode(buf)? { - Decoding::Paused | Decoding::NotReady => (), - Decoding::Ready => self.payload = None, - } - } + self.decoder = decoder; return Ok(Async::Ready(msg)); }, Async::NotReady => { if buf.capacity() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(HttpResponseParserError::Error(ParseError::TooLarge)); } if read { match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - debug!("Ignored premature client disconnection"); - return Err(HttpResponseParserError::Disconnect); - }, + Ok(Async::Ready(0)) => return Err(HttpResponseParserError::Disconnect), Ok(Async::Ready(_)) => (), - Ok(Async::NotReady) => - return Ok(Async::NotReady), - Err(err) => - return Err(HttpResponseParserError::Error(err.into())) + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => return Err(HttpResponseParserError::Error(err.into())), } } else { return Ok(Async::NotReady) @@ -128,35 +77,24 @@ impl HttpResponseParser { } } - fn decode(&mut self, buf: &mut BytesMut) -> Result { - if let Some(ref mut payload) = self.payload { - if payload.tx.capacity() > DEFAULT_BUFFER_SIZE { - return Ok(Decoding::Paused) - } - loop { - match payload.decoder.decode(buf) { - Ok(Async::Ready(Some(bytes))) => { - payload.tx.feed_data(bytes) - }, - Ok(Async::Ready(None)) => { - payload.tx.feed_eof(); - return Ok(Decoding::Ready) - }, - Ok(Async::NotReady) => return Ok(Decoding::NotReady), - Err(err) => { - payload.tx.set_error(err.into()); - return Err(HttpResponseParserError::Payload) - } - } + pub fn parse_payload(&mut self, io: &mut T, buf: &mut BytesMut) + -> Poll, PayloadError> + where T: IoStream + { + if let Some(ref mut decoder) = self.decoder { + // read payload + match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => return Err(PayloadError::Incomplete), + Err(err) => return Err(err.into()), + _ => (), } + decoder.decode(buf).map_err(|e| e.into()) } else { - return Ok(Decoding::Ready) + Ok(Async::Ready(None)) } } - fn parse_message(buf: &mut BytesMut) - -> Poll<(ClientResponse, Option), ParseError> - { + fn parse_message(buf: &mut BytesMut) -> Poll<(ClientResponse, Option), ParseError> { // Parse http message let bytes_ptr = buf.as_ref().as_ptr() as usize; let mut headers: [httparse::Header; MAX_HEADERS] = @@ -181,7 +119,6 @@ impl HttpResponseParser { } }; - let slice = buf.split_to(len).freeze(); // convert headers @@ -221,22 +158,19 @@ impl HttpResponseParser { }; if let Some(decoder) = decoder { - let (psender, payload) = Payload::new(false); - let info = PayloadInfo { - tx: PayloadType::new(&hdrs, psender), - decoder: decoder, - }; + //let info = PayloadInfo { + //tx: PayloadType::new(&hdrs, psender), + // decoder: decoder, + //}; Ok(Async::Ready( (ClientResponse::new( - ClientMessage{ - status: status, version: version, - headers: hdrs, cookies: None, payload: Some(payload)}), Some(info)))) + ClientMessage{status: status, version: version, + headers: hdrs, cookies: None}), Some(decoder)))) } else { Ok(Async::Ready( (ClientResponse::new( - ClientMessage{ - status: status, version: version, - headers: hdrs, cookies: None, payload: None}), None))) + ClientMessage{status: status, version: version, + headers: hdrs, cookies: None}), None))) } } } diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs new file mode 100644 index 000000000..8dfdb855d --- /dev/null +++ b/src/client/pipeline.rs @@ -0,0 +1,126 @@ +use std::{io, mem}; +use bytes::{Bytes, BytesMut}; +use futures::{Async, Future, Poll}; + +use actix::prelude::*; + +use error::PayloadError; +use server::shared::SharedBytes; +use super::{ClientRequest, ClientResponse}; +use super::{Connect, Connection, ClientConnector, ClientConnectorError}; +use super::HttpClientWriter; +use super::{HttpResponseParser, HttpResponseParserError}; + +pub enum SendRequestError { + Connector(ClientConnectorError), + ParseError(HttpResponseParserError), + Io(io::Error), +} + +impl From for SendRequestError { + fn from(err: io::Error) -> SendRequestError { + SendRequestError::Io(err) + } +} + +enum State { + New, + Connect(actix::dev::Request), + Send(Box), + None, +} + +/// `SendRequest` is a `Future` which represents asynchronous request sending process. +#[must_use = "SendRequest do nothing unless polled"] +pub struct SendRequest { + req: ClientRequest, + state: State, +} + +impl SendRequest { + pub(crate) fn new(req: ClientRequest) -> SendRequest { + SendRequest{ + req: req, + state: State::New, + } + } +} + +impl Future for SendRequest { + type Item = ClientResponse; + type Error = SendRequestError; + + fn poll(&mut self) -> Poll { + loop { + let state = mem::replace(&mut self.state, State::None); + + match state { + State::New => + self.state = State::Connect( + ClientConnector::from_registry().send(Connect(self.req.uri().clone()))), + State::Connect(mut conn) => match conn.poll() { + Ok(Async::NotReady) => { + self.state = State::Connect(conn); + return Ok(Async::NotReady); + }, + Ok(Async::Ready(result)) => match result { + Ok(stream) => { + let mut pl = Box::new(Pipeline { + conn: stream, + writer: HttpClientWriter::new(SharedBytes::default()), + parser: HttpResponseParser::default(), + parser_buf: BytesMut::new(), + }); + pl.writer.start(&mut self.req)?; + self.state = State::Send(pl); + }, + Err(err) => return Err(SendRequestError::Connector(err)), + }, + Err(_) => + return Err(SendRequestError::Connector(ClientConnectorError::Disconnected)) + }, + State::Send(mut pl) => { + pl.poll_write()?; + match pl.parse() { + Ok(Async::Ready(mut resp)) => { + resp.set_pipeline(pl); + return Ok(Async::Ready(resp)) + }, + Ok(Async::NotReady) => { + self.state = State::Send(pl); + return Ok(Async::NotReady) + }, + Err(err) => return Err(SendRequestError::ParseError(err)) + } + } + State::None => unreachable!(), + } + } + } +} + + +pub(crate) struct Pipeline { + conn: Connection, + writer: HttpClientWriter, + parser: HttpResponseParser, + parser_buf: BytesMut, +} + +impl Pipeline { + + #[inline] + pub fn parse(&mut self) -> Poll { + self.parser.parse(&mut self.conn, &mut self.parser_buf) + } + + #[inline] + pub fn poll(&mut self) -> Poll, PayloadError> { + self.parser.parse_payload(&mut self.conn, &mut self.parser_buf) + } + + #[inline] + pub fn poll_write(&mut self) -> Poll<(), io::Error> { + self.writer.poll_completed(&mut self.conn, false) + } +} diff --git a/src/client/request.rs b/src/client/request.rs index 980f18796..2983094fe 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -11,6 +11,7 @@ use serde::Serialize; use body::Body; use error::Error; use headers::ContentEncoding; +use super::pipeline::SendRequest; /// An HTTP Client Request pub struct ClientRequest { @@ -19,7 +20,8 @@ pub struct ClientRequest { version: Version, headers: HeaderMap, body: Body, - chunked: Option, + chunked: bool, + upgrade: bool, encoding: ContentEncoding, } @@ -32,7 +34,8 @@ impl Default for ClientRequest { version: Version::HTTP_11, headers: HeaderMap::with_capacity(16), body: Body::Empty, - chunked: None, + chunked: false, + upgrade: false, encoding: ContentEncoding::Auto, } } @@ -135,6 +138,24 @@ impl ClientRequest { &mut self.headers } + /// is chunked encoding enabled + #[inline] + pub fn chunked(&self) -> bool { + self.chunked + } + + /// is upgrade request + #[inline] + pub fn upgrade(&self) -> bool { + self.upgrade + } + + /// Content encoding + #[inline] + pub fn content_encoding(&self) -> ContentEncoding { + self.encoding + } + /// Get body os this response #[inline] pub fn body(&self) -> &Body { @@ -146,9 +167,14 @@ impl ClientRequest { self.body = body.into(); } - /// Set a body and return previous body value - pub fn replace_body>(&mut self, body: B) -> Body { - mem::replace(&mut self.body, body.into()) + /// Extract body, replace it with Empty + pub(crate) fn replace_body(&mut self, body: Body) -> Body { + mem::replace(&mut self.body, body) + } + + /// Send request + pub fn send(self) -> SendRequest { + SendRequest::new(self) } } @@ -288,7 +314,16 @@ impl ClientRequestBuilder { #[inline] pub fn chunked(&mut self) -> &mut Self { if let Some(parts) = parts(&mut self.request, &self.err) { - parts.chunked = Some(true); + parts.chunked = true; + } + self + } + + /// Enable connection upgrade + #[inline] + pub fn upgrade(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.upgrade = true; } self } diff --git a/src/client/response.rs b/src/client/response.rs index 9ae45553c..06ccf0837 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -1,21 +1,22 @@ use std::{fmt, str}; use std::rc::Rc; use std::cell::UnsafeCell; +use std::collections::HashMap; use bytes::{Bytes, BytesMut}; use cookie::Cookie; use futures::{Async, Future, Poll, Stream}; -use http_range::HttpRange; use http::{HeaderMap, StatusCode, Version}; use http::header::{self, HeaderValue}; use mime::Mime; use serde_json; use serde::de::DeserializeOwned; +use url::form_urlencoded; -use payload::{Payload, ReadAny}; -use multipart::Multipart; -use httprequest::UrlEncoded; -use error::{CookieParseError, ParseError, PayloadError, JsonPayloadError, HttpRangeError}; +// use multipart::Multipart; +use error::{CookieParseError, ParseError, PayloadError, JsonPayloadError, UrlencodedError}; + +use super::pipeline::Pipeline; pub(crate) struct ClientMessage { @@ -23,7 +24,6 @@ pub(crate) struct ClientMessage { pub version: Version, pub headers: HeaderMap, pub cookies: Option>>, - pub payload: Option, } impl Default for ClientMessage { @@ -34,18 +34,21 @@ impl Default for ClientMessage { version: Version::HTTP_11, headers: HeaderMap::with_capacity(16), cookies: None, - payload: None, } } } /// An HTTP Client response -pub struct ClientResponse(Rc>); +pub struct ClientResponse(Rc>, Option>); impl ClientResponse { pub(crate) fn new(msg: ClientMessage) -> ClientResponse { - ClientResponse(Rc::new(UnsafeCell::new(msg))) + ClientResponse(Rc::new(UnsafeCell::new(msg)), None) + } + + pub(crate) fn set_pipeline(&mut self, pl: Box) { + self.1 = Some(pl); } #[inline] @@ -155,53 +158,12 @@ impl ClientResponse { } } - /// Parses Range HTTP header string as per RFC 2616. - /// `size` is full size of response (file). - pub fn range(&self, size: u64) -> Result, HttpRangeError> { - if let Some(range) = self.headers().get(header::RANGE) { - HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size) - .map_err(|e| e.into()) - } else { - Ok(Vec::new()) - } - } - - /// Returns reference to the associated http payload. - #[inline] - pub fn payload(&self) -> &Payload { - let msg = self.as_mut(); - if msg.payload.is_none() { - msg.payload = Some(Payload::empty()); - } - msg.payload.as_ref().unwrap() - } - - /// Returns mutable reference to the associated http payload. - #[inline] - pub fn payload_mut(&mut self) -> &mut Payload { - let msg = self.as_mut(); - if msg.payload.is_none() { - msg.payload = Some(Payload::empty()); - } - msg.payload.as_mut().unwrap() - } - - /// Load request body. - /// - /// By default only 256Kb payload reads to a memory, then `ResponseBody` - /// resolves to an error. Use `RequestBody::limit()` - /// method to change upper limit. - pub fn body(&self) -> ResponseBody { - ResponseBody::from_response(self) - } - - - /// Return stream to http payload processes as multipart. - /// - /// Content-type: multipart/form-data; - pub fn multipart(&mut self) -> Multipart { - Multipart::from_response(self) - } + // /// Return stream to http payload processes as multipart. + // /// + // /// Content-type: multipart/form-data; + // pub fn multipart(mut self) -> Multipart { + // Multipart::from_response(&mut self) + // } /// Parse `application/x-www-form-urlencoded` encoded body. /// Return `UrlEncoded` future. It resolves to a `HashMap` which @@ -212,10 +174,8 @@ impl ClientResponse { /// * content type is not `application/x-www-form-urlencoded` /// * transfer encoding is `chunked`. /// * content-length is greater than 256k - pub fn urlencoded(&self) -> UrlEncoded { - UrlEncoded::from(self.payload().clone(), - self.headers(), - self.chunked().unwrap_or(false)) + pub fn urlencoded(self) -> UrlEncoded { + UrlEncoded::new(self) } /// Parse `application/json` encoded body. @@ -225,7 +185,7 @@ impl ClientResponse { /// /// * content type is not `application/json` /// * content length is greater than 256k - pub fn json(&self) -> JsonResponse { + pub fn json(self) -> JsonResponse { JsonResponse::from_response(self) } } @@ -247,77 +207,16 @@ impl fmt::Debug for ClientResponse { } } -impl Clone for ClientResponse { - fn clone(&self) -> ClientResponse { - ClientResponse(self.0.clone()) - } -} - /// Future that resolves to a complete request body. -pub struct ResponseBody { - pl: ReadAny, - body: BytesMut, - limit: usize, - resp: Option, -} - -impl ResponseBody { - - /// Create `RequestBody` for request. - pub fn from_response(resp: &ClientResponse) -> ResponseBody { - let pl = resp.payload().readany(); - ResponseBody { - pl: pl, - body: BytesMut::new(), - limit: 262_144, - resp: Some(resp.clone()), - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } -} - -impl Future for ResponseBody { +impl Stream for ClientResponse { type Item = Bytes; type Error = PayloadError; - fn poll(&mut self) -> Poll { - if let Some(resp) = self.resp.take() { - if let Some(len) = resp.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > self.limit { - return Err(PayloadError::Overflow); - } - } else { - return Err(PayloadError::UnknownLength); - } - } else { - return Err(PayloadError::UnknownLength); - } - } - } - - loop { - return match self.pl.poll() { - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(None)) => { - Ok(Async::Ready(self.body.take().freeze())) - }, - Ok(Async::Ready(Some(chunk))) => { - if (self.body.len() + chunk.len()) > self.limit { - Err(PayloadError::Overflow) - } else { - self.body.extend_from_slice(&chunk); - continue - } - }, - Err(err) => Err(err), - } + fn poll(&mut self) -> Poll, Self::Error> { + if let Some(ref mut pl) = self.1 { + pl.poll() + } else { + Ok(Async::Ready(None)) } } } @@ -328,6 +227,7 @@ impl Future for ResponseBody { /// /// * content type is not `application/json` /// * content length is greater than 256k +#[must_use = "JsonResponse does nothing unless polled"] pub struct JsonResponse{ limit: usize, ct: &'static str, @@ -338,12 +238,12 @@ pub struct JsonResponse{ impl JsonResponse { /// Create `JsonBody` for request. - pub fn from_response(resp: &ClientResponse) -> Self { + pub fn from_response(resp: ClientResponse) -> Self { JsonResponse{ limit: 262_144, - resp: Some(resp.clone()), - fut: None, + resp: Some(resp), ct: "application/json", + fut: None, } } @@ -386,8 +286,7 @@ impl Future for JsonResponse { } let limit = self.limit; - let fut = resp.payload().readany() - .from_err() + let fut = resp.from_err() .fold(BytesMut::new(), move |mut body, chunk| { if (body.len() + chunk.len()) > limit { Err(JsonPayloadError::Overflow) @@ -397,9 +296,93 @@ impl Future for JsonResponse { } }) .and_then(|body| Ok(serde_json::from_slice::(&body)?)); + self.fut = Some(Box::new(fut)); } self.fut.as_mut().expect("JsonResponse could not be used second time").poll() } } + +/// Future that resolves to a parsed urlencoded values. +#[must_use = "UrlEncoded does nothing unless polled"] +pub struct UrlEncoded { + resp: Option, + limit: usize, + fut: Option, Error=UrlencodedError>>>, +} + +impl UrlEncoded { + pub fn new(resp: ClientResponse) -> UrlEncoded { + UrlEncoded{resp: Some(resp), + limit: 262_144, + fut: None} + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } +} + +impl Future for UrlEncoded { + type Item = HashMap; + type Error = UrlencodedError; + + fn poll(&mut self) -> Poll { + if let Some(resp) = self.resp.take() { + if resp.chunked().unwrap_or(false) { + return Err(UrlencodedError::Chunked) + } else if let Some(len) = resp.headers().get(header::CONTENT_LENGTH) { + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + if len > 262_144 { + return Err(UrlencodedError::Overflow); + } + } else { + return Err(UrlencodedError::UnknownLength); + } + } else { + return Err(UrlencodedError::UnknownLength); + } + } + + // check content type + let mut encoding = false; + if let Some(content_type) = resp.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + if content_type.to_lowercase() == "application/x-www-form-urlencoded" { + encoding = true; + } + } + } + if !encoding { + return Err(UrlencodedError::ContentType); + } + + // urlencoded body + let limit = self.limit; + let fut = resp.from_err() + .fold(BytesMut::new(), move |mut body, chunk| { + if (body.len() + chunk.len()) > limit { + Err(UrlencodedError::Overflow) + } else { + body.extend_from_slice(&chunk); + Ok(body) + } + }) + .and_then(|body| { + let mut m = HashMap::new(); + for (k, v) in form_urlencoded::parse(&body) { + m.insert(k.into(), v.into()); + } + Ok(m) + }); + + self.fut = Some(Box::new(fut)); + } + + self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() + } +} diff --git a/src/client/writer.rs b/src/client/writer.rs index 4370b37b6..17d43a966 100644 --- a/src/client/writer.rs +++ b/src/client/writer.rs @@ -1,13 +1,20 @@ #![allow(dead_code)] use std::io; use std::fmt::Write; -use bytes::BufMut; +use bytes::{BytesMut, BufMut}; use futures::{Async, Poll}; use tokio_io::AsyncWrite; +use http::{Version, HttpTryFrom}; +use http::header::{HeaderValue, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; +use flate2::Compression; +use flate2::write::{GzEncoder, DeflateEncoder}; +use brotli2::write::BrotliEncoder; -use body::Binary; +use body::{Body, Binary}; +use headers::ContentEncoding; use server::WriterState; use server::shared::SharedBytes; +use server::encoding::{ContentEncoder, TransferEncoding}; use client::ClientRequest; @@ -30,6 +37,7 @@ pub(crate) struct HttpClientWriter { written: u64, headers_size: u32, buffer: SharedBytes, + encoder: ContentEncoder, low: usize, high: usize, } @@ -37,11 +45,13 @@ pub(crate) struct HttpClientWriter { impl HttpClientWriter { pub fn new(buf: SharedBytes) -> HttpClientWriter { + let encoder = ContentEncoder::Identity(TransferEncoding::eof(buf.clone())); HttpClientWriter { flags: Flags::empty(), written: 0, headers_size: 0, buffer: buf, + encoder: encoder, low: LOW_WATERMARK, high: HIGH_WATERMARK, } @@ -87,18 +97,23 @@ impl HttpClientWriter { impl HttpClientWriter { - pub fn start(&mut self, msg: &mut ClientRequest) { + pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> { // prepare task self.flags.insert(Flags::STARTED); + self.encoder = content_encoder(self.buffer.clone(), msg); // render message { let buffer = self.buffer.get_mut(); - buffer.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE); + if let Body::Binary(ref bytes) = *msg.body() { + buffer.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len()); + } else { + buffer.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE); + } // status line let _ = write!(buffer, "{} {} {:?}\r\n", - msg.method(), msg.uri().path(), msg.version()); + msg.method(), msg.uri().path(), msg.version()); // write headers for (key, value) in msg.headers() { @@ -119,7 +134,15 @@ impl HttpClientWriter { buffer.extend_from_slice(b"\r\n"); //} self.headers_size = buffer.len() as u32; + + if msg.body().is_binary() { + if let Body::Binary(bytes) = msg.replace_body(Body::Empty) { + self.written += bytes.len() as u64; + self.encoder.write(bytes)?; + } + } } + Ok(()) } pub fn write(&mut self, payload: &Binary) -> io::Result { @@ -160,3 +183,127 @@ impl HttpClientWriter { } } } + + +fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder { + let version = req.version(); + let mut body = req.replace_body(Body::Empty); + let mut encoding = req.content_encoding(); + + let transfer = match body { + Body::Empty => { + req.headers_mut().remove(CONTENT_LENGTH); + TransferEncoding::length(0, buf) + }, + Body::Binary(ref mut bytes) => { + if encoding.is_compression() { + let tmp = SharedBytes::default(); + let transfer = TransferEncoding::eof(tmp.clone()); + let mut enc = match encoding { + ContentEncoding::Deflate => ContentEncoder::Deflate( + DeflateEncoder::new(transfer, Compression::default())), + ContentEncoding::Gzip => ContentEncoder::Gzip( + GzEncoder::new(transfer, Compression::default())), + ContentEncoding::Br => ContentEncoder::Br( + BrotliEncoder::new(transfer, 5)), + ContentEncoding::Identity => ContentEncoder::Identity(transfer), + ContentEncoding::Auto => unreachable!() + }; + // TODO return error! + let _ = enc.write(bytes.clone()); + let _ = enc.write_eof(); + + *bytes = Binary::from(tmp.take()); + encoding = ContentEncoding::Identity; + } + let mut b = BytesMut::new(); + let _ = write!(b, "{}", bytes.len()); + req.headers_mut().insert( + CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); + TransferEncoding::eof(buf) + }, + Body::Streaming(_) | Body::Actor(_) => { + if req.upgrade() { + if version == Version::HTTP_2 { + error!("Connection upgrade is forbidden for HTTP/2"); + } else { + req.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade")); + } + if encoding != ContentEncoding::Identity { + encoding = ContentEncoding::Identity; + req.headers_mut().remove(CONTENT_ENCODING); + } + TransferEncoding::eof(buf) + } else { + streaming_encoding(buf, version, req) + } + } + }; + + req.replace_body(body); + match encoding { + ContentEncoding::Deflate => ContentEncoder::Deflate( + DeflateEncoder::new(transfer, Compression::default())), + ContentEncoding::Gzip => ContentEncoder::Gzip( + GzEncoder::new(transfer, Compression::default())), + ContentEncoding::Br => ContentEncoder::Br( + BrotliEncoder::new(transfer, 5)), + ContentEncoding::Identity | ContentEncoding::Auto => ContentEncoder::Identity(transfer), + } +} + +fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientRequest) + -> TransferEncoding { + if req.chunked() { + // Enable transfer encoding + req.headers_mut().remove(CONTENT_LENGTH); + if version == Version::HTTP_2 { + req.headers_mut().remove(TRANSFER_ENCODING); + TransferEncoding::eof(buf) + } else { + req.headers_mut().insert( + TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + TransferEncoding::chunked(buf) + } + } else { + // if Content-Length is specified, then use it as length hint + let (len, chunked) = + if let Some(len) = req.headers().get(CONTENT_LENGTH) { + // Content-Length + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + (Some(len), false) + } else { + error!("illegal Content-Length: {:?}", len); + (None, false) + } + } else { + error!("illegal Content-Length: {:?}", len); + (None, false) + } + } else { + (None, true) + }; + + if !chunked { + if let Some(len) = len { + TransferEncoding::length(len, buf) + } else { + TransferEncoding::eof(buf) + } + } else { + // Enable transfer encoding + match version { + Version::HTTP_11 => { + req.headers_mut().insert( + TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + TransferEncoding::chunked(buf) + }, + _ => { + req.headers_mut().remove(TRANSFER_ENCODING); + TransferEncoding::eof(buf) + } + } + } + } +} diff --git a/src/multipart.rs b/src/multipart.rs index 00f2676c4..9da15ba59 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -14,7 +14,6 @@ use futures::task::{Task, current as current_task}; use error::{ParseError, PayloadError, MultipartError}; use payload::Payload; -use client::ClientResponse; use httprequest::HttpRequest; const MAX_HEADERS: usize = 32; @@ -98,18 +97,18 @@ impl Multipart { } } - /// Create multipart instance for client response. - pub fn from_response(resp: &mut ClientResponse) -> Multipart { - match Multipart::boundary(resp.headers()) { - Ok(boundary) => Multipart::new(boundary, resp.payload().clone()), - Err(err) => - Multipart { - error: Some(err), - safety: Safety::new(), - inner: None, - } - } - } + // /// Create multipart instance for client response. + // pub fn from_response(resp: &mut ClientResponse) -> Multipart { + // match Multipart::boundary(resp.headers()) { + // Ok(boundary) => Multipart::new(boundary, resp.payload().clone()), + // Err(err) => + // Multipart { + // error: Some(err), + // safety: Safety::new(), + // inner: None, + // } + // } + // } /// Extract boundary info from headers. pub fn boundary(headers: &HeaderMap) -> Result { diff --git a/src/server/encoding.rs b/src/server/encoding.rs index 125a9e523..d53f921ce 100644 --- a/src/server/encoding.rs +++ b/src/server/encoding.rs @@ -26,7 +26,7 @@ use super::shared::SharedBytes; impl ContentEncoding { #[inline] - fn is_compression(&self) -> bool { + pub fn is_compression(&self) -> bool { match *self { ContentEncoding::Identity | ContentEncoding::Auto => false, _ => true @@ -546,7 +546,7 @@ impl PayloadEncoder { } } -enum ContentEncoder { +pub(crate) enum ContentEncoder { Deflate(DeflateEncoder), Gzip(GzEncoder), Br(BrotliEncoder), diff --git a/src/ws/client.rs b/src/ws/client.rs index ebb2f3849..98e5b35b9 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -178,6 +178,7 @@ impl WsClient { self.request.set_header(header::ORIGIN, origin); } + self.request.upgrade(); self.request.set_header(header::UPGRADE, "websocket"); self.request.set_header(header::CONNECTION, "upgrade"); self.request.set_header("SEC-WEBSOCKET-VERSION", "13"); @@ -265,7 +266,7 @@ impl Future for WsHandshake { if !self.sent { self.sent = true; - inner.writer.start(&mut self.request); + inner.writer.start(&mut self.request)?; } if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) { return Err(err.into())