1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-23 23:51:06 +01:00

CORS: Take TryInto instead of TryFrom (#106)

This commit is contained in:
FallenWarrior2k 2020-09-26 21:02:45 +02:00 committed by GitHub
parent 99fe08f332
commit f185a9c7e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 26 deletions

View File

@ -2,6 +2,9 @@
## Unreleased - 2020-xx-xx ## Unreleased - 2020-xx-xx
* Implement `allowed_origin_fn` builder method. * Implement `allowed_origin_fn` builder method.
* Use `TryInto` instead of `TryFrom` where applicable. [#106]
[#106]: https://github.com/actix/actix-extras/pull/106
## 0.3.0 - 2020-09-11 ## 0.3.0 - 2020-09-11

View File

@ -52,7 +52,7 @@
#![deny(missing_docs, missing_debug_implementations, rust_2018_idioms)] #![deny(missing_docs, missing_debug_implementations, rust_2018_idioms)]
use std::collections::HashSet; use std::collections::HashSet;
use std::convert::TryFrom; use std::convert::TryInto;
use std::fmt; use std::fmt;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::rc::Rc; use std::rc::Rc;
@ -265,7 +265,7 @@ impl Cors {
/// Builder panics if supplied origin is not valid uri. /// Builder panics if supplied origin is not valid uri.
pub fn allowed_origin(mut self, origin: &str) -> Cors { pub fn allowed_origin(mut self, origin: &str) -> Cors {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
match Uri::try_from(origin) { match TryInto::<Uri>::try_into(origin) {
Ok(_) => { Ok(_) => {
if cors.origins.is_all() { if cors.origins.is_all() {
cors.origins = AllOrSome::Some(HashSet::new()); cors.origins = AllOrSome::Some(HashSet::new());
@ -306,13 +306,13 @@ impl Cors {
pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
where where
U: IntoIterator<Item = M>, U: IntoIterator<Item = M>,
Method: TryFrom<M>, M: TryInto<Method>,
<Method as TryFrom<M>>::Error: Into<HttpError>, <M as TryInto<Method>>::Error: Into<HttpError>,
{ {
self.methods = true; self.methods = true;
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
for m in methods { for m in methods {
match Method::try_from(m) { match m.try_into() {
Ok(method) => { Ok(method) => {
cors.methods.insert(method); cors.methods.insert(method);
} }
@ -329,11 +329,11 @@ impl Cors {
/// Set an allowed header. /// Set an allowed header.
pub fn allowed_header<H>(mut self, header: H) -> Cors pub fn allowed_header<H>(mut self, header: H) -> Cors
where where
HeaderName: TryFrom<H>, H: TryInto<HeaderName>,
<HeaderName as TryFrom<H>>::Error: Into<HttpError>, <H as TryInto<HeaderName>>::Error: Into<HttpError>,
{ {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
match HeaderName::try_from(header) { match header.try_into() {
Ok(method) => { Ok(method) => {
if cors.headers.is_all() { if cors.headers.is_all() {
cors.headers = AllOrSome::Some(HashSet::new()); cors.headers = AllOrSome::Some(HashSet::new());
@ -362,12 +362,12 @@ impl Cors {
pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
where where
U: IntoIterator<Item = H>, U: IntoIterator<Item = H>,
HeaderName: TryFrom<H>, H: TryInto<HeaderName>,
<HeaderName as TryFrom<H>>::Error: Into<HttpError>, <H as TryInto<HeaderName>>::Error: Into<HttpError>,
{ {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
for h in headers { for h in headers {
match HeaderName::try_from(h) { match h.try_into() {
Ok(method) => { Ok(method) => {
if cors.headers.is_all() { if cors.headers.is_all() {
cors.headers = AllOrSome::Some(HashSet::new()); cors.headers = AllOrSome::Some(HashSet::new());
@ -397,11 +397,11 @@ impl Cors {
pub fn expose_headers<U, H>(mut self, headers: U) -> Cors pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
where where
U: IntoIterator<Item = H>, U: IntoIterator<Item = H>,
HeaderName: TryFrom<H>, H: TryInto<HeaderName>,
<HeaderName as TryFrom<H>>::Error: Into<HttpError>, <H as TryInto<HeaderName>>::Error: Into<HttpError>,
{ {
for h in headers { for h in headers {
match HeaderName::try_from(h) { match h.try_into() {
Ok(method) => { Ok(method) => {
self.expose_hdrs.insert(method); self.expose_hdrs.insert(method);
} }
@ -528,7 +528,7 @@ impl Cors {
let s = origins let s = origins
.iter() .iter()
.fold(String::new(), |s, v| format!("{}, {}", s, v)); .fold(String::new(), |s, v| format!("{}, {}", s, v));
cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap()); cors.origins_str = Some(s[2..].try_into().unwrap());
} }
if !slf.expose_hdrs.is_empty() { if !slf.expose_hdrs.is_empty() {
@ -686,7 +686,7 @@ impl Inner {
fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> { fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) { if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) {
if let Ok(meth) = hdr.to_str() { if let Ok(meth) = hdr.to_str() {
if let Ok(method) = Method::try_from(meth) { if let Ok(method) = meth.try_into() {
return self return self
.methods .methods
.get(&method) .get(&method)
@ -711,7 +711,7 @@ impl Inner {
#[allow(clippy::mutable_key_type)] // FIXME: revisit here #[allow(clippy::mutable_key_type)] // FIXME: revisit here
let mut hdrs = HashSet::new(); let mut hdrs = HashSet::new();
for hdr in headers.split(',') { for hdr in headers.split(',') {
match HeaderName::try_from(hdr.trim()) { match hdr.trim().try_into() {
Ok(hdr) => hdrs.insert(hdr), Ok(hdr) => hdrs.insert(hdr),
Err(_) => return Err(CorsError::BadRequestHeaders), Err(_) => return Err(CorsError::BadRequestHeaders),
}; };
@ -766,13 +766,12 @@ where
// allowed headers // allowed headers
let headers = if let Some(headers) = self.inner.headers.as_ref() { let headers = if let Some(headers) = self.inner.headers.as_ref() {
Some( Some(
HeaderValue::try_from( headers
&headers .iter()
.iter() .fold(String::new(), |s, v| s + "," + v.as_str())
.fold(String::new(), |s, v| s + "," + v.as_str()) .as_str()[1..]
.as_str()[1..], .try_into()
) .unwrap(),
.unwrap(),
) )
} else if let Some(hdr) = } else if let Some(hdr) =
req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS)
@ -842,7 +841,7 @@ where
if let Some(ref expose) = inner.expose_hdrs { if let Some(ref expose) = inner.expose_hdrs {
res.headers_mut().insert( res.headers_mut().insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS, header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::try_from(expose.as_str()).unwrap(), expose.as_str().try_into().unwrap(),
); );
} }
if inner.supports_credentials { if inner.supports_credentials {
@ -859,7 +858,7 @@ where
Vec::with_capacity(hdr.as_bytes().len() + 8); Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes()); val.extend(hdr.as_bytes());
val.extend(b", Origin"); val.extend(b", Origin");
HeaderValue::try_from(&val[..]).unwrap() val.try_into().unwrap()
} else { } else {
HeaderValue::from_static("Origin") HeaderValue::from_static("Origin")
}; };
@ -880,9 +879,30 @@ where
mod tests { mod tests {
use actix_service::{fn_service, Transform}; use actix_service::{fn_service, Transform};
use actix_web::test::{self, TestRequest}; use actix_web::test::{self, TestRequest};
use std::convert::Infallible;
use super::*; use super::*;
#[actix_rt::test]
async fn allowed_header_tryfrom() {
let _cors = Cors::new().allowed_header("Content-Type");
}
#[actix_rt::test]
async fn allowed_header_tryinto() {
struct ContentType;
impl TryInto<HeaderName> for ContentType {
type Error = Infallible;
fn try_into(self) -> Result<HeaderName, Self::Error> {
Ok(HeaderName::from_static("content-type"))
}
}
let _cors = Cors::new().allowed_header(ContentType);
}
#[actix_rt::test] #[actix_rt::test]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
async fn cors_validates_illegal_allow_credentials() { async fn cors_validates_illegal_allow_credentials() {