diff --git a/Cargo.toml b/Cargo.toml index 59f101e48..e43712543 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,7 +67,6 @@ rand = "0.5" regex = "1.0" serde = "1.0" serde_json = "1.0" -serde_urlencoded = "0.5" sha1 = "0.6" smallvec = "0.6" time = "0.1" @@ -105,6 +104,10 @@ tokio-tls = { version="0.1", optional = true } openssl = { version="0.10", optional = true } tokio-openssl = { version="0.2", optional = true } +# forked url_encoded +itoa = "0.4" +dtoa = "0.4" + [dev-dependencies] env_logger = "0.5" serde_derive = "1.0" diff --git a/src/extractor.rs b/src/extractor.rs index 683e1526d..bebfbf207 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -139,15 +139,24 @@ impl fmt::Display for Path { /// #[macro_use] extern crate serde_derive; /// use actix_web::{App, Query, http}; /// -/// #[derive(Deserialize)] -/// struct Info { -/// username: String, -/// } +/// +///#[derive(Debug, Deserialize)] +///pub enum ResponseType { +/// Token, +/// Code +///} +/// +///#[derive(Deserialize)] +///pub struct AuthRequest { +/// id: u64, +/// response_type: ResponseType, +///} /// /// // use `with` extractor for query info /// // this handler get called only if request's query contains `username` field -/// fn index(info: Query) -> String { -/// format!("Welcome {}!", info.username) +/// // The correct request for this handler would be `/index.html?id=64&response_type=Code"` +/// fn index(info: Query) -> String { +/// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) /// } /// /// fn main() { diff --git a/src/lib.rs b/src/lib.rs index a1a09982a..d61c94f35 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -133,7 +133,6 @@ extern crate num_cpus; #[macro_use] extern crate percent_encoding; extern crate serde_json; -extern crate serde_urlencoded; extern crate smallvec; #[macro_use] extern crate actix as actix_inner; @@ -152,6 +151,7 @@ extern crate openssl; #[cfg(feature = "openssl")] extern crate tokio_openssl; +mod serde_urlencoded; mod application; mod body; mod context; diff --git a/src/serde_urlencoded/de.rs b/src/serde_urlencoded/de.rs new file mode 100644 index 000000000..affbfb37e --- /dev/null +++ b/src/serde_urlencoded/de.rs @@ -0,0 +1,298 @@ +//! Deserialization support for the `application/x-www-form-urlencoded` format. + +use serde::de::{self, DeserializeSeed, EnumAccess, IntoDeserializer, VariantAccess, Visitor}; +use serde::de::Error as de_Error; + +use serde::de::value::MapDeserializer; +use std::borrow::Cow; +use std::io::Read; +use url::form_urlencoded::Parse as UrlEncodedParse; +use url::form_urlencoded::parse; + +#[doc(inline)] +pub use serde::de::value::Error; + +/// Deserializes a `application/x-wwww-url-encoded` value from a `&[u8]`. +/// +/// ```ignore +/// let meal = vec![ +/// ("bread".to_owned(), "baguette".to_owned()), +/// ("cheese".to_owned(), "comté".to_owned()), +/// ("meat".to_owned(), "ham".to_owned()), +/// ("fat".to_owned(), "butter".to_owned()), +/// ]; +/// +/// assert_eq!( +/// serde_urlencoded::from_bytes::>( +/// b"bread=baguette&cheese=comt%C3%A9&meat=ham&fat=butter"), +/// Ok(meal)); +/// ``` +pub fn from_bytes<'de, T>(input: &'de [u8]) -> Result + where T: de::Deserialize<'de>, +{ + T::deserialize(Deserializer::new(parse(input))) +} + +/// Deserializes a `application/x-wwww-url-encoded` value from a `&str`. +/// +/// ```ignore +/// let meal = vec![ +/// ("bread".to_owned(), "baguette".to_owned()), +/// ("cheese".to_owned(), "comté".to_owned()), +/// ("meat".to_owned(), "ham".to_owned()), +/// ("fat".to_owned(), "butter".to_owned()), +/// ]; +/// +/// assert_eq!( +/// serde_urlencoded::from_str::>( +/// "bread=baguette&cheese=comt%C3%A9&meat=ham&fat=butter"), +/// Ok(meal)); +/// ``` +pub fn from_str<'de, T>(input: &'de str) -> Result + where T: de::Deserialize<'de>, +{ + from_bytes(input.as_bytes()) +} + +#[allow(dead_code)] +/// Convenience function that reads all bytes from `reader` and deserializes +/// them with `from_bytes`. +pub fn from_reader(mut reader: R) -> Result + where T: de::DeserializeOwned, + R: Read, +{ + let mut buf = vec![]; + reader.read_to_end(&mut buf) + .map_err(|e| { + de::Error::custom(format_args!("could not read input: {}", e)) + })?; + from_bytes(&buf) +} + +/// A deserializer for the `application/x-www-form-urlencoded` format. +/// +/// * Supported top-level outputs are structs, maps and sequences of pairs, +/// with or without a given length. +/// +/// * Main `deserialize` methods defers to `deserialize_map`. +/// +/// * Everything else but `deserialize_seq` and `deserialize_seq_fixed_size` +/// defers to `deserialize`. +pub struct Deserializer<'de> { + inner: MapDeserializer<'de, PartIterator<'de>, Error>, +} + +impl<'de> Deserializer<'de> { + /// Returns a new `Deserializer`. + pub fn new(parser: UrlEncodedParse<'de>) -> Self { + Deserializer { + inner: MapDeserializer::new(PartIterator(parser)), + } + } +} + +impl<'de> de::Deserializer<'de> for Deserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where V: de::Visitor<'de>, + { + self.deserialize_map(visitor) + } + + fn deserialize_map(self, visitor: V) -> Result + where V: de::Visitor<'de>, + { + visitor.visit_map(self.inner) + } + + fn deserialize_seq(self, visitor: V) -> Result + where V: de::Visitor<'de>, + { + visitor.visit_seq(self.inner) + } + + fn deserialize_unit(self, visitor: V) -> Result + where V: de::Visitor<'de>, + { + self.inner.end()?; + visitor.visit_unit() + } + + forward_to_deserialize_any! { + bool + u8 + u16 + u32 + u64 + i8 + i16 + i32 + i64 + f32 + f64 + char + str + string + option + bytes + byte_buf + unit_struct + newtype_struct + tuple_struct + struct + identifier + tuple + enum + ignored_any + } +} + +struct PartIterator<'de>(UrlEncodedParse<'de>); + +impl<'de> Iterator for PartIterator<'de> { + type Item = (Part<'de>, Part<'de>); + + fn next(&mut self) -> Option { + self.0.next().map(|(k, v)| (Part(k), Part(v))) + } +} + +struct Part<'de>(Cow<'de, str>); + +impl<'de> IntoDeserializer<'de> for Part<'de> +{ + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +macro_rules! forward_parsed_value { + ($($ty:ident => $method:ident,)*) => { + $( + fn $method(self, visitor: V) -> Result + where V: de::Visitor<'de> + { + match self.0.parse::<$ty>() { + Ok(val) => val.into_deserializer().$method(visitor), + Err(e) => Err(de::Error::custom(e)) + } + } + )* + } +} + +impl<'de> de::Deserializer<'de> for Part<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where V: de::Visitor<'de>, + { + self.0.into_deserializer().deserialize_any(visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where V: de::Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_enum(self, _name: &'static str, _variants: &'static [&'static str], visitor: V) -> Result + where V: de::Visitor<'de>, + { + visitor.visit_enum(ValueEnumAccess { value: self.0 }) + } + + forward_to_deserialize_any! { + char + str + string + unit + bytes + byte_buf + unit_struct + newtype_struct + tuple_struct + struct + identifier + tuple + ignored_any + seq + map + } + + forward_parsed_value! { + bool => deserialize_bool, + u8 => deserialize_u8, + u16 => deserialize_u16, + u32 => deserialize_u32, + u64 => deserialize_u64, + i8 => deserialize_i8, + i16 => deserialize_i16, + i32 => deserialize_i32, + i64 => deserialize_i64, + f32 => deserialize_f32, + f64 => deserialize_f64, + } +} + +/// Provides access to a keyword which can be deserialized into an enum variant. The enum variant +/// must be a unit variant, otherwise deserialization will fail. +struct ValueEnumAccess<'de> { + value: Cow<'de, str>, +} + +impl<'de> EnumAccess<'de> for ValueEnumAccess<'de> { + type Error = Error; + type Variant = UnitOnlyVariantAccess; + + fn variant_seed( + self, + seed: V, + ) -> Result<(V::Value, Self::Variant), Self::Error> + where V: DeserializeSeed<'de>, + { + let variant = seed.deserialize(self.value.into_deserializer())?; + Ok((variant, UnitOnlyVariantAccess)) + } +} + +/// A visitor for deserializing the contents of the enum variant. As we only support +/// `unit_variant`, all other variant types will return an error. +struct UnitOnlyVariantAccess; + +impl<'de> VariantAccess<'de> for UnitOnlyVariantAccess { + type Error = Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where T: DeserializeSeed<'de>, + { + Err(Error::custom("expected unit variant")) + } + + fn tuple_variant( + self, + _len: usize, + _visitor: V, + ) -> Result + where V: Visitor<'de>, + { + Err(Error::custom("expected unit variant")) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where V: Visitor<'de>, + { + Err(Error::custom("expected unit variant")) + } +} diff --git a/src/serde_urlencoded/mod.rs b/src/serde_urlencoded/mod.rs new file mode 100644 index 000000000..6ee62c2a9 --- /dev/null +++ b/src/serde_urlencoded/mod.rs @@ -0,0 +1,114 @@ +//! `x-www-form-urlencoded` meets Serde + +extern crate itoa; +extern crate dtoa; + +pub mod de; +pub mod ser; + +#[doc(inline)] +pub use self::de::{Deserializer, from_bytes, from_reader, from_str}; +#[doc(inline)] +pub use self::ser::{Serializer, to_string}; + +#[cfg(test)] +mod tests { + #[test] + fn deserialize_bytes() { + let result = vec![("first".to_owned(), 23), ("last".to_owned(), 42)]; + + assert_eq!(super::from_bytes(b"first=23&last=42"), + Ok(result)); + } + + #[test] + fn deserialize_str() { + let result = vec![("first".to_owned(), 23), ("last".to_owned(), 42)]; + + assert_eq!(super::from_str("first=23&last=42"), + Ok(result)); + } + + #[test] + fn deserialize_reader() { + let result = vec![("first".to_owned(), 23), ("last".to_owned(), 42)]; + + assert_eq!(super::from_reader(b"first=23&last=42" as &[_]), + Ok(result)); + } + + #[test] + fn deserialize_option() { + let result = vec![ + ("first".to_owned(), Some(23)), + ("last".to_owned(), Some(42)), + ]; + assert_eq!(super::from_str("first=23&last=42"), Ok(result)); + } + + #[test] + fn deserialize_unit() { + assert_eq!(super::from_str(""), Ok(())); + assert_eq!(super::from_str("&"), Ok(())); + assert_eq!(super::from_str("&&"), Ok(())); + assert!(super::from_str::<()>("first=23").is_err()); + } + + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] + enum X { + A, + B, + C, + } + + #[test] + fn deserialize_unit_enum() { + let result = vec![ + ("one".to_owned(), X::A), + ("two".to_owned(), X::B), + ("three".to_owned(), X::C) + ]; + + assert_eq!(super::from_str("one=A&two=B&three=C"), Ok(result)); + } + + #[test] + fn serialize_option_map_int() { + let params = &[("first", Some(23)), ("middle", None), ("last", Some(42))]; + + assert_eq!(super::to_string(params), + Ok("first=23&last=42".to_owned())); + } + + #[test] + fn serialize_option_map_string() { + let params = + &[("first", Some("hello")), ("middle", None), ("last", Some("world"))]; + + assert_eq!(super::to_string(params), + Ok("first=hello&last=world".to_owned())); + } + + #[test] + fn serialize_option_map_bool() { + let params = &[("one", Some(true)), ("two", Some(false))]; + + assert_eq!(super::to_string(params), + Ok("one=true&two=false".to_owned())); + } + + #[test] + fn serialize_map_bool() { + let params = &[("one", true), ("two", false)]; + + assert_eq!(super::to_string(params), + Ok("one=true&two=false".to_owned())); + } + + #[test] + fn serialize_unit_enum() { + let params = &[("one", X::A), ("two", X::B), ("three", X::C)]; + assert_eq!(super::to_string(params), + Ok("one=A&two=B&three=C".to_owned())); + } +} diff --git a/src/serde_urlencoded/ser/key.rs b/src/serde_urlencoded/ser/key.rs new file mode 100644 index 000000000..2d138f18e --- /dev/null +++ b/src/serde_urlencoded/ser/key.rs @@ -0,0 +1,76 @@ +use super::super::ser::Error; +use super::super::ser::part::Sink; +use serde::Serialize; +use std::borrow::Cow; +use std::ops::Deref; + +pub enum Key<'key> { + Static(&'static str), + Dynamic(Cow<'key, str>), +} + +impl<'key> Deref for Key<'key> { + type Target = str; + + fn deref(&self) -> &str { + match *self { + Key::Static(key) => key, + Key::Dynamic(ref key) => key, + } + } +} + +impl<'key> From> for Cow<'static, str> { + fn from(key: Key<'key>) -> Self { + match key { + Key::Static(key) => key.into(), + Key::Dynamic(key) => key.into_owned().into(), + } + } +} + +pub struct KeySink { + end: End, +} + +impl KeySink + where End: for<'key> FnOnce(Key<'key>) -> Result +{ + pub fn new(end: End) -> Self { + KeySink { end: end } + } +} + +impl Sink for KeySink + where End: for<'key> FnOnce(Key<'key>) -> Result +{ + type Ok = Ok; + + fn serialize_static_str(self, + value: &'static str) + -> Result { + (self.end)(Key::Static(value)) + } + + fn serialize_str(self, value: &str) -> Result { + (self.end)(Key::Dynamic(value.into())) + } + + fn serialize_string(self, value: String) -> Result { + (self.end)(Key::Dynamic(value.into())) + } + + fn serialize_none(self) -> Result { + Err(self.unsupported()) + } + + fn serialize_some(self, + _value: &T) + -> Result { + Err(self.unsupported()) + } + + fn unsupported(self) -> Error { + Error::Custom("unsupported key".into()) + } +} diff --git a/src/serde_urlencoded/ser/mod.rs b/src/serde_urlencoded/ser/mod.rs new file mode 100644 index 000000000..f8d5e13e6 --- /dev/null +++ b/src/serde_urlencoded/ser/mod.rs @@ -0,0 +1,507 @@ +//! Serialization support for the `application/x-www-form-urlencoded` format. + +mod key; +mod pair; +mod part; +mod value; + +use serde::ser; +use std::borrow::Cow; +use std::error; +use std::fmt; +use std::str; +use url::form_urlencoded::Serializer as UrlEncodedSerializer; +use url::form_urlencoded::Target as UrlEncodedTarget; + +/// Serializes a value into a `application/x-wwww-url-encoded` `String` buffer. +/// +/// ```ignore +/// let meal = &[ +/// ("bread", "baguette"), +/// ("cheese", "comté"), +/// ("meat", "ham"), +/// ("fat", "butter"), +/// ]; +/// +/// assert_eq!( +/// serde_urlencoded::to_string(meal), +/// Ok("bread=baguette&cheese=comt%C3%A9&meat=ham&fat=butter".to_owned())); +/// ``` +pub fn to_string(input: T) -> Result { + let mut urlencoder = UrlEncodedSerializer::new("".to_owned()); + input.serialize(Serializer::new(&mut urlencoder))?; + Ok(urlencoder.finish()) +} + +/// A serializer for the `application/x-www-form-urlencoded` format. +/// +/// * Supported top-level inputs are structs, maps and sequences of pairs, +/// with or without a given length. +/// +/// * Supported keys and values are integers, bytes (if convertible to strings), +/// unit structs and unit variants. +/// +/// * Newtype structs defer to their inner values. +pub struct Serializer<'output, Target: 'output + UrlEncodedTarget> { + urlencoder: &'output mut UrlEncodedSerializer, +} + +impl<'output, Target: 'output + UrlEncodedTarget> Serializer<'output, Target> { + /// Returns a new `Serializer`. + pub fn new(urlencoder: &'output mut UrlEncodedSerializer) -> Self { + Serializer { urlencoder: urlencoder } + } +} + +/// Errors returned during serializing to `application/x-www-form-urlencoded`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Error { + Custom(Cow<'static, str>), + Utf8(str::Utf8Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::Custom(ref msg) => msg.fmt(f), + Error::Utf8(ref err) => write!(f, "invalid UTF-8: {}", err), + } + } +} + +impl error::Error for Error { + fn description(&self) -> &str { + match *self { + Error::Custom(ref msg) => msg, + Error::Utf8(ref err) => error::Error::description(err), + } + } + + /// The lower-level cause of this error, in the case of a `Utf8` error. + fn cause(&self) -> Option<&error::Error> { + match *self { + Error::Custom(_) => None, + Error::Utf8(ref err) => Some(err), + } + } +} + +impl ser::Error for Error { + fn custom(msg: T) -> Self { + Error::Custom(format!("{}", msg).into()) + } +} + +/// Sequence serializer. +pub struct SeqSerializer<'output, Target: 'output + UrlEncodedTarget> { + urlencoder: &'output mut UrlEncodedSerializer, +} + +/// Tuple serializer. +/// +/// Mostly used for arrays. +pub struct TupleSerializer<'output, Target: 'output + UrlEncodedTarget> { + urlencoder: &'output mut UrlEncodedSerializer, +} + +/// Tuple struct serializer. +/// +/// Never instantiated, tuple structs are not supported. +pub struct TupleStructSerializer<'output, T: 'output + UrlEncodedTarget> { + inner: ser::Impossible<&'output mut UrlEncodedSerializer, Error>, +} + +/// Tuple variant serializer. +/// +/// Never instantiated, tuple variants are not supported. +pub struct TupleVariantSerializer<'output, T: 'output + UrlEncodedTarget> { + inner: ser::Impossible<&'output mut UrlEncodedSerializer, Error>, +} + +/// Map serializer. +pub struct MapSerializer<'output, Target: 'output + UrlEncodedTarget> { + urlencoder: &'output mut UrlEncodedSerializer, + key: Option>, +} + +/// Struct serializer. +pub struct StructSerializer<'output, Target: 'output + UrlEncodedTarget> { + urlencoder: &'output mut UrlEncodedSerializer, +} + +/// Struct variant serializer. +/// +/// Never instantiated, struct variants are not supported. +pub struct StructVariantSerializer<'output, T: 'output + UrlEncodedTarget> { + inner: ser::Impossible<&'output mut UrlEncodedSerializer, Error>, +} + +impl<'output, Target> ser::Serializer for Serializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + type SerializeSeq = SeqSerializer<'output, Target>; + type SerializeTuple = TupleSerializer<'output, Target>; + type SerializeTupleStruct = TupleStructSerializer<'output, Target>; + type SerializeTupleVariant = TupleVariantSerializer<'output, Target>; + type SerializeMap = MapSerializer<'output, Target>; + type SerializeStruct = StructSerializer<'output, Target>; + type SerializeStructVariant = StructVariantSerializer<'output, Target>; + + /// Returns an error. + fn serialize_bool(self, _v: bool) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_i8(self, _v: i8) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_i16(self, _v: i16) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_i32(self, _v: i32) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_i64(self, _v: i64) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_u8(self, _v: u8) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_u16(self, _v: u16) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_u32(self, _v: u32) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_u64(self, _v: u64) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_f32(self, _v: f32) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_f64(self, _v: f64) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_char(self, _v: char) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_str(self, _value: &str) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_bytes(self, _value: &[u8]) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_unit(self) -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_unit_struct(self, + _name: &'static str) + -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_unit_variant(self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str) + -> Result { + Err(Error::top_level()) + } + + /// Serializes the inner value, ignoring the newtype name. + fn serialize_newtype_struct + (self, + _name: &'static str, + value: &T) + -> Result { + value.serialize(self) + } + + /// Returns an error. + fn serialize_newtype_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T) + -> Result { + Err(Error::top_level()) + } + + /// Returns `Ok`. + fn serialize_none(self) -> Result { + Ok(self.urlencoder) + } + + /// Serializes the given value. + fn serialize_some + (self, + value: &T) + -> Result { + value.serialize(self) + } + + /// Serialize a sequence, given length (if any) is ignored. + fn serialize_seq(self, + _len: Option) + -> Result { + Ok(SeqSerializer { urlencoder: self.urlencoder }) + } + + /// Returns an error. + fn serialize_tuple(self, + _len: usize) + -> Result { + Ok(TupleSerializer { urlencoder: self.urlencoder }) + } + + /// Returns an error. + fn serialize_tuple_struct(self, + _name: &'static str, + _len: usize) + -> Result { + Err(Error::top_level()) + } + + /// Returns an error. + fn serialize_tuple_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize) + -> Result { + Err(Error::top_level()) + } + + /// Serializes a map, given length is ignored. + fn serialize_map(self, + _len: Option) + -> Result { + Ok(MapSerializer { + urlencoder: self.urlencoder, + key: None, + }) + } + + /// Serializes a struct, given length is ignored. + fn serialize_struct(self, + _name: &'static str, + _len: usize) + -> Result { + Ok(StructSerializer { urlencoder: self.urlencoder }) + } + + /// Returns an error. + fn serialize_struct_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize) + -> Result { + Err(Error::top_level()) + } +} + +impl<'output, Target> ser::SerializeSeq for SeqSerializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + + fn serialize_element(&mut self, + value: &T) + -> Result<(), Error> { + value.serialize(pair::PairSerializer::new(self.urlencoder)) + } + + fn end(self) -> Result { + Ok(self.urlencoder) + } +} + +impl<'output, Target> ser::SerializeTuple for TupleSerializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + + fn serialize_element(&mut self, + value: &T) + -> Result<(), Error> { + value.serialize(pair::PairSerializer::new(self.urlencoder)) + } + + fn end(self) -> Result { + Ok(self.urlencoder) + } +} + +impl<'output, Target> ser::SerializeTupleStruct + for + TupleStructSerializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + + fn serialize_field(&mut self, + value: &T) + -> Result<(), Error> { + self.inner.serialize_field(value) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +impl<'output, Target> ser::SerializeTupleVariant + for + TupleVariantSerializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + + fn serialize_field(&mut self, + value: &T) + -> Result<(), Error> { + self.inner.serialize_field(value) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +impl<'output, Target> ser::SerializeMap for MapSerializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + + fn serialize_entry + (&mut self, + key: &K, + value: &V) + -> Result<(), Error> { + let key_sink = key::KeySink::new(|key| { + let value_sink = value::ValueSink::new(self.urlencoder, &key); + value.serialize(part::PartSerializer::new(value_sink))?; + self.key = None; + Ok(()) + }); + let entry_serializer = part::PartSerializer::new(key_sink); + key.serialize(entry_serializer) + } + + fn serialize_key(&mut self, + key: &T) + -> Result<(), Error> { + let key_sink = key::KeySink::new(|key| Ok(key.into())); + let key_serializer = part::PartSerializer::new(key_sink); + self.key = Some(key.serialize(key_serializer)?); + Ok(()) + } + + fn serialize_value(&mut self, + value: &T) + -> Result<(), Error> { + { + let key = self.key.as_ref().ok_or_else(|| Error::no_key())?; + let value_sink = value::ValueSink::new(self.urlencoder, &key); + value.serialize(part::PartSerializer::new(value_sink))?; + } + self.key = None; + Ok(()) + } + + fn end(self) -> Result { + Ok(self.urlencoder) + } +} + +impl<'output, Target> ser::SerializeStruct for StructSerializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + + fn serialize_field(&mut self, + key: &'static str, + value: &T) + -> Result<(), Error> { + let value_sink = value::ValueSink::new(self.urlencoder, key); + value.serialize(part::PartSerializer::new(value_sink)) + } + + fn end(self) -> Result { + Ok(self.urlencoder) + } +} + +impl<'output, Target> ser::SerializeStructVariant + for + StructVariantSerializer<'output, Target> + where Target: 'output + UrlEncodedTarget, +{ + type Ok = &'output mut UrlEncodedSerializer; + type Error = Error; + + fn serialize_field(&mut self, + key: &'static str, + value: &T) + -> Result<(), Error> { + self.inner.serialize_field(key, value) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +impl Error { + fn top_level() -> Self { + let msg = "top-level serializer supports only maps and structs"; + Error::Custom(msg.into()) + } + + fn no_key() -> Self { + let msg = "tried to serialize a value before serializing key"; + Error::Custom(msg.into()) + } +} diff --git a/src/serde_urlencoded/ser/pair.rs b/src/serde_urlencoded/ser/pair.rs new file mode 100644 index 000000000..37f984755 --- /dev/null +++ b/src/serde_urlencoded/ser/pair.rs @@ -0,0 +1,257 @@ +use super::super::ser::Error; +use super::super::ser::key::KeySink; +use super::super::ser::part::PartSerializer; +use super::super::ser::value::ValueSink; +use serde::ser; +use std::borrow::Cow; +use std::mem; +use url::form_urlencoded::Serializer as UrlEncodedSerializer; +use url::form_urlencoded::Target as UrlEncodedTarget; + +pub struct PairSerializer<'target, Target: 'target + UrlEncodedTarget> { + urlencoder: &'target mut UrlEncodedSerializer, + state: PairState, +} + +impl<'target, Target> PairSerializer<'target, Target> + where Target: 'target + UrlEncodedTarget, +{ + pub fn new(urlencoder: &'target mut UrlEncodedSerializer) -> Self { + PairSerializer { + urlencoder: urlencoder, + state: PairState::WaitingForKey, + } + } +} + +impl<'target, Target> ser::Serializer for PairSerializer<'target, Target> + where Target: 'target + UrlEncodedTarget, +{ + type Ok = (); + type Error = Error; + type SerializeSeq = ser::Impossible<(), Error>; + type SerializeTuple = Self; + type SerializeTupleStruct = ser::Impossible<(), Error>; + type SerializeTupleVariant = ser::Impossible<(), Error>; + type SerializeMap = ser::Impossible<(), Error>; + type SerializeStruct = ser::Impossible<(), Error>; + type SerializeStructVariant = ser::Impossible<(), Error>; + + fn serialize_bool(self, _v: bool) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_i8(self, _v: i8) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_i16(self, _v: i16) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_i32(self, _v: i32) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_i64(self, _v: i64) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_u8(self, _v: u8) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_u16(self, _v: u16) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_u32(self, _v: u32) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_u64(self, _v: u64) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_f32(self, _v: f32) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_f64(self, _v: f64) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_char(self, _v: char) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_str(self, _value: &str) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_bytes(self, _value: &[u8]) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_unit(self) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_unit_variant(self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str) + -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_newtype_struct + (self, + _name: &'static str, + value: &T) + -> Result<(), Error> { + value.serialize(self) + } + + fn serialize_newtype_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T) + -> Result<(), Error> { + Err(Error::unsupported_pair()) + } + + fn serialize_none(self) -> Result<(), Error> { + Ok(()) + } + + fn serialize_some(self, + value: &T) + -> Result<(), Error> { + value.serialize(self) + } + + fn serialize_seq(self, + _len: Option) + -> Result { + Err(Error::unsupported_pair()) + } + + fn serialize_tuple(self, len: usize) -> Result { + if len == 2 { + Ok(self) + } else { + Err(Error::unsupported_pair()) + } + } + + fn serialize_tuple_struct(self, + _name: &'static str, + _len: usize) + -> Result { + Err(Error::unsupported_pair()) + } + + fn serialize_tuple_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize) + -> Result { + Err(Error::unsupported_pair()) + } + + fn serialize_map(self, + _len: Option) + -> Result { + Err(Error::unsupported_pair()) + } + + fn serialize_struct(self, + _name: &'static str, + _len: usize) + -> Result { + Err(Error::unsupported_pair()) + } + + fn serialize_struct_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize) + -> Result { + Err(Error::unsupported_pair()) + } +} + +impl<'target, Target> ser::SerializeTuple for PairSerializer<'target, Target> + where Target: 'target + UrlEncodedTarget, +{ + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, + value: &T) + -> Result<(), Error> { + match mem::replace(&mut self.state, PairState::Done) { + PairState::WaitingForKey => { + let key_sink = KeySink::new(|key| Ok(key.into())); + let key_serializer = PartSerializer::new(key_sink); + self.state = PairState::WaitingForValue { + key: value.serialize(key_serializer)?, + }; + Ok(()) + }, + PairState::WaitingForValue { key } => { + let result = { + let value_sink = ValueSink::new(self.urlencoder, &key); + let value_serializer = PartSerializer::new(value_sink); + value.serialize(value_serializer) + }; + if result.is_ok() { + self.state = PairState::Done; + } else { + self.state = PairState::WaitingForValue { key: key }; + } + result + }, + PairState::Done => Err(Error::done()), + } + } + + fn end(self) -> Result<(), Error> { + if let PairState::Done = self.state { + Ok(()) + } else { + Err(Error::not_done()) + } + } +} + +enum PairState { + WaitingForKey, + WaitingForValue { key: Cow<'static, str> }, + Done, +} + +impl Error { + fn done() -> Self { + Error::Custom("this pair has already been serialized".into()) + } + + fn not_done() -> Self { + Error::Custom("this pair has not yet been serialized".into()) + } + + fn unsupported_pair() -> Self { + Error::Custom("unsupported pair".into()) + } +} diff --git a/src/serde_urlencoded/ser/part.rs b/src/serde_urlencoded/ser/part.rs new file mode 100644 index 000000000..2ce352d24 --- /dev/null +++ b/src/serde_urlencoded/ser/part.rs @@ -0,0 +1,222 @@ +use ::serde; + +use super::super::dtoa; +use super::super::itoa; +use super::super::ser::Error; +use std::str; + +pub struct PartSerializer { + sink: S, +} + +impl PartSerializer { + pub fn new(sink: S) -> Self { + PartSerializer { sink: sink } + } +} + +pub trait Sink: Sized { + type Ok; + + fn serialize_static_str(self, + value: &'static str) + -> Result; + + fn serialize_str(self, value: &str) -> Result; + fn serialize_string(self, value: String) -> Result; + fn serialize_none(self) -> Result; + + fn serialize_some + (self, + value: &T) + -> Result; + + fn unsupported(self) -> Error; +} + +impl serde::ser::Serializer for PartSerializer { + type Ok = S::Ok; + type Error = Error; + type SerializeSeq = serde::ser::Impossible; + type SerializeTuple = serde::ser::Impossible; + type SerializeTupleStruct = serde::ser::Impossible; + type SerializeTupleVariant = serde::ser::Impossible; + type SerializeMap = serde::ser::Impossible; + type SerializeStruct = serde::ser::Impossible; + type SerializeStructVariant = serde::ser::Impossible; + + fn serialize_bool(self, v: bool) -> Result { + self.sink.serialize_static_str(if v { "true" } else { "false" }) + } + + fn serialize_i8(self, v: i8) -> Result { + self.serialize_integer(v) + } + + fn serialize_i16(self, v: i16) -> Result { + self.serialize_integer(v) + } + + fn serialize_i32(self, v: i32) -> Result { + self.serialize_integer(v) + } + + fn serialize_i64(self, v: i64) -> Result { + self.serialize_integer(v) + } + + fn serialize_u8(self, v: u8) -> Result { + self.serialize_integer(v) + } + + fn serialize_u16(self, v: u16) -> Result { + self.serialize_integer(v) + } + + fn serialize_u32(self, v: u32) -> Result { + self.serialize_integer(v) + } + + fn serialize_u64(self, v: u64) -> Result { + self.serialize_integer(v) + } + + fn serialize_f32(self, v: f32) -> Result { + self.serialize_floating(v) + } + + fn serialize_f64(self, v: f64) -> Result { + self.serialize_floating(v) + } + + fn serialize_char(self, v: char) -> Result { + self.sink.serialize_string(v.to_string()) + } + + fn serialize_str(self, value: &str) -> Result { + self.sink.serialize_str(value) + } + + fn serialize_bytes(self, value: &[u8]) -> Result { + match str::from_utf8(value) { + Ok(value) => self.sink.serialize_str(value), + Err(err) => Err(Error::Utf8(err)), + } + } + + fn serialize_unit(self) -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + self.sink.serialize_static_str(name.into()) + } + + fn serialize_unit_variant(self, + _name: &'static str, + _variant_index: u32, + variant: &'static str) + -> Result { + self.sink.serialize_static_str(variant.into()) + } + + fn serialize_newtype_struct + (self, + _name: &'static str, + value: &T) + -> Result { + value.serialize(self) + } + + fn serialize_newtype_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T) + -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_none(self) -> Result { + self.sink.serialize_none() + } + + fn serialize_some(self, + value: &T) + -> Result { + self.sink.serialize_some(value) + } + + fn serialize_seq(self, + _len: Option) + -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_tuple(self, + _len: usize) + -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_tuple_struct(self, + _name: &'static str, + _len: usize) + -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_tuple_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize) + -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_map(self, + _len: Option) + -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_struct(self, + _name: &'static str, + _len: usize) + -> Result { + Err(self.sink.unsupported()) + } + + fn serialize_struct_variant + (self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize) + -> Result { + Err(self.sink.unsupported()) + } +} + +impl PartSerializer { + fn serialize_integer(self, value: I) -> Result + where I: itoa::Integer, + { + let mut buf = [b'\0'; 20]; + let len = itoa::write(&mut buf[..], value).unwrap(); + let part = unsafe { str::from_utf8_unchecked(&buf[0..len]) }; + serde::ser::Serializer::serialize_str(self, part) + } + + fn serialize_floating(self, value: F) -> Result + where F: dtoa::Floating, + { + let mut buf = [b'\0'; 24]; + let len = dtoa::write(&mut buf[..], value).unwrap(); + let part = unsafe { str::from_utf8_unchecked(&buf[0..len]) }; + serde::ser::Serializer::serialize_str(self, part) + } +} diff --git a/src/serde_urlencoded/ser/value.rs b/src/serde_urlencoded/ser/value.rs new file mode 100644 index 000000000..ef63b010d --- /dev/null +++ b/src/serde_urlencoded/ser/value.rs @@ -0,0 +1,59 @@ +use super::super::ser::Error; +use super::super::ser::part::{PartSerializer, Sink}; +use serde::ser::Serialize; +use std::str; +use url::form_urlencoded::Serializer as UrlEncodedSerializer; +use url::form_urlencoded::Target as UrlEncodedTarget; + +pub struct ValueSink<'key, 'target, Target> + where Target: 'target + UrlEncodedTarget, +{ + urlencoder: &'target mut UrlEncodedSerializer, + key: &'key str, +} + +impl<'key, 'target, Target> ValueSink<'key, 'target, Target> + where Target: 'target + UrlEncodedTarget, +{ + pub fn new(urlencoder: &'target mut UrlEncodedSerializer, + key: &'key str) + -> Self { + ValueSink { + urlencoder: urlencoder, + key: key, + } + } +} + +impl<'key, 'target, Target> Sink for ValueSink<'key, 'target, Target> + where Target: 'target + UrlEncodedTarget, +{ + type Ok = (); + + fn serialize_str(self, value: &str) -> Result<(), Error> { + self.urlencoder.append_pair(self.key, value); + Ok(()) + } + + fn serialize_static_str(self, value: &'static str) -> Result<(), Error> { + self.serialize_str(value) + } + + fn serialize_string(self, value: String) -> Result<(), Error> { + self.serialize_str(&value) + } + + fn serialize_none(self) -> Result { + Ok(()) + } + + fn serialize_some(self, + value: &T) + -> Result { + value.serialize(PartSerializer::new(self)) + } + + fn unsupported(self) -> Error { + Error::Custom("unsupported value".into()) + } +} diff --git a/tests/test_handlers.rs b/tests/test_handlers.rs index 95bd5be2e..bc65b93f8 100644 --- a/tests/test_handlers.rs +++ b/tests/test_handlers.rs @@ -91,6 +91,48 @@ fn test_query_extractor() { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } +#[derive(Deserialize, Debug)] +pub enum ResponseType { + Token, + Code +} + +#[derive(Debug, Deserialize)] +pub struct AuthRequest { + id: u64, + response_type: ResponseType, +} + +#[test] +fn test_query_enum_extractor() { + let mut srv = test::TestServer::new(|app| { + app.resource("/index.html", |r| { + r.with(|p: Query| format!("{:?}", p.into_inner())) + }); + }); + + // client request + let request = srv + .get() + .uri(srv.url("/index.html?id=64&response_type=Code")) + .finish() + .unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(b"AuthRequest { id: 64, response_type: Code }")); + + let request = srv.get().uri(srv.url("/index.html?id=64&response_type=Co")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let request = srv.get().uri(srv.url("/index.html?response_type=Code")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); +} + #[test] fn test_async_extractor_async() { let mut srv = test::TestServer::new(|app| {