1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-08-02 21:35:42 +02:00

Compare commits

..

12 Commits

Author SHA1 Message Date
Nikolay Kim
6acb6dd4e7 set release date 2018-03-02 22:31:58 -08:00
Nikolay Kim
791a980e2d update tests 2018-03-02 22:08:56 -08:00
Nikolay Kim
c2d8abcee7 Fix disconnect on idle connections 2018-03-02 20:47:23 -08:00
Nikolay Kim
16c05f07ba make HttpRequest::match_info_mut() public 2018-03-02 20:40:08 -08:00
Nikolay Kim
2158ad29ee add Pattern::with_prefix, make it usable outside of actix 2018-03-02 20:39:22 -08:00
Nikolay Kim
feba5aeffd bump version 2018-03-02 14:31:23 -08:00
Nikolay Kim
343888017e Update CHANGES.md 2018-03-02 12:26:31 -08:00
Nikolay Kim
3a5d445b2f Merge pull request #89 from niklasf/csrf-middleware
add csrf filter middleware
2018-03-02 12:25:23 -08:00
Nikolay Kim
e60acb7607 Merge branch 'master' into csrf-middleware 2018-03-02 12:25:05 -08:00
Nikolay Kim
bebfc6c9b5 sleep for test 2018-03-02 11:32:37 -08:00
Nikolay Kim
3b2928a391 Better naming for websockets implementation 2018-03-02 11:29:55 -08:00
Niklas Fiekas
10f57dac31 add csrf filter middleware 2018-03-02 20:13:43 +01:00
24 changed files with 478 additions and 157 deletions

View File

@@ -47,7 +47,7 @@ script:
USE_SKEPTIC=1 cargo test --features=alpn USE_SKEPTIC=1 cargo test --features=alpn
else else
cargo clean cargo clean
cargo test cargo test -- --nocapture
# --features=alpn # --features=alpn
fi fi

View File

@@ -1,5 +1,16 @@
# Changes # Changes
## 0.4.2 (2018-03-02)
* Better naming for websockets implementation
* Add `Pattern::with_prefix()`, make it more usable outside of actix
* Add csrf middleware for filter for cross-site request forgery #89
* Fix disconnect on idle connections
## 0.4.1 (2018-03-01) ## 0.4.1 (2018-03-01)
* Rename `Route::p()` to `Route::filter()` * Rename `Route::p()` to `Route::filter()`

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-web" name = "actix-web"
version = "0.4.1" version = "0.4.2"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix web is a small, pragmatic, extremely fast, web framework for Rust." description = "Actix web is a small, pragmatic, extremely fast, web framework for Rust."
readme = "README.md" readme = "README.md"
@@ -81,7 +81,7 @@ openssl = { version="0.10", optional = true }
tokio-openssl = { version="0.2", optional = true } tokio-openssl = { version="0.2", optional = true }
[dependencies.actix] [dependencies.actix]
version = "0.5" version = "^0.5.1"
[dev-dependencies] [dev-dependencies]
env_logger = "0.5" env_logger = "0.5"

View File

@@ -139,7 +139,7 @@ fn main() {
// default // default
.default_resource(|r| { .default_resource(|r| {
r.method(Method::GET).f(p404); r.method(Method::GET).f(p404);
r.route().p(pred::Not(pred::Get())).f(|req| httpcodes::HTTPMethodNotAllowed); r.route().filter(pred::Not(pred::Get())).f(|req| httpcodes::HTTPMethodNotAllowed);
})) }))
.bind("127.0.0.1:8080").expect("Can not bind to 127.0.0.1:8080") .bind("127.0.0.1:8080").expect("Can not bind to 127.0.0.1:8080")

View File

@@ -36,7 +36,7 @@ impl Actor for MyWebSocket {
type Context = ws::WebsocketContext<Self, AppState>; type Context = ws::WebsocketContext<Self, AppState>;
} }
impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket { impl StreamHandler<ws::Message, ws::ProtocolError> for MyWebSocket {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
self.counter += 1; self.counter += 1;

View File

@@ -92,7 +92,7 @@ impl Handler<session::Message> for WsChatSession {
} }
/// WebSocket message handler /// WebSocket message handler
impl StreamHandler<ws::Message, ws::WsError> for WsChatSession { impl StreamHandler<ws::Message, ws::ProtocolError> for WsChatSession {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
println!("WEBSOCKET MESSAGE: {:?}", msg); println!("WEBSOCKET MESSAGE: {:?}", msg);

View File

@@ -12,7 +12,7 @@ use std::time::Duration;
use actix::*; use actix::*;
use futures::Future; use futures::Future;
use actix_web::ws::{Message, WsError, WsClient, WsClientWriter}; use actix_web::ws::{Message, ProtocolError, Client, ClientWriter};
fn main() { fn main() {
@@ -21,7 +21,7 @@ fn main() {
let sys = actix::System::new("ws-example"); let sys = actix::System::new("ws-example");
Arbiter::handle().spawn( Arbiter::handle().spawn(
WsClient::new("http://127.0.0.1:8080/ws/") Client::new("http://127.0.0.1:8080/ws/")
.connect() .connect()
.map_err(|e| { .map_err(|e| {
println!("Error: {}", e); println!("Error: {}", e);
@@ -53,7 +53,7 @@ fn main() {
} }
struct ChatClient(WsClientWriter); struct ChatClient(ClientWriter);
#[derive(Message)] #[derive(Message)]
struct ClientCommand(String); struct ClientCommand(String);
@@ -93,7 +93,7 @@ impl Handler<ClientCommand> for ChatClient {
} }
/// Handle server websocket messages /// Handle server websocket messages
impl StreamHandler<Message, WsError> for ChatClient { impl StreamHandler<Message, ProtocolError> for ChatClient {
fn handle(&mut self, msg: Message, ctx: &mut Context<Self>) { fn handle(&mut self, msg: Message, ctx: &mut Context<Self>) {
match msg { match msg {

View File

@@ -25,7 +25,7 @@ impl Actor for MyWebSocket {
} }
/// Handler for `ws::Message` /// Handler for `ws::Message`
impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket { impl StreamHandler<ws::Message, ws::ProtocolError> for MyWebSocket {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
// process websocket messages // process websocket messages

View File

@@ -22,7 +22,7 @@ impl Actor for Ws {
} }
/// Handler for ws::Message message /// Handler for ws::Message message
impl StreamHandler<ws::Message, ws::WsError> for Ws { impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {

View File

@@ -39,10 +39,8 @@ impl HttpResponseParser {
// if buf is empty parse_message will always return NotReady, let's avoid that // if buf is empty parse_message will always return NotReady, let's avoid that
let read = if buf.is_empty() { let read = if buf.is_empty() {
match utils::read_from_io(io, buf) { match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) => { Ok(Async::Ready(0)) =>
// debug!("Ignored premature client disconnection"); return Err(HttpResponseParserError::Disconnect),
return Err(HttpResponseParserError::Disconnect);
},
Ok(Async::Ready(_)) => (), Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => Ok(Async::NotReady) =>
return Ok(Async::NotReady), return Ok(Async::NotReady),
@@ -66,10 +64,12 @@ impl HttpResponseParser {
} }
if read || buf.remaining_mut() == 0 { if read || buf.remaining_mut() == 0 {
match utils::read_from_io(io, buf) { match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) => return Err(HttpResponseParserError::Disconnect), Ok(Async::Ready(0)) =>
return Err(HttpResponseParserError::Disconnect),
Ok(Async::Ready(_)) => (), Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(HttpResponseParserError::Error(err.into())), Err(err) =>
return Err(HttpResponseParserError::Error(err.into())),
} }
} else { } else {
return Ok(Async::NotReady) return Ok(Async::NotReady)
@@ -109,7 +109,8 @@ impl HttpResponseParser {
} }
} }
fn parse_message(buf: &mut BytesMut) -> Poll<(ClientResponse, Option<Decoder>), ParseError> fn parse_message(buf: &mut BytesMut)
-> Poll<(ClientResponse, Option<Decoder>), ParseError>
{ {
// Parse http message // Parse http message
let bytes_ptr = buf.as_ref().as_ptr() as usize; let bytes_ptr = buf.as_ref().as_ptr() as usize;

View File

@@ -111,7 +111,7 @@ impl Future for SendRequest {
_ => IoBody::Done, _ => IoBody::Done,
}; };
let mut pl = Box::new(Pipeline { let pl = Box::new(Pipeline {
body, conn, writer, body, conn, writer,
parser: Some(HttpResponseParser::default()), parser: Some(HttpResponseParser::default()),
parser_buf: BytesMut::new(), parser_buf: BytesMut::new(),

View File

@@ -365,7 +365,7 @@ impl<S> HttpRequest<S> {
/// Get mutable reference to request's Params. /// Get mutable reference to request's Params.
#[inline] #[inline]
pub(crate) fn match_info_mut(&mut self) -> &mut Params { pub fn match_info_mut(&mut self) -> &mut Params {
unsafe{ mem::transmute(&mut self.as_mut().params) } unsafe{ mem::transmute(&mut self.as_mut().params) }
} }

266
src/middleware/csrf.rs Normal file
View File

@@ -0,0 +1,266 @@
//! A filter for cross-site request forgery (CSRF).
//!
//! This middleware is stateless and [based on request
//! headers](https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)_Prevention_Cheat_Sheet#Verifying_Same_Origin_with_Standard_Headers).
//!
//! By default requests are allowed only if one of these is true:
//!
//! * The request method is safe (`GET`, `HEAD`, `OPTIONS`). It is the
//! applications responsibility to ensure these methods cannot be used to
//! execute unwanted actions. Note that upgrade requests for websockets are
//! also considered safe.
//! * The `Origin` header (added automatically by the browser) matches one
//! of the allowed origins.
//! * There is no `Origin` header but the `Referer` header matches one of
//! the allowed origins.
//!
//! Use [`CsrfFilterBuilder::allow_xhr()`](struct.CsrfFilterBuilder.html#method.allow_xhr)
//! if you want to allow requests with unsafe methods via
//! [CORS](../cors/struct.Cors.html).
//!
//! # Example
//!
//! ```
//! # extern crate actix_web;
//! # use actix_web::*;
//!
//! use actix_web::middleware::csrf;
//!
//! fn handle_post(_req: HttpRequest) -> &'static str {
//! "This action should only be triggered with requests from the same site"
//! }
//!
//! fn main() {
//! let app = Application::new()
//! .middleware(
//! csrf::CsrfFilter::build()
//! .allowed_origin("https://www.example.com")
//! .finish())
//! .resource("/", |r| {
//! r.method(Method::GET).f(|_| httpcodes::HttpOk);
//! r.method(Method::POST).f(handle_post);
//! })
//! .finish();
//! }
//! ```
//!
//! In this example the entire application is protected from CSRF.
use std::borrow::Cow;
use std::collections::HashSet;
use bytes::Bytes;
use error::{Result, ResponseError};
use http::{HeaderMap, HttpTryFrom, Uri, header};
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use httpmessage::HttpMessage;
use httpcodes::HttpForbidden;
use middleware::{Middleware, Started};
/// Potential cross-site request forgery detected.
#[derive(Debug, Fail)]
pub enum CsrfError {
/// The HTTP request header `Origin` was required but not provided.
#[fail(display="Origin header required")]
MissingOrigin,
/// The HTTP request header `Origin` could not be parsed correctly.
#[fail(display="Could not parse Origin header")]
BadOrigin,
/// The cross-site request was denied.
#[fail(display="Cross-site request denied")]
CsrDenied,
}
impl ResponseError for CsrfError {
fn error_response(&self) -> HttpResponse {
HttpForbidden.build().body(self.to_string()).unwrap()
}
}
fn uri_origin(uri: &Uri) -> Option<String> {
match (uri.scheme_part(), uri.host(), uri.port()) {
(Some(scheme), Some(host), Some(port)) => {
Some(format!("{}://{}:{}", scheme, host, port))
}
(Some(scheme), Some(host), None) => {
Some(format!("{}://{}", scheme, host))
}
_ => None
}
}
fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
headers.get(header::ORIGIN)
.map(|origin| {
origin
.to_str()
.map_err(|_| CsrfError::BadOrigin)
.map(|o| o.into())
})
.or_else(|| {
headers.get(header::REFERER)
.map(|referer| {
Uri::try_from(Bytes::from(referer.as_bytes()))
.ok()
.as_ref()
.and_then(uri_origin)
.ok_or(CsrfError::BadOrigin)
.map(|o| o.into())
})
})
}
/// A middleware that filters cross-site requests.
pub struct CsrfFilter {
origins: HashSet<String>,
allow_xhr: bool,
allow_missing_origin: bool,
}
impl CsrfFilter {
/// Start building a `CsrfFilter`.
pub fn build() -> CsrfFilterBuilder {
CsrfFilterBuilder {
cors: CsrfFilter {
origins: HashSet::new(),
allow_xhr: false,
allow_missing_origin: false,
}
}
}
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
Ok(())
} else if let Some(header) = origin(req.headers()) {
match header {
Ok(ref origin) if self.origins.contains(origin.as_ref()) => Ok(()),
Ok(_) => Err(CsrfError::CsrDenied),
Err(err) => Err(err),
}
} else if self.allow_missing_origin {
Ok(())
} else {
Err(CsrfError::MissingOrigin)
}
}
}
impl<S> Middleware<S> for CsrfFilter {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
self.validate(req)?;
Ok(Started::Done)
}
}
/// Used to build a `CsrfFilter`.
///
/// To construct a CSRF filter:
///
/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to
/// start building.
/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed
/// origins.
/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve
/// the constructed filter.
///
/// # Example
///
/// ```
/// use actix_web::middleware::csrf;
///
/// let csrf = csrf::CsrfFilter::build()
/// .allowed_origin("https://www.example.com")
/// .finish();
/// ```
pub struct CsrfFilterBuilder {
cors: CsrfFilter,
}
impl CsrfFilterBuilder {
/// Add an origin that is allowed to make requests. Will be verified
/// against the `Origin` request header.
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
self.cors.origins.insert(origin.to_owned());
self
}
/// Allow all requests with an `X-Requested-With` header.
///
/// A cross-site attacker should not be able to send requests with custom
/// headers unless a CORS policy whitelists them. Therefore it should be
/// safe to allow requests with an `X-Requested-With` header (added
/// automatically by many JavaScript libraries).
///
/// This is disabled by default, because in Safari it is possible to
/// circumvent this using redirects and Flash.
///
/// Use this method to enable more lax filtering.
pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
self.cors.allow_xhr = true;
self
}
/// Allow requests if the expected `Origin` header is missing (and
/// there is no `Referer` to fall back on).
///
/// The filter is conservative by default, but it should be safe to allow
/// missing `Origin` headers because a cross-site attacker cannot prevent
/// the browser from sending `Origin` on unsafe requests.
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
self.cors.allow_missing_origin = true;
self
}
/// Finishes building the `CsrfFilter` instance.
pub fn finish(self) -> CsrfFilter {
self.cors
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Method;
use test::TestRequest;
#[test]
fn test_safe() {
let csrf = CsrfFilter::build()
.allowed_origin("https://www.example.com")
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::HEAD)
.finish();
assert!(csrf.start(&mut req).is_ok());
}
#[test]
fn test_csrf() {
let csrf = CsrfFilter::build()
.allowed_origin("https://www.example.com")
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::POST)
.finish();
assert!(csrf.start(&mut req).is_err());
}
#[test]
fn test_referer() {
let csrf = CsrfFilter::build()
.allowed_origin("https://www.example.com")
.finish();
let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param")
.method(Method::POST)
.finish();
assert!(csrf.start(&mut req).is_ok());
}
}

View File

@@ -9,6 +9,7 @@ mod logger;
mod session; mod session;
mod defaultheaders; mod defaultheaders;
pub mod cors; pub mod cors;
pub mod csrf;
pub use self::logger::Logger; pub use self::logger::Logger;
pub use self::defaultheaders::{DefaultHeaders, DefaultHeadersBuilder}; pub use self::defaultheaders::{DefaultHeaders, DefaultHeadersBuilder};
pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage, pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage,

View File

@@ -168,7 +168,7 @@ impl Inner {
len: 0, len: 0,
err: None, err: None,
items: VecDeque::new(), items: VecDeque::new(),
need_read: false, need_read: true,
} }
} }

View File

@@ -173,7 +173,8 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
PipelineState::None => PipelineState::None =>
return Ok(Async::Ready(true)), return Ok(Async::Ready(true)),
PipelineState::Error => PipelineState::Error =>
return Err(io::Error::new(io::ErrorKind::Other, "Internal error").into()), return Err(io::Error::new(
io::ErrorKind::Other, "Internal error").into()),
_ => (), _ => (),
} }

View File

@@ -148,7 +148,12 @@ impl Pattern {
/// ///
/// Panics if path pattern is wrong. /// Panics if path pattern is wrong.
pub fn new(name: &str, path: &str) -> Self { pub fn new(name: &str, path: &str) -> Self {
let (pattern, elements, is_dynamic) = Pattern::parse(path); Pattern::with_prefix(name, path, "/")
}
/// Parse path pattern and create new `Pattern` instance with custom prefix
pub fn with_prefix(name: &str, path: &str, prefix: &str) -> Self {
let (pattern, elements, is_dynamic) = Pattern::parse(path, prefix);
let tp = if is_dynamic { let tp = if is_dynamic {
let re = match Regex::new(&pattern) { let re = match Regex::new(&pattern) {
@@ -188,7 +193,9 @@ impl Pattern {
} }
} }
pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>) -> bool { pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>)
-> bool
{
match self.tp { match self.tp {
PatternType::Static(ref s) => s == path, PatternType::Static(ref s) => s == path,
PatternType::Dynamic(ref re, ref names) => { PatternType::Dynamic(ref re, ref names) => {
@@ -236,11 +243,11 @@ impl Pattern {
Ok(path) Ok(path)
} }
fn parse(pattern: &str) -> (String, Vec<PatternElement>, bool) { fn parse(pattern: &str, prefix: &str) -> (String, Vec<PatternElement>, bool) {
const DEFAULT_PATTERN: &str = "[^/]+"; const DEFAULT_PATTERN: &str = "[^/]+";
let mut re1 = String::from("^/"); let mut re1 = String::from("^") + prefix;
let mut re2 = String::from("/"); let mut re2 = String::from(prefix);
let mut el = String::new(); let mut el = String::new();
let mut in_param = false; let mut in_param = false;
let mut in_param_pattern = false; let mut in_param_pattern = false;

View File

@@ -125,8 +125,8 @@ impl<T, H> Http1<T, H>
// TODO: refactor // TODO: refactor
pub fn poll_io(&mut self) -> Poll<bool, ()> { pub fn poll_io(&mut self) -> Poll<bool, ()> {
// read incoming data // read incoming data
let need_read = let need_read = if !self.flags.intersects(Flags::ERROR) &&
if !self.flags.contains(Flags::ERROR) && self.tasks.len() < MAX_PIPELINED_MESSAGES self.tasks.len() < MAX_PIPELINED_MESSAGES
{ {
'outer: loop { 'outer: loop {
match self.reader.parse(self.stream.get_mut(), match self.reader.parse(self.stream.get_mut(),
@@ -1413,6 +1413,10 @@ mod tests {
assert!(req.chunked().unwrap()); assert!(req.chunked().unwrap());
assert!(!req.payload().eof()); assert!(!req.payload().eof());
buf.feed_data("4\r\n1111\r\n");
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"1111");
buf.feed_data("4\r\ndata\r"); buf.feed_data("4\r\ndata\r");
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
@@ -1430,6 +1434,7 @@ mod tests {
buf.feed_data("ne\r\n0\r\n"); buf.feed_data("ne\r\n0\r\n");
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
//trailers
//buf.feed_data("test: test\r\n"); //buf.feed_data("test: test\r\n");
//not_ready!(reader.parse(&mut buf, &mut readbuf)); //not_ready!(reader.parse(&mut buf, &mut readbuf));

View File

@@ -14,6 +14,7 @@ use tokio_core::net::TcpListener;
use tokio_core::reactor::Core; use tokio_core::reactor::Core;
use net2::TcpBuilder; use net2::TcpBuilder;
use ws;
use body::Binary; use body::Binary;
use error::Error; use error::Error;
use handler::{Handler, Responder, ReplyItem}; use handler::{Handler, Responder, ReplyItem};
@@ -25,7 +26,6 @@ use payload::Payload;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use server::{HttpServer, IntoHttpHandler, ServerSettings}; use server::{HttpServer, IntoHttpHandler, ServerSettings};
use ws::{WsClient, WsClientError, WsClientReader, WsClientWriter};
use client::{ClientRequest, ClientRequestBuilder}; use client::{ClientRequest, ClientRequestBuilder};
/// The `TestServer` type. /// The `TestServer` type.
@@ -180,9 +180,9 @@ impl TestServer {
} }
/// Connect to websocket server /// Connect to websocket server
pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> { pub fn ws(&mut self) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> {
let url = self.url("/"); let url = self.url("/");
self.system.run_until_complete(WsClient::new(url).connect()) self.system.run_until_complete(ws::Client::new(url).connect())
} }
/// Create `GET` request /// Create `GET` request

View File

@@ -25,14 +25,32 @@ use client::{ClientRequest, ClientRequestBuilder, ClientResponse,
ClientConnector, SendRequest, SendRequestError, ClientConnector, SendRequest, SendRequestError,
HttpResponseParserError}; HttpResponseParserError};
use super::{Message, WsError}; use super::{Message, ProtocolError};
use super::frame::Frame; use super::frame::Frame;
use super::proto::{CloseCode, OpCode}; use super::proto::{CloseCode, OpCode};
/// Backward compatibility
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::Client` instead")]
pub type WsClient = Client;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientError` instead")]
pub type WsClientError = ClientError;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientReader` instead")]
pub type WsClientReader = ClientReader;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientWriter` instead")]
pub type WsClientWriter = ClientWriter;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientHandshake` instead")]
pub type WsClientHandshake = ClientHandshake;
/// Websocket client error /// Websocket client error
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum WsClientError { pub enum ClientError {
#[fail(display="Invalid url")] #[fail(display="Invalid url")]
InvalidUrl, InvalidUrl,
#[fail(display="Invalid response status")] #[fail(display="Invalid response status")]
@@ -56,46 +74,46 @@ pub enum WsClientError {
#[fail(display="{}", _0)] #[fail(display="{}", _0)]
SendRequest(SendRequestError), SendRequest(SendRequestError),
#[fail(display="{}", _0)] #[fail(display="{}", _0)]
Protocol(#[cause] WsError), Protocol(#[cause] ProtocolError),
#[fail(display="{}", _0)] #[fail(display="{}", _0)]
Io(io::Error), Io(io::Error),
#[fail(display="Disconnected")] #[fail(display="Disconnected")]
Disconnected, Disconnected,
} }
impl From<HttpError> for WsClientError { impl From<HttpError> for ClientError {
fn from(err: HttpError) -> WsClientError { fn from(err: HttpError) -> ClientError {
WsClientError::Http(err) ClientError::Http(err)
} }
} }
impl From<UrlParseError> for WsClientError { impl From<UrlParseError> for ClientError {
fn from(err: UrlParseError) -> WsClientError { fn from(err: UrlParseError) -> ClientError {
WsClientError::Url(err) ClientError::Url(err)
} }
} }
impl From<SendRequestError> for WsClientError { impl From<SendRequestError> for ClientError {
fn from(err: SendRequestError) -> WsClientError { fn from(err: SendRequestError) -> ClientError {
WsClientError::SendRequest(err) ClientError::SendRequest(err)
} }
} }
impl From<WsError> for WsClientError { impl From<ProtocolError> for ClientError {
fn from(err: WsError) -> WsClientError { fn from(err: ProtocolError) -> ClientError {
WsClientError::Protocol(err) ClientError::Protocol(err)
} }
} }
impl From<io::Error> for WsClientError { impl From<io::Error> for ClientError {
fn from(err: io::Error) -> WsClientError { fn from(err: io::Error) -> ClientError {
WsClientError::Io(err) ClientError::Io(err)
} }
} }
impl From<HttpResponseParserError> for WsClientError { impl From<HttpResponseParserError> for ClientError {
fn from(err: HttpResponseParserError) -> WsClientError { fn from(err: HttpResponseParserError) -> ClientError {
WsClientError::ResponseParseError(err) ClientError::ResponseParseError(err)
} }
} }
@@ -104,9 +122,9 @@ impl From<HttpResponseParserError> for WsClientError {
/// Example of `WebSocket` client usage is available in /// Example of `WebSocket` client usage is available in
/// [websocket example]( /// [websocket example](
/// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24) /// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24)
pub struct WsClient { pub struct Client {
request: ClientRequestBuilder, request: ClientRequestBuilder,
err: Option<WsClientError>, err: Option<ClientError>,
http_err: Option<HttpError>, http_err: Option<HttpError>,
origin: Option<HeaderValue>, origin: Option<HeaderValue>,
protocols: Option<String>, protocols: Option<String>,
@@ -114,16 +132,16 @@ pub struct WsClient {
max_size: usize, max_size: usize,
} }
impl WsClient { impl Client {
/// Create new websocket connection /// Create new websocket connection
pub fn new<S: AsRef<str>>(uri: S) -> WsClient { pub fn new<S: AsRef<str>>(uri: S) -> Client {
WsClient::with_connector(uri, ClientConnector::from_registry()) Client::with_connector(uri, ClientConnector::from_registry())
} }
/// Create new websocket connection with custom `ClientConnector` /// Create new websocket connection with custom `ClientConnector`
pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> WsClient { pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> Client {
let mut cl = WsClient { let mut cl = Client {
request: ClientRequest::build(), request: ClientRequest::build(),
err: None, err: None,
http_err: None, http_err: None,
@@ -182,12 +200,12 @@ impl WsClient {
} }
/// Connect to websocket server and do ws handshake /// Connect to websocket server and do ws handshake
pub fn connect(&mut self) -> WsClientHandshake { pub fn connect(&mut self) -> ClientHandshake {
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
WsClientHandshake::error(e) ClientHandshake::error(e)
} }
else if let Some(e) = self.http_err.take() { else if let Some(e) = self.http_err.take() {
WsClientHandshake::error(e.into()) ClientHandshake::error(e.into())
} else { } else {
// origin // origin
if let Some(origin) = self.origin.take() { if let Some(origin) = self.origin.take() {
@@ -205,42 +223,42 @@ impl WsClient {
} }
let request = match self.request.finish() { let request = match self.request.finish() {
Ok(req) => req, Ok(req) => req,
Err(err) => return WsClientHandshake::error(err.into()), Err(err) => return ClientHandshake::error(err.into()),
}; };
if request.uri().host().is_none() { if request.uri().host().is_none() {
return WsClientHandshake::error(WsClientError::InvalidUrl) return ClientHandshake::error(ClientError::InvalidUrl)
} }
if let Some(scheme) = request.uri().scheme_part() { if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return WsClientHandshake::error(WsClientError::InvalidUrl) return ClientHandshake::error(ClientError::InvalidUrl)
} }
} else { } else {
return WsClientHandshake::error(WsClientError::InvalidUrl) return ClientHandshake::error(ClientError::InvalidUrl)
} }
// start handshake // start handshake
WsClientHandshake::new(request, self.max_size) ClientHandshake::new(request, self.max_size)
} }
} }
} }
struct WsInner { struct Inner {
tx: UnboundedSender<Bytes>, tx: UnboundedSender<Bytes>,
rx: PayloadHelper<ClientResponse>, rx: PayloadHelper<ClientResponse>,
closed: bool, closed: bool,
} }
pub struct WsClientHandshake { pub struct ClientHandshake {
request: Option<SendRequest>, request: Option<SendRequest>,
tx: Option<UnboundedSender<Bytes>>, tx: Option<UnboundedSender<Bytes>>,
key: String, key: String,
error: Option<WsClientError>, error: Option<ClientError>,
max_size: usize, max_size: usize,
} }
impl WsClientHandshake { impl ClientHandshake {
fn new(mut request: ClientRequest, max_size: usize) -> WsClientHandshake fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake
{ {
// Generate a random key for the `Sec-WebSocket-Key` header. // Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that, // a base64-encoded (see Section 4 of [RFC4648]) value that,
@@ -257,7 +275,7 @@ impl WsClientHandshake {
Box::new(rx.map_err(|_| io::Error::new( Box::new(rx.map_err(|_| io::Error::new(
io::ErrorKind::Other, "disconnected").into())))); io::ErrorKind::Other, "disconnected").into()))));
WsClientHandshake { ClientHandshake {
key, key,
max_size, max_size,
request: Some(request.send()), request: Some(request.send()),
@@ -266,8 +284,8 @@ impl WsClientHandshake {
} }
} }
fn error(err: WsClientError) -> WsClientHandshake { fn error(err: ClientError) -> ClientHandshake {
WsClientHandshake { ClientHandshake {
key: String::new(), key: String::new(),
request: None, request: None,
tx: None, tx: None,
@@ -277,9 +295,9 @@ impl WsClientHandshake {
} }
} }
impl Future for WsClientHandshake { impl Future for ClientHandshake {
type Item = (WsClientReader, WsClientWriter); type Item = (ClientReader, ClientWriter);
type Error = WsClientError; type Error = ClientError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(err) = self.error.take() { if let Some(err) = self.error.take() {
@@ -296,7 +314,7 @@ impl Future for WsClientHandshake {
// verify response // verify response
if resp.status() != StatusCode::SWITCHING_PROTOCOLS { if resp.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(resp.status())) return Err(ClientError::InvalidResponseStatus(resp.status()))
} }
// Check for "UPGRADE" to websocket header // Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) {
@@ -310,22 +328,22 @@ impl Future for WsClientHandshake {
}; };
if !has_hdr { if !has_hdr {
trace!("Invalid upgrade header"); trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader) return Err(ClientError::InvalidUpgradeHeader)
} }
// Check for "CONNECTION" header // Check for "CONNECTION" header
if let Some(conn) = resp.headers().get(header::CONNECTION) { if let Some(conn) = resp.headers().get(header::CONNECTION) {
if let Ok(s) = conn.to_str() { if let Ok(s) = conn.to_str() {
if !s.to_lowercase().contains("upgrade") { if !s.to_lowercase().contains("upgrade") {
trace!("Invalid connection header: {}", s); trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(conn.clone())) return Err(ClientError::InvalidConnectionHeader(conn.clone()))
} }
} else { } else {
trace!("Invalid connection header: {:?}", conn); trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(conn.clone())) return Err(ClientError::InvalidConnectionHeader(conn.clone()))
} }
} else { } else {
trace!("Missing connection header"); trace!("Missing connection header");
return Err(WsClientError::MissingConnectionHeader) return Err(ClientError::MissingConnectionHeader)
} }
if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT) if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT)
@@ -341,14 +359,14 @@ impl Future for WsClientHandshake {
trace!( trace!(
"Invalid challenge response: expected: {} received: {:?}", "Invalid challenge response: expected: {} received: {:?}",
encoded, key); encoded, key);
return Err(WsClientError::InvalidChallengeResponse(encoded, key.clone())); return Err(ClientError::InvalidChallengeResponse(encoded, key.clone()));
} }
} else { } else {
trace!("Missing SEC-WEBSOCKET-ACCEPT header"); trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(WsClientError::MissingWebSocketAcceptHeader) return Err(ClientError::MissingWebSocketAcceptHeader)
}; };
let inner = WsInner { let inner = Inner {
tx: self.tx.take().unwrap(), tx: self.tx.take().unwrap(),
rx: PayloadHelper::new(resp), rx: PayloadHelper::new(resp),
closed: false, closed: false,
@@ -356,33 +374,33 @@ impl Future for WsClientHandshake {
let inner = Rc::new(UnsafeCell::new(inner)); let inner = Rc::new(UnsafeCell::new(inner));
Ok(Async::Ready( Ok(Async::Ready(
(WsClientReader{inner: Rc::clone(&inner), max_size: self.max_size}, (ClientReader{inner: Rc::clone(&inner), max_size: self.max_size},
WsClientWriter{inner}))) ClientWriter{inner})))
} }
} }
pub struct WsClientReader { pub struct ClientReader {
inner: Rc<UnsafeCell<WsInner>>, inner: Rc<UnsafeCell<Inner>>,
max_size: usize, max_size: usize,
} }
impl fmt::Debug for WsClientReader { impl fmt::Debug for ClientReader {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WsClientReader()") write!(f, "ws::ClientReader()")
} }
} }
impl WsClientReader { impl ClientReader {
#[inline] #[inline]
fn as_mut(&mut self) -> &mut WsInner { fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() } unsafe{ &mut *self.inner.get() }
} }
} }
impl Stream for WsClientReader { impl Stream for ClientReader {
type Item = Message; type Item = Message;
type Error = WsError; type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let max_size = self.max_size; let max_size = self.max_size;
@@ -399,14 +417,14 @@ impl Stream for WsClientReader {
// continuation is not supported // continuation is not supported
if !finished { if !finished {
inner.closed = true; inner.closed = true;
return Err(WsError::NoContinuation) return Err(ProtocolError::NoContinuation)
} }
match opcode { match opcode {
OpCode::Continue => unimplemented!(), OpCode::Continue => unimplemented!(),
OpCode::Bad => { OpCode::Bad => {
inner.closed = true; inner.closed = true;
Err(WsError::BadOpCode) Err(ProtocolError::BadOpCode)
}, },
OpCode::Close => { OpCode::Close => {
inner.closed = true; inner.closed = true;
@@ -430,7 +448,7 @@ impl Stream for WsClientReader {
Ok(Async::Ready(Some(Message::Text(s)))), Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => { Err(_) => {
inner.closed = true; inner.closed = true;
Err(WsError::BadEncoding) Err(ProtocolError::BadEncoding)
} }
} }
} }
@@ -446,18 +464,18 @@ impl Stream for WsClientReader {
} }
} }
pub struct WsClientWriter { pub struct ClientWriter {
inner: Rc<UnsafeCell<WsInner>> inner: Rc<UnsafeCell<Inner>>
} }
impl WsClientWriter { impl ClientWriter {
#[inline] #[inline]
fn as_mut(&mut self) -> &mut WsInner { fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() } unsafe{ &mut *self.inner.get() }
} }
} }
impl WsClientWriter { impl ClientWriter {
/// Write payload /// Write payload
#[inline] #[inline]

View File

@@ -9,7 +9,7 @@ use body::Binary;
use error::{PayloadError}; use error::{PayloadError};
use payload::PayloadHelper; use payload::PayloadHelper;
use ws::WsError; use ws::ProtocolError;
use ws::proto::{OpCode, CloseCode}; use ws::proto::{OpCode, CloseCode};
use ws::mask::apply_mask; use ws::mask::apply_mask;
@@ -53,7 +53,7 @@ impl Frame {
/// Parse the input stream into a frame. /// Parse the input stream into a frame.
pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize) pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
-> Poll<Option<Frame>, WsError> -> Poll<Option<Frame>, ProtocolError>
where S: Stream<Item=Bytes, Error=PayloadError> where S: Stream<Item=Bytes, Error=PayloadError>
{ {
let mut idx = 2; let mut idx = 2;
@@ -69,9 +69,9 @@ impl Frame {
// check masking // check masking
let masked = second & 0x80 != 0; let masked = second & 0x80 != 0;
if !masked && server { if !masked && server {
return Err(WsError::UnmaskedFrame) return Err(ProtocolError::UnmaskedFrame)
} else if masked && !server { } else if masked && !server {
return Err(WsError::MaskedFrame) return Err(ProtocolError::MaskedFrame)
} }
let rsv1 = first & 0x40 != 0; let rsv1 = first & 0x40 != 0;
@@ -104,7 +104,7 @@ impl Frame {
// check for max allowed size // check for max allowed size
if length > max_size { if length > max_size {
return Err(WsError::Overflow) return Err(ProtocolError::Overflow)
} }
let mask = if server { let mask = if server {
@@ -133,13 +133,13 @@ impl Frame {
// Disallow bad opcode // Disallow bad opcode
if let OpCode::Bad = opcode { if let OpCode::Bad = opcode {
return Err(WsError::InvalidOpcode(first & 0x0F)) return Err(ProtocolError::InvalidOpcode(first & 0x0F))
} }
// control frames must have length <= 125 // control frames must have length <= 125
match opcode { match opcode {
OpCode::Ping | OpCode::Pong if length > 125 => { OpCode::Ping | OpCode::Pong if length > 125 => {
return Err(WsError::InvalidLength(length)) return Err(ProtocolError::InvalidLength(length))
} }
OpCode::Close if length > 125 => { OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
@@ -257,14 +257,14 @@ mod tests {
use super::*; use super::*;
use futures::stream::once; use futures::stream::once;
fn is_none(frm: Poll<Option<Frame>, WsError>) -> bool { fn is_none(frm: Poll<Option<Frame>, ProtocolError>) -> bool {
match frm { match frm {
Ok(Async::Ready(None)) => true, Ok(Async::Ready(None)) => true,
_ => false, _ => false,
} }
} }
fn extract(frm: Poll<Option<Frame>, WsError>) -> Frame { fn extract(frm: Poll<Option<Frame>, ProtocolError>) -> Frame {
match frm { match frm {
Ok(Async::Ready(Some(frame))) => frame, Ok(Async::Ready(Some(frame))) => frame,
_ => panic!("error"), _ => panic!("error"),
@@ -370,7 +370,7 @@ mod tests {
assert!(Frame::parse(&mut buf, true, 1).is_err()); assert!(Frame::parse(&mut buf, true, 1).is_err());
if let Err(WsError::Overflow) = Frame::parse(&mut buf, false, 0) { if let Err(ProtocolError::Overflow) = Frame::parse(&mut buf, false, 0) {
} else { } else {
panic!("error"); panic!("error");
} }

View File

@@ -67,13 +67,24 @@ use self::frame::Frame;
use self::proto::{hash_key, OpCode}; use self::proto::{hash_key, OpCode};
pub use self::proto::CloseCode; pub use self::proto::CloseCode;
pub use self::context::WebsocketContext; pub use self::context::WebsocketContext;
pub use self::client::{Client, ClientError,
ClientReader, ClientWriter, ClientHandshake};
#[allow(deprecated)]
pub use self::client::{WsClient, WsClientError, pub use self::client::{WsClient, WsClientError,
WsClientReader, WsClientWriter, WsClientHandshake}; WsClientReader, WsClientWriter, WsClientHandshake};
/// Backward compatibility
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ProtocolError` instead")]
pub type WsError = ProtocolError;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::HandshakeError` instead")]
pub type WsHandshakeError = HandshakeError;
/// Websocket errors /// Websocket errors
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum WsError { pub enum ProtocolError {
/// Received an unmasked frame from client /// Received an unmasked frame from client
#[fail(display="Received an unmasked frame from client")] #[fail(display="Received an unmasked frame from client")]
UnmaskedFrame, UnmaskedFrame,
@@ -103,17 +114,17 @@ pub enum WsError {
Payload(#[cause] PayloadError), Payload(#[cause] PayloadError),
} }
impl ResponseError for WsError {} impl ResponseError for ProtocolError {}
impl From<PayloadError> for WsError { impl From<PayloadError> for ProtocolError {
fn from(err: PayloadError) -> WsError { fn from(err: PayloadError) -> ProtocolError {
WsError::Payload(err) ProtocolError::Payload(err)
} }
} }
/// Websocket handshake errors /// Websocket handshake errors
#[derive(Fail, PartialEq, Debug)] #[derive(Fail, PartialEq, Debug)]
pub enum WsHandshakeError { pub enum HandshakeError {
/// Only get method is allowed /// Only get method is allowed
#[fail(display="Method not allowed")] #[fail(display="Method not allowed")]
GetMethodRequired, GetMethodRequired,
@@ -134,26 +145,26 @@ pub enum WsHandshakeError {
BadWebsocketKey, BadWebsocketKey,
} }
impl ResponseError for WsHandshakeError { impl ResponseError for HandshakeError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
match *self { match *self {
WsHandshakeError::GetMethodRequired => { HandshakeError::GetMethodRequired => {
HttpMethodNotAllowed HttpMethodNotAllowed
.build() .build()
.header(header::ALLOW, "GET") .header(header::ALLOW, "GET")
.finish() .finish()
.unwrap() .unwrap()
} }
WsHandshakeError::NoWebsocketUpgrade => HandshakeError::NoWebsocketUpgrade =>
HttpBadRequest.with_reason("No WebSocket UPGRADE header found"), HttpBadRequest.with_reason("No WebSocket UPGRADE header found"),
WsHandshakeError::NoConnectionUpgrade => HandshakeError::NoConnectionUpgrade =>
HttpBadRequest.with_reason("No CONNECTION upgrade"), HttpBadRequest.with_reason("No CONNECTION upgrade"),
WsHandshakeError::NoVersionHeader => HandshakeError::NoVersionHeader =>
HttpBadRequest.with_reason("Websocket version header is required"), HttpBadRequest.with_reason("Websocket version header is required"),
WsHandshakeError::UnsupportedVersion => HandshakeError::UnsupportedVersion =>
HttpBadRequest.with_reason("Unsupported version"), HttpBadRequest.with_reason("Unsupported version"),
WsHandshakeError::BadWebsocketKey => HandshakeError::BadWebsocketKey =>
HttpBadRequest.with_reason("Handshake error"), HttpBadRequest.with_reason("Handshake error"),
} }
} }
@@ -171,7 +182,7 @@ pub enum Message {
/// Do websocket handshake and start actor /// Do websocket handshake and start actor
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error> pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, WsError>, where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, ProtocolError>,
S: 'static S: 'static
{ {
let mut resp = handshake(&req)?; let mut resp = handshake(&req)?;
@@ -191,10 +202,10 @@ pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
// /// `protocols` is a sequence of known protocols. On successful handshake, // /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list // /// the returned response headers contain the first protocol in this list
// /// which the server also knows. // /// which the server also knows.
pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHandshakeError> { pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, HandshakeError> {
// WebSocket accepts only GET // WebSocket accepts only GET
if *req.method() != Method::GET { if *req.method() != Method::GET {
return Err(WsHandshakeError::GetMethodRequired) return Err(HandshakeError::GetMethodRequired)
} }
// Check for "UPGRADE" to websocket header // Check for "UPGRADE" to websocket header
@@ -208,17 +219,17 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
false false
}; };
if !has_hdr { if !has_hdr {
return Err(WsHandshakeError::NoWebsocketUpgrade) return Err(HandshakeError::NoWebsocketUpgrade)
} }
// Upgrade connection // Upgrade connection
if !req.upgrade() { if !req.upgrade() {
return Err(WsHandshakeError::NoConnectionUpgrade) return Err(HandshakeError::NoConnectionUpgrade)
} }
// check supported version // check supported version
if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) { if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
return Err(WsHandshakeError::NoVersionHeader) return Err(HandshakeError::NoVersionHeader)
} }
let supported_ver = { let supported_ver = {
if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) { if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
@@ -228,12 +239,12 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
} }
}; };
if !supported_ver { if !supported_ver {
return Err(WsHandshakeError::UnsupportedVersion) return Err(HandshakeError::UnsupportedVersion)
} }
// check client handshake for validity // check client handshake for validity
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
return Err(WsHandshakeError::BadWebsocketKey) return Err(HandshakeError::BadWebsocketKey)
} }
let key = { let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
@@ -275,7 +286,7 @@ impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
type Item = Message; type Item = Message;
type Error = WsError; type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.closed { if self.closed {
@@ -289,14 +300,14 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// continuation is not supported // continuation is not supported
if !finished { if !finished {
self.closed = true; self.closed = true;
return Err(WsError::NoContinuation) return Err(ProtocolError::NoContinuation)
} }
match opcode { match opcode {
OpCode::Continue => unimplemented!(), OpCode::Continue => unimplemented!(),
OpCode::Bad => { OpCode::Bad => {
self.closed = true; self.closed = true;
Err(WsError::BadOpCode) Err(ProtocolError::BadOpCode)
} }
OpCode::Close => { OpCode::Close => {
self.closed = true; self.closed = true;
@@ -320,7 +331,7 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Ok(Async::Ready(Some(Message::Text(s)))), Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => { Err(_) => {
self.closed = true; self.closed = true;
Err(WsError::BadEncoding) Err(ProtocolError::BadEncoding)
} }
} }
} }
@@ -346,25 +357,25 @@ mod tests {
fn test_handshake() { fn test_handshake() {
let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(), let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None); Version::HTTP_11, HeaderMap::new(), None);
assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); assert_eq!(HandshakeError::GetMethodRequired, handshake(&req).err().unwrap());
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None); Version::HTTP_11, HeaderMap::new(), None);
assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(header::UPGRADE,
header::HeaderValue::from_static("test")); header::HeaderValue::from_static("test"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None); Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(header::UPGRADE,
header::HeaderValue::from_static("websocket")); header::HeaderValue::from_static("websocket"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None); Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap()); assert_eq!(HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(header::UPGRADE,
@@ -373,7 +384,7 @@ mod tests {
header::HeaderValue::from_static("upgrade")); header::HeaderValue::from_static("upgrade"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None); Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::NoVersionHeader, handshake(&req).err().unwrap()); assert_eq!(HandshakeError::NoVersionHeader, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(header::UPGRADE,
@@ -384,7 +395,7 @@ mod tests {
header::HeaderValue::from_static("5")); header::HeaderValue::from_static("5"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None); Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::UnsupportedVersion, handshake(&req).err().unwrap()); assert_eq!(HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(header::UPGRADE,
@@ -395,7 +406,7 @@ mod tests {
header::HeaderValue::from_static("13")); header::HeaderValue::from_static("13"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None); Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::BadWebsocketKey, handshake(&req).err().unwrap()); assert_eq!(HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(header::UPGRADE,
@@ -414,17 +425,17 @@ mod tests {
#[test] #[test]
fn test_wserror_http_response() { fn test_wserror_http_response() {
let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response(); let resp: HttpResponse = HandshakeError::GetMethodRequired.error_response();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response(); let resp: HttpResponse = HandshakeError::NoWebsocketUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response(); let resp: HttpResponse = HandshakeError::NoConnectionUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response(); let resp: HttpResponse = HandshakeError::NoVersionHeader.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response(); let resp: HttpResponse = HandshakeError::UnsupportedVersion.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response(); let resp: HttpResponse = HandshakeError::BadWebsocketKey.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
} }
} }

View File

@@ -121,7 +121,7 @@ fn test_shutdown() {
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
thread::sleep(time::Duration::from_millis(100)); thread::sleep(time::Duration::from_millis(1000));
assert!(net::TcpStream::connect(addr).is_err()); assert!(net::TcpStream::connect(addr).is_err());
} }
@@ -163,7 +163,7 @@ fn test_headers() {
// read response // read response
let bytes = srv.execute(response.body()).unwrap(); let bytes = srv.execute(response.body()).unwrap();
assert_eq!(Bytes::from(bytes), Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[test] #[test]

View File

@@ -16,7 +16,7 @@ impl Actor for Ws {
type Context = ws::WebsocketContext<Self>; type Context = ws::WebsocketContext<Self>;
} }
impl StreamHandler<ws::Message, ws::WsError> for Ws { impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {