From c836de44af0667246a2691a01df9f0dc55bcafd7 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Sun, 28 Feb 2021 10:17:08 -0800 Subject: [PATCH] add client middleware (#2013) --- awc/Cargo.toml | 3 +- awc/src/builder.rs | 59 +++++- awc/src/connect.rs | 123 ++++++++---- awc/src/lib.rs | 13 +- awc/src/middleware/mod.rs | 71 +++++++ awc/src/middleware/redirect.rs | 350 +++++++++++++++++++++++++++++++++ awc/src/response.rs | 4 +- src/types/form.rs | 8 +- src/types/json.rs | 4 +- 9 files changed, 570 insertions(+), 65 deletions(-) create mode 100644 awc/src/middleware/mod.rs create mode 100644 awc/src/middleware/redirect.rs diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 8cbba432c..ca345d3cb 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -47,7 +47,7 @@ trust-dns = ["actix-http/trust-dns"] actix-codec = "0.4.0-beta.1" actix-service = "2.0.0-beta.4" actix-http = "3.0.0-beta.3" -actix-rt = "2.1" +actix-rt = { version = "2.1", default-features = false } base64 = "0.13" bytes = "1" @@ -57,6 +57,7 @@ futures-core = { version = "0.3.7", default-features = false } log =" 0.4" mime = "0.3" percent-encoding = "2.1" +pin-project-lite = "0.2" rand = "0.8" serde = "1.0" serde_json = "1.0" diff --git a/awc/src/builder.rs b/awc/src/builder.rs index b7cdefd40..363056c02 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -10,24 +10,27 @@ use actix_http::{ http::{self, header, Error as HttpError, HeaderMap, HeaderName, Uri}, }; use actix_rt::net::TcpStream; -use actix_service::Service; +use actix_service::{boxed, Service}; -use crate::connect::ConnectorWrapper; -use crate::{Client, ClientConfig}; +use crate::connect::DefaultConnector; +use crate::error::SendRequestError; +use crate::middleware::{NestTransform, Transform}; +use crate::{Client, ClientConfig, ConnectRequest, ConnectResponse, ConnectorService}; /// An HTTP Client builder /// /// This type can be used to construct an instance of `Client` through a /// builder-like pattern. -pub struct ClientBuilder { +pub struct ClientBuilder { default_headers: bool, max_http_version: Option, stream_window_size: Option, conn_window_size: Option, headers: HeaderMap, timeout: Option, + connector: Connector, + middleware: M, local_address: Option, - connector: Connector, } impl ClientBuilder { @@ -39,8 +42,10 @@ impl ClientBuilder { Error = TcpConnectError, > + Clone, TcpStream, + (), > { ClientBuilder { + middleware: (), default_headers: true, headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), @@ -53,7 +58,7 @@ impl ClientBuilder { } } -impl ClientBuilder +impl ClientBuilder where S: Service, Response = TcpConnection, Error = TcpConnectError> + Clone @@ -61,7 +66,7 @@ where Io: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { /// Use custom connector service. - pub fn connector(self, connector: Connector) -> ClientBuilder + pub fn connector(self, connector: Connector) -> ClientBuilder where S1: Service< TcpConnect, @@ -72,10 +77,11 @@ where Io1: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { ClientBuilder { + middleware: self.middleware, default_headers: self.default_headers, headers: self.headers, timeout: self.timeout, - local_address: None, + local_address: self.local_address, connector, max_http_version: self.max_http_version, stream_window_size: self.stream_window_size, @@ -181,8 +187,38 @@ where self.header(header::AUTHORIZATION, format!("Bearer {}", token)) } + /// Registers middleware, in the form of a middleware component (type), + /// that runs during inbound and/or outbound processing in the request + /// life-cycle (request -> response), modifying request/response as + /// necessary, across all requests managed by the Client. + pub fn wrap( + self, + mw: M1, + ) -> ClientBuilder> + where + M: Transform, + M1: Transform, + { + ClientBuilder { + middleware: NestTransform::new(self.middleware, mw), + default_headers: self.default_headers, + max_http_version: self.max_http_version, + stream_window_size: self.stream_window_size, + conn_window_size: self.conn_window_size, + headers: self.headers, + timeout: self.timeout, + connector: self.connector, + local_address: self.local_address, + } + } + /// Finish build process and create `Client` instance. - pub fn finish(self) -> Client { + pub fn finish(self) -> Client + where + M: Transform + 'static, + M::Transform: + Service, + { let mut connector = self.connector; if let Some(val) = self.max_http_version { @@ -198,10 +234,13 @@ where connector = connector.local_address(val); } + let connector = boxed::service(DefaultConnector::new(connector.finish())); + let connector = boxed::service(self.middleware.new_transform(connector)); + let config = ClientConfig { headers: self.headers, timeout: self.timeout, - connector: Box::new(ConnectorWrapper::new(connector.finish())) as _, + connector, }; Client(Rc::new(config)) diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 97af2d1cc..a4abbc46b 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,5 +1,7 @@ use std::{ - fmt, io, net, + fmt, + future::Future, + io, net, pin::Pin, task::{Context, Poll}, }; @@ -9,24 +11,14 @@ use actix_http::{ body::Body, client::{Connect as ClientConnect, ConnectError, Connection, SendRequestError}, h1::ClientCodec, - RequestHead, RequestHeadType, ResponseHead, + Payload, RequestHead, RequestHeadType, ResponseHead, }; use actix_service::Service; -use futures_core::future::LocalBoxFuture; +use futures_core::{future::LocalBoxFuture, ready}; use crate::response::ClientResponse; -pub(crate) struct ConnectorWrapper { - connector: T, -} - -impl ConnectorWrapper { - pub(crate) fn new(connector: T) -> Self { - Self { connector } - } -} - -pub type ConnectService = Box< +pub type ConnectorService = Box< dyn Service< ConnectRequest, Response = ConnectResponse, @@ -65,16 +57,25 @@ impl ConnectResponse { } } -impl Service for ConnectorWrapper +pub(crate) struct DefaultConnector { + connector: S, +} + +impl DefaultConnector { + pub(crate) fn new(connector: S) -> Self { + Self { connector } + } +} + +impl Service for DefaultConnector where - T: Service, - T::Response: Connection, - ::Io: 'static, - T::Future: 'static, + S: Service, + S::Response: Connection, + ::Io: 'static, { type Response = ConnectResponse; type Error = SendRequestError; - type Future = LocalBoxFuture<'static, Result>; + type Future = ConnectRequestFuture::Io>; actix_service::forward_ready!(connector); @@ -91,26 +92,76 @@ where }), }; - Box::pin(async move { - let connection = fut.await?; + ConnectRequestFuture::Connection { + fut, + req: Some(req), + } + } +} - match req { - ConnectRequest::Client(head, body, ..) => { - // send request - let (head, payload) = connection.send_request(head, body).await?; +pin_project_lite::pin_project! { + #[project = ConnectRequestProj] + pub(crate) enum ConnectRequestFuture { + Connection { + #[pin] + fut: Fut, + req: Option + }, + Client { + fut: LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + }, + Tunnel { + fut: LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + >, + } + } +} - Ok(ConnectResponse::Client(ClientResponse::new(head, payload))) - } - ConnectRequest::Tunnel(head, ..) => { - // send request - let (head, framed) = - connection.open_tunnel(RequestHeadType::from(head)).await?; - - let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); - Ok(ConnectResponse::Tunnel(head, framed)) +impl Future for ConnectRequestFuture +where + Fut: Future>, + C: Connection, + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + ConnectRequestProj::Connection { fut, req } => { + let connection = ready!(fut.poll(cx))?; + let req = req.take().unwrap(); + match req { + ConnectRequest::Client(head, body, ..) => { + // send request + let fut = ConnectRequestFuture::Client { + fut: connection.send_request(head, body), + }; + self.as_mut().set(fut); + } + ConnectRequest::Tunnel(head, ..) => { + // send request + let fut = ConnectRequestFuture::Tunnel { + fut: connection.open_tunnel(RequestHeadType::from(head)), + }; + self.as_mut().set(fut); + } } + self.poll(cx) } - }) + ConnectRequestProj::Client { fut } => { + let (head, payload) = ready!(fut.as_mut().poll(cx))?; + Poll::Ready(Ok(ConnectResponse::Client(ClientResponse::new( + head, payload, + )))) + } + ConnectRequestProj::Tunnel { fut } => { + let (head, framed) = ready!(fut.as_mut().poll(cx))?; + let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed))) + } + } } } diff --git a/awc/src/lib.rs b/awc/src/lib.rs index 66ff55402..2f48dca79 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -107,12 +107,13 @@ use actix_http::{ RequestHead, }; use actix_rt::net::TcpStream; -use actix_service::Service; +use actix_service::{boxed, Service}; mod builder; mod connect; pub mod error; mod frozen; +pub mod middleware; mod request; mod response; mod sender; @@ -120,14 +121,12 @@ pub mod test; pub mod ws; pub use self::builder::ClientBuilder; -pub use self::connect::{BoxedSocket, ConnectRequest, ConnectResponse, ConnectService}; +pub use self::connect::{BoxedSocket, ConnectRequest, ConnectResponse, ConnectorService}; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; pub use self::response::{ClientResponse, JsonBody, MessageBody}; pub use self::sender::SendClientRequest; -use self::connect::ConnectorWrapper; - /// An asynchronous HTTP and WebSocket client. /// /// ## Examples @@ -151,7 +150,7 @@ use self::connect::ConnectorWrapper; pub struct Client(Rc); pub(crate) struct ClientConfig { - pub(crate) connector: ConnectService, + pub(crate) connector: ConnectorService, pub(crate) headers: HeaderMap, pub(crate) timeout: Option, } @@ -159,7 +158,9 @@ pub(crate) struct ClientConfig { impl Default for Client { fn default() -> Self { Client(Rc::new(ClientConfig { - connector: Box::new(ConnectorWrapper::new(Connector::new().finish())), + connector: boxed::service(self::connect::DefaultConnector::new( + Connector::new().finish(), + )), headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), })) diff --git a/awc/src/middleware/mod.rs b/awc/src/middleware/mod.rs new file mode 100644 index 000000000..330e3b7fe --- /dev/null +++ b/awc/src/middleware/mod.rs @@ -0,0 +1,71 @@ +mod redirect; + +pub use self::redirect::Redirect; + +use std::marker::PhantomData; + +use actix_service::Service; + +/// Trait for transform a type to another one. +/// Both the input and output type should impl [actix_service::Service] trait. +pub trait Transform { + type Transform: Service; + + /// Creates and returns a new Transform component. + fn new_transform(self, service: S) -> Self::Transform; +} + +#[doc(hidden)] +/// Helper struct for constructing Nested types that would call `Transform::new_transform` +/// in a chain. +/// +/// The child field would be called first and the output `Service` type is +/// passed to parent as input type. +pub struct NestTransform +where + T1: Transform, + T2: Transform, +{ + child: T1, + parent: T2, + _service: PhantomData<(S, Req)>, +} + +impl NestTransform +where + T1: Transform, + T2: Transform, +{ + pub(crate) fn new(child: T1, parent: T2) -> Self { + NestTransform { + child, + parent, + _service: PhantomData, + } + } +} + +impl Transform for NestTransform +where + T1: Transform, + T2: Transform, +{ + type Transform = T2::Transform; + + fn new_transform(self, service: S) -> Self::Transform { + let service = self.child.new_transform(service); + self.parent.new_transform(service) + } +} + +/// Dummy impl for kick start `NestTransform` type in `ClientBuilder` type +impl Transform for () +where + S: Service, +{ + type Transform = S; + + fn new_transform(self, service: S) -> Self::Transform { + service + } +} diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs new file mode 100644 index 000000000..1d0ace166 --- /dev/null +++ b/awc/src/middleware/redirect.rs @@ -0,0 +1,350 @@ +use std::{ + convert::TryFrom, + future::Future, + net::SocketAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; + +use actix_http::{ + body::Body, + client::{InvalidUrl, SendRequestError}, + http::{header, Method, StatusCode, Uri}, + RequestHead, RequestHeadType, +}; +use actix_service::Service; +use bytes::Bytes; +use futures_core::ready; + +use super::Transform; + +use crate::connect::{ConnectRequest, ConnectResponse}; +use crate::ClientResponse; + +pub struct Redirect { + max_redirect_times: u8, +} + +impl Default for Redirect { + fn default() -> Self { + Self::new() + } +} + +impl Redirect { + pub fn new() -> Self { + Self { + max_redirect_times: 10, + } + } + + pub fn max_redirect_times(mut self, times: u8) -> Self { + self.max_redirect_times = times; + self + } +} + +impl Transform for Redirect +where + S: Service + 'static, +{ + type Transform = RedirectService; + + fn new_transform(self, service: S) -> Self::Transform { + RedirectService { + max_redirect_times: self.max_redirect_times, + connector: Rc::new(service), + } + } +} + +pub struct RedirectService { + max_redirect_times: u8, + connector: Rc, +} + +impl Service for RedirectService +where + S: Service + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = RedirectServiceFuture; + + actix_service::forward_ready!(connector); + + fn call(&self, req: ConnectRequest) -> Self::Future { + match req { + ConnectRequest::Tunnel(head, addr) => { + let fut = self.connector.call(ConnectRequest::Tunnel(head, addr)); + RedirectServiceFuture::Tunnel { fut } + } + ConnectRequest::Client(head, body, addr) => { + let connector = self.connector.clone(); + let max_redirect_times = self.max_redirect_times; + + // backup the uri and method for reuse schema and authority. + let (uri, method) = match head { + RequestHeadType::Owned(ref head) => (head.uri.clone(), head.method.clone()), + RequestHeadType::Rc(ref head, ..) => { + (head.uri.clone(), head.method.clone()) + } + }; + + let body_opt = match body { + Body::Bytes(ref b) => Some(b.clone()), + _ => None, + }; + + let fut = connector.call(ConnectRequest::Client(head, body, addr)); + + RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + body: body_opt, + addr, + connector: Some(connector), + } + } + } + } +} + +pin_project_lite::pin_project! { + #[project = RedirectServiceProj] + pub enum RedirectServiceFuture + where + S: Service, + S: 'static + { + Tunnel { #[pin] fut: S::Future }, + Client { + #[pin] + fut: S::Future, + max_redirect_times: u8, + uri: Option, + method: Option, + body: Option, + addr: Option, + connector: Option> + } + } +} + +impl Future for RedirectServiceFuture +where + S: Service + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + RedirectServiceProj::Tunnel { fut } => fut.poll(cx), + RedirectServiceProj::Client { + fut, + max_redirect_times, + uri, + method, + body, + addr, + connector, + } => match ready!(fut.poll(cx))? { + ConnectResponse::Client(res) => match res.head().status { + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::SEE_OTHER + if *max_redirect_times > 0 => + { + let org_uri = uri.take().unwrap(); + // rebuild uri from the location header value. + let uri = rebuild_uri(&res, org_uri)?; + + // reset method + let method = method.take().unwrap(); + let method = match method { + Method::GET | Method::HEAD => method, + _ => Method::GET, + }; + + // take ownership of states that could be reused + let addr = addr.take(); + let connector = connector.take(); + let mut max_redirect_times = *max_redirect_times; + + // use a new request head. + let mut head = RequestHead::default(); + head.uri = uri.clone(); + head.method = method.clone(); + + let head = RequestHeadType::Owned(head); + + max_redirect_times -= 1; + + let fut = connector + .as_ref() + .unwrap() + // remove body + .call(ConnectRequest::Client(head, Body::None, addr)); + + self.as_mut().set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + // body is dropped on 301,302,303 + body: None, + addr, + connector, + }); + + self.poll(cx) + } + StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT + if *max_redirect_times > 0 => + { + let org_uri = uri.take().unwrap(); + // rebuild uri from the location header value. + let uri = rebuild_uri(&res, org_uri)?; + + // try to reuse body + let body = body.take(); + let body_new = match body { + Some(ref bytes) => Body::Bytes(bytes.clone()), + // TODO: should this be Body::Empty or Body::None. + _ => Body::Empty, + }; + + let addr = addr.take(); + let method = method.take().unwrap(); + let connector = connector.take(); + let mut max_redirect_times = *max_redirect_times; + + // use a new request head. + let mut head = RequestHead::default(); + head.uri = uri.clone(); + head.method = method.clone(); + + let head = RequestHeadType::Owned(head); + + max_redirect_times -= 1; + + let fut = connector + .as_ref() + .unwrap() + .call(ConnectRequest::Client(head, body_new, addr)); + + self.as_mut().set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + body, + addr, + connector, + }); + + self.poll(cx) + } + _ => Poll::Ready(Ok(ConnectResponse::Client(res))), + }, + _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"), + }, + } + } +} + +fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result { + let uri = res + .headers() + .get(header::LOCATION) + .map(|value| { + // try to parse the location to a full uri + let uri = Uri::try_from(value.as_bytes()) + .map_err(|e| SendRequestError::Url(InvalidUrl::HttpError(e.into())))?; + if uri.scheme().is_none() || uri.authority().is_none() { + let uri = Uri::builder() + .scheme(org_uri.scheme().cloned().unwrap()) + .authority(org_uri.authority().cloned().unwrap()) + .path_and_query(value.as_bytes()) + .build()?; + Ok::<_, SendRequestError>(uri) + } else { + Ok(uri) + } + }) + // TODO: this error type is wrong. + .ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))??; + + Ok(uri) +} + +#[cfg(test)] +mod tests { + use actix_web::{test::start, web, App, Error, HttpResponse}; + + use super::*; + + use crate::ClientBuilder; + + #[actix_rt::test] + async fn test_basic_redirect() { + let client = ClientBuilder::new() + .connector(crate::Connector::new()) + .wrap(Redirect::new().max_redirect_times(10)) + .finish(); + + let srv = start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::BadRequest()) + }))) + .service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test")) + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 400); + } + + #[actix_rt::test] + async fn test_redirect_limit() { + let client = ClientBuilder::new() + .wrap(Redirect::new().max_redirect_times(1)) + .connector(crate::Connector::new()) + .finish(); + + let srv = start(|| { + App::new() + .service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test")) + .finish(), + ) + }))) + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test2")) + .finish(), + ) + }))) + .service(web::resource("/test2").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::BadRequest()) + }))) + }); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 302); + } +} diff --git a/awc/src/response.rs b/awc/src/response.rs index 514b8a90b..40de3dc17 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -492,9 +492,7 @@ mod tests { JsonPayloadError::Payload(PayloadError::Overflow) => { matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) } - JsonPayloadError::ContentType => { - matches!(other, JsonPayloadError::ContentType) - } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), _ => false, } } diff --git a/src/types/form.rs b/src/types/form.rs index 0b5c3c1b4..4f3ecbe7c 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -400,12 +400,8 @@ mod tests { UrlencodedError::Overflow { .. } => { matches!(other, UrlencodedError::Overflow { .. }) } - UrlencodedError::UnknownLength => { - matches!(other, UrlencodedError::UnknownLength) - } - UrlencodedError::ContentType => { - matches!(other, UrlencodedError::ContentType) - } + UrlencodedError::UnknownLength => matches!(other, UrlencodedError::UnknownLength), + UrlencodedError::ContentType => matches!(other, UrlencodedError::ContentType), _ => false, } } diff --git a/src/types/json.rs b/src/types/json.rs index 31ff680f4..866d835f2 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -441,9 +441,7 @@ mod tests { fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { match err { JsonPayloadError::Overflow => matches!(other, JsonPayloadError::Overflow), - JsonPayloadError::ContentType => { - matches!(other, JsonPayloadError::ContentType) - } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), _ => false, } }