use std::net::ToSocketAddrs; use actix_web::{ dev::PeerAddr, error, http::Method, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, }; use awc::Client; use clap::Parser; use futures_util::StreamExt as _; use tokio_stream::wrappers::UnboundedReceiverStream; use url::Url; const REQWEST_PREFIX: &str = "/using-reqwest"; /// Forwards the incoming HTTP request using `awc`. async fn forward( req: HttpRequest, payload: web::Payload, peer_addr: Option, url: web::Data, client: web::Data, ) -> Result { let mut new_url = (**url).clone(); new_url.set_path(req.uri().path()); new_url.set_query(req.uri().query()); let forwarded_req = client .request_from(new_url.as_str(), req.head()) .no_decompress(); // TODO: This forwarded implementation is incomplete as it only handles the unofficial // X-Forwarded-For header but not the official Forwarded one. let forwarded_req = match peer_addr { Some(PeerAddr(addr)) => { forwarded_req.insert_header(("x-forwarded-for", addr.ip().to_string())) } None => forwarded_req, }; let res = forwarded_req .send_stream(payload) .await .map_err(error::ErrorInternalServerError)?; let mut client_resp = HttpResponse::build(res.status()); // Remove `Connection` as per // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection#Directives for (header_name, header_value) in res.headers().iter().filter(|(h, _)| *h != "connection") { client_resp.insert_header((header_name.clone(), header_value.clone())); } Ok(client_resp.streaming(res)) } /// Same as `forward` but uses `reqwest` as the client used to forward the request. async fn forward_reqwest( req: HttpRequest, mut payload: web::Payload, method: Method, peer_addr: Option, url: web::Data, client: web::Data, ) -> Result { let path = req .uri() .path() .strip_prefix(REQWEST_PREFIX) .unwrap_or(req.uri().path()); let mut new_url = (**url).clone(); new_url.set_path(path); new_url.set_query(req.uri().query()); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); actix_web::rt::spawn(async move { while let Some(chunk) = payload.next().await { tx.send(chunk).unwrap(); } }); let forwarded_req = client .request(method, new_url) .body(reqwest::Body::wrap_stream(UnboundedReceiverStream::new(rx))); // TODO: This forwarded implementation is incomplete as it only handles the unofficial // X-Forwarded-For header but not the official Forwarded one. let forwarded_req = match peer_addr { Some(PeerAddr(addr)) => forwarded_req.header("x-forwarded-for", addr.ip().to_string()), None => forwarded_req, }; let res = forwarded_req .send() .await .map_err(error::ErrorInternalServerError)?; let mut client_resp = HttpResponse::build(res.status()); // Remove `Connection` as per // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection#Directives for (header_name, header_value) in res.headers().iter().filter(|(h, _)| *h != "connection") { client_resp.insert_header((header_name.clone(), header_value.clone())); } Ok(client_resp.streaming(res.bytes_stream())) } #[derive(clap::Parser, Debug)] struct CliArguments { listen_addr: String, listen_port: u16, forward_addr: String, forward_port: u16, } #[actix_web::main] async fn main() -> std::io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); let args = CliArguments::parse(); let forward_socket_addr = (args.forward_addr, args.forward_port) .to_socket_addrs()? .next() .expect("given forwarding address was not valid"); let forward_url = format!("http://{forward_socket_addr}"); let forward_url = Url::parse(&forward_url).unwrap(); log::info!( "starting HTTP server at http://{}:{}", &args.listen_addr, args.listen_port ); log::info!("forwarding to {forward_url}"); let reqwest_client = reqwest::Client::default(); HttpServer::new(move || { App::new() .app_data(web::Data::new(Client::default())) .app_data(web::Data::new(reqwest_client.clone())) .app_data(web::Data::new(forward_url.clone())) .wrap(middleware::Logger::default()) .service(web::scope(REQWEST_PREFIX).default_service(web::to(forward_reqwest))) .default_service(web::to(forward)) }) .bind((args.listen_addr, args.listen_port))? .workers(2) .run() .await }