From 68cf32e8482bcf12a4129378c09f8fa9016c5014 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 26 Mar 2018 15:58:30 -0700 Subject: [PATCH] add path and query extractors --- Cargo.toml | 1 + src/extractor.rs | 231 +++++++++++++++++++++++++++++++++++ src/handler.rs | 3 + src/header/shared/charset.rs | 1 - src/httprequest.rs | 40 +++++- src/lib.rs | 6 +- src/param.rs | 5 + src/payload.rs | 5 + src/server/h1.rs | 1 + 9 files changed, 290 insertions(+), 3 deletions(-) create mode 100644 src/extractor.rs diff --git a/Cargo.toml b/Cargo.toml index d1facfd37..aa9f2df72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ rand = "0.4" regex = "0.2" serde = "1.0" serde_json = "1.0" +serde_urlencoded = "0.5" sha1 = "0.6" smallvec = "0.6" time = "0.1" diff --git a/src/extractor.rs b/src/extractor.rs new file mode 100644 index 000000000..1dcc2cdf7 --- /dev/null +++ b/src/extractor.rs @@ -0,0 +1,231 @@ +use serde_urlencoded; +use serde::de::{self, Deserializer, Visitor, Error as DeError}; + +use error::{Error, ErrorBadRequest}; +use httprequest::HttpRequest; + +pub trait HttpRequestExtractor<'de> { + fn extract(&self, req: &'de HttpRequest) -> Result + where T: de::Deserialize<'de>, S: 'static; +} + +/// Extract typed information from the request's path. +/// +/// ## Example +/// +/// ```rust +/// # extern crate bytes; +/// # extern crate actix_web; +/// # extern crate futures; +/// #[macro_use] extern crate serde_derive; +/// use actix_web::*; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// fn index(mut req: HttpRequest) -> Result { +/// let info: Info = req.extract(Path)?; // <- extract path info using serde +/// Ok(format!("Welcome {}!", info.username)) +/// } +/// +/// fn main() { +/// let app = Application::new() +/// .resource("/{username}/index.html", // <- define path parameters +/// |r| r.method(Method::GET).f(index)); +/// } +/// ``` +pub struct Path; + +impl<'de> HttpRequestExtractor<'de> for Path { + #[inline] + fn extract(&self, req: &'de HttpRequest) -> Result + where T: de::Deserialize<'de>, S: 'static, + { + Ok(de::Deserialize::deserialize(PathExtractor{req: req}) + .map_err(ErrorBadRequest)?) + } +} + +/// Extract typed information from from the request's query. +/// +/// ## Example +/// +/// ```rust +/// # extern crate bytes; +/// # extern crate actix_web; +/// # extern crate futures; +/// #[macro_use] extern crate serde_derive; +/// use actix_web::*; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// fn index(mut req: HttpRequest) -> Result { +/// let info: Info = req.extract(Query)?; // <- extract query info using serde +/// Ok(format!("Welcome {}!", info.username)) +/// } +/// +/// # fn main() {} +/// ``` +pub struct Query; + +impl<'de> HttpRequestExtractor<'de> for Query { + #[inline] + fn extract(&self, req: &'de HttpRequest) -> Result + where T: de::Deserialize<'de>, S: 'static, + { + Ok(serde_urlencoded::from_str::(req.query_string()) + .map_err(ErrorBadRequest)?) + } +} + +macro_rules! unsupported_type { + ($trait_fn:ident, $name:expr) => { + fn $trait_fn(self, _: V) -> Result + where V: Visitor<'de> + { + Err(de::value::Error::custom(concat!("unsupported type: ", $name))) + } + }; +} + +pub struct PathExtractor<'de, S: 'static> { + req: &'de HttpRequest +} + +impl<'de, S: 'static> Deserializer<'de> for PathExtractor<'de, S> +{ + type Error = de::value::Error; + + fn deserialize_map(self, visitor: V) -> Result + where V: Visitor<'de>, + { + visitor.visit_map(de::value::MapDeserializer::new( + self.req.match_info().iter().map(|&(ref k, ref v)| (k.as_ref(), v.as_ref())))) + } + + fn deserialize_struct(self, _: &'static str, _: &'static [&'static str], visitor: V) + -> Result + where V: Visitor<'de>, + { + self.deserialize_map(visitor) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct(self, _: &'static str, visitor: V) + -> Result + where V: Visitor<'de> + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct(self, _: &'static str, visitor: V) + -> Result + where V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where V: Visitor<'de> + { + if self.req.match_info().len() < len { + Err(de::value::Error::custom( + format!("wrong number of parameters: {} expected {}", + self.req.match_info().len(), len).as_str())) + } else { + visitor.visit_seq(de::value::SeqDeserializer::new( + self.req.match_info().iter().map(|&(_, ref v)| v.as_ref()))) + } + } + + fn deserialize_tuple_struct(self, _: &'static str, _: usize, visitor: V) + -> Result + where V: Visitor<'de> + { + visitor.visit_seq(de::value::SeqDeserializer::new( + self.req.match_info().iter().map(|&(_, ref v)| v.as_ref()))) + } + + fn deserialize_enum(self, _: &'static str, _: &'static [&'static str], _: V) + -> Result + where V: Visitor<'de> + { + Err(de::value::Error::custom("unsupported type: enum")) + } + + unsupported_type!(deserialize_any, "'any'"); + unsupported_type!(deserialize_bool, "bool"); + unsupported_type!(deserialize_i8, "i8"); + unsupported_type!(deserialize_i16, "i16"); + unsupported_type!(deserialize_i32, "i32"); + unsupported_type!(deserialize_i64, "i64"); + unsupported_type!(deserialize_u8, "u8"); + unsupported_type!(deserialize_u16, "u16"); + unsupported_type!(deserialize_u32, "u32"); + unsupported_type!(deserialize_u64, "u64"); + unsupported_type!(deserialize_f32, "f32"); + unsupported_type!(deserialize_f64, "f64"); + unsupported_type!(deserialize_char, "char"); + unsupported_type!(deserialize_str, "str"); + unsupported_type!(deserialize_string, "String"); + unsupported_type!(deserialize_bytes, "bytes"); + unsupported_type!(deserialize_byte_buf, "byte buf"); + unsupported_type!(deserialize_option, "Option"); + unsupported_type!(deserialize_seq, "sequence"); + unsupported_type!(deserialize_identifier, "identifier"); + unsupported_type!(deserialize_ignored_any, "ignored_any"); +} + +#[cfg(test)] +mod tests { + use super::*; + use router::{Router, Pattern}; + use resource::Resource; + use test::TestRequest; + use server::ServerSettings; + + #[derive(Deserialize)] + struct MyStruct { + key: String, + value: String, + } + + #[derive(Deserialize)] + struct Id { + id: String, + } + + #[test] + fn test_request_extract() { + let mut req = TestRequest::with_uri("/name/user1/?id=test").finish(); + + let mut resource = Resource::<()>::default(); + resource.name("index"); + let mut routes = Vec::new(); + routes.push((Pattern::new("index", "/{key}/{value}/"), Some(resource))); + let (router, _) = Router::new("", ServerSettings::default(), routes); + assert!(router.recognize(&mut req).is_some()); + + let s: MyStruct = req.extract(Path).unwrap(); + assert_eq!(s.key, "name"); + assert_eq!(s.value, "user1"); + + let s: (String, String) = req.extract(Path).unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, "user1"); + + let s: Id = req.extract(Query).unwrap(); + assert_eq!(s.id, "test"); + } +} diff --git a/src/handler.rs b/src/handler.rs index 69839073a..7b8f2d480 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -227,6 +227,9 @@ impl From>> for Reply { } } +/// Convenience type alias +pub type FutureResponse = Box>; + impl Responder for Box> where I: Responder + 'static, E: Into + 'static diff --git a/src/header/shared/charset.rs b/src/header/shared/charset.rs index 6a07fda59..f3d3f06f4 100644 --- a/src/header/shared/charset.rs +++ b/src/header/shared/charset.rs @@ -1,7 +1,6 @@ #![allow(unused)] use std::fmt::{self, Display}; use std::str::FromStr; -use std::ascii::AsciiExt; use self::Charset::*; diff --git a/src/httprequest.rs b/src/httprequest.rs index 214861ed7..3584ec52d 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -10,6 +10,7 @@ use failure; use url::{Url, form_urlencoded}; use http::{header, Uri, Method, Version, HeaderMap, Extensions, StatusCode}; use tokio_io::AsyncRead; +use serde::de; use body::Body; use info::ConnectionInfo; @@ -19,7 +20,8 @@ use payload::Payload; use httpmessage::HttpMessage; use httpresponse::{HttpResponse, HttpResponseBuilder}; use helpers::SharedHttpInnerMessage; -use error::{UrlGenerationError, CookieParseError, PayloadError}; +use extractor::HttpRequestExtractor; +use error::{Error, UrlGenerationError, CookieParseError, PayloadError}; pub struct HttpInnerMessage { @@ -395,6 +397,42 @@ impl HttpRequest { unsafe{ mem::transmute(&mut self.as_mut().params) } } + /// Extract typed information from path. + /// + /// ## Example + /// + /// ```rust + /// # extern crate bytes; + /// # extern crate actix_web; + /// # extern crate futures; + /// #[macro_use] extern crate serde_derive; + /// use actix_web::*; + /// + /// #[derive(Deserialize)] + /// struct Info { + /// username: String, + /// } + /// + /// fn index(mut req: HttpRequest) -> Result { + /// let info: Info = req.extract(Path)?; // <- extract path info using serde + /// let info: Info = req.extract(Query)?; // <- extract query info + /// Ok(format!("Welcome {}!", info.username)) + /// } + /// + /// fn main() { + /// let app = Application::new() + /// .resource("/{username}/index.html", // <- define path parameters + /// |r| r.method(Method::GET).f(index)); + /// } + /// ``` + pub fn extract<'a, T, D>(&'a self, ds: D) -> Result + where S: 'static, + T: de::Deserialize<'a>, + D: HttpRequestExtractor<'a> + { + ds.extract(self) + } + /// Checks if a connection should be kept alive. pub fn keep_alive(&self) -> bool { self.as_ref().keep_alive() diff --git a/src/lib.rs b/src/lib.rs index 05552fdc0..30428aed4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,7 @@ extern crate url; extern crate libc; extern crate serde; extern crate serde_json; +extern crate serde_urlencoded; extern crate flate2; #[cfg(feature="brotli")] extern crate brotli2; @@ -118,6 +119,7 @@ mod resource; mod param; mod payload; mod pipeline; +mod extractor; pub mod client; pub mod fs; @@ -137,11 +139,12 @@ pub use application::Application; pub use httpmessage::HttpMessage; pub use httprequest::HttpRequest; pub use httpresponse::HttpResponse; -pub use handler::{Either, Reply, Responder, NormalizePath, AsyncResponder}; +pub use handler::{Either, Reply, Responder, NormalizePath, AsyncResponder, FutureResponse}; pub use route::Route; pub use resource::Resource; pub use context::HttpContext; pub use server::HttpServer; +pub use extractor::{Path, Query}; // re-exports pub use http::{Method, StatusCode, Version}; @@ -187,4 +190,5 @@ pub mod dev { pub use param::{FromParam, Params}; pub use httpmessage::{UrlEncoded, MessageBody}; pub use httpresponse::HttpResponseBuilder; + pub use extractor::HttpRequestExtractor; } diff --git a/src/param.rs b/src/param.rs index a062bd096..2630ac7a5 100644 --- a/src/param.rs +++ b/src/param.rs @@ -46,6 +46,11 @@ impl<'a> Params<'a> { self.0.is_empty() } + /// Check number of extracted parameters + pub fn len(&self) -> usize { + self.0.len() + } + /// Get matched parameter by name without type conversion pub fn get(&'a self, key: &str) -> Option<&'a str> { for item in self.0.iter() { diff --git a/src/payload.rs b/src/payload.rs index 6fb63f69e..8afff81c9 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -302,6 +302,11 @@ impl PayloadHelper where S: Stream { } } + /// Get mutable reference to an inner stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } + #[inline] fn poll_stream(&mut self) -> Poll { self.stream.poll().map(|res| { diff --git a/src/server/h1.rs b/src/server/h1.rs index e1f61461b..6d504e41c 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -549,6 +549,7 @@ impl Reader { msg }; + // https://tools.ietf.org/html/rfc7230#section-3.3.3 let decoder = if has_te && chunked(&msg.get_mut().headers)? { // Chunked encoding Some(Decoder::chunked())