mirror of
https://github.com/fafhrd91/actix-web
synced 2024-11-30 18:44:35 +01:00
fix cors allowed header validation
This commit is contained in:
parent
11342e4566
commit
dab918261c
@ -109,7 +109,7 @@ pub enum CorsBuilderError {
|
||||
impl ResponseError for CorsError {
|
||||
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
HTTPBadRequest.into()
|
||||
HTTPBadRequest.build().body(format!("{}", self)).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@ -159,7 +159,7 @@ impl<T> AllOrSome<T> {
|
||||
/// for responses to be generated.
|
||||
pub struct Cors {
|
||||
methods: HashSet<Method>,
|
||||
origins: AllOrSome<HashSet<Uri>>,
|
||||
origins: AllOrSome<HashSet<String>>,
|
||||
origins_str: Option<HeaderValue>,
|
||||
headers: AllOrSome<HashSet<HeaderName>>,
|
||||
expose_hdrs: Option<String>,
|
||||
@ -225,17 +225,15 @@ impl Cors {
|
||||
fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
|
||||
if let Some(hdr) = req.headers().get(header::ORIGIN) {
|
||||
if let Ok(origin) = hdr.to_str() {
|
||||
if let Ok(uri) = Uri::try_from(origin) {
|
||||
return match self.origins {
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_origins) => {
|
||||
allowed_origins
|
||||
.get(&uri)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| CorsError::OriginNotAllowed)
|
||||
}
|
||||
};
|
||||
}
|
||||
return match self.origins {
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_origins) => {
|
||||
allowed_origins
|
||||
.get(origin)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| CorsError::OriginNotAllowed)
|
||||
}
|
||||
};
|
||||
}
|
||||
Err(CorsError::BadOrigin)
|
||||
} else {
|
||||
@ -262,11 +260,11 @@ impl Cors {
|
||||
}
|
||||
|
||||
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
|
||||
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
|
||||
if let Ok(headers) = hdr.to_str() {
|
||||
match self.headers {
|
||||
AllOrSome::All => return Ok(()),
|
||||
AllOrSome::Some(ref allowed_headers) => {
|
||||
match self.headers {
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_headers) => {
|
||||
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
|
||||
if let Ok(headers) = hdr.to_str() {
|
||||
let mut hdrs = HashSet::new();
|
||||
for hdr in headers.split(',') {
|
||||
match HeaderName::try_from(hdr.trim()) {
|
||||
@ -280,11 +278,11 @@ impl Cors {
|
||||
}
|
||||
return Ok(())
|
||||
}
|
||||
Err(CorsError::BadRequestHeaders)
|
||||
} else {
|
||||
Err(CorsError::MissingRequestHeaders)
|
||||
}
|
||||
}
|
||||
Err(CorsError::BadRequestHeaders)
|
||||
} else {
|
||||
Err(CorsError::MissingRequestHeaders)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -437,17 +435,15 @@ impl CorsBuilder {
|
||||
///
|
||||
/// Defaults to `All`.
|
||||
/// ```
|
||||
pub fn allowed_origin<U>(&mut self, origin: U) -> &mut CorsBuilder
|
||||
where Uri: HttpTryFrom<U>
|
||||
{
|
||||
pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder {
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
match Uri::try_from(origin) {
|
||||
Ok(uri) => {
|
||||
Ok(_) => {
|
||||
if cors.origins.is_all() {
|
||||
cors.origins = AllOrSome::Some(HashSet::new());
|
||||
}
|
||||
if let AllOrSome::Some(ref mut origins) = cors.origins {
|
||||
origins.insert(uri);
|
||||
origins.insert(origin.to_owned());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
@ -26,20 +26,19 @@ pub struct HttpChannel<T, H>
|
||||
impl<T, H> HttpChannel<T, H>
|
||||
where T: IoStream, H: HttpHandler + 'static
|
||||
{
|
||||
pub(crate) fn new(h: Rc<WorkerSettings<H>>,
|
||||
pub(crate) fn new(settings: Rc<WorkerSettings<H>>,
|
||||
io: T, peer: Option<SocketAddr>, http2: bool) -> HttpChannel<T, H>
|
||||
{
|
||||
h.add_channel();
|
||||
settings.add_channel();
|
||||
if http2 {
|
||||
HttpChannel {
|
||||
node: None,
|
||||
proto: Some(HttpProtocol::H2(
|
||||
h2::Http2::new(h, io, peer, Bytes::new()))) }
|
||||
h2::Http2::new(settings, io, peer, Bytes::new()))) }
|
||||
} else {
|
||||
HttpChannel {
|
||||
node: None,
|
||||
proto: Some(HttpProtocol::H1(
|
||||
h1::Http1::new(h, io, peer))) }
|
||||
proto: Some(HttpProtocol::H1(h1::Http1::new(settings, io, peer))) }
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,14 +57,6 @@ impl<T, H> HttpChannel<T, H>
|
||||
}
|
||||
}
|
||||
|
||||
/*impl<T, H> Drop for HttpChannel<T, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
println!("Drop http channel");
|
||||
}
|
||||
}*/
|
||||
|
||||
impl<T, H> Future for HttpChannel<T, H>
|
||||
where T: IoStream, H: HttpHandler + 'static
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user