From d54c26ad309c35145144b684b6d0a40ae38e408c Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Sun, 3 Jul 2022 23:56:33 +0100 Subject: [PATCH] add redirect option to NormalizePath --- actix-web/src/middleware/normalize.rs | 202 ++++++++++++++++++++++---- 1 file changed, 174 insertions(+), 28 deletions(-) diff --git a/actix-web/src/middleware/normalize.rs b/actix-web/src/middleware/normalize.rs index 3ab908481..19eca89b4 100644 --- a/actix-web/src/middleware/normalize.rs +++ b/actix-web/src/middleware/normalize.rs @@ -1,17 +1,31 @@ //! For middleware documentation, see [`NormalizePath`]. -use actix_http::uri::{PathAndQuery, Uri}; +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + use actix_service::{Service, Transform}; use actix_utils::future::{ready, Ready}; use bytes::Bytes; +use futures_core::ready; +use pin_project_lite::pin_project; use regex::Regex; use crate::{ + body::EitherBody, + http::{ + header, + uri::{PathAndQuery, Uri}, + StatusCode, + }, service::{ServiceRequest, ServiceResponse}, - Error, + Error, HttpResponse, }; -/// Determines the behavior of the [`NormalizePath`] middleware. +/// Determines the path rewriting behavior of the [`NormalizePath`] middleware. /// /// The default is `TrailingSlash::Trim`. #[non_exhaustive] @@ -86,7 +100,13 @@ impl Default for TrailingSlash { /// # }) /// ``` #[derive(Debug, Clone, Copy)] -pub struct NormalizePath(TrailingSlash); +pub struct NormalizePath { + /// Controls path normalization behavior. + trailing_slash_behavior: TrailingSlash, + + /// Returns redirects for non-normalized paths if `Some`. + use_redirects: Option, +} impl Default for NormalizePath { fn default() -> Self { @@ -95,14 +115,20 @@ impl Default for NormalizePath { in v4 from `Always` to `Trim`. Update your call to `NormalizePath::new(...)`." ); - Self(TrailingSlash::Trim) + Self { + trailing_slash_behavior: TrailingSlash::default(), + use_redirects: None, + } } } impl NormalizePath { /// Create new `NormalizePath` middleware with the specified trailing slash style. - pub fn new(trailing_slash_style: TrailingSlash) -> Self { - Self(trailing_slash_style) + pub fn new(behavior: TrailingSlash) -> Self { + Self { + trailing_slash_behavior: behavior, + use_redirects: None, + } } /// Constructs a new `NormalizePath` middleware with [trim](TrailingSlash::Trim) semantics. @@ -111,6 +137,32 @@ impl NormalizePath { pub fn trim() -> Self { Self::new(TrailingSlash::Trim) } + + /// Configures middleware to respond to requests with non-normalized paths with a 307 redirect. + /// + /// If configured + /// + /// For example, a request with the path `/api//v1/foo/` would receive a response with a + /// `Location: /api/v1/foo` header (assuming `Trim` trailing slash behavior.) + /// + /// To customize the status code, use [`use_redirects_with`](Self::use_redirects_with). + pub fn use_redirects(mut self) -> Self { + self.use_redirects = Some(StatusCode::TEMPORARY_REDIRECT); + self + } + + /// Configures middleware to respond to requests with non-normalized paths with a redirect. + /// + /// For example, a request with the path `/api//v1/foo/` would receive a 307 response with a + /// `Location: /api/v1/foo` header (assuming `Trim` trailing slash behavior.) + /// + /// # Panics + /// Panics if `status_code` is not a redirect (300-399). + pub fn use_redirects_with(mut self, status_code: StatusCode) -> Self { + assert!(status_code.is_redirection()); + self.use_redirects = Some(status_code); + self + } } impl Transform for NormalizePath @@ -118,35 +170,37 @@ where S: Service, Error = Error>, S::Future: 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse>; type Error = Error; - type Transform = NormalizePathNormalization; + type Transform = NormalizePathService; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ready(Ok(NormalizePathNormalization { + ready(Ok(NormalizePathService { service, merge_slash: Regex::new("//+").unwrap(), - trailing_slash_behavior: self.0, + trailing_slash_behavior: self.trailing_slash_behavior, + use_redirects: self.use_redirects, })) } } -pub struct NormalizePathNormalization { +pub struct NormalizePathService { service: S, merge_slash: Regex, trailing_slash_behavior: TrailingSlash, + use_redirects: Option, } -impl Service for NormalizePathNormalization +impl Service for NormalizePathService where S: Service, Error = Error>, S::Future: 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse>; type Error = Error; - type Future = S::Future; + type Future = NormalizePathFuture; actix_service::forward_ready!(service); @@ -189,7 +243,7 @@ where let query = parts.path_and_query.as_ref().and_then(|pq| pq.query()); let path = match query { - Some(q) => Bytes::from(format!("{}?{}", path, q)), + Some(query) => Bytes::from(format!("{}?{}", path, query)), None => Bytes::copy_from_slice(path.as_bytes()), }; parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap()); @@ -199,20 +253,87 @@ where req.head_mut().uri = uri; } } - self.service.call(req) + + match self.use_redirects { + Some(code) => { + let mut res = HttpResponse::with_body(code, ()); + res.headers_mut().insert( + header::LOCATION, + req.head_mut().uri.to_string().parse().unwrap(), + ); + NormalizePathFuture::redirect(req.into_response(res)) + } + + None => NormalizePathFuture::service(self.service.call(req)), + } + } +} + +pin_project! { + pub struct NormalizePathFuture, B> { + #[pin] inner: Inner, + } +} + +impl, B> NormalizePathFuture { + fn service(fut: S::Future) -> Self { + Self { + inner: Inner::Service { + fut, + _body: PhantomData, + }, + } + } + + fn redirect(res: ServiceResponse<()>) -> Self { + Self { + inner: Inner::Redirect { res: Some(res) }, + } + } +} + +pin_project! { + #[project = InnerProj] + enum Inner, B> { + Redirect { res: Option>, }, + Service { + #[pin] fut: S::Future, + _body: PhantomData, + }, + } +} + +impl Future for NormalizePathFuture +where + S: Service, Error = Error>, +{ + type Output = Result>, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match this.inner.project() { + InnerProj::Redirect { res } => { + Poll::Ready(Ok(res.take().unwrap().map_into_right_body())) + } + + InnerProj::Service { fut, .. } => { + let res = ready!(fut.poll(cx))?; + Poll::Ready(Ok(res.map_into_left_body())) + } + } } } #[cfg(test)] mod tests { - use actix_http::StatusCode; use actix_service::IntoService; use super::*; use crate::{ dev::ServiceRequest, guard::fn_guard, - test::{call_service, init_service, TestRequest}, + test::{self, call_service, init_service, TestRequest}, web, App, HttpResponse, }; @@ -256,7 +377,7 @@ mod tests { async fn trim_trailing_slashes() { let app = init_service( App::new() - .wrap(NormalizePath(TrailingSlash::Trim)) + .wrap(NormalizePath::new(TrailingSlash::Trim)) .service(web::resource("/").to(HttpResponse::Ok)) .service(web::resource("/v1/something").to(HttpResponse::Ok)) .service( @@ -292,11 +413,13 @@ mod tests { #[actix_rt::test] async fn trim_root_trailing_slashes_with_query() { let app = init_service( - App::new().wrap(NormalizePath(TrailingSlash::Trim)).service( - web::resource("/") - .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) - .to(HttpResponse::Ok), - ), + App::new() + .wrap(NormalizePath::new(TrailingSlash::Trim)) + .service( + web::resource("/") + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) + .to(HttpResponse::Ok), + ), ) .await; @@ -313,7 +436,7 @@ mod tests { async fn ensure_trailing_slash() { let app = init_service( App::new() - .wrap(NormalizePath(TrailingSlash::Always)) + .wrap(NormalizePath::new(TrailingSlash::Always)) .service(web::resource("/").to(HttpResponse::Ok)) .service(web::resource("/v1/something/").to(HttpResponse::Ok)) .service( @@ -350,7 +473,7 @@ mod tests { async fn ensure_root_trailing_slash_with_query() { let app = init_service( App::new() - .wrap(NormalizePath(TrailingSlash::Always)) + .wrap(NormalizePath::new(TrailingSlash::Always)) .service( web::resource("/") .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) @@ -372,7 +495,7 @@ mod tests { async fn keep_trailing_slash_unchanged() { let app = init_service( App::new() - .wrap(NormalizePath(TrailingSlash::MergeOnly)) + .wrap(NormalizePath::new(TrailingSlash::MergeOnly)) .service(web::resource("/").to(HttpResponse::Ok)) .service(web::resource("/v1/something").to(HttpResponse::Ok)) .service(web::resource("/v1/").to(HttpResponse::Ok)) @@ -486,4 +609,27 @@ mod tests { let res = normalize.call(req).await.unwrap(); assert!(res.status().is_success()); } + + #[actix_rt::test] + async fn should_return_redirects_when_configured() { + let normalize = NormalizePath::trim() + .use_redirects() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_uri("/v1/something/").to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); + + let normalize = NormalizePath::trim() + .use_redirects_with(StatusCode::PERMANENT_REDIRECT) + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_uri("/v1/something/").to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); + } }