diff --git a/Cargo.lock b/Cargo.lock index 8ba1b5b..a11f0cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3419,7 +3419,11 @@ dependencies = [ "awc", "clap", "env_logger", + "futures-util", "log", + "reqwest 0.11.13", + "tokio 1.23.0", + "tokio-stream", "url", ] @@ -5648,6 +5652,7 @@ dependencies = [ "serde_urlencoded", "tokio 1.23.0", "tokio-native-tls", + "tokio-util 0.7.4", "tower-service", "url", "wasm-bindgen", @@ -7290,6 +7295,7 @@ dependencies = [ "futures-core", "pin-project-lite 0.2.9", "tokio 1.23.0", + "tokio-util 0.7.4", ] [[package]] diff --git a/http-proxy/Cargo.toml b/http-proxy/Cargo.toml index 9efdc9c..32f7c88 100644 --- a/http-proxy/Cargo.toml +++ b/http-proxy/Cargo.toml @@ -9,5 +9,9 @@ awc = "3" clap = { version = "4", features = ["derive"] } env_logger.workspace = true +futures-util = { version = "0.3.17", default-features = false, features = ["std"] } log.workspace = true +reqwest = { version = "0.11", features = ["stream"] } +tokio = { version = "1.13.1", features = ["sync"] } +tokio-stream = { version = "0.1.3", features = ["sync"] } url = "2.2" diff --git a/http-proxy/src/main.rs b/http-proxy/src/main.rs index 83f874f..d95ed59 100644 --- a/http-proxy/src/main.rs +++ b/http-proxy/src/main.rs @@ -1,27 +1,39 @@ use std::net::ToSocketAddrs; -use actix_web::{error, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer}; +use actix_web::{ + dev::PeerAddr, error, http::Method, middleware, web, App, Error, HttpRequest, HttpResponse, + HttpServer, +}; use awc::Client; use clap::Parser; +use futures_util::StreamExt as _; +use tokio_stream::wrappers::UnboundedReceiverStream; use url::Url; +const REQWEST_PREFIX: &str = "/using-reqwest"; + +/// Forwards the incoming HTTP request using `awc`. async fn forward( req: HttpRequest, payload: web::Payload, + peer_addr: Option, url: web::Data, client: web::Data, ) -> Result { - let mut new_url = url.get_ref().clone(); + let mut new_url = (**url).clone(); new_url.set_path(req.uri().path()); new_url.set_query(req.uri().query()); - // TODO: This forwarded implementation is incomplete as it only handles the unofficial - // X-Forwarded-For header but not the official Forwarded one. let forwarded_req = client .request_from(new_url.as_str(), req.head()) .no_decompress(); - let forwarded_req = match req.head().peer_addr { - Some(addr) => forwarded_req.insert_header(("x-forwarded-for", format!("{}", addr.ip()))), + + // TODO: This forwarded implementation is incomplete as it only handles the unofficial + // X-Forwarded-For header but not the official Forwarded one. + let forwarded_req = match peer_addr { + Some(PeerAddr(addr)) => { + forwarded_req.insert_header(("x-forwarded-for", addr.ip().to_string())) + } None => forwarded_req, }; @@ -40,6 +52,59 @@ async fn forward( Ok(client_resp.streaming(res)) } +/// Same as `forward` but uses `reqwest` as the client used to forward the request. +async fn forward_reqwest( + req: HttpRequest, + mut payload: web::Payload, + method: Method, + peer_addr: Option, + url: web::Data, + client: web::Data, +) -> Result { + let path = req + .uri() + .path() + .strip_prefix(REQWEST_PREFIX) + .unwrap_or(req.uri().path()); + + let mut new_url = (**url).clone(); + new_url.set_path(path); + new_url.set_query(req.uri().query()); + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + actix_web::rt::spawn(async move { + while let Some(chunk) = payload.next().await { + tx.send(chunk).unwrap(); + } + }); + + let forwarded_req = client + .request(method, new_url) + .body(reqwest::Body::wrap_stream(UnboundedReceiverStream::new(rx))); + + // TODO: This forwarded implementation is incomplete as it only handles the unofficial + // X-Forwarded-For header but not the official Forwarded one. + let forwarded_req = match peer_addr { + Some(PeerAddr(addr)) => forwarded_req.header("x-forwarded-for", addr.ip().to_string()), + None => forwarded_req, + }; + + let res = forwarded_req + .send() + .await + .map_err(error::ErrorInternalServerError)?; + + let mut client_resp = HttpResponse::build(res.status()); + // Remove `Connection` as per + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection#Directives + for (header_name, header_value) in res.headers().iter().filter(|(h, _)| *h != "connection") { + client_resp.insert_header((header_name.clone(), header_value.clone())); + } + + Ok(client_resp.streaming(res.bytes_stream())) +} + #[derive(clap::Parser, Debug)] struct CliArguments { listen_addr: String, @@ -70,11 +135,15 @@ async fn main() -> std::io::Result<()> { log::info!("forwarding to {forward_url}"); + let reqwest_client = reqwest::Client::default(); + HttpServer::new(move || { App::new() .app_data(web::Data::new(Client::default())) + .app_data(web::Data::new(reqwest_client.clone())) .app_data(web::Data::new(forward_url.clone())) .wrap(middleware::Logger::default()) + .service(web::scope(REQWEST_PREFIX).default_service(web::to(forward_reqwest))) .default_service(web::to(forward)) }) .bind((args.listen_addr, args.listen_port))?