diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index d5f8f969..49f74c72 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -109,7 +109,7 @@ pub enum CorsBuilderError { impl ResponseError for CorsError { fn error_response(&self) -> HttpResponse { - HTTPBadRequest.into() + HTTPBadRequest.build().body(format!("{}", self)).unwrap() } } @@ -159,7 +159,7 @@ impl AllOrSome { /// for responses to be generated. pub struct Cors { methods: HashSet, - origins: AllOrSome>, + origins: AllOrSome>, origins_str: Option, headers: AllOrSome>, expose_hdrs: Option, @@ -225,17 +225,15 @@ impl Cors { fn validate_origin(&self, req: &mut HttpRequest) -> Result<(), CorsError> { if let Some(hdr) = req.headers().get(header::ORIGIN) { if let Ok(origin) = hdr.to_str() { - if let Ok(uri) = Uri::try_from(origin) { - return match self.origins { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_origins) => { - allowed_origins - .get(&uri) - .and_then(|_| Some(())) - .ok_or_else(|| CorsError::OriginNotAllowed) - } - }; - } + return match self.origins { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_origins) => { + allowed_origins + .get(origin) + .and_then(|_| Some(())) + .ok_or_else(|| CorsError::OriginNotAllowed) + } + }; } Err(CorsError::BadOrigin) } else { @@ -262,11 +260,11 @@ impl Cors { } fn validate_allowed_headers(&self, req: &mut HttpRequest) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { - if let Ok(headers) = hdr.to_str() { - match self.headers { - AllOrSome::All => return Ok(()), - AllOrSome::Some(ref allowed_headers) => { + match self.headers { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_headers) => { + if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { + if let Ok(headers) = hdr.to_str() { let mut hdrs = HashSet::new(); for hdr in headers.split(',') { match HeaderName::try_from(hdr.trim()) { @@ -280,11 +278,11 @@ impl Cors { } return Ok(()) } + Err(CorsError::BadRequestHeaders) + } else { + Err(CorsError::MissingRequestHeaders) } } - Err(CorsError::BadRequestHeaders) - } else { - Err(CorsError::MissingRequestHeaders) } } } @@ -437,17 +435,15 @@ impl CorsBuilder { /// /// Defaults to `All`. /// ``` - pub fn allowed_origin(&mut self, origin: U) -> &mut CorsBuilder - where Uri: HttpTryFrom - { + pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder { if let Some(cors) = cors(&mut self.cors, &self.error) { match Uri::try_from(origin) { - Ok(uri) => { + Ok(_) => { if cors.origins.is_all() { cors.origins = AllOrSome::Some(HashSet::new()); } if let AllOrSome::Some(ref mut origins) = cors.origins { - origins.insert(uri); + origins.insert(origin.to_owned()); } } Err(e) => { diff --git a/src/server/channel.rs b/src/server/channel.rs index aadbe853..6ea14c45 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -26,20 +26,19 @@ pub struct HttpChannel impl HttpChannel where T: IoStream, H: HttpHandler + 'static { - pub(crate) fn new(h: Rc>, + pub(crate) fn new(settings: Rc>, io: T, peer: Option, http2: bool) -> HttpChannel { - h.add_channel(); + settings.add_channel(); if http2 { HttpChannel { node: None, proto: Some(HttpProtocol::H2( - h2::Http2::new(h, io, peer, Bytes::new()))) } + h2::Http2::new(settings, io, peer, Bytes::new()))) } } else { HttpChannel { node: None, - proto: Some(HttpProtocol::H1( - h1::Http1::new(h, io, peer))) } + proto: Some(HttpProtocol::H1(h1::Http1::new(settings, io, peer))) } } } @@ -58,14 +57,6 @@ impl HttpChannel } } -/*impl Drop for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static -{ - fn drop(&mut self) { - println!("Drop http channel"); - } -}*/ - impl Future for HttpChannel where T: IoStream, H: HttpHandler + 'static {