use std::{ future::Future, net, pin::Pin, rc::Rc, task::{Context, Poll}, time::Duration, }; use actix_http::{ body::{Body, BodyStream}, http::{ header::{self, HeaderMap, HeaderName, IntoHeaderValue}, Error as HttpError, }, Error, RequestHead, RequestHeadType, }; use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; use derive_more::From; use futures_core::Stream; use serde::Serialize; #[cfg(feature = "compress")] use actix_http::encoding::Decoder; #[cfg(feature = "compress")] use actix_http::http::header::ContentEncoding; #[cfg(feature = "compress")] use actix_http::{Payload, PayloadStream}; use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use crate::response::ClientResponse; use crate::ClientConfig; #[derive(Debug, From)] pub(crate) enum PrepForSendingError { Url(InvalidUrl), Http(HttpError), } impl From for FreezeRequestError { fn from(err: PrepForSendingError) -> FreezeRequestError { match err { PrepForSendingError::Url(e) => FreezeRequestError::Url(e), PrepForSendingError::Http(e) => FreezeRequestError::Http(e), } } } impl From for SendRequestError { fn from(err: PrepForSendingError) -> SendRequestError { match err { PrepForSendingError::Url(e) => SendRequestError::Url(e), PrepForSendingError::Http(e) => SendRequestError::Http(e), } } } /// Future that sends request's payload and resolves to a server response. #[must_use = "futures do nothing unless polled"] pub enum SendClientRequest { Fut( Pin>>>, // FIXME: use a pinned Sleep instead of box. Option>>, bool, ), Err(Option), } impl SendClientRequest { pub(crate) fn new( send: Pin>>>, response_decompress: bool, timeout: Option, ) -> SendClientRequest { let delay = timeout.map(|d| Box::pin(sleep(d))); SendClientRequest::Fut(send, delay, response_decompress) } } #[cfg(feature = "compress")] impl Future for SendClientRequest { type Output = Result>>, SendRequestError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, response_decompress) => { if delay.is_some() { match Pin::new(delay.as_mut().unwrap()).poll(cx) { Poll::Pending => {} _ => return Poll::Ready(Err(SendRequestError::Timeout)), } } let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { res.map_body(|head, payload| { if *response_decompress { Payload::Stream(Decoder::from_headers(payload, &head.headers)) } else { Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) } }) }); Poll::Ready(res) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), None => panic!("Attempting to call completed future"), }, } } } #[cfg(not(feature = "compress"))] impl Future for SendClientRequest { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, _) => { if delay.is_some() { match Pin::new(delay.as_mut().unwrap()).poll(cx) { Poll::Pending => {} _ => return Poll::Ready(Err(SendRequestError::Timeout)), } } Pin::new(send).poll(cx) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), None => panic!("Attempting to call completed future"), }, } } } impl From for SendClientRequest { fn from(e: SendRequestError) -> Self { SendClientRequest::Err(Some(e)) } } impl From for SendClientRequest { fn from(e: Error) -> Self { SendClientRequest::Err(Some(e.into())) } } impl From for SendClientRequest { fn from(e: HttpError) -> Self { SendClientRequest::Err(Some(e.into())) } } impl From for SendClientRequest { fn from(e: PrepForSendingError) -> Self { SendClientRequest::Err(Some(e.into())) } } #[derive(Debug)] pub(crate) enum RequestSender { Owned(RequestHead), Rc(Rc, Option), } impl RequestSender { pub(crate) fn send_body( self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, body: B, ) -> SendClientRequest where B: Into, { let fut = match self { RequestSender::Owned(head) => { config .connector .send_request(RequestHeadType::Owned(head), body.into(), addr) } RequestSender::Rc(head, extra_headers) => config.connector.send_request( RequestHeadType::Rc(head, extra_headers), body.into(), addr, ), }; SendClientRequest::new(fut, response_decompress, timeout.or(config.timeout)) } pub(crate) fn send_json( mut self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, value: &T, ) -> SendClientRequest { let body = match serde_json::to_string(value) { Ok(body) => body, Err(e) => return Error::from(e).into(), }; if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") { return e.into(); } self.send_body( addr, response_decompress, timeout, config, Body::Bytes(Bytes::from(body)), ) } pub(crate) fn send_form( mut self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, value: &T, ) -> SendClientRequest { let body = match serde_urlencoded::to_string(value) { Ok(body) => body, Err(e) => return Error::from(e).into(), }; // set content-type if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/x-www-form-urlencoded") { return e.into(); } self.send_body( addr, response_decompress, timeout, config, Body::Bytes(Bytes::from(body)), ) } pub(crate) fn send_stream( self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, stream: S, ) -> SendClientRequest where S: Stream> + Unpin + 'static, E: Into + 'static, { self.send_body( addr, response_decompress, timeout, config, Body::from_message(BodyStream::new(stream)), ) } pub(crate) fn send( self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, ) -> SendClientRequest { self.send_body(addr, response_decompress, timeout, config, Body::Empty) } fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> where V: IntoHeaderValue, { match self { RequestSender::Owned(head) => { if !head.headers.contains_key(&key) { match value.try_into_value() { Ok(value) => { head.headers.insert(key, value); } Err(e) => return Err(e.into()), } } } RequestSender::Rc(head, extra_headers) => { if !head.headers.contains_key(&key) && !extra_headers.iter().any(|h| h.contains_key(&key)) { match value.try_into_value() { Ok(v) => { let h = extra_headers.get_or_insert(HeaderMap::new()); h.insert(key, v) } Err(e) => return Err(e.into()), }; } } } Ok(()) } }