diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index c7cdcf0ab..80da691cd 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,5 +1,12 @@ # Changes +## + +### Added + +* Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests with `RequestHead` + + ## [0.2.10] - 2019-09-xx ### Fixed diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index 36913c5f0..d2b94b3e5 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -8,7 +8,7 @@ use h2::client::SendRequest; use crate::body::MessageBody; use crate::h1::ClientCodec; -use crate::message::{RequestHead, ResponseHead}; +use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; use super::error::SendRequestError; @@ -27,9 +27,9 @@ pub trait Connection { fn protocol(&self) -> Protocol; /// Send request and body - fn send_request( + fn send_request>( self, - head: RequestHead, + head: H, body: B, ) -> Self::Future; @@ -39,7 +39,7 @@ pub trait Connection { >; /// Send request, returns Response and Framed - fn open_tunnel(self, head: RequestHead) -> Self::TunnelFuture; + fn open_tunnel>(self, head: H) -> Self::TunnelFuture; } pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { @@ -105,22 +105,22 @@ where } } - fn send_request( + fn send_request>( mut self, - head: RequestHead, + head: H, body: B, ) -> Self::Future { match self.io.take().unwrap() { ConnectionType::H1(io) => Box::new(h1proto::send_request( io, - head, + head.into(), body, self.created, self.pool, )), ConnectionType::H2(io) => Box::new(h2proto::send_request( io, - head, + head.into(), body, self.created, self.pool, @@ -139,10 +139,10 @@ where >; /// Send request, returns Response and Framed - fn open_tunnel(mut self, head: RequestHead) -> Self::TunnelFuture { + fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { match self.io.take().unwrap() { ConnectionType::H1(io) => { - Either::A(Box::new(h1proto::open_tunnel(io, head))) + Either::A(Box::new(h1proto::open_tunnel(io, head.into()))) } ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { @@ -180,9 +180,9 @@ where } } - fn send_request( + fn send_request>( self, - head: RequestHead, + head: H, body: RB, ) -> Self::Future { match self { @@ -199,7 +199,7 @@ where >; /// Send request, returns Response and Framed - fn open_tunnel(self, head: RequestHead) -> Self::TunnelFuture { + fn open_tunnel>(self, head: H) -> Self::TunnelFuture { match self { EitherConnection::A(con) => Box::new( con.open_tunnel(head) diff --git a/actix-http/src/client/error.rs b/actix-http/src/client/error.rs index fc4b5b72b..40aef2cce 100644 --- a/actix-http/src/client/error.rs +++ b/actix-http/src/client/error.rs @@ -128,3 +128,23 @@ impl ResponseError for SendRequestError { .into() } } + +/// A set of errors that can occur during freezing a request +#[derive(Debug, Display, From)] +pub enum FreezeRequestError { + /// Invalid URL + #[display(fmt = "Invalid URL: {}", _0)] + Url(InvalidUrl), + /// Http error + #[display(fmt = "{}", _0)] + Http(HttpError), +} + +impl From for SendRequestError { + fn from(e: FreezeRequestError) -> Self { + match e { + FreezeRequestError::Url(e) => e.into(), + FreezeRequestError::Http(e) => e.into(), + } + } +} \ No newline at end of file diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index 97ed3bbc7..fa920ab92 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -9,8 +9,9 @@ use futures::{Async, Future, Poll, Sink, Stream}; use crate::error::PayloadError; use crate::h1; use crate::http::header::{IntoHeaderValue, HOST}; -use crate::message::{RequestHead, ResponseHead}; +use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::{Payload, PayloadStream}; +use crate::header::HeaderMap; use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; use super::error::{ConnectError, SendRequestError}; @@ -19,7 +20,7 @@ use crate::body::{BodySize, MessageBody}; pub(crate) fn send_request( io: T, - mut head: RequestHead, + mut head: RequestHeadType, body: B, created: time::Instant, pool: Option>, @@ -29,21 +30,29 @@ where B: MessageBody, { // set request host header - if !head.headers.contains_key(HOST) { - if let Some(host) = head.uri.host() { + if !head.as_ref().headers.contains_key(HOST) && !head.extra_headers().iter().any(|h| h.contains_key(HOST)) { + if let Some(host) = head.as_ref().uri.host() { let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); - let _ = match head.uri.port_u16() { + let _ = match head.as_ref().uri.port_u16() { None | Some(80) | Some(443) => write!(wrt, "{}", host), Some(port) => write!(wrt, "{}:{}", host, port), }; match wrt.get_mut().take().freeze().try_into() { Ok(value) => { - head.headers.insert(HOST, value); + match head { + RequestHeadType::Owned(ref mut head) => { + head.headers.insert(HOST, value) + }, + RequestHeadType::Rc(_, ref mut extra_headers) => { + let headers = extra_headers.get_or_insert(HeaderMap::new()); + headers.insert(HOST, value) + }, + } } Err(e) => { - log::error!("Can not set HOST header {}", e); + log::error!("Can not set HOST header {}", e) } } } @@ -57,7 +66,7 @@ where let len = body.size(); - // create Framed and send reqest + // create Framed and send request Framed::new(io, h1::ClientCodec::default()) .send((head, len).into()) .from_err() @@ -95,12 +104,12 @@ where pub(crate) fn open_tunnel( io: T, - head: RequestHead, + head: RequestHeadType, ) -> impl Future), Error = SendRequestError> where T: AsyncRead + AsyncWrite + 'static, { - // create Framed and send reqest + // create Framed and send request Framed::new(io, h1::ClientCodec::default()) .send((head, BodySize::None).into()) .from_err() diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 91240268e..2993d89d8 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -9,8 +9,9 @@ use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::{request::Request, HttpTryFrom, Method, Version}; use crate::body::{BodySize, MessageBody}; -use crate::message::{RequestHead, ResponseHead}; +use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; +use crate::header::HeaderMap; use super::connection::{ConnectionType, IoConnection}; use super::error::SendRequestError; @@ -18,7 +19,7 @@ use super::pool::Acquired; pub(crate) fn send_request( io: SendRequest, - head: RequestHead, + head: RequestHeadType, body: B, created: time::Instant, pool: Option>, @@ -28,7 +29,7 @@ where B: MessageBody, { trace!("Sending client request: {:?} {:?}", head, body.size()); - let head_req = head.method == Method::HEAD; + let head_req = head.as_ref().method == Method::HEAD; let length = body.size(); let eof = match length { BodySize::None | BodySize::Empty | BodySize::Sized(0) => true, @@ -39,8 +40,8 @@ where .map_err(SendRequestError::from) .and_then(move |mut io| { let mut req = Request::new(()); - *req.uri_mut() = head.uri; - *req.method_mut() = head.method; + *req.uri_mut() = head.as_ref().uri.clone(); + *req.method_mut() = head.as_ref().method.clone(); *req.version_mut() = Version::HTTP_2; let mut skip_len = true; @@ -66,8 +67,21 @@ where ), }; + // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. + let (head, extra_headers) = match head { + RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), + RequestHeadType::Rc(head, extra_headers) => (RequestHeadType::Rc(head, None), extra_headers.unwrap_or(HeaderMap::new())), + }; + + // merging headers from head and extra headers. + let headers = head.as_ref().headers.iter() + .filter(|(name, _)| { + !extra_headers.contains_key(*name) + }) + .chain(extra_headers.iter()); + // copy headers - for (key, value) in head.headers.iter() { + for (key, value) in headers { match *key { CONNECTION | TRANSFER_ENCODING => continue, // http2 specific CONTENT_LENGTH if skip_len => continue, diff --git a/actix-http/src/client/mod.rs b/actix-http/src/client/mod.rs index 1d10117cd..04427ce42 100644 --- a/actix-http/src/client/mod.rs +++ b/actix-http/src/client/mod.rs @@ -10,7 +10,7 @@ mod pool; pub use self::connection::Connection; pub use self::connector::Connector; -pub use self::error::{ConnectError, InvalidUrl, SendRequestError}; +pub use self::error::{ConnectError, InvalidUrl, SendRequestError, FreezeRequestError}; pub use self::pool::Protocol; #[derive(Clone)] diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs index f93bc496a..c0bbcc694 100644 --- a/actix-http/src/h1/client.rs +++ b/actix-http/src/h1/client.rs @@ -1,5 +1,6 @@ #![allow(unused_imports, unused_variables, dead_code)] use std::io::{self, Write}; +use std::rc::Rc; use actix_codec::{Decoder, Encoder}; use bitflags::bitflags; @@ -16,7 +17,8 @@ use crate::body::BodySize; use crate::config::ServiceConfig; use crate::error::{ParseError, PayloadError}; use crate::helpers; -use crate::message::{ConnectionType, Head, MessagePool, RequestHead, ResponseHead}; +use crate::message::{ConnectionType, Head, MessagePool, RequestHead, RequestHeadType, ResponseHead}; +use crate::header::HeaderMap; bitflags! { struct Flags: u8 { @@ -48,7 +50,7 @@ struct ClientCodecInner { // encoder part flags: Flags, headers_size: u32, - encoder: encoder::MessageEncoder, + encoder: encoder::MessageEncoder, } impl Default for ClientCodec { @@ -183,7 +185,7 @@ impl Decoder for ClientPayloadCodec { } impl Encoder for ClientCodec { - type Item = Message<(RequestHead, BodySize)>; + type Item = Message<(RequestHeadType, BodySize)>; type Error = io::Error; fn encode( @@ -192,13 +194,13 @@ impl Encoder for ClientCodec { dst: &mut BytesMut, ) -> Result<(), Self::Error> { match item { - Message::Item((mut msg, length)) => { + Message::Item((mut head, length)) => { let inner = &mut self.inner; - inner.version = msg.version; - inner.flags.set(Flags::HEAD, msg.method == Method::HEAD); + inner.version = head.as_ref().version; + inner.flags.set(Flags::HEAD, head.as_ref().method == Method::HEAD); // connection status - inner.ctype = match msg.connection_type() { + inner.ctype = match head.as_ref().connection_type() { ConnectionType::KeepAlive => { if inner.flags.contains(Flags::KEEPALIVE_ENABLED) { ConnectionType::KeepAlive @@ -212,7 +214,7 @@ impl Encoder for ClientCodec { inner.encoder.encode( dst, - &mut msg, + &mut head, false, false, inner.version, diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 61ca48b1d..380dfe328 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -4,6 +4,7 @@ use std::io::Write; use std::marker::PhantomData; use std::str::FromStr; use std::{cmp, fmt, io, mem}; +use std::rc::Rc; use bytes::{BufMut, Bytes, BytesMut}; @@ -15,7 +16,7 @@ use crate::http::header::{ HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, }; use crate::http::{HeaderMap, Method, StatusCode, Version}; -use crate::message::{ConnectionType, Head, RequestHead, ResponseHead}; +use crate::message::{ConnectionType, Head, RequestHead, ResponseHead, RequestHeadType}; use crate::request::Request; use crate::response::Response; @@ -43,6 +44,8 @@ pub(crate) trait MessageType: Sized { fn headers(&self) -> &HeaderMap; + fn extra_headers(&self) -> Option<&HeaderMap>; + fn camel_case(&self) -> bool { false } @@ -128,12 +131,21 @@ pub(crate) trait MessageType: Sized { _ => (), } + // merging headers from head and extra headers. HeaderMap::new() does not allocate. + let empty_headers = HeaderMap::new(); + let extra_headers = self.extra_headers().unwrap_or(&empty_headers); + let headers = self.headers().inner.iter() + .filter(|(name, _)| { + !extra_headers.contains_key(*name) + }) + .chain(extra_headers.inner.iter()); + // write headers let mut pos = 0; let mut has_date = false; let mut remaining = dst.remaining_mut(); let mut buf = unsafe { &mut *(dst.bytes_mut() as *mut [u8]) }; - for (key, value) in self.headers().inner.iter() { + for (key, value) in headers { match *key { CONNECTION => continue, TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => continue, @@ -235,6 +247,10 @@ impl MessageType for Response<()> { &self.head().headers } + fn extra_headers(&self) -> Option<&HeaderMap> { + None + } + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { let head = self.head(); let reason = head.reason().as_bytes(); @@ -247,31 +263,36 @@ impl MessageType for Response<()> { } } -impl MessageType for RequestHead { +impl MessageType for RequestHeadType { fn status(&self) -> Option { None } fn chunked(&self) -> bool { - self.chunked() + self.as_ref().chunked() } fn camel_case(&self) -> bool { - RequestHead::camel_case_headers(self) + self.as_ref().camel_case_headers() } fn headers(&self) -> &HeaderMap { - &self.headers + self.as_ref().headers() + } + + fn extra_headers(&self) -> Option<&HeaderMap> { + self.extra_headers() } fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { - dst.reserve(256 + self.headers.len() * AVERAGE_HEADER_SIZE); + let head = self.as_ref(); + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); write!( Writer(dst), "{} {} {}", - self.method, - self.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), - match self.version { + head.method, + head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), + match head.version { Version::HTTP_09 => "HTTP/0.9", Version::HTTP_10 => "HTTP/1.0", Version::HTTP_11 => "HTTP/1.1", @@ -488,9 +509,11 @@ fn write_camel_case(value: &[u8], buffer: &mut [u8]) { #[cfg(test)] mod tests { use bytes::Bytes; + //use std::rc::Rc; use super::*; use crate::http::header::{HeaderValue, CONTENT_TYPE}; + use http::header::AUTHORIZATION; #[test] fn test_chunked_te() { @@ -515,6 +538,8 @@ mod tests { head.headers .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text")); + let mut head = RequestHeadType::Owned(head); + let _ = head.encode_headers( &mut bytes, Version::HTTP_11, @@ -551,21 +576,16 @@ mod tests { Bytes::from_static(b"\r\nContent-Length: 100\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") ); + let mut head = RequestHead::default(); + head.set_camel_case_headers(false); + head.headers.insert(DATE, HeaderValue::from_static("date")); + head.headers + .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text")); head.headers .append(CONTENT_TYPE, HeaderValue::from_static("xml")); - let _ = head.encode_headers( - &mut bytes, - Version::HTTP_11, - BodySize::Stream, - ConnectionType::KeepAlive, - &ServiceConfig::default(), - ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\nTransfer-Encoding: chunked\r\nDate: date\r\nContent-Type: xml\r\nContent-Type: plain/text\r\n\r\n") - ); - head.set_camel_case_headers(false); + let mut head = RequestHeadType::Owned(head); + let _ = head.encode_headers( &mut bytes, Version::HTTP_11, @@ -578,4 +598,30 @@ mod tests { Bytes::from_static(b"\r\ntransfer-encoding: chunked\r\ndate: date\r\ncontent-type: xml\r\ncontent-type: plain/text\r\n\r\n") ); } + + #[test] + fn test_extra_headers() { + let mut bytes = BytesMut::with_capacity(2048); + + let mut head = RequestHead::default(); + head.headers.insert(AUTHORIZATION, HeaderValue::from_static("some authorization")); + + let mut extra_headers = HeaderMap::new(); + extra_headers.insert(AUTHORIZATION,HeaderValue::from_static("another authorization")); + extra_headers.insert(DATE, HeaderValue::from_static("date")); + + let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers)); + + let _ = head.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Empty, + ConnectionType::Close, + &ServiceConfig::default(), + ); + assert_eq!( + bytes.take().freeze(), + Bytes::from_static(b"\r\ncontent-length: 0\r\nconnection: close\r\nauthorization: another authorization\r\ndate: date\r\n\r\n") + ); + } } diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index 6b8874b23..b57fdddce 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -39,7 +39,7 @@ pub use self::config::{KeepAlive, ServiceConfig}; pub use self::error::{Error, ResponseError, Result}; pub use self::extensions::Extensions; pub use self::httpmessage::HttpMessage; -pub use self::message::{Message, RequestHead, ResponseHead}; +pub use self::message::{Message, RequestHead, RequestHeadType, ResponseHead}; pub use self::payload::{Payload, PayloadStream}; pub use self::request::Request; pub use self::response::{Response, ResponseBuilder}; diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index cf23a401c..316df2611 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -181,6 +181,36 @@ impl RequestHead { } } +#[derive(Debug)] +pub enum RequestHeadType { + Owned(RequestHead), + Rc(Rc, Option), +} + +impl RequestHeadType { + pub fn extra_headers(&self) -> Option<&HeaderMap> { + match self { + RequestHeadType::Owned(_) => None, + RequestHeadType::Rc(_, headers) => headers.as_ref(), + } + } +} + +impl AsRef for RequestHeadType { + fn as_ref(&self) -> &RequestHead { + match self { + RequestHeadType::Owned(head) => &head, + RequestHeadType::Rc(head, _) => head.as_ref(), + } + } +} + +impl From for RequestHeadType { + fn from(head: RequestHead) -> Self { + RequestHeadType::Owned(head) + } +} + #[derive(Debug)] pub struct ResponseHead { pub version: Version, diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 5442e9dbd..27962a6f5 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,11 +1,19 @@ # Changes +## + +### Added + +* Add `FrozenClientRequest` to support retries for sending HTTP requests + + ## [0.2.5] - 2019-09-06 ### Changed * Ensure that the `Host` header is set when initiating a WebSocket client connection. + ## [0.2.4] - 2019-08-13 ### Changed diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 04f08ecdc..82fd6a759 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,4 +1,5 @@ use std::{fmt, io, net}; +use std::rc::Rc; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::body::Body; @@ -6,7 +7,8 @@ use actix_http::client::{ Connect as ClientConnect, ConnectError, Connection, SendRequestError, }; use actix_http::h1::ClientCodec; -use actix_http::{RequestHead, ResponseHead}; +use actix_http::{RequestHead, RequestHeadType, ResponseHead}; +use actix_http::http::HeaderMap; use actix_service::Service; use futures::{Future, Poll}; @@ -22,6 +24,14 @@ pub(crate) trait Connect { addr: Option, ) -> Box>; + fn send_request_extra( + &mut self, + head: Rc, + extra_headers: Option, + body: Body, + addr: Option, + ) -> Box>; + /// Send request, returns Response and Framed fn open_tunnel( &mut self, @@ -33,6 +43,19 @@ pub(crate) trait Connect { Error = SendRequestError, >, >; + + /// Send request and extra headers, returns Response and Framed + fn open_tunnel_extra( + &mut self, + head: Rc, + extra_headers: Option, + addr: Option, + ) -> Box< + dyn Future< + Item = (ResponseHead, Framed), + Error = SendRequestError, + >, + >; } impl Connect for ConnectorWrapper @@ -59,7 +82,28 @@ where }) .from_err() // send request - .and_then(move |connection| connection.send_request(head, body)) + .and_then(move |connection| connection.send_request(RequestHeadType::from(head), body)) + .map(|(head, payload)| ClientResponse::new(head, payload)), + ) + } + + fn send_request_extra( + &mut self, + head: Rc, + extra_headers: Option, + body: Body, + addr: Option, + ) -> Box> { + Box::new( + self.0 + // connect to the host + .call(ClientConnect { + uri: head.uri.clone(), + addr, + }) + .from_err() + // send request + .and_then(move |connection| connection.send_request(RequestHeadType::Rc(head, extra_headers), body)) .map(|(head, payload)| ClientResponse::new(head, payload)), ) } @@ -83,7 +127,35 @@ where }) .from_err() // send request - .and_then(move |connection| connection.open_tunnel(head)) + .and_then(move |connection| connection.open_tunnel(RequestHeadType::from(head))) + .map(|(head, framed)| { + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + (head, framed) + }), + ) + } + + fn open_tunnel_extra( + &mut self, + head: Rc, + extra_headers: Option, + addr: Option, + ) -> Box< + dyn Future< + Item = (ResponseHead, Framed), + Error = SendRequestError, + >, + > { + Box::new( + self.0 + // connect to the host + .call(ClientConnect { + uri: head.uri.clone(), + addr, + }) + .from_err() + // send request + .and_then(move |connection| connection.open_tunnel(RequestHeadType::Rc(head, extra_headers))) .map(|(head, framed)| { let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); (head, framed) diff --git a/awc/src/error.rs b/awc/src/error.rs index f78355c67..4eb929007 100644 --- a/awc/src/error.rs +++ b/awc/src/error.rs @@ -1,5 +1,5 @@ //! Http client errors -pub use actix_http::client::{ConnectError, InvalidUrl, SendRequestError}; +pub use actix_http::client::{ConnectError, InvalidUrl, SendRequestError, FreezeRequestError}; pub use actix_http::error::PayloadError; pub use actix_http::ws::HandshakeError as WsHandshakeError; pub use actix_http::ws::ProtocolError as WsProtocolError; diff --git a/awc/src/request.rs b/awc/src/request.rs index 437157853..4dd07c5d8 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -1,16 +1,16 @@ use std::fmt::Write as FmtWrite; use std::io::Write; use std::rc::Rc; -use std::time::Duration; +use std::time::{Duration, Instant}; use std::{fmt, net}; use bytes::{BufMut, Bytes, BytesMut}; -use futures::future::{err, Either}; -use futures::{Future, Stream}; +use futures::{Async, Future, Poll, Stream, try_ready}; use percent_encoding::percent_encode; use serde::Serialize; use serde_json; -use tokio_timer::Timeout; +use tokio_timer::Delay; +use derive_more::From; use actix_http::body::{Body, BodyStream}; use actix_http::cookie::{Cookie, CookieJar, USERINFO}; @@ -20,9 +20,9 @@ use actix_http::http::{ uri, ConnectionType, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method, Uri, Version, }; -use actix_http::{Error, Payload, RequestHead}; +use actix_http::{Error, Payload, PayloadStream, RequestHead}; -use crate::error::{InvalidUrl, PayloadError, SendRequestError}; +use crate::error::{InvalidUrl, SendRequestError, FreezeRequestError}; use crate::response::ClientResponse; use crate::ClientConfig; @@ -99,6 +99,11 @@ impl ClientRequest { self } + /// Get HTTP URI of request + pub fn get_uri(&self) -> &Uri { + &self.head.uri + } + /// Set socket address of the server. /// /// This address is used for connection. If address is not @@ -115,6 +120,11 @@ impl ClientRequest { self } + /// Get HTTP method of this request + pub fn get_method(&self) -> &Method { + &self.head.method + } + #[doc(hidden)] /// Set HTTP version of this request. /// @@ -365,34 +375,122 @@ impl ClientRequest { } } + pub fn freeze(self) -> Result { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return Err(e.into()), + }; + + let request = FrozenClientRequest { + head: Rc::new(slf.head), + addr: slf.addr, + response_decompress: slf.response_decompress, + timeout: slf.timeout, + config: slf.config, + }; + + Ok(request) + } + /// Complete request construction and send body. pub fn send_body( - mut self, + self, body: B, - ) -> impl Future< - Item = ClientResponse>, - Error = SendRequestError, - > + ) -> SendBody where B: Into, { - if let Some(e) = self.err.take() { - return Either::A(err(e.into())); + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head) + .send_body(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), body) + } + + /// Set a JSON body and generate `ClientRequest` + pub fn send_json( + self, + value: &T, + ) -> SendBody + { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head) + .send_json(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), value) + } + + /// Set a urlencoded body and generate `ClientRequest` + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn send_form( + self, + value: &T, + ) -> SendBody + { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head) + .send_form(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), value) + } + + /// Set an streaming body and generate `ClientRequest`. + pub fn send_stream( + self, + stream: S, + ) -> SendBody + where + S: Stream + 'static, + E: Into + 'static, + { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head) + .send_stream(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), stream) + } + + /// Set an empty body and generate `ClientRequest`. + pub fn send( + self, + ) -> SendBody + { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head) + .send(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref()) + } + + fn prep_for_sending(mut self) -> Result { + if let Some(e) = self.err { + return Err(e.into()); } // validate uri let uri = &self.head.uri; if uri.host().is_none() { - return Either::A(err(InvalidUrl::MissingHost.into())); + return Err(InvalidUrl::MissingHost.into()); } else if uri.scheme_part().is_none() { - return Either::A(err(InvalidUrl::MissingScheme.into())); + return Err(InvalidUrl::MissingScheme.into()); } else if let Some(scheme) = uri.scheme_part() { match scheme.as_str() { "http" | "ws" | "https" | "wss" => (), - _ => return Either::A(err(InvalidUrl::UnknownScheme.into())), + _ => return Err(InvalidUrl::UnknownScheme.into()), } } else { - return Either::A(err(InvalidUrl::UnknownScheme.into())); + return Err(InvalidUrl::UnknownScheme.into()); } // set cookies @@ -430,112 +528,15 @@ impl ClientRequest { slf = slf.set_header_if_none(header::ACCEPT_ENCODING, HTTPS_ENCODING) } else { #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] - { - slf = slf - .set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate") - } + { + slf = slf + .set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate") + } }; } } - let head = slf.head; - let config = slf.config.as_ref(); - let response_decompress = slf.response_decompress; - - let fut = config - .connector - .borrow_mut() - .send_request(head, body.into(), slf.addr) - .map(move |res| { - res.map_body(|head, payload| { - if response_decompress { - Payload::Stream(Decoder::from_headers(payload, &head.headers)) - } else { - Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) - } - }) - }); - - // set request timeout - if let Some(timeout) = slf.timeout.or_else(|| config.timeout) { - Either::B(Either::A(Timeout::new(fut, timeout).map_err(|e| { - if let Some(e) = e.into_inner() { - e - } else { - SendRequestError::Timeout - } - }))) - } else { - Either::B(Either::B(fut)) - } - } - - /// Set a JSON body and generate `ClientRequest` - pub fn send_json( - self, - value: &T, - ) -> impl Future< - Item = ClientResponse>, - Error = SendRequestError, - > { - let body = match serde_json::to_string(value) { - Ok(body) => body, - Err(e) => return Either::A(err(Error::from(e).into())), - }; - - // set content-type - let slf = self.set_header_if_none(header::CONTENT_TYPE, "application/json"); - - Either::B(slf.send_body(Body::Bytes(Bytes::from(body)))) - } - - /// Set a urlencoded body and generate `ClientRequest` - /// - /// `ClientRequestBuilder` can not be used after this call. - pub fn send_form( - self, - value: &T, - ) -> impl Future< - Item = ClientResponse>, - Error = SendRequestError, - > { - let body = match serde_urlencoded::to_string(value) { - Ok(body) => body, - Err(e) => return Either::A(err(Error::from(e).into())), - }; - - // set content-type - let slf = self.set_header_if_none( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ); - - Either::B(slf.send_body(Body::Bytes(Bytes::from(body)))) - } - - /// Set an streaming body and generate `ClientRequest`. - pub fn send_stream( - self, - stream: S, - ) -> impl Future< - Item = ClientResponse>, - Error = SendRequestError, - > - where - S: Stream + 'static, - E: Into + 'static, - { - self.send_body(Body::from_message(BodyStream::new(stream))) - } - - /// Set an empty body and generate `ClientRequest`. - pub fn send( - self, - ) -> impl Future< - Item = ClientResponse>, - Error = SendRequestError, - > { - self.send_body(Body::Empty) + Ok(slf) } } @@ -554,6 +555,441 @@ impl fmt::Debug for ClientRequest { } } +#[derive(Clone)] +pub struct FrozenClientRequest { + pub(crate) head: Rc, + pub(crate) addr: Option, + pub(crate) response_decompress: bool, + pub(crate) timeout: Option, + pub(crate) config: Rc, +} + +impl FrozenClientRequest { + /// Get HTTP URI of request + pub fn get_uri(&self) -> &Uri { + &self.head.uri + } + + /// Get HTTP method of this request + pub fn get_method(&self) -> &Method { + &self.head.method + } + + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + /// Send a body. + pub fn send_body( + &self, + body: B, + ) -> SendBody + where + B: Into, + { + RequestSender::Rc(self.head.clone(), None) + .send_body(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), body) + } + + /// Send a json body. + pub fn send_json( + &self, + value: &T, + ) -> SendBody + { + RequestSender::Rc(self.head.clone(), None) + .send_json(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), value) + } + + /// Send an urlencoded body. + pub fn send_form( + &self, + value: &T, + ) -> SendBody + { + RequestSender::Rc(self.head.clone(), None) + .send_form(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), value) + } + + /// Send a streaming body. + pub fn send_stream( + &self, + stream: S, + ) -> SendBody + where + S: Stream + 'static, + E: Into + 'static, + { + RequestSender::Rc(self.head.clone(), None) + .send_stream(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), stream) + } + + /// Send an empty body. + pub fn send( + &self, + ) -> SendBody + { + RequestSender::Rc(self.head.clone(), None) + .send(self.addr, self.response_decompress, self.timeout, self.config.as_ref()) + } + + /// Create a `FrozenSendBuilder` with extra headers + pub fn extra_headers(&self, extra_headers: HeaderMap) -> FrozenSendBuilder { + FrozenSendBuilder::new(self.clone(), extra_headers) + } + + /// Create a `FrozenSendBuilder` with an extra header + pub fn extra_header(&self, key: K, value: V) -> FrozenSendBuilder + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + self.extra_headers(HeaderMap::new()).extra_header(key, value) + } +} + +pub struct FrozenSendBuilder { + req: FrozenClientRequest, + extra_headers: HeaderMap, + err: Option, +} + +impl FrozenSendBuilder { + pub(crate) fn new(req: FrozenClientRequest, extra_headers: HeaderMap) -> Self { + Self { + req, + extra_headers, + err: None, + } + } + + /// Insert a header, it overrides existing header in `FrozenClientRequest`. + pub fn extra_header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => self.extra_headers.insert(key, value), + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Complete request construction and send a body. + pub fn send_body( + self, + body: B, + ) -> SendBody + where + B: Into, + { + if let Some(e) = self.err { + return e.into() + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)) + .send_body(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), body) + } + + /// Complete request construction and send a json body. + pub fn send_json( + self, + value: &T, + ) -> SendBody + { + if let Some(e) = self.err { + return e.into() + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)) + .send_json(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), value) + } + + /// Complete request construction and send an urlencoded body. + pub fn send_form( + self, + value: &T, + ) -> SendBody + { + if let Some(e) = self.err { + return e.into() + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)) + .send_form(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), value) + } + + /// Complete request construction and send a streaming body. + pub fn send_stream( + self, + stream: S, + ) -> SendBody + where + S: Stream + 'static, + E: Into + 'static, + { + if let Some(e) = self.err { + return e.into() + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)) + .send_stream(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), stream) + } + + /// Complete request construction and send an empty body. + pub fn send( + self, + ) -> SendBody + { + if let Some(e) = self.err { + return e.into() + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)) + .send(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref()) + } +} + +#[derive(Debug, From)] +enum PrepForSendingError { + Url(InvalidUrl), + Http(HttpError), +} + +impl Into for PrepForSendingError { + fn into(self) -> FreezeRequestError { + match self { + PrepForSendingError::Url(e) => FreezeRequestError::Url(e), + PrepForSendingError::Http(e) => FreezeRequestError::Http(e), + } + } +} + +impl Into for PrepForSendingError { + fn into(self) -> SendRequestError { + match self { + PrepForSendingError::Url(e) => SendRequestError::Url(e), + PrepForSendingError::Http(e) => SendRequestError::Http(e), + } + } +} + +pub enum SendBody +{ + Fut(Box>, Option, bool), + Err(Option), +} + +impl SendBody +{ + pub fn new( + send: Box>, + response_decompress: bool, + timeout: Option, + ) -> SendBody + { + let delay = timeout.map(|t| Delay::new(Instant::now() + t)); + SendBody::Fut(send, delay, response_decompress) + } +} + +impl Future for SendBody +{ + type Item = ClientResponse>>; + type Error = SendRequestError; + + fn poll(&mut self) -> Poll { + match self { + SendBody::Fut(send, delay, response_decompress) => { + if delay.is_some() { + match delay.poll() { + Ok(Async::NotReady) => (), + _ => return Err(SendRequestError::Timeout), + } + } + + let res = try_ready!(send.poll()) + .map_body(|head, payload| { + if *response_decompress { + Payload::Stream(Decoder::from_headers(payload, &head.headers)) + } else { + Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) + } + }); + + Ok(Async::Ready(res)) + }, + SendBody::Err(ref mut e) => { + match e.take() { + Some(e) => Err(e.into()), + None => panic!("Attempting to call completed future"), + } + } + } + } +} + + +impl From for SendBody +{ + fn from(e: SendRequestError) -> Self { + SendBody::Err(Some(e)) + } +} + +impl From for SendBody +{ + fn from(e: Error) -> Self { + SendBody::Err(Some(e.into())) + } +} + +impl From for SendBody +{ + fn from(e: HttpError) -> Self { + SendBody::Err(Some(e.into())) + } +} + +impl From for SendBody +{ + fn from(e: PrepForSendingError) -> Self { + SendBody::Err(Some(e.into())) + } +} + +#[derive(Debug)] +enum RequestSender { + Owned(RequestHead), + Rc(Rc, Option), +} + +impl RequestSender { + pub fn send_body( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + body: B, + ) -> SendBody + where + B: Into, + { + let mut connector = config.connector.borrow_mut(); + + let fut = match self { + RequestSender::Owned(head) => connector.send_request(head, body.into(), addr), + RequestSender::Rc(head, extra_headers) => connector.send_request_extra(head, extra_headers, body.into(), addr), + }; + + SendBody::new(fut, response_decompress, timeout.or_else(|| config.timeout.clone())) + } + + pub fn send_json( + mut self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + value: &T, + ) -> SendBody + { + let body = match serde_json::to_string(value) { + Ok(body) => body, + Err(e) => return Error::from(e).into(), + }; + + if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") { + return e.into(); + } + + self.send_body(addr, response_decompress, timeout, config, Body::Bytes(Bytes::from(body))) + } + + pub fn send_form( + mut self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + value: &T, + ) -> SendBody + { + let body = match serde_urlencoded::to_string(value) { + Ok(body) => body, + Err(e) => return Error::from(e).into(), + }; + + // set content-type + if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/x-www-form-urlencoded") { + return e.into(); + } + + self.send_body(addr, response_decompress, timeout, config, Body::Bytes(Bytes::from(body))) + } + + pub fn send_stream( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + stream: S, + ) -> SendBody + where + S: Stream + 'static, + E: Into + 'static, + { + self.send_body(addr, response_decompress, timeout, config, Body::from_message(BodyStream::new(stream))) + } + + pub fn send( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + ) -> SendBody + { + self.send_body(addr, response_decompress, timeout, config, Body::Empty) + } + + fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> + where + V: IntoHeaderValue, + { + match self { + RequestSender::Owned(head) => { + if !head.headers.contains_key(&key) { + match value.try_into() { + Ok(value) => head.headers.insert(key, value), + Err(e) => return Err(e.into()), + } + } + }, + RequestSender::Rc(head, extra_headers) => { + if !head.headers.contains_key(&key) && !extra_headers.iter().any(|h| h.contains_key(&key)) { + match value.try_into(){ + Ok(v) => { + let h = extra_headers.get_or_insert(HeaderMap::new()); + h.insert(key, v) + }, + Err(e) => return Err(e.into()), + }; + } + } + } + + Ok(()) + } +} + #[cfg(test)] mod tests { use std::time::SystemTime;