diff --git a/Cargo.toml b/Cargo.toml index fdafc6d21..6d85a8159 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,9 +30,6 @@ sha1 = "0.2" url = "1.5" route-recognizer = "0.1" -hyper = "0.11" -unicase = "2.0" - # tokio bytes = "0.4" futures = "0.1" diff --git a/src/application.rs b/src/application.rs index 797be7ccc..8df6c1482 100644 --- a/src/application.rs +++ b/src/application.rs @@ -16,6 +16,7 @@ use httpmessage::HttpRequest; pub struct Application { state: S, default: Resource, + handlers: HashMap>>, resources: HashMap>, } @@ -23,6 +24,7 @@ impl Application where S: 'static { pub(crate) fn prepare(self, prefix: String) -> Box { let mut router = Router::new(); + let mut handlers = HashMap::new(); let prefix = if prefix.ends_with('/') {prefix } else { prefix + "/" }; for (path, handler) in self.resources { @@ -30,10 +32,16 @@ impl Application where S: 'static router.add(path.as_str(), handler); } + for (path, mut handler) in self.handlers { + let path = prefix.clone() + path.trim_left_matches('/'); + handler.set_prefix(path.clone()); + handlers.insert(path, handler); + } Box::new( InnerApplication { state: Rc::new(self.state), default: self.default, + handlers: handlers, router: router } ) } @@ -46,6 +54,7 @@ impl Default for Application<()> { Application { state: (), default: Resource::default(), + handlers: HashMap::new(), resources: HashMap::new(), } } @@ -60,6 +69,7 @@ impl Application where S: 'static { Application { state: state, default: Resource::default(), + handlers: HashMap::new(), resources: HashMap::new(), } } @@ -77,6 +87,20 @@ impl Application where S: 'static { self.resources.get_mut(&path).unwrap() } + /// Add path handler + pub fn add_handler(&mut self, path: P, h: H) + where H: RouteHandler + 'static, P: ToString + { + let path = path.to_string(); + + // add resource + if self.handlers.contains_key(&path) { + panic!("Handler already registered: {:?}", path); + } + + self.handlers.insert(path, Box::new(h)); + } + /// Default resource is used if no matches route could be found. pub fn default_resource(&mut self) -> &mut Resource { &mut self.default @@ -88,6 +112,7 @@ pub(crate) struct InnerApplication { state: Rc, default: Resource, + handlers: HashMap>>, router: Router>, } @@ -98,6 +123,11 @@ impl Handler for InnerApplication { if let Ok(h) = self.router.recognize(req.path()) { h.handler.handle(req.with_params(h.params), payload, Rc::clone(&self.state)) } else { + for (prefix, handler) in &self.handlers { + if req.path().starts_with(prefix) { + return handler.handle(req, payload, Rc::clone(&self.state)) + } + } self.default.handle(req, payload, Rc::clone(&self.state)) } } diff --git a/src/context.rs b/src/context.rs index 4b03c2b57..0f50e24a6 100644 --- a/src/context.rs +++ b/src/context.rs @@ -69,7 +69,7 @@ impl AsyncContextApi for HttpContext where A: Actor + Rou impl HttpContext where A: Actor + Route { - pub(crate) fn new(state: Rc<::State>) -> HttpContext + pub fn new(state: Rc<::State>) -> HttpContext { HttpContext { act: None, diff --git a/src/dev.rs b/src/dev.rs new file mode 100644 index 000000000..c533bfd64 --- /dev/null +++ b/src/dev.rs @@ -0,0 +1,21 @@ +//! The `actix-http` prelude for library developers +//! +//! The purpose of this module is to alleviate imports of many common actix traits +//! by adding a glob import to the top of actix heavy modules: +//! +//! ``` +//! # #![allow(unused_imports)] +//! use actix_http::dev::*; +//! ``` +pub use ws; +pub use httpcodes; +pub use application::Application; +pub use httpmessage::{HttpRequest, HttpResponse, IntoHttpResponse}; +pub use payload::{Payload, PayloadItem}; +pub use router::RoutingMap; +pub use resource::{Reply, Resource}; +pub use route::{Route, RouteFactory, RouteHandler}; +pub use server::HttpServer; +pub use context::HttpContext; +pub use task::Task; +pub use route_recognizer::Params; diff --git a/src/httpmessage.rs b/src/httpmessage.rs index 754ad5e1c..ac8846be5 100644 --- a/src/httpmessage.rs +++ b/src/httpmessage.rs @@ -1,13 +1,10 @@ //! Pieces pertaining to the HTTP message protocol. use std::{io, mem}; -use std::str::FromStr; use std::convert::Into; use bytes::Bytes; -use http::{Method, StatusCode, Version, Uri}; -use hyper::header::{Header, Headers}; -use hyper::header::{Connection, ConnectionOption, - Expect, Encoding, ContentLength, TransferEncoding}; +use http::{Method, StatusCode, Version, Uri, HeaderMap}; +use http::header::{self, HeaderName, HeaderValue}; use Params; use error::Error; @@ -23,43 +20,44 @@ pub trait Message { fn version(&self) -> Version; - fn headers(&self) -> &Headers; + fn headers(&self) -> &HeaderMap; /// Checks if a connection should be kept alive. - fn should_keep_alive(&self) -> bool { - let ret = match (self.version(), self.headers().get::()) { - (Version::HTTP_10, None) => false, - (Version::HTTP_10, Some(conn)) - if !conn.contains(&ConnectionOption::KeepAlive) => false, - (Version::HTTP_11, Some(conn)) - if conn.contains(&ConnectionOption::Close) => false, - _ => true - }; - trace!("should_keep_alive(version={:?}, header={:?}) = {:?}", - self.version(), self.headers().get::(), ret); - ret + fn keep_alive(&self) -> bool { + if let Some(conn) = self.headers().get(header::CONNECTION) { + if let Ok(conn) = conn.to_str() { + if self.version() == Version::HTTP_10 && !conn.contains("keep-alive") { + false + } else if self.version() == Version::HTTP_11 && conn.contains("close") { + false + } else { + true + } + } else { + false + } + } else { + self.version() != Version::HTTP_10 + } } /// Checks if a connection is expecting a `100 Continue` before sending its body. #[inline] fn expecting_continue(&self) -> bool { - let ret = match (self.version(), self.headers().get::()) { - (Version::HTTP_11, Some(&Expect::Continue)) => true, - _ => false - }; - trace!("expecting_continue(version={:?}, header={:?}) = {:?}", - self.version(), self.headers().get::(), ret); - ret + if self.version() == Version::HTTP_11 { + if let Some(hdr) = self.headers().get(header::EXPECT) { + if let Ok(hdr) = hdr.to_str() { + return hdr.to_lowercase().contains("continue") + } + } + } + false } fn is_chunked(&self) -> Result { - if let Some(&TransferEncoding(ref encodings)) = self.headers().get() { - // https://tools.ietf.org/html/rfc7230#section-3.3.3 - // If Transfer-Encoding header is present, and 'chunked' is - // not the final encoding, and this is a Request, then it is - // mal-formed. A server should responsed with 400 Bad Request. - if encodings.last() == Some(&Encoding::Chunked) { - Ok(true) + if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { + if let Ok(s) = encodings.to_str() { + return Ok(s.to_lowercase().contains("chunked")) } else { debug!("request with transfer-encoding header, but not chunked, bad request"); Err(Error::Header) @@ -77,7 +75,7 @@ pub struct HttpRequest { version: Version, method: Method, uri: Uri, - headers: Headers, + headers: HeaderMap, params: Params, } @@ -85,7 +83,7 @@ impl Message for HttpRequest { fn version(&self) -> Version { self.version } - fn headers(&self) -> &Headers { + fn headers(&self) -> &HeaderMap { &self.headers } } @@ -93,7 +91,7 @@ impl Message for HttpRequest { impl HttpRequest { /// Construct a new Request. #[inline] - pub fn new(method: Method, uri: Uri, version: Version, headers: Headers) -> Self { + pub fn new(method: Method, uri: Uri, version: Version, headers: HeaderMap) -> Self { HttpRequest { method: method, uri: uri, @@ -113,7 +111,7 @@ impl HttpRequest { /// Read the Request headers. #[inline] - pub fn headers(&self) -> &Headers { &self.headers } + pub fn headers(&self) -> &HeaderMap { &self.headers } /// Read the Request method. #[inline] @@ -142,7 +140,7 @@ impl HttpRequest { /// Get a mutable reference to the Request headers. #[inline] - pub fn headers_mut(&mut self) -> &mut Headers { + pub fn headers_mut(&mut self) -> &mut HeaderMap { &mut self.headers } @@ -164,27 +162,13 @@ impl HttpRequest { } } - /// Is keepalive enabled by client? - pub fn keep_alive(&self) -> bool { - let ret = match (self.version(), self.headers().get::()) { - (Version::HTTP_10, None) => false, - (Version::HTTP_10, Some(conn)) - if !conn.contains(&ConnectionOption::KeepAlive) => false, - (Version::HTTP_11, Some(conn)) - if conn.contains(&ConnectionOption::Close) => false, - _ => true - }; - trace!("should_keep_alive(version={:?}, header={:?}) = {:?}", - self.version(), self.headers().get::(), ret); - ret - } - pub(crate) fn is_upgrade(&self) -> bool { - if let Some(&Connection(ref conn)) = self.headers().get() { - conn.contains(&ConnectionOption::from_str("upgrade").unwrap()) - } else { - false + if let Some(ref conn) = self.headers().get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + return s.to_lowercase().contains("upgrade") + } } + false } } @@ -225,12 +209,12 @@ pub trait IntoHttpResponse { pub struct HttpResponse { request: HttpRequest, pub version: Version, - pub headers: Headers, + pub headers: HeaderMap, pub status: StatusCode, reason: Option<&'static str>, body: Body, chunked: bool, - compression: Option, + // compression: Option, connection_type: Option, } @@ -238,7 +222,7 @@ impl Message for HttpResponse { fn version(&self) -> Version { self.version } - fn headers(&self) -> &Headers { + fn headers(&self) -> &HeaderMap { &self.headers } } @@ -256,7 +240,7 @@ impl HttpResponse { reason: None, body: body, chunked: false, - compression: None, + // compression: None, connection_type: None, } } @@ -275,13 +259,13 @@ impl HttpResponse { /// Get the headers from the response. #[inline] - pub fn headers(&self) -> &Headers { + pub fn headers(&self) -> &HeaderMap { &self.headers } /// Get a mutable reference to the headers. #[inline] - pub fn headers_mut(&mut self) -> &mut Headers { + pub fn headers_mut(&mut self) -> &mut HeaderMap { &mut self.headers } @@ -300,14 +284,14 @@ impl HttpResponse { /// Set a header and move the Response. #[inline] - pub fn set_header(mut self, header: H) -> Self { - self.headers.set(header); + pub fn set_header(mut self, name: HeaderName, value: HeaderValue) -> Self { + self.headers.insert(name, value); self } /// Set the headers. #[inline] - pub fn with_headers(mut self, headers: Headers) -> Self { + pub fn with_headers(mut self, headers: HeaderMap) -> Self { self.headers = headers; self } @@ -335,7 +319,7 @@ impl HttpResponse { if let Some(ConnectionType::KeepAlive) = self.connection_type { true } else { - self.request.should_keep_alive() + self.request.keep_alive() } } @@ -351,7 +335,7 @@ impl HttpResponse { /// Enables automatic chunked transfer encoding pub fn enable_chunked_encoding(&mut self) -> Result<(), io::Error> { - if self.headers.has::() { + if self.headers.contains_key(header::CONTENT_LENGTH) { Err(io::Error::new(io::ErrorKind::Other, "You can't enable chunked encoding when a content length is set")) } else { diff --git a/src/lib.rs b/src/lib.rs index 9df803f84..5bdcb0dba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,9 +11,6 @@ extern crate futures; extern crate tokio_core; extern crate tokio_io; extern crate tokio_proto; -#[macro_use] -extern crate hyper; -extern crate unicase; extern crate http; extern crate httparse; @@ -33,11 +30,11 @@ mod router; mod task; mod reader; mod server; - -pub mod ws; mod wsframe; mod wsproto; +pub mod ws; +pub mod dev; pub mod httpcodes; pub use application::Application; pub use httpmessage::{HttpRequest, HttpResponse, IntoHttpResponse}; diff --git a/src/reader.rs b/src/reader.rs index 9fb5e457b..c90592fc7 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,13 +1,12 @@ use std::{self, fmt, io, ptr}; use httparse; -use http::{Method, Version, Uri, HttpTryFrom}; -use bytes::{Bytes, BytesMut, BufMut}; +use http::{Method, Version, Uri, HttpTryFrom, HeaderMap}; +use http::header::{self, HeaderName, HeaderValue}; +use bytes::{BytesMut, BufMut}; use futures::{Async, Poll}; use tokio_io::AsyncRead; -use hyper::header::{Headers, ContentLength}; - use error::{Error, Result}; use decode::Decoder; use payload::{Payload, PayloadSender}; @@ -50,8 +49,7 @@ impl Reader { b'\r' | b'\n' => i += 1, _ => break, } - } - self.read_buf.split_to(i); + } self.read_buf.split_to(i); } } @@ -82,13 +80,10 @@ impl Reader { pub fn parse(&mut self, io: &mut T) -> Poll<(HttpRequest, Payload), Error> where T: AsyncRead { - - loop { match self.decode()? { Decoding::Paused => return Ok(Async::NotReady), Decoding::Ready => { - println!("decode ready"); self.payload = None; break }, @@ -117,7 +112,6 @@ impl Reader { Decoding::Paused => break, Decoding::Ready => { - println!("decoded 3"); self.payload = None; break }, @@ -238,38 +232,56 @@ pub fn parse(buf: &mut BytesMut) -> Result) } }; - let mut headers = Headers::with_capacity(headers_len); let slice = buf.split_to(len).freeze(); let path = slice.slice(path.0, path.1); // path was found to be utf8 by httparse let uri = Uri::from_shared(path).map_err(|_| Error::Uri)?; - headers.extend(HeadersAsBytesIter { - headers: headers_indices[..headers_len].iter(), - slice: slice, - }); + // convert headers + let mut headers = HeaderMap::with_capacity(headers_len); + for header in headers_indices[..headers_len].iter() { + if let Ok(name) = HeaderName::try_from(slice.slice(header.name.0, header.name.1)) { + if let Ok(value) = HeaderValue::try_from( + slice.slice(header.value.0, header.value.1)) + { + headers.insert(name, value); + } else { + return Err(Error::Header) + } + } else { + return Err(Error::Header) + } + } let msg = HttpRequest::new(method, uri, version, headers); let upgrade = msg.is_upgrade() || *msg.method() == Method::CONNECT; let chunked = msg.is_chunked()?; - if upgrade { - Ok(Some((msg, Some(Decoder::eof())))) + let decoder = if upgrade { + Some(Decoder::eof()) } // Content-Length - else if let Some(&ContentLength(len)) = msg.headers().get() { + else if let Some(ref len) = msg.headers().get(header::CONTENT_LENGTH) { if chunked { return Err(Error::Header) } - Ok(Some((msg, Some(Decoder::length(len))))) - } else if msg.headers().has::() { - debug!("illegal Content-Length: {:?}", msg.headers().get_raw("Content-Length")); - Err(Error::Header) + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + Some(Decoder::length(len)) + } else { + debug!("illegal Content-Length: {:?}", len); + return Err(Error::Header) + } + } else { + debug!("illegal Content-Length: {:?}", len); + return Err(Error::Header) + } } else if chunked { - Ok(Some((msg, Some(Decoder::chunked())))) + Some(Decoder::chunked()) } else { - Ok(Some((msg, None))) - } + None + }; + Ok(Some((msg, decoder))) } #[derive(Clone, Copy)] @@ -292,24 +304,3 @@ fn record_header_indices(bytes: &[u8], indices.value = (value_start, value_end); } } - -struct HeadersAsBytesIter<'a> { - headers: ::std::slice::Iter<'a, HeaderIndices>, - slice: Bytes, -} - -impl<'a> Iterator for HeadersAsBytesIter<'a> { - type Item = (&'a str, Bytes); - fn next(&mut self) -> Option { - self.headers.next().map(|header| { - let name = unsafe { - let bytes = ::std::slice::from_raw_parts( - self.slice.as_ref().as_ptr().offset(header.name.0 as isize), - header.name.1 - header.name.0 - ); - ::std::str::from_utf8_unchecked(bytes) - }; - (name, self.slice.slice(header.value.0, header.value.1)) - }) - } -} diff --git a/src/resource.rs b/src/resource.rs index b7cf6e4f8..9ce1f25d8 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -66,25 +66,29 @@ impl Resource where S: 'static { } /// Handler for `GET` method. - pub fn get(&mut self) -> &mut Self where A: Route + pub fn get(&mut self) -> &mut Self + where A: Actor> + Route { self.handler(Method::GET, A::factory()) } /// Handler for `POST` method. - pub fn post(&mut self) -> &mut Self where A: Route + pub fn post(&mut self) -> &mut Self + where A: Actor> + Route { self.handler(Method::POST, A::factory()) } /// Handler for `PUR` method. - pub fn put(&mut self) -> &mut Self where A: Route + pub fn put(&mut self) -> &mut Self + where A: Actor> + Route { self.handler(Method::PUT, A::factory()) } /// Handler for `METHOD` method. - pub fn delete(&mut self) -> &mut Self where A: Route + pub fn delete(&mut self) -> &mut Self + where A: Actor> + Route { self.handler(Method::DELETE, A::factory()) } @@ -104,15 +108,15 @@ impl RouteHandler for Resource { #[cfg_attr(feature="cargo-clippy", allow(large_enum_variant))] -enum ReplyItem where A: Actor> + Route { +enum ReplyItem where A: Actor + Route { Message(HttpResponse), Actor(A), } /// Represents response process. -pub struct Reply> + Route> (ReplyItem); +pub struct Reply (ReplyItem); -impl Reply where A: Actor> + Route +impl Reply where A: Actor + Route { /// Create async response pub fn stream(act: A) -> Self { @@ -129,7 +133,8 @@ impl Reply where A: Actor> + Route Reply(ReplyItem::Message(msg.response(req))) } - pub(crate) fn into(self, mut ctx: HttpContext) -> Task { + pub fn into(self, mut ctx: HttpContext) -> Task where A: Actor> + { match self.0 { ReplyItem::Message(msg) => { Task::reply(msg) diff --git a/src/route.rs b/src/route.rs index e6836415b..8e38c554a 100644 --- a/src/route.rs +++ b/src/route.rs @@ -20,11 +20,15 @@ pub enum Frame { /// Trait defines object that could be regestered as resource route pub trait RouteHandler: 'static { + /// Handle request fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task; + + /// Set route prefix + fn set_prefix(&mut self, _prefix: String) {} } /// Actors with ability to handle http requests -pub trait Route: Actor> { +pub trait Route: Actor { /// Route shared state. State is shared with all routes within same application and could be /// accessed with `HttpContext::state()` method. type State; @@ -33,7 +37,7 @@ pub trait Route: Actor> { /// result immediately with `Reply::reply` or `Reply::with`. /// Actor itself could be returned for handling streaming request/response. /// In that case `HttpContext::start` and `HttpContext::write` has to be used. - fn request(req: HttpRequest, payload: Payload, ctx: &mut HttpContext) -> Reply; + fn request(req: HttpRequest, payload: Payload, ctx: &mut Self::Context) -> Reply; /// This method creates `RouteFactory` for this actor. fn factory() -> RouteFactory { @@ -45,7 +49,7 @@ pub trait Route: Actor> { pub struct RouteFactory, S>(PhantomData); impl RouteHandler for RouteFactory - where A: Route, + where A: Actor> + Route, S: 'static { fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task diff --git a/src/task.rs b/src/task.rs index 0360b19d2..dcf276374 100644 --- a/src/task.rs +++ b/src/task.rs @@ -4,14 +4,12 @@ use std::fmt::Write; use std::collections::VecDeque; use http::{StatusCode, Version}; +use http::header::{HeaderValue, + CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE}; use bytes::BytesMut; use futures::{Async, Future, Poll, Stream}; use tokio_core::net::TcpStream; -use unicase::Ascii; -use hyper::header::{Date, Connection, ConnectionOption, - ContentType, ContentLength, Encoding, TransferEncoding}; - use date; use route::Frame; use httpmessage::{Body, HttpResponse}; @@ -100,22 +98,26 @@ impl Task { if msg.chunked() { error!("Chunked transfer is enabled but body is set to Empty"); } - msg.headers.set(ContentLength(0)); - msg.headers.remove::(); + msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); + msg.headers.remove(TRANSFER_ENCODING); self.encoder = Encoder::length(0); }, Body::Length(n) => { if msg.chunked() { error!("Chunked transfer is enabled but body with specific length is specified"); } - msg.headers.set(ContentLength(n)); - msg.headers.remove::(); + msg.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); + msg.headers.remove(TRANSFER_ENCODING); self.encoder = Encoder::length(n); }, Body::Binary(ref bytes) => { extra = bytes.len(); - msg.headers.set(ContentLength(bytes.len() as u64)); - msg.headers.remove::(); + msg.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); + msg.headers.remove(TRANSFER_ENCODING); self.encoder = Encoder::length(0); } Body::Streaming => { @@ -123,16 +125,15 @@ impl Task { if msg.version < Version::HTTP_11 { error!("Chunked transfer encoding is forbidden for {:?}", msg.version); } - msg.headers.remove::(); - msg.headers.set(TransferEncoding(vec![Encoding::Chunked])); + msg.headers.remove(CONTENT_LENGTH); + msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); self.encoder = Encoder::chunked(); } else { self.encoder = Encoder::eof(); } } Body::Upgrade => { - msg.headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(Ascii::new("upgrade".to_owned()))])); + msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); self.encoder = Encoder::eof(); } } @@ -140,10 +141,10 @@ impl Task { // keep-alive if msg.keep_alive() { if msg.version < Version::HTTP_11 { - msg.headers.set(Connection::keep_alive()); + msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive")); } } else if msg.version >= Version::HTTP_11 { - msg.headers.set(Connection::close()); + msg.headers.insert(CONNECTION, HeaderValue::from_static("close")); } // render message @@ -152,14 +153,20 @@ impl Task { if msg.version == Version::HTTP_11 && msg.status == StatusCode::OK { self.buffer.extend(b"HTTP/1.1 200 OK\r\n"); - let _ = write!(self.buffer, "{}", msg.headers); } else { - let _ = write!(self.buffer, "{:?} {}\r\n{}", msg.version, msg.status, msg.headers); + let _ = write!(self.buffer, "{:?} {}\r\n", msg.version, msg.status); + } + for (key, value) in &msg.headers { + let t: &[u8] = key.as_ref(); + self.buffer.extend(t); + self.buffer.extend(b": "); + self.buffer.extend(value.as_ref()); + self.buffer.extend(b"\r\n"); } // using http::h1::date is quite a lot faster than generating // a unique Date header each time like req/s goes up about 10% - if !msg.headers.has::() { + if !msg.headers.contains_key(DATE) { self.buffer.reserve(date::DATE_VALUE_LENGTH + 8); self.buffer.extend(b"Date: "); date::extend(&mut self.buffer); @@ -167,7 +174,7 @@ impl Task { } // default content-type - if !msg.headers.has::() { + if !msg.headers.contains_key(CONTENT_TYPE) { self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref()); } diff --git a/src/ws.rs b/src/ws.rs index 59442c301..aabebe35f 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -64,10 +64,10 @@ //! fn main() {} //! ``` use std::vec::Vec; -use http::{Method, StatusCode}; +use std::str::FromStr; +use http::{Method, StatusCode, header}; use bytes::{Bytes, BytesMut}; use futures::{Async, Poll, Stream}; -use hyper::header; use actix::Actor; @@ -81,22 +81,13 @@ use wsframe; use wsproto::*; #[doc(hidden)] -header! { - /// SEC-WEBSOCKET-ACCEPT header - (WebSocketAccept, "SEC-WEBSOCKET-ACCEPT") => [String] -} -header! { - /// SEC-WEBSOCKET-KEY header - (WebSocketKey, "SEC-WEBSOCKET-KEY") => [String] -} -header! { - /// SEC-WEBSOCKET-VERSION header - (WebSocketVersion, "SEC-WEBSOCKET-VERSION") => [String] -} -header! { - /// SEC-WEBSOCKET-PROTOCOL header - (WebSocketProtocol, "SEC-WEBSOCKET-PROTOCOL") => [String] -} +const SEC_WEBSOCKET_ACCEPT: &'static str = "SEC-WEBSOCKET-ACCEPT"; +#[doc(hidden)] +const SEC_WEBSOCKET_KEY: &'static str = "SEC-WEBSOCKET-KEY"; +#[doc(hidden)] +const SEC_WEBSOCKET_VERSION: &'static str = "SEC-WEBSOCKET-VERSION"; +// #[doc(hidden)] +// const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL"; /// `WebSocket` Message @@ -126,8 +117,12 @@ pub fn handshake(req: HttpRequest) -> Result { } // Check for "UPGRADE" to websocket header - let has_hdr = if let Some::<&header::Upgrade>(hdr) = req.headers().get() { - hdr.0.contains(&header::Protocol::new(header::ProtocolName::WebSocket, None)) + let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_lowercase().contains("websocket") + } else { + false + } } else { false }; @@ -141,14 +136,14 @@ pub fn handshake(req: HttpRequest) -> Result { } // check supported version - if !req.headers().has::() { + if !req.headers().contains_key(SEC_WEBSOCKET_VERSION) { return Err(HTTPBadRequest.with_reason(req, "No websocket version header is required")) } let supported_ver = { - let hdr = req.headers().get::().unwrap(); - match hdr.0.as_str() { - "13" | "8" | "7" => true, - _ => false, + if let Some(hdr) = req.headers().get(SEC_WEBSOCKET_VERSION) { + hdr == "13" || hdr == "8" || hdr == "7" + } else { + false } }; if !supported_ver { @@ -156,25 +151,20 @@ pub fn handshake(req: HttpRequest) -> Result { } // check client handshake for validity - let key = if let Some::<&WebSocketKey>(hdr) = req.headers().get() { - Some(hash_key(hdr.0.as_bytes())) - } else { - None - }; - let key = if let Some(key) = key { - key - } else { + if !req.headers().contains_key(SEC_WEBSOCKET_KEY) { return Err(HTTPBadRequest.with_reason(req, "Handshake error")); + } + let key = { + let key = req.headers().get(SEC_WEBSOCKET_KEY).unwrap(); + hash_key(key.as_ref()) }; Ok(HttpResponse::new(req, StatusCode::SWITCHING_PROTOCOLS, Body::Empty) .set_connection_type(ConnectionType::Upgrade) - .set_header( - header::Upgrade(vec![header::Protocol::new(header::ProtocolName::WebSocket, None)])) - .set_header( - header::TransferEncoding(vec![header::Encoding::Chunked])) - .set_header( - WebSocketAccept(key)) + .set_header(header::UPGRADE, header::HeaderValue::from_static("websocket")) + .set_header(header::TRANSFER_ENCODING, header::HeaderValue::from_static("chunked")) + .set_header(header::HeaderName::from_str(SEC_WEBSOCKET_ACCEPT).unwrap(), + header::HeaderValue::from_str(key.as_str()).unwrap()) .set_body(Body::Upgrade) ) }