diff --git a/actix-router/src/de.rs b/actix-router/src/de.rs index ec7b1066..27aa49ef 100644 --- a/actix-router/src/de.rs +++ b/actix-router/src/de.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use serde::de::{self, Deserializer, Error as DeError, Visitor}; use serde::forward_to_deserialize_any; @@ -20,7 +22,7 @@ macro_rules! unsupported_type { } macro_rules! parse_single_value { - ($trait_fn:ident, $visit_fn:ident, $tp:expr) => { + ($trait_fn:ident) => { fn $trait_fn(self, visitor: V) -> Result where V: Visitor<'de>, @@ -34,15 +36,10 @@ macro_rules! parse_single_value { .as_str(), )) } else { - let decoded = FULL_QUOTER - .with(|q| q.requote(self.path[0].as_bytes())) - .unwrap_or_else(|| self.path[0].to_owned()); - - let v = decoded.parse().map_err(|_| { - de::Error::custom(format!("can not parse {:?} to a {}", &self.path[0], $tp)) - })?; - - visitor.$visit_fn(v) + Value { + value: &self.path[0], + } + .$trait_fn(visitor) } } }; @@ -56,7 +53,8 @@ macro_rules! parse_value { { let decoded = FULL_QUOTER .with(|q| q.requote(self.value.as_bytes())) - .unwrap_or_else(|| self.value.to_owned()); + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed(self.value)); let v = decoded.parse().map_err(|_| { de::value::Error::custom(format!("can not parse {:?} to a {}", self.value, $tp)) @@ -204,26 +202,26 @@ impl<'de, T: ResourcePath + 'de> Deserializer<'de> for PathDeserializer<'de, T> } unsupported_type!(deserialize_any, "'any'"); - unsupported_type!(deserialize_bytes, "bytes"); unsupported_type!(deserialize_option, "Option"); unsupported_type!(deserialize_identifier, "identifier"); unsupported_type!(deserialize_ignored_any, "ignored_any"); - parse_single_value!(deserialize_bool, visit_bool, "bool"); - parse_single_value!(deserialize_i8, visit_i8, "i8"); - parse_single_value!(deserialize_i16, visit_i16, "i16"); - parse_single_value!(deserialize_i32, visit_i32, "i32"); - parse_single_value!(deserialize_i64, visit_i64, "i64"); - parse_single_value!(deserialize_u8, visit_u8, "u8"); - parse_single_value!(deserialize_u16, visit_u16, "u16"); - parse_single_value!(deserialize_u32, visit_u32, "u32"); - parse_single_value!(deserialize_u64, visit_u64, "u64"); - parse_single_value!(deserialize_f32, visit_f32, "f32"); - parse_single_value!(deserialize_f64, visit_f64, "f64"); - parse_single_value!(deserialize_str, visit_string, "String"); - parse_single_value!(deserialize_string, visit_string, "String"); - parse_single_value!(deserialize_byte_buf, visit_string, "String"); - parse_single_value!(deserialize_char, visit_char, "char"); + parse_single_value!(deserialize_bool); + parse_single_value!(deserialize_i8); + parse_single_value!(deserialize_i16); + parse_single_value!(deserialize_i32); + parse_single_value!(deserialize_i64); + parse_single_value!(deserialize_u8); + parse_single_value!(deserialize_u16); + parse_single_value!(deserialize_u32); + parse_single_value!(deserialize_u64); + parse_single_value!(deserialize_f32); + parse_single_value!(deserialize_f64); + parse_single_value!(deserialize_str); + parse_single_value!(deserialize_string); + parse_single_value!(deserialize_bytes); + parse_single_value!(deserialize_byte_buf); + parse_single_value!(deserialize_char); } struct ParamsDeserializer<'de, T: ResourcePath> { @@ -303,8 +301,6 @@ impl<'de> Deserializer<'de> for Value<'de> { parse_value!(deserialize_u64, visit_u64, "u64"); parse_value!(deserialize_f32, visit_f32, "f32"); parse_value!(deserialize_f64, visit_f64, "f64"); - parse_value!(deserialize_string, visit_string, "String"); - parse_value!(deserialize_byte_buf, visit_string, "String"); parse_value!(deserialize_char, visit_char, "char"); fn deserialize_ignored_any(self, visitor: V) -> Result @@ -332,18 +328,38 @@ impl<'de> Deserializer<'de> for Value<'de> { visitor.visit_unit() } - fn deserialize_bytes(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_borrowed_bytes(self.value.as_bytes()) - } - fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, { - visitor.visit_borrowed_str(self.value) + match FULL_QUOTER.with(|q| q.requote(self.value.as_bytes())) { + Some(s) => visitor.visit_string(s), + None => visitor.visit_borrowed_str(self.value), + } + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match FULL_QUOTER.with(|q| q.requote(self.value.as_bytes())) { + Some(s) => visitor.visit_byte_buf(s.into()), + None => visitor.visit_borrowed_bytes(self.value.as_bytes()), + } + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_bytes(visitor) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(visitor) } fn deserialize_option(self, visitor: V) -> Result @@ -671,12 +687,12 @@ mod tests { fn deserialize_path_decode_seq() { let rdef = ResourceDef::new("/{key}/{value}"); - let mut path = Path::new("/%25/%2F"); + let mut path = Path::new("/%30%25/%30%2F"); rdef.capture_match_info(&mut path); let de = PathDeserializer::new(&path); let segment: (String, String) = serde::Deserialize::deserialize(de).unwrap(); - assert_eq!(segment.0, "%"); - assert_eq!(segment.1, "/"); + assert_eq!(segment.0, "0%"); + assert_eq!(segment.1, "0/"); } #[test] @@ -697,6 +713,32 @@ mod tests { assert_eq!(vals.value, "/"); } + #[test] + fn deserialize_borrowed() { + #[derive(Debug, Deserialize)] + struct Params<'a> { + val: &'a str, + } + + let rdef = ResourceDef::new("/{val}"); + + let mut path = Path::new("/X"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + let params: Params<'_> = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(params.val, "X"); + let de = PathDeserializer::new(&path); + let params: &str = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(params, "X"); + + let mut path = Path::new("/%2F"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + assert!( as serde::Deserialize>::deserialize(de).is_err()); + let de = PathDeserializer::new(&path); + assert!(<&str as serde::Deserialize>::deserialize(de).is_err()); + } + // #[test] // fn test_extract_path_decode() { // let mut router = Router::<()>::default();