diff --git a/CHANGES.md b/CHANGES.md index 67a97f14..2e4d23ba 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +### Add + +* `QueryConfig`, similar to `JsonConfig` for customizing error handling of query extractors. + ### Changes * `JsonConfig` is now `Send + Sync`, this implies that `error_handler` must be `Send + Sync` too. diff --git a/src/error.rs b/src/error.rs index e9e225f2..a3062b58 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,7 @@ use url::ParseError as UrlParseError; use crate::http::StatusCode; use crate::HttpResponse; +use serde_urlencoded::de; /// Errors which can occur when attempting to generate resource uri. #[derive(Debug, PartialEq, Display, From)] @@ -91,6 +92,23 @@ impl ResponseError for JsonPayloadError { } } +/// A set of errors that can occur during parsing query strings +#[derive(Debug, Display, From)] +pub enum QueryPayloadError { + /// Deserialize error + #[display(fmt = "Query deserialize error: {}", _0)] + Deserialize(de::Error), +} + +/// Return `BadRequest` for `QueryPayloadError` +impl ResponseError for QueryPayloadError { + fn error_response(&self) -> HttpResponse { + match *self { + QueryPayloadError::Deserialize(_) => HttpResponse::new(StatusCode::BAD_REQUEST), + } + } +} + /// Error type returned when reading body as lines. #[derive(From, Display, Debug)] pub enum ReadlinesError { @@ -143,6 +161,12 @@ mod tests { assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } + #[test] + fn test_query_payload_error() { + let resp: HttpResponse = QueryPayloadError::Deserialize(serde_urlencoded::from_str::("bad query").unwrap_err()).error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + #[test] fn test_readlines_error() { let resp: HttpResponse = ReadlinesError::LimitOverflow.error_response(); diff --git a/src/types/mod.rs b/src/types/mod.rs index 30ee7309..d01d597b 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -11,4 +11,4 @@ pub use self::form::{Form, FormConfig}; pub use self::json::{Json, JsonConfig}; pub use self::path::Path; pub use self::payload::{Payload, PayloadConfig}; -pub use self::query::Query; +pub use self::query::{Query, QueryConfig}; diff --git a/src/types/query.rs b/src/types/query.rs index f9f545d6..24225dad 100644 --- a/src/types/query.rs +++ b/src/types/query.rs @@ -1,5 +1,6 @@ //! Query extractor +use std::sync::Arc; use std::{fmt, ops}; use actix_http::error::Error; @@ -9,6 +10,7 @@ use serde_urlencoded; use crate::dev::Payload; use crate::extract::FromRequest; use crate::request::HttpRequest; +use crate::error::QueryPayloadError; #[derive(PartialEq, Eq, PartialOrd, Ord)] /// Extract typed information from from the request's query. @@ -115,32 +117,103 @@ impl FromRequest for Query where T: de::DeserializeOwned, { - type Config = (); type Error = Error; type Future = Result; + type Config = QueryConfig; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + let error_handler = req + .app_data::() + .map(|c| c.ehandler.clone()) + .unwrap_or(None); + serde_urlencoded::from_str::(req.query_string()) .map(|val| Ok(Query(val))) - .unwrap_or_else(|e| { + .unwrap_or_else(move |e| { + let e = QueryPayloadError::Deserialize(e); + log::debug!( "Failed during Query extractor deserialization. \ Request path: {:?}", req.path() ); - Err(e.into()) + + let e = if let Some(error_handler) = error_handler { + (error_handler)(e, req) + } else { + e.into() + }; + + Err(e) }) } } +/// Query extractor configuration +/// +/// ```rust +/// #[macro_use] extern crate serde_derive; +/// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// deserialize `Info` from request's querystring +/// fn index(info: web::Query) -> String { +/// format!("Welcome {}!", info.username) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").data( +/// // change query extractor configuration +/// web::Query::::configure(|cfg| { +/// cfg.error_handler(|err, req| { // <- create custom error response +/// error::InternalError::from_response( +/// err, HttpResponse::Conflict().finish()).into() +/// }) +/// })) +/// .route(web::post().to(index)) +/// ); +/// } +/// ``` +#[derive(Clone)] +pub struct QueryConfig { + ehandler: Option Error + Send + Sync>>, +} + +impl QueryConfig { + /// Set custom error handler + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(QueryPayloadError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.ehandler = Some(Arc::new(f)); + self + } +} + +impl Default for QueryConfig { + fn default() -> Self { + QueryConfig { + ehandler: None, + } + } +} + #[cfg(test)] mod tests { use derive_more::Display; use serde_derive::Deserialize; + use actix_http::http::StatusCode; use super::*; use crate::test::TestRequest; + use crate::error::InternalError; + use crate::HttpResponse; #[derive(Deserialize, Debug, Display)] struct Id { @@ -164,4 +237,19 @@ mod tests { let s = s.into_inner(); assert_eq!(s.id, "test1"); } + + #[test] + fn test_custom_error_responder() { + let req = TestRequest::with_uri("/name/user1/") + .data(QueryConfig::default().error_handler(|e, _| { + let resp = HttpResponse::UnprocessableEntity().finish(); + InternalError::from_response(e, resp).into() + })).to_srv_request(); + + let (req, mut pl) = req.into_parts(); + let query = Query::::from_request(&req, &mut pl); + + assert!(query.is_err()); + assert_eq!(query.unwrap_err().as_response_error().error_response().status(), StatusCode::UNPROCESSABLE_ENTITY); + } }