1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-02-20 09:40:34 +01:00

fix cors allowed header validation

This commit is contained in:
Nikolay Kim 2018-01-11 20:11:34 -08:00
parent 11342e4566
commit dab918261c
2 changed files with 26 additions and 39 deletions

View File

@ -109,7 +109,7 @@ pub enum CorsBuilderError {
impl ResponseError for CorsError { impl ResponseError for CorsError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HTTPBadRequest.into() HTTPBadRequest.build().body(format!("{}", self)).unwrap()
} }
} }
@ -159,7 +159,7 @@ impl<T> AllOrSome<T> {
/// for responses to be generated. /// for responses to be generated.
pub struct Cors { pub struct Cors {
methods: HashSet<Method>, methods: HashSet<Method>,
origins: AllOrSome<HashSet<Uri>>, origins: AllOrSome<HashSet<String>>,
origins_str: Option<HeaderValue>, origins_str: Option<HeaderValue>,
headers: AllOrSome<HashSet<HeaderName>>, headers: AllOrSome<HashSet<HeaderName>>,
expose_hdrs: Option<String>, expose_hdrs: Option<String>,
@ -225,17 +225,15 @@ impl Cors {
fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> { fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ORIGIN) { if let Some(hdr) = req.headers().get(header::ORIGIN) {
if let Ok(origin) = hdr.to_str() { if let Ok(origin) = hdr.to_str() {
if let Ok(uri) = Uri::try_from(origin) { return match self.origins {
return match self.origins { AllOrSome::All => Ok(()),
AllOrSome::All => Ok(()), AllOrSome::Some(ref allowed_origins) => {
AllOrSome::Some(ref allowed_origins) => { allowed_origins
allowed_origins .get(origin)
.get(&uri) .and_then(|_| Some(()))
.and_then(|_| Some(())) .ok_or_else(|| CorsError::OriginNotAllowed)
.ok_or_else(|| CorsError::OriginNotAllowed) }
} };
};
}
} }
Err(CorsError::BadOrigin) Err(CorsError::BadOrigin)
} else { } else {
@ -262,11 +260,11 @@ impl Cors {
} }
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> { fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { match self.headers {
if let Ok(headers) = hdr.to_str() { AllOrSome::All => Ok(()),
match self.headers { AllOrSome::Some(ref allowed_headers) => {
AllOrSome::All => return Ok(()), if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
AllOrSome::Some(ref allowed_headers) => { if let Ok(headers) = hdr.to_str() {
let mut hdrs = HashSet::new(); let mut hdrs = HashSet::new();
for hdr in headers.split(',') { for hdr in headers.split(',') {
match HeaderName::try_from(hdr.trim()) { match HeaderName::try_from(hdr.trim()) {
@ -280,11 +278,11 @@ impl Cors {
} }
return Ok(()) return Ok(())
} }
Err(CorsError::BadRequestHeaders)
} else {
Err(CorsError::MissingRequestHeaders)
} }
} }
Err(CorsError::BadRequestHeaders)
} else {
Err(CorsError::MissingRequestHeaders)
} }
} }
} }
@ -437,17 +435,15 @@ impl CorsBuilder {
/// ///
/// Defaults to `All`. /// Defaults to `All`.
/// ``` /// ```
pub fn allowed_origin<U>(&mut self, origin: U) -> &mut CorsBuilder pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder {
where Uri: HttpTryFrom<U>
{
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
match Uri::try_from(origin) { match Uri::try_from(origin) {
Ok(uri) => { Ok(_) => {
if cors.origins.is_all() { if cors.origins.is_all() {
cors.origins = AllOrSome::Some(HashSet::new()); cors.origins = AllOrSome::Some(HashSet::new());
} }
if let AllOrSome::Some(ref mut origins) = cors.origins { if let AllOrSome::Some(ref mut origins) = cors.origins {
origins.insert(uri); origins.insert(origin.to_owned());
} }
} }
Err(e) => { Err(e) => {

View File

@ -26,20 +26,19 @@ pub struct HttpChannel<T, H>
impl<T, H> HttpChannel<T, H> impl<T, H> HttpChannel<T, H>
where T: IoStream, H: HttpHandler + 'static where T: IoStream, H: HttpHandler + 'static
{ {
pub(crate) fn new(h: Rc<WorkerSettings<H>>, pub(crate) fn new(settings: Rc<WorkerSettings<H>>,
io: T, peer: Option<SocketAddr>, http2: bool) -> HttpChannel<T, H> io: T, peer: Option<SocketAddr>, http2: bool) -> HttpChannel<T, H>
{ {
h.add_channel(); settings.add_channel();
if http2 { if http2 {
HttpChannel { HttpChannel {
node: None, node: None,
proto: Some(HttpProtocol::H2( proto: Some(HttpProtocol::H2(
h2::Http2::new(h, io, peer, Bytes::new()))) } h2::Http2::new(settings, io, peer, Bytes::new()))) }
} else { } else {
HttpChannel { HttpChannel {
node: None, node: None,
proto: Some(HttpProtocol::H1( proto: Some(HttpProtocol::H1(h1::Http1::new(settings, io, peer))) }
h1::Http1::new(h, io, peer))) }
} }
} }
@ -58,14 +57,6 @@ impl<T, H> HttpChannel<T, H>
} }
} }
/*impl<T, H> Drop for HttpChannel<T, H>
where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static
{
fn drop(&mut self) {
println!("Drop http channel");
}
}*/
impl<T, H> Future for HttpChannel<T, H> impl<T, H> Future for HttpChannel<T, H>
where T: IoStream, H: HttpHandler + 'static where T: IoStream, H: HttpHandler + 'static
{ {