diff --git a/src/message.rs b/src/message.rs index f9dfe973..bf465ee6 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,7 +3,7 @@ use std::collections::VecDeque; use std::rc::Rc; use crate::extensions::Extensions; -use crate::http::{HeaderMap, Method, StatusCode, Uri, Version}; +use crate::http::{header, HeaderMap, Method, StatusCode, Uri, Version}; /// Represents various types of connection #[derive(Copy, Clone, PartialEq, Debug)] @@ -33,7 +33,15 @@ pub trait Head: Default + 'static { fn set_connection_type(&mut self, ctype: ConnectionType); fn upgrade(&self) -> bool { - self.connection_type() == ConnectionType::Upgrade + if let Some(hdr) = self.headers().get(header::CONNECTION) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("upgrade") + } else { + false + } + } else { + false + } } /// Check if keep-alive is enabled diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 3d3f5b92..88fabde9 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -132,7 +132,7 @@ pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> { // Check for "UPGRADE" to websocket header let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { if let Ok(s) = hdr.to_str() { - s.to_lowercase().contains("websocket") + s.to_ascii_lowercase().contains("websocket") } else { false }