From 359d5d5c80dc893f4c75bf0bda20d332d3b306b9 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Mon, 6 Feb 2023 17:06:47 +0000 Subject: [PATCH] refactor codegen route guards --- actix-web-codegen/src/route.rs | 255 ++++++++++-------- actix-web-codegen/tests/test_macro.rs | 5 + .../tests/trybuild/route-custom-method.rs | 30 ++- 3 files changed, 168 insertions(+), 122 deletions(-) diff --git a/actix-web-codegen/src/route.rs b/actix-web-codegen/src/route.rs index 594a58626..717ac844c 100644 --- a/actix-web-codegen/src/route.rs +++ b/actix-web-codegen/src/route.rs @@ -6,11 +6,11 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{parse_macro_input, AttributeArgs, Ident, LitStr, Meta, NestedMeta, Path}; -macro_rules! method_type { +macro_rules! standard_method_type { ( $($variant:ident, $upper:ident, $lower:ident,)+ ) => { - #[derive(Debug, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum MethodType { $( $variant, @@ -27,13 +27,7 @@ macro_rules! method_type { fn parse(method: &str) -> Result { match method { $(stringify!($upper) => Ok(Self::$variant),)+ - _ => { - if method.chars().all(|c| c.is_ascii_uppercase()) { - Ok(Self::Method) - } else { - Err(format!("HTTP method must be uppercase: `{}`", method)) - } - }, + _ => Err(format!("HTTP method must be uppercase: `{}`", method)), } } @@ -47,13 +41,7 @@ macro_rules! method_type { }; } -#[derive(Eq, Hash, PartialEq)] -struct MethodTypeExt { - method: MethodType, - custom_method: Option, -} - -method_type! { +standard_method_type! { Get, GET, get, Post, POST, post, Put, PUT, put, @@ -63,7 +51,15 @@ method_type! { Options, OPTIONS, options, Trace, TRACE, trace, Patch, PATCH, patch, - Method, METHOD, method, +} + +impl TryFrom<&syn::LitStr> for MethodType { + type Error = syn::Error; + + fn try_from(value: &syn::LitStr) -> Result { + Self::parse(value.value().as_str()) + .map_err(|message| syn::Error::new_spanned(value, message)) + } } impl ToTokens for MethodType { @@ -73,27 +69,107 @@ impl ToTokens for MethodType { } } -impl ToTokens for MethodTypeExt { - fn to_tokens(&self, stream: &mut TokenStream2) { - match self.method { - MethodType::Method => { - let ident = Ident::new( - self.custom_method.as_ref().unwrap().value().as_str(), - Span::call_site(), - ); - stream.append(ident); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum MethodTypeExt { + Standard(MethodType), + Custom(LitStr), +} + +impl MethodTypeExt { + /// Returns a single method guard token stream. + fn to_tokens_single_guard(&self) -> TokenStream2 { + match self { + MethodTypeExt::Standard(method) => { + quote! { + .guard(::actix_web::guard::#method()) + } + } + MethodTypeExt::Custom(lit) => { + quote! { + .guard(::actix_web::guard::Method( + ::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap() + )) + } + } + } + } + + /// Returns a multi-method guard chain token stream. + fn to_tokens_multi_guard(&self, or_chain: Vec) -> TokenStream2 { + debug_assert!( + or_chain.len() > 0, + "empty or_chain passed to multi-guard constructor" + ); + + match self { + MethodTypeExt::Standard(method) => { + quote! { + .guard( + ::actix_web::guard::Any(::actix_web::guard::#method()) + #(#or_chain)* + ) + } + } + MethodTypeExt::Custom(lit) => { + quote! { + .guard( + ::actix_web::guard::Any( + ::actix_web::guard::Method( + ::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap() + ) + ) + #(#or_chain)* + ) + } + } + } + } + + /// Returns a token stream containing the `.or` chain to be passed in to + /// [`MethodTypeExt::to_tokens_multi_guard()`]. + fn to_tokens_multi_guard_or_chain(&self) -> TokenStream2 { + match self { + MethodTypeExt::Standard(method_type) => { + quote! { + .or(::actix_web::guard::#method_type()) + } + } + MethodTypeExt::Custom(lit) => { + quote! { + .or( + ::actix_web::guard::Method( + ::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap() + ) + ) + } } - _ => self.method.to_tokens(stream), } } } -impl TryFrom<&syn::LitStr> for MethodType { +impl ToTokens for MethodTypeExt { + fn to_tokens(&self, stream: &mut TokenStream2) { + match self { + MethodTypeExt::Custom(lit_str) => { + let ident = Ident::new(lit_str.value().as_str(), Span::call_site()); + stream.append(ident); + } + MethodTypeExt::Standard(method) => method.to_tokens(stream), + } + } +} + +impl TryFrom<&syn::LitStr> for MethodTypeExt { type Error = syn::Error; fn try_from(value: &syn::LitStr) -> Result { - Self::parse(value.value().as_str()) - .map_err(|message| syn::Error::new_spanned(value, message)) + match MethodType::try_from(value) { + Ok(method) => Ok(MethodTypeExt::Standard(method)), + Err(_) if value.value().chars().all(|c| c.is_ascii_uppercase()) => { + Ok(MethodTypeExt::Custom(value.clone())) + } + Err(err) => Err(err), + } } } @@ -127,12 +203,7 @@ impl Args { let is_route_macro = method.is_none(); if let Some(method) = method { - methods.insert({ - MethodTypeExt { - method, - custom_method: None, - } - }); + methods.insert(MethodTypeExt::Standard(method)); } for arg in args { @@ -149,6 +220,7 @@ impl Args { )); } }, + NestedMeta::Meta(syn::Meta::NameValue(nv)) => { if nv.path.is_ident("name") { if let syn::Lit::Str(lit) = nv.lit { @@ -184,23 +256,10 @@ impl Args { "HTTP method forbidden here. To handle multiple methods, use `route` instead", )); } else if let syn::Lit::Str(ref lit) = nv.lit { - let method = MethodType::try_from(lit)?; - if !methods.insert({ - if method == MethodType::Method { - MethodTypeExt { - method, - custom_method: Some(lit.clone()), - } - } else { - MethodTypeExt { - method, - custom_method: None, - } - } - }) { + if !methods.insert(MethodTypeExt::try_from(lit)?) { return Err(syn::Error::new_spanned( &nv.lit, - &format!( + format!( "HTTP method defined more than once: `{}`", lit.value() ), @@ -219,11 +278,13 @@ impl Args { )); } } + arg => { return Err(syn::Error::new_spanned(arg, "Unknown attribute.")); } } } + Ok(Args { path: path.unwrap(), resource_name, @@ -343,72 +404,34 @@ impl ToTokens for Route { .as_ref() .map_or_else(|| name.to_string(), LitStr::value); - let method_guards = { - let mut others = methods.iter(); - let first = others.next().unwrap(); - let first_method = &first.method; - if methods.len() > 1 { - let mut mult_method_guards: Vec = Vec::new(); - for method_ext in methods { - let method_type = &method_ext.method; - let custom_method = &method_ext.custom_method; - match custom_method { - Some(lit) => { - mult_method_guards.push(quote! { - .or(::actix_web::guard::#method_type(::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap())) - }); - } - None => { - mult_method_guards.push(quote! { - .or(::actix_web::guard::#method_type()) - }); - } - } - } - match &first.custom_method { - Some(lit) => { - quote! { - .guard( - ::actix_web::guard::Any(::actix_web::guard::#first_method(::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap())) - #(#mult_method_guards)* - ) - } - } - None => { - quote! { - .guard( - ::actix_web::guard::Any(::actix_web::guard::#first_method()) - #(#mult_method_guards)* - ) - } - } - } - } else { - match &first.custom_method { - Some(lit) => { - quote! { - .guard(::actix_web::guard::#first_method(::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap())) - } - } - None => { - quote! { - .guard(::actix_web::guard::#first_method()) - } - } - } - } - }; - quote! { - let __resource = ::actix_web::Resource::new(#path) - .name(#resource_name) - #method_guards - #(.guard(::actix_web::guard::fn_guard(#guards)))* - #(.wrap(#wrappers))* - .to(#name); - ::actix_web::dev::HttpServiceFactory::register(__resource, __config); + let method_guards = { + debug_assert!(methods.len() > 0, "Args::methods should not be empty"); + + let mut others = methods.iter(); + let first = others.next().unwrap(); + + if methods.len() > 1 { + let other_method_guards = others + .map(|method_ext| method_ext.to_tokens_multi_guard_or_chain()) + .collect(); + + first.to_tokens_multi_guard(other_method_guards) + } else { + first.to_tokens_single_guard() } - }) - .collect(); + }; + + quote! { + let __resource = ::actix_web::Resource::new(#path) + .name(#resource_name) + #method_guards + #(.guard(::actix_web::guard::fn_guard(#guards)))* + #(.wrap(#wrappers))* + .to(#name); + ::actix_web::dev::HttpServiceFactory::register(__resource, __config); + } + }) + .collect(); let stream = quote! { #(#doc_attributes)* diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs index a95d2aa37..f28654cd9 100644 --- a/actix-web-codegen/tests/test_macro.rs +++ b/actix-web-codegen/tests/test_macro.rs @@ -86,6 +86,11 @@ async fn get_param_test(_: web::Path) -> impl Responder { HttpResponse::Ok() } +#[route("/hello", method = "HELLO")] +async fn custom_route_test() -> impl Responder { + HttpResponse::Ok() +} + #[route( "/multi", method = "GET", diff --git a/actix-web-codegen/tests/trybuild/route-custom-method.rs b/actix-web-codegen/tests/trybuild/route-custom-method.rs index 6cc6af5a8..525a60b83 100644 --- a/actix-web-codegen/tests/trybuild/route-custom-method.rs +++ b/actix-web-codegen/tests/trybuild/route-custom-method.rs @@ -1,19 +1,37 @@ -use actix_web_codegen::*; -use actix_web::http::Method; use std::str::FromStr; -#[route("/", method="UNEXPECTED")] +use actix_web::http::Method; +use actix_web_codegen::route; + +#[route("/single", method = "CUSTOM")] async fn index() -> String { - "Hello World!".to_owned() + "Hello Single!".to_owned() +} + +#[route("/multi", method = "GET", method = "CUSTOM")] +async fn custom() -> String { + "Hello Multi!".to_owned() } #[actix_web::main] async fn main() { use actix_web::App; - let srv = actix_test::start(|| App::new().service(index)); + let srv = actix_test::start(|| App::new().service(index).service(custom)); - let request = srv.request(Method::from_str("UNEXPECTED").unwrap(), srv.url("/")); + let request = srv.request(Method::GET, srv.url("/")); + let response = request.send().await.unwrap(); + assert!(response.status().is_client_error()); + + let request = srv.request(Method::from_str("CUSTOM").unwrap(), srv.url("/single")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(Method::GET, srv.url("/multi")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(Method::from_str("CUSTOM").unwrap(), srv.url("/multi")); let response = request.send().await.unwrap(); assert!(response.status().is_success()); }