mirror of
https://github.com/fafhrd91/actix-web
synced 2025-08-02 21:35:42 +02:00
Compare commits
12 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
6acb6dd4e7 | ||
|
791a980e2d | ||
|
c2d8abcee7 | ||
|
16c05f07ba | ||
|
2158ad29ee | ||
|
feba5aeffd | ||
|
343888017e | ||
|
3a5d445b2f | ||
|
e60acb7607 | ||
|
bebfc6c9b5 | ||
|
3b2928a391 | ||
|
10f57dac31 |
@@ -47,7 +47,7 @@ script:
|
|||||||
USE_SKEPTIC=1 cargo test --features=alpn
|
USE_SKEPTIC=1 cargo test --features=alpn
|
||||||
else
|
else
|
||||||
cargo clean
|
cargo clean
|
||||||
cargo test
|
cargo test -- --nocapture
|
||||||
# --features=alpn
|
# --features=alpn
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
11
CHANGES.md
11
CHANGES.md
@@ -1,5 +1,16 @@
|
|||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
|
## 0.4.2 (2018-03-02)
|
||||||
|
|
||||||
|
* Better naming for websockets implementation
|
||||||
|
|
||||||
|
* Add `Pattern::with_prefix()`, make it more usable outside of actix
|
||||||
|
|
||||||
|
* Add csrf middleware for filter for cross-site request forgery #89
|
||||||
|
|
||||||
|
* Fix disconnect on idle connections
|
||||||
|
|
||||||
|
|
||||||
## 0.4.1 (2018-03-01)
|
## 0.4.1 (2018-03-01)
|
||||||
|
|
||||||
* Rename `Route::p()` to `Route::filter()`
|
* Rename `Route::p()` to `Route::filter()`
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "actix-web"
|
name = "actix-web"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
|
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
|
||||||
description = "Actix web is a small, pragmatic, extremely fast, web framework for Rust."
|
description = "Actix web is a small, pragmatic, extremely fast, web framework for Rust."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@@ -81,7 +81,7 @@ openssl = { version="0.10", optional = true }
|
|||||||
tokio-openssl = { version="0.2", optional = true }
|
tokio-openssl = { version="0.2", optional = true }
|
||||||
|
|
||||||
[dependencies.actix]
|
[dependencies.actix]
|
||||||
version = "0.5"
|
version = "^0.5.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
env_logger = "0.5"
|
env_logger = "0.5"
|
||||||
|
@@ -139,7 +139,7 @@ fn main() {
|
|||||||
// default
|
// default
|
||||||
.default_resource(|r| {
|
.default_resource(|r| {
|
||||||
r.method(Method::GET).f(p404);
|
r.method(Method::GET).f(p404);
|
||||||
r.route().p(pred::Not(pred::Get())).f(|req| httpcodes::HTTPMethodNotAllowed);
|
r.route().filter(pred::Not(pred::Get())).f(|req| httpcodes::HTTPMethodNotAllowed);
|
||||||
}))
|
}))
|
||||||
|
|
||||||
.bind("127.0.0.1:8080").expect("Can not bind to 127.0.0.1:8080")
|
.bind("127.0.0.1:8080").expect("Can not bind to 127.0.0.1:8080")
|
||||||
|
@@ -36,7 +36,7 @@ impl Actor for MyWebSocket {
|
|||||||
type Context = ws::WebsocketContext<Self, AppState>;
|
type Context = ws::WebsocketContext<Self, AppState>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket {
|
impl StreamHandler<ws::Message, ws::ProtocolError> for MyWebSocket {
|
||||||
|
|
||||||
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
||||||
self.counter += 1;
|
self.counter += 1;
|
||||||
|
@@ -92,7 +92,7 @@ impl Handler<session::Message> for WsChatSession {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// WebSocket message handler
|
/// WebSocket message handler
|
||||||
impl StreamHandler<ws::Message, ws::WsError> for WsChatSession {
|
impl StreamHandler<ws::Message, ws::ProtocolError> for WsChatSession {
|
||||||
|
|
||||||
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
||||||
println!("WEBSOCKET MESSAGE: {:?}", msg);
|
println!("WEBSOCKET MESSAGE: {:?}", msg);
|
||||||
|
@@ -12,7 +12,7 @@ use std::time::Duration;
|
|||||||
|
|
||||||
use actix::*;
|
use actix::*;
|
||||||
use futures::Future;
|
use futures::Future;
|
||||||
use actix_web::ws::{Message, WsError, WsClient, WsClientWriter};
|
use actix_web::ws::{Message, ProtocolError, Client, ClientWriter};
|
||||||
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
@@ -21,7 +21,7 @@ fn main() {
|
|||||||
let sys = actix::System::new("ws-example");
|
let sys = actix::System::new("ws-example");
|
||||||
|
|
||||||
Arbiter::handle().spawn(
|
Arbiter::handle().spawn(
|
||||||
WsClient::new("http://127.0.0.1:8080/ws/")
|
Client::new("http://127.0.0.1:8080/ws/")
|
||||||
.connect()
|
.connect()
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
println!("Error: {}", e);
|
println!("Error: {}", e);
|
||||||
@@ -53,7 +53,7 @@ fn main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
struct ChatClient(WsClientWriter);
|
struct ChatClient(ClientWriter);
|
||||||
|
|
||||||
#[derive(Message)]
|
#[derive(Message)]
|
||||||
struct ClientCommand(String);
|
struct ClientCommand(String);
|
||||||
@@ -93,7 +93,7 @@ impl Handler<ClientCommand> for ChatClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handle server websocket messages
|
/// Handle server websocket messages
|
||||||
impl StreamHandler<Message, WsError> for ChatClient {
|
impl StreamHandler<Message, ProtocolError> for ChatClient {
|
||||||
|
|
||||||
fn handle(&mut self, msg: Message, ctx: &mut Context<Self>) {
|
fn handle(&mut self, msg: Message, ctx: &mut Context<Self>) {
|
||||||
match msg {
|
match msg {
|
||||||
|
@@ -25,7 +25,7 @@ impl Actor for MyWebSocket {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handler for `ws::Message`
|
/// Handler for `ws::Message`
|
||||||
impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket {
|
impl StreamHandler<ws::Message, ws::ProtocolError> for MyWebSocket {
|
||||||
|
|
||||||
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
||||||
// process websocket messages
|
// process websocket messages
|
||||||
|
@@ -22,7 +22,7 @@ impl Actor for Ws {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handler for ws::Message message
|
/// Handler for ws::Message message
|
||||||
impl StreamHandler<ws::Message, ws::WsError> for Ws {
|
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
|
||||||
|
|
||||||
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
||||||
match msg {
|
match msg {
|
||||||
|
@@ -39,10 +39,8 @@ impl HttpResponseParser {
|
|||||||
// if buf is empty parse_message will always return NotReady, let's avoid that
|
// if buf is empty parse_message will always return NotReady, let's avoid that
|
||||||
let read = if buf.is_empty() {
|
let read = if buf.is_empty() {
|
||||||
match utils::read_from_io(io, buf) {
|
match utils::read_from_io(io, buf) {
|
||||||
Ok(Async::Ready(0)) => {
|
Ok(Async::Ready(0)) =>
|
||||||
// debug!("Ignored premature client disconnection");
|
return Err(HttpResponseParserError::Disconnect),
|
||||||
return Err(HttpResponseParserError::Disconnect);
|
|
||||||
},
|
|
||||||
Ok(Async::Ready(_)) => (),
|
Ok(Async::Ready(_)) => (),
|
||||||
Ok(Async::NotReady) =>
|
Ok(Async::NotReady) =>
|
||||||
return Ok(Async::NotReady),
|
return Ok(Async::NotReady),
|
||||||
@@ -66,10 +64,12 @@ impl HttpResponseParser {
|
|||||||
}
|
}
|
||||||
if read || buf.remaining_mut() == 0 {
|
if read || buf.remaining_mut() == 0 {
|
||||||
match utils::read_from_io(io, buf) {
|
match utils::read_from_io(io, buf) {
|
||||||
Ok(Async::Ready(0)) => return Err(HttpResponseParserError::Disconnect),
|
Ok(Async::Ready(0)) =>
|
||||||
|
return Err(HttpResponseParserError::Disconnect),
|
||||||
Ok(Async::Ready(_)) => (),
|
Ok(Async::Ready(_)) => (),
|
||||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||||
Err(err) => return Err(HttpResponseParserError::Error(err.into())),
|
Err(err) =>
|
||||||
|
return Err(HttpResponseParserError::Error(err.into())),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Ok(Async::NotReady)
|
return Ok(Async::NotReady)
|
||||||
@@ -109,7 +109,8 @@ impl HttpResponseParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_message(buf: &mut BytesMut) -> Poll<(ClientResponse, Option<Decoder>), ParseError>
|
fn parse_message(buf: &mut BytesMut)
|
||||||
|
-> Poll<(ClientResponse, Option<Decoder>), ParseError>
|
||||||
{
|
{
|
||||||
// Parse http message
|
// Parse http message
|
||||||
let bytes_ptr = buf.as_ref().as_ptr() as usize;
|
let bytes_ptr = buf.as_ref().as_ptr() as usize;
|
||||||
|
@@ -111,7 +111,7 @@ impl Future for SendRequest {
|
|||||||
_ => IoBody::Done,
|
_ => IoBody::Done,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut pl = Box::new(Pipeline {
|
let pl = Box::new(Pipeline {
|
||||||
body, conn, writer,
|
body, conn, writer,
|
||||||
parser: Some(HttpResponseParser::default()),
|
parser: Some(HttpResponseParser::default()),
|
||||||
parser_buf: BytesMut::new(),
|
parser_buf: BytesMut::new(),
|
||||||
|
@@ -365,7 +365,7 @@ impl<S> HttpRequest<S> {
|
|||||||
|
|
||||||
/// Get mutable reference to request's Params.
|
/// Get mutable reference to request's Params.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub(crate) fn match_info_mut(&mut self) -> &mut Params {
|
pub fn match_info_mut(&mut self) -> &mut Params {
|
||||||
unsafe{ mem::transmute(&mut self.as_mut().params) }
|
unsafe{ mem::transmute(&mut self.as_mut().params) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
266
src/middleware/csrf.rs
Normal file
266
src/middleware/csrf.rs
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
//! A filter for cross-site request forgery (CSRF).
|
||||||
|
//!
|
||||||
|
//! This middleware is stateless and [based on request
|
||||||
|
//! headers](https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)_Prevention_Cheat_Sheet#Verifying_Same_Origin_with_Standard_Headers).
|
||||||
|
//!
|
||||||
|
//! By default requests are allowed only if one of these is true:
|
||||||
|
//!
|
||||||
|
//! * The request method is safe (`GET`, `HEAD`, `OPTIONS`). It is the
|
||||||
|
//! applications responsibility to ensure these methods cannot be used to
|
||||||
|
//! execute unwanted actions. Note that upgrade requests for websockets are
|
||||||
|
//! also considered safe.
|
||||||
|
//! * The `Origin` header (added automatically by the browser) matches one
|
||||||
|
//! of the allowed origins.
|
||||||
|
//! * There is no `Origin` header but the `Referer` header matches one of
|
||||||
|
//! the allowed origins.
|
||||||
|
//!
|
||||||
|
//! Use [`CsrfFilterBuilder::allow_xhr()`](struct.CsrfFilterBuilder.html#method.allow_xhr)
|
||||||
|
//! if you want to allow requests with unsafe methods via
|
||||||
|
//! [CORS](../cors/struct.Cors.html).
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! # extern crate actix_web;
|
||||||
|
//! # use actix_web::*;
|
||||||
|
//!
|
||||||
|
//! use actix_web::middleware::csrf;
|
||||||
|
//!
|
||||||
|
//! fn handle_post(_req: HttpRequest) -> &'static str {
|
||||||
|
//! "This action should only be triggered with requests from the same site"
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! fn main() {
|
||||||
|
//! let app = Application::new()
|
||||||
|
//! .middleware(
|
||||||
|
//! csrf::CsrfFilter::build()
|
||||||
|
//! .allowed_origin("https://www.example.com")
|
||||||
|
//! .finish())
|
||||||
|
//! .resource("/", |r| {
|
||||||
|
//! r.method(Method::GET).f(|_| httpcodes::HttpOk);
|
||||||
|
//! r.method(Method::POST).f(handle_post);
|
||||||
|
//! })
|
||||||
|
//! .finish();
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! In this example the entire application is protected from CSRF.
|
||||||
|
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
use error::{Result, ResponseError};
|
||||||
|
use http::{HeaderMap, HttpTryFrom, Uri, header};
|
||||||
|
use httprequest::HttpRequest;
|
||||||
|
use httpresponse::HttpResponse;
|
||||||
|
use httpmessage::HttpMessage;
|
||||||
|
use httpcodes::HttpForbidden;
|
||||||
|
use middleware::{Middleware, Started};
|
||||||
|
|
||||||
|
/// Potential cross-site request forgery detected.
|
||||||
|
#[derive(Debug, Fail)]
|
||||||
|
pub enum CsrfError {
|
||||||
|
/// The HTTP request header `Origin` was required but not provided.
|
||||||
|
#[fail(display="Origin header required")]
|
||||||
|
MissingOrigin,
|
||||||
|
/// The HTTP request header `Origin` could not be parsed correctly.
|
||||||
|
#[fail(display="Could not parse Origin header")]
|
||||||
|
BadOrigin,
|
||||||
|
/// The cross-site request was denied.
|
||||||
|
#[fail(display="Cross-site request denied")]
|
||||||
|
CsrDenied,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResponseError for CsrfError {
|
||||||
|
fn error_response(&self) -> HttpResponse {
|
||||||
|
HttpForbidden.build().body(self.to_string()).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uri_origin(uri: &Uri) -> Option<String> {
|
||||||
|
match (uri.scheme_part(), uri.host(), uri.port()) {
|
||||||
|
(Some(scheme), Some(host), Some(port)) => {
|
||||||
|
Some(format!("{}://{}:{}", scheme, host, port))
|
||||||
|
}
|
||||||
|
(Some(scheme), Some(host), None) => {
|
||||||
|
Some(format!("{}://{}", scheme, host))
|
||||||
|
}
|
||||||
|
_ => None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
|
||||||
|
headers.get(header::ORIGIN)
|
||||||
|
.map(|origin| {
|
||||||
|
origin
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| CsrfError::BadOrigin)
|
||||||
|
.map(|o| o.into())
|
||||||
|
})
|
||||||
|
.or_else(|| {
|
||||||
|
headers.get(header::REFERER)
|
||||||
|
.map(|referer| {
|
||||||
|
Uri::try_from(Bytes::from(referer.as_bytes()))
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(uri_origin)
|
||||||
|
.ok_or(CsrfError::BadOrigin)
|
||||||
|
.map(|o| o.into())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A middleware that filters cross-site requests.
|
||||||
|
pub struct CsrfFilter {
|
||||||
|
origins: HashSet<String>,
|
||||||
|
allow_xhr: bool,
|
||||||
|
allow_missing_origin: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CsrfFilter {
|
||||||
|
/// Start building a `CsrfFilter`.
|
||||||
|
pub fn build() -> CsrfFilterBuilder {
|
||||||
|
CsrfFilterBuilder {
|
||||||
|
cors: CsrfFilter {
|
||||||
|
origins: HashSet::new(),
|
||||||
|
allow_xhr: false,
|
||||||
|
allow_missing_origin: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
|
||||||
|
if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
|
||||||
|
Ok(())
|
||||||
|
} else if let Some(header) = origin(req.headers()) {
|
||||||
|
match header {
|
||||||
|
Ok(ref origin) if self.origins.contains(origin.as_ref()) => Ok(()),
|
||||||
|
Ok(_) => Err(CsrfError::CsrDenied),
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
} else if self.allow_missing_origin {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(CsrfError::MissingOrigin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Middleware<S> for CsrfFilter {
|
||||||
|
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
|
||||||
|
self.validate(req)?;
|
||||||
|
Ok(Started::Done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Used to build a `CsrfFilter`.
|
||||||
|
///
|
||||||
|
/// To construct a CSRF filter:
|
||||||
|
///
|
||||||
|
/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to
|
||||||
|
/// start building.
|
||||||
|
/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed
|
||||||
|
/// origins.
|
||||||
|
/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve
|
||||||
|
/// the constructed filter.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use actix_web::middleware::csrf;
|
||||||
|
///
|
||||||
|
/// let csrf = csrf::CsrfFilter::build()
|
||||||
|
/// .allowed_origin("https://www.example.com")
|
||||||
|
/// .finish();
|
||||||
|
/// ```
|
||||||
|
pub struct CsrfFilterBuilder {
|
||||||
|
cors: CsrfFilter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CsrfFilterBuilder {
|
||||||
|
/// Add an origin that is allowed to make requests. Will be verified
|
||||||
|
/// against the `Origin` request header.
|
||||||
|
pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder {
|
||||||
|
self.cors.origins.insert(origin.to_owned());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allow all requests with an `X-Requested-With` header.
|
||||||
|
///
|
||||||
|
/// A cross-site attacker should not be able to send requests with custom
|
||||||
|
/// headers unless a CORS policy whitelists them. Therefore it should be
|
||||||
|
/// safe to allow requests with an `X-Requested-With` header (added
|
||||||
|
/// automatically by many JavaScript libraries).
|
||||||
|
///
|
||||||
|
/// This is disabled by default, because in Safari it is possible to
|
||||||
|
/// circumvent this using redirects and Flash.
|
||||||
|
///
|
||||||
|
/// Use this method to enable more lax filtering.
|
||||||
|
pub fn allow_xhr(mut self) -> CsrfFilterBuilder {
|
||||||
|
self.cors.allow_xhr = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allow requests if the expected `Origin` header is missing (and
|
||||||
|
/// there is no `Referer` to fall back on).
|
||||||
|
///
|
||||||
|
/// The filter is conservative by default, but it should be safe to allow
|
||||||
|
/// missing `Origin` headers because a cross-site attacker cannot prevent
|
||||||
|
/// the browser from sending `Origin` on unsafe requests.
|
||||||
|
pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder {
|
||||||
|
self.cors.allow_missing_origin = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Finishes building the `CsrfFilter` instance.
|
||||||
|
pub fn finish(self) -> CsrfFilter {
|
||||||
|
self.cors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use http::Method;
|
||||||
|
use test::TestRequest;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_safe() {
|
||||||
|
let csrf = CsrfFilter::build()
|
||||||
|
.allowed_origin("https://www.example.com")
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
|
||||||
|
.method(Method::HEAD)
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
assert!(csrf.start(&mut req).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_csrf() {
|
||||||
|
let csrf = CsrfFilter::build()
|
||||||
|
.allowed_origin("https://www.example.com")
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
|
||||||
|
.method(Method::POST)
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
assert!(csrf.start(&mut req).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_referer() {
|
||||||
|
let csrf = CsrfFilter::build()
|
||||||
|
.allowed_origin("https://www.example.com")
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param")
|
||||||
|
.method(Method::POST)
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
assert!(csrf.start(&mut req).is_ok());
|
||||||
|
}
|
||||||
|
}
|
@@ -9,6 +9,7 @@ mod logger;
|
|||||||
mod session;
|
mod session;
|
||||||
mod defaultheaders;
|
mod defaultheaders;
|
||||||
pub mod cors;
|
pub mod cors;
|
||||||
|
pub mod csrf;
|
||||||
pub use self::logger::Logger;
|
pub use self::logger::Logger;
|
||||||
pub use self::defaultheaders::{DefaultHeaders, DefaultHeadersBuilder};
|
pub use self::defaultheaders::{DefaultHeaders, DefaultHeadersBuilder};
|
||||||
pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage,
|
pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage,
|
||||||
|
@@ -168,7 +168,7 @@ impl Inner {
|
|||||||
len: 0,
|
len: 0,
|
||||||
err: None,
|
err: None,
|
||||||
items: VecDeque::new(),
|
items: VecDeque::new(),
|
||||||
need_read: false,
|
need_read: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -173,7 +173,8 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
|
|||||||
PipelineState::None =>
|
PipelineState::None =>
|
||||||
return Ok(Async::Ready(true)),
|
return Ok(Async::Ready(true)),
|
||||||
PipelineState::Error =>
|
PipelineState::Error =>
|
||||||
return Err(io::Error::new(io::ErrorKind::Other, "Internal error").into()),
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::Other, "Internal error").into()),
|
||||||
_ => (),
|
_ => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -148,7 +148,12 @@ impl Pattern {
|
|||||||
///
|
///
|
||||||
/// Panics if path pattern is wrong.
|
/// Panics if path pattern is wrong.
|
||||||
pub fn new(name: &str, path: &str) -> Self {
|
pub fn new(name: &str, path: &str) -> Self {
|
||||||
let (pattern, elements, is_dynamic) = Pattern::parse(path);
|
Pattern::with_prefix(name, path, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse path pattern and create new `Pattern` instance with custom prefix
|
||||||
|
pub fn with_prefix(name: &str, path: &str, prefix: &str) -> Self {
|
||||||
|
let (pattern, elements, is_dynamic) = Pattern::parse(path, prefix);
|
||||||
|
|
||||||
let tp = if is_dynamic {
|
let tp = if is_dynamic {
|
||||||
let re = match Regex::new(&pattern) {
|
let re = match Regex::new(&pattern) {
|
||||||
@@ -188,7 +193,9 @@ impl Pattern {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>) -> bool {
|
pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>)
|
||||||
|
-> bool
|
||||||
|
{
|
||||||
match self.tp {
|
match self.tp {
|
||||||
PatternType::Static(ref s) => s == path,
|
PatternType::Static(ref s) => s == path,
|
||||||
PatternType::Dynamic(ref re, ref names) => {
|
PatternType::Dynamic(ref re, ref names) => {
|
||||||
@@ -236,11 +243,11 @@ impl Pattern {
|
|||||||
Ok(path)
|
Ok(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse(pattern: &str) -> (String, Vec<PatternElement>, bool) {
|
fn parse(pattern: &str, prefix: &str) -> (String, Vec<PatternElement>, bool) {
|
||||||
const DEFAULT_PATTERN: &str = "[^/]+";
|
const DEFAULT_PATTERN: &str = "[^/]+";
|
||||||
|
|
||||||
let mut re1 = String::from("^/");
|
let mut re1 = String::from("^") + prefix;
|
||||||
let mut re2 = String::from("/");
|
let mut re2 = String::from(prefix);
|
||||||
let mut el = String::new();
|
let mut el = String::new();
|
||||||
let mut in_param = false;
|
let mut in_param = false;
|
||||||
let mut in_param_pattern = false;
|
let mut in_param_pattern = false;
|
||||||
|
@@ -125,8 +125,8 @@ impl<T, H> Http1<T, H>
|
|||||||
// TODO: refactor
|
// TODO: refactor
|
||||||
pub fn poll_io(&mut self) -> Poll<bool, ()> {
|
pub fn poll_io(&mut self) -> Poll<bool, ()> {
|
||||||
// read incoming data
|
// read incoming data
|
||||||
let need_read =
|
let need_read = if !self.flags.intersects(Flags::ERROR) &&
|
||||||
if !self.flags.contains(Flags::ERROR) && self.tasks.len() < MAX_PIPELINED_MESSAGES
|
self.tasks.len() < MAX_PIPELINED_MESSAGES
|
||||||
{
|
{
|
||||||
'outer: loop {
|
'outer: loop {
|
||||||
match self.reader.parse(self.stream.get_mut(),
|
match self.reader.parse(self.stream.get_mut(),
|
||||||
@@ -1413,6 +1413,10 @@ mod tests {
|
|||||||
assert!(req.chunked().unwrap());
|
assert!(req.chunked().unwrap());
|
||||||
assert!(!req.payload().eof());
|
assert!(!req.payload().eof());
|
||||||
|
|
||||||
|
buf.feed_data("4\r\n1111\r\n");
|
||||||
|
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
|
||||||
|
assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"1111");
|
||||||
|
|
||||||
buf.feed_data("4\r\ndata\r");
|
buf.feed_data("4\r\ndata\r");
|
||||||
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
|
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
|
||||||
|
|
||||||
@@ -1430,6 +1434,7 @@ mod tests {
|
|||||||
buf.feed_data("ne\r\n0\r\n");
|
buf.feed_data("ne\r\n0\r\n");
|
||||||
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
|
not_ready!(reader.parse(&mut buf, &mut readbuf, &settings));
|
||||||
|
|
||||||
|
//trailers
|
||||||
//buf.feed_data("test: test\r\n");
|
//buf.feed_data("test: test\r\n");
|
||||||
//not_ready!(reader.parse(&mut buf, &mut readbuf));
|
//not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||||
|
|
||||||
|
@@ -14,6 +14,7 @@ use tokio_core::net::TcpListener;
|
|||||||
use tokio_core::reactor::Core;
|
use tokio_core::reactor::Core;
|
||||||
use net2::TcpBuilder;
|
use net2::TcpBuilder;
|
||||||
|
|
||||||
|
use ws;
|
||||||
use body::Binary;
|
use body::Binary;
|
||||||
use error::Error;
|
use error::Error;
|
||||||
use handler::{Handler, Responder, ReplyItem};
|
use handler::{Handler, Responder, ReplyItem};
|
||||||
@@ -25,7 +26,6 @@ use payload::Payload;
|
|||||||
use httprequest::HttpRequest;
|
use httprequest::HttpRequest;
|
||||||
use httpresponse::HttpResponse;
|
use httpresponse::HttpResponse;
|
||||||
use server::{HttpServer, IntoHttpHandler, ServerSettings};
|
use server::{HttpServer, IntoHttpHandler, ServerSettings};
|
||||||
use ws::{WsClient, WsClientError, WsClientReader, WsClientWriter};
|
|
||||||
use client::{ClientRequest, ClientRequestBuilder};
|
use client::{ClientRequest, ClientRequestBuilder};
|
||||||
|
|
||||||
/// The `TestServer` type.
|
/// The `TestServer` type.
|
||||||
@@ -180,9 +180,9 @@ impl TestServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to websocket server
|
/// Connect to websocket server
|
||||||
pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> {
|
pub fn ws(&mut self) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> {
|
||||||
let url = self.url("/");
|
let url = self.url("/");
|
||||||
self.system.run_until_complete(WsClient::new(url).connect())
|
self.system.run_until_complete(ws::Client::new(url).connect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create `GET` request
|
/// Create `GET` request
|
||||||
|
164
src/ws/client.rs
164
src/ws/client.rs
@@ -25,14 +25,32 @@ use client::{ClientRequest, ClientRequestBuilder, ClientResponse,
|
|||||||
ClientConnector, SendRequest, SendRequestError,
|
ClientConnector, SendRequest, SendRequestError,
|
||||||
HttpResponseParserError};
|
HttpResponseParserError};
|
||||||
|
|
||||||
use super::{Message, WsError};
|
use super::{Message, ProtocolError};
|
||||||
use super::frame::Frame;
|
use super::frame::Frame;
|
||||||
use super::proto::{CloseCode, OpCode};
|
use super::proto::{CloseCode, OpCode};
|
||||||
|
|
||||||
|
|
||||||
|
/// Backward compatibility
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since="0.4.2", note="please use `ws::Client` instead")]
|
||||||
|
pub type WsClient = Client;
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since="0.4.2", note="please use `ws::ClientError` instead")]
|
||||||
|
pub type WsClientError = ClientError;
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since="0.4.2", note="please use `ws::ClientReader` instead")]
|
||||||
|
pub type WsClientReader = ClientReader;
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since="0.4.2", note="please use `ws::ClientWriter` instead")]
|
||||||
|
pub type WsClientWriter = ClientWriter;
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since="0.4.2", note="please use `ws::ClientHandshake` instead")]
|
||||||
|
pub type WsClientHandshake = ClientHandshake;
|
||||||
|
|
||||||
|
|
||||||
/// Websocket client error
|
/// Websocket client error
|
||||||
#[derive(Fail, Debug)]
|
#[derive(Fail, Debug)]
|
||||||
pub enum WsClientError {
|
pub enum ClientError {
|
||||||
#[fail(display="Invalid url")]
|
#[fail(display="Invalid url")]
|
||||||
InvalidUrl,
|
InvalidUrl,
|
||||||
#[fail(display="Invalid response status")]
|
#[fail(display="Invalid response status")]
|
||||||
@@ -56,46 +74,46 @@ pub enum WsClientError {
|
|||||||
#[fail(display="{}", _0)]
|
#[fail(display="{}", _0)]
|
||||||
SendRequest(SendRequestError),
|
SendRequest(SendRequestError),
|
||||||
#[fail(display="{}", _0)]
|
#[fail(display="{}", _0)]
|
||||||
Protocol(#[cause] WsError),
|
Protocol(#[cause] ProtocolError),
|
||||||
#[fail(display="{}", _0)]
|
#[fail(display="{}", _0)]
|
||||||
Io(io::Error),
|
Io(io::Error),
|
||||||
#[fail(display="Disconnected")]
|
#[fail(display="Disconnected")]
|
||||||
Disconnected,
|
Disconnected,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<HttpError> for WsClientError {
|
impl From<HttpError> for ClientError {
|
||||||
fn from(err: HttpError) -> WsClientError {
|
fn from(err: HttpError) -> ClientError {
|
||||||
WsClientError::Http(err)
|
ClientError::Http(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<UrlParseError> for WsClientError {
|
impl From<UrlParseError> for ClientError {
|
||||||
fn from(err: UrlParseError) -> WsClientError {
|
fn from(err: UrlParseError) -> ClientError {
|
||||||
WsClientError::Url(err)
|
ClientError::Url(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<SendRequestError> for WsClientError {
|
impl From<SendRequestError> for ClientError {
|
||||||
fn from(err: SendRequestError) -> WsClientError {
|
fn from(err: SendRequestError) -> ClientError {
|
||||||
WsClientError::SendRequest(err)
|
ClientError::SendRequest(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<WsError> for WsClientError {
|
impl From<ProtocolError> for ClientError {
|
||||||
fn from(err: WsError) -> WsClientError {
|
fn from(err: ProtocolError) -> ClientError {
|
||||||
WsClientError::Protocol(err)
|
ClientError::Protocol(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<io::Error> for WsClientError {
|
impl From<io::Error> for ClientError {
|
||||||
fn from(err: io::Error) -> WsClientError {
|
fn from(err: io::Error) -> ClientError {
|
||||||
WsClientError::Io(err)
|
ClientError::Io(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<HttpResponseParserError> for WsClientError {
|
impl From<HttpResponseParserError> for ClientError {
|
||||||
fn from(err: HttpResponseParserError) -> WsClientError {
|
fn from(err: HttpResponseParserError) -> ClientError {
|
||||||
WsClientError::ResponseParseError(err)
|
ClientError::ResponseParseError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,9 +122,9 @@ impl From<HttpResponseParserError> for WsClientError {
|
|||||||
/// Example of `WebSocket` client usage is available in
|
/// Example of `WebSocket` client usage is available in
|
||||||
/// [websocket example](
|
/// [websocket example](
|
||||||
/// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24)
|
/// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24)
|
||||||
pub struct WsClient {
|
pub struct Client {
|
||||||
request: ClientRequestBuilder,
|
request: ClientRequestBuilder,
|
||||||
err: Option<WsClientError>,
|
err: Option<ClientError>,
|
||||||
http_err: Option<HttpError>,
|
http_err: Option<HttpError>,
|
||||||
origin: Option<HeaderValue>,
|
origin: Option<HeaderValue>,
|
||||||
protocols: Option<String>,
|
protocols: Option<String>,
|
||||||
@@ -114,16 +132,16 @@ pub struct WsClient {
|
|||||||
max_size: usize,
|
max_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsClient {
|
impl Client {
|
||||||
|
|
||||||
/// Create new websocket connection
|
/// Create new websocket connection
|
||||||
pub fn new<S: AsRef<str>>(uri: S) -> WsClient {
|
pub fn new<S: AsRef<str>>(uri: S) -> Client {
|
||||||
WsClient::with_connector(uri, ClientConnector::from_registry())
|
Client::with_connector(uri, ClientConnector::from_registry())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new websocket connection with custom `ClientConnector`
|
/// Create new websocket connection with custom `ClientConnector`
|
||||||
pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> WsClient {
|
pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> Client {
|
||||||
let mut cl = WsClient {
|
let mut cl = Client {
|
||||||
request: ClientRequest::build(),
|
request: ClientRequest::build(),
|
||||||
err: None,
|
err: None,
|
||||||
http_err: None,
|
http_err: None,
|
||||||
@@ -182,12 +200,12 @@ impl WsClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to websocket server and do ws handshake
|
/// Connect to websocket server and do ws handshake
|
||||||
pub fn connect(&mut self) -> WsClientHandshake {
|
pub fn connect(&mut self) -> ClientHandshake {
|
||||||
if let Some(e) = self.err.take() {
|
if let Some(e) = self.err.take() {
|
||||||
WsClientHandshake::error(e)
|
ClientHandshake::error(e)
|
||||||
}
|
}
|
||||||
else if let Some(e) = self.http_err.take() {
|
else if let Some(e) = self.http_err.take() {
|
||||||
WsClientHandshake::error(e.into())
|
ClientHandshake::error(e.into())
|
||||||
} else {
|
} else {
|
||||||
// origin
|
// origin
|
||||||
if let Some(origin) = self.origin.take() {
|
if let Some(origin) = self.origin.take() {
|
||||||
@@ -205,42 +223,42 @@ impl WsClient {
|
|||||||
}
|
}
|
||||||
let request = match self.request.finish() {
|
let request = match self.request.finish() {
|
||||||
Ok(req) => req,
|
Ok(req) => req,
|
||||||
Err(err) => return WsClientHandshake::error(err.into()),
|
Err(err) => return ClientHandshake::error(err.into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
if request.uri().host().is_none() {
|
if request.uri().host().is_none() {
|
||||||
return WsClientHandshake::error(WsClientError::InvalidUrl)
|
return ClientHandshake::error(ClientError::InvalidUrl)
|
||||||
}
|
}
|
||||||
if let Some(scheme) = request.uri().scheme_part() {
|
if let Some(scheme) = request.uri().scheme_part() {
|
||||||
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
|
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
|
||||||
return WsClientHandshake::error(WsClientError::InvalidUrl)
|
return ClientHandshake::error(ClientError::InvalidUrl)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return WsClientHandshake::error(WsClientError::InvalidUrl)
|
return ClientHandshake::error(ClientError::InvalidUrl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// start handshake
|
// start handshake
|
||||||
WsClientHandshake::new(request, self.max_size)
|
ClientHandshake::new(request, self.max_size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WsInner {
|
struct Inner {
|
||||||
tx: UnboundedSender<Bytes>,
|
tx: UnboundedSender<Bytes>,
|
||||||
rx: PayloadHelper<ClientResponse>,
|
rx: PayloadHelper<ClientResponse>,
|
||||||
closed: bool,
|
closed: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct WsClientHandshake {
|
pub struct ClientHandshake {
|
||||||
request: Option<SendRequest>,
|
request: Option<SendRequest>,
|
||||||
tx: Option<UnboundedSender<Bytes>>,
|
tx: Option<UnboundedSender<Bytes>>,
|
||||||
key: String,
|
key: String,
|
||||||
error: Option<WsClientError>,
|
error: Option<ClientError>,
|
||||||
max_size: usize,
|
max_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsClientHandshake {
|
impl ClientHandshake {
|
||||||
fn new(mut request: ClientRequest, max_size: usize) -> WsClientHandshake
|
fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake
|
||||||
{
|
{
|
||||||
// Generate a random key for the `Sec-WebSocket-Key` header.
|
// Generate a random key for the `Sec-WebSocket-Key` header.
|
||||||
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
||||||
@@ -257,7 +275,7 @@ impl WsClientHandshake {
|
|||||||
Box::new(rx.map_err(|_| io::Error::new(
|
Box::new(rx.map_err(|_| io::Error::new(
|
||||||
io::ErrorKind::Other, "disconnected").into()))));
|
io::ErrorKind::Other, "disconnected").into()))));
|
||||||
|
|
||||||
WsClientHandshake {
|
ClientHandshake {
|
||||||
key,
|
key,
|
||||||
max_size,
|
max_size,
|
||||||
request: Some(request.send()),
|
request: Some(request.send()),
|
||||||
@@ -266,8 +284,8 @@ impl WsClientHandshake {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn error(err: WsClientError) -> WsClientHandshake {
|
fn error(err: ClientError) -> ClientHandshake {
|
||||||
WsClientHandshake {
|
ClientHandshake {
|
||||||
key: String::new(),
|
key: String::new(),
|
||||||
request: None,
|
request: None,
|
||||||
tx: None,
|
tx: None,
|
||||||
@@ -277,9 +295,9 @@ impl WsClientHandshake {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Future for WsClientHandshake {
|
impl Future for ClientHandshake {
|
||||||
type Item = (WsClientReader, WsClientWriter);
|
type Item = (ClientReader, ClientWriter);
|
||||||
type Error = WsClientError;
|
type Error = ClientError;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||||
if let Some(err) = self.error.take() {
|
if let Some(err) = self.error.take() {
|
||||||
@@ -296,7 +314,7 @@ impl Future for WsClientHandshake {
|
|||||||
|
|
||||||
// verify response
|
// verify response
|
||||||
if resp.status() != StatusCode::SWITCHING_PROTOCOLS {
|
if resp.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||||
return Err(WsClientError::InvalidResponseStatus(resp.status()))
|
return Err(ClientError::InvalidResponseStatus(resp.status()))
|
||||||
}
|
}
|
||||||
// Check for "UPGRADE" to websocket header
|
// Check for "UPGRADE" to websocket header
|
||||||
let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) {
|
let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) {
|
||||||
@@ -310,22 +328,22 @@ impl Future for WsClientHandshake {
|
|||||||
};
|
};
|
||||||
if !has_hdr {
|
if !has_hdr {
|
||||||
trace!("Invalid upgrade header");
|
trace!("Invalid upgrade header");
|
||||||
return Err(WsClientError::InvalidUpgradeHeader)
|
return Err(ClientError::InvalidUpgradeHeader)
|
||||||
}
|
}
|
||||||
// Check for "CONNECTION" header
|
// Check for "CONNECTION" header
|
||||||
if let Some(conn) = resp.headers().get(header::CONNECTION) {
|
if let Some(conn) = resp.headers().get(header::CONNECTION) {
|
||||||
if let Ok(s) = conn.to_str() {
|
if let Ok(s) = conn.to_str() {
|
||||||
if !s.to_lowercase().contains("upgrade") {
|
if !s.to_lowercase().contains("upgrade") {
|
||||||
trace!("Invalid connection header: {}", s);
|
trace!("Invalid connection header: {}", s);
|
||||||
return Err(WsClientError::InvalidConnectionHeader(conn.clone()))
|
return Err(ClientError::InvalidConnectionHeader(conn.clone()))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
trace!("Invalid connection header: {:?}", conn);
|
trace!("Invalid connection header: {:?}", conn);
|
||||||
return Err(WsClientError::InvalidConnectionHeader(conn.clone()))
|
return Err(ClientError::InvalidConnectionHeader(conn.clone()))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
trace!("Missing connection header");
|
trace!("Missing connection header");
|
||||||
return Err(WsClientError::MissingConnectionHeader)
|
return Err(ClientError::MissingConnectionHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT)
|
if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT)
|
||||||
@@ -341,14 +359,14 @@ impl Future for WsClientHandshake {
|
|||||||
trace!(
|
trace!(
|
||||||
"Invalid challenge response: expected: {} received: {:?}",
|
"Invalid challenge response: expected: {} received: {:?}",
|
||||||
encoded, key);
|
encoded, key);
|
||||||
return Err(WsClientError::InvalidChallengeResponse(encoded, key.clone()));
|
return Err(ClientError::InvalidChallengeResponse(encoded, key.clone()));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
trace!("Missing SEC-WEBSOCKET-ACCEPT header");
|
trace!("Missing SEC-WEBSOCKET-ACCEPT header");
|
||||||
return Err(WsClientError::MissingWebSocketAcceptHeader)
|
return Err(ClientError::MissingWebSocketAcceptHeader)
|
||||||
};
|
};
|
||||||
|
|
||||||
let inner = WsInner {
|
let inner = Inner {
|
||||||
tx: self.tx.take().unwrap(),
|
tx: self.tx.take().unwrap(),
|
||||||
rx: PayloadHelper::new(resp),
|
rx: PayloadHelper::new(resp),
|
||||||
closed: false,
|
closed: false,
|
||||||
@@ -356,33 +374,33 @@ impl Future for WsClientHandshake {
|
|||||||
|
|
||||||
let inner = Rc::new(UnsafeCell::new(inner));
|
let inner = Rc::new(UnsafeCell::new(inner));
|
||||||
Ok(Async::Ready(
|
Ok(Async::Ready(
|
||||||
(WsClientReader{inner: Rc::clone(&inner), max_size: self.max_size},
|
(ClientReader{inner: Rc::clone(&inner), max_size: self.max_size},
|
||||||
WsClientWriter{inner})))
|
ClientWriter{inner})))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub struct WsClientReader {
|
pub struct ClientReader {
|
||||||
inner: Rc<UnsafeCell<WsInner>>,
|
inner: Rc<UnsafeCell<Inner>>,
|
||||||
max_size: usize,
|
max_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for WsClientReader {
|
impl fmt::Debug for ClientReader {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
write!(f, "WsClientReader()")
|
write!(f, "ws::ClientReader()")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsClientReader {
|
impl ClientReader {
|
||||||
#[inline]
|
#[inline]
|
||||||
fn as_mut(&mut self) -> &mut WsInner {
|
fn as_mut(&mut self) -> &mut Inner {
|
||||||
unsafe{ &mut *self.inner.get() }
|
unsafe{ &mut *self.inner.get() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Stream for WsClientReader {
|
impl Stream for ClientReader {
|
||||||
type Item = Message;
|
type Item = Message;
|
||||||
type Error = WsError;
|
type Error = ProtocolError;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||||
let max_size = self.max_size;
|
let max_size = self.max_size;
|
||||||
@@ -399,14 +417,14 @@ impl Stream for WsClientReader {
|
|||||||
// continuation is not supported
|
// continuation is not supported
|
||||||
if !finished {
|
if !finished {
|
||||||
inner.closed = true;
|
inner.closed = true;
|
||||||
return Err(WsError::NoContinuation)
|
return Err(ProtocolError::NoContinuation)
|
||||||
}
|
}
|
||||||
|
|
||||||
match opcode {
|
match opcode {
|
||||||
OpCode::Continue => unimplemented!(),
|
OpCode::Continue => unimplemented!(),
|
||||||
OpCode::Bad => {
|
OpCode::Bad => {
|
||||||
inner.closed = true;
|
inner.closed = true;
|
||||||
Err(WsError::BadOpCode)
|
Err(ProtocolError::BadOpCode)
|
||||||
},
|
},
|
||||||
OpCode::Close => {
|
OpCode::Close => {
|
||||||
inner.closed = true;
|
inner.closed = true;
|
||||||
@@ -430,7 +448,7 @@ impl Stream for WsClientReader {
|
|||||||
Ok(Async::Ready(Some(Message::Text(s)))),
|
Ok(Async::Ready(Some(Message::Text(s)))),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
inner.closed = true;
|
inner.closed = true;
|
||||||
Err(WsError::BadEncoding)
|
Err(ProtocolError::BadEncoding)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -446,18 +464,18 @@ impl Stream for WsClientReader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct WsClientWriter {
|
pub struct ClientWriter {
|
||||||
inner: Rc<UnsafeCell<WsInner>>
|
inner: Rc<UnsafeCell<Inner>>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsClientWriter {
|
impl ClientWriter {
|
||||||
#[inline]
|
#[inline]
|
||||||
fn as_mut(&mut self) -> &mut WsInner {
|
fn as_mut(&mut self) -> &mut Inner {
|
||||||
unsafe{ &mut *self.inner.get() }
|
unsafe{ &mut *self.inner.get() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsClientWriter {
|
impl ClientWriter {
|
||||||
|
|
||||||
/// Write payload
|
/// Write payload
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@@ -9,7 +9,7 @@ use body::Binary;
|
|||||||
use error::{PayloadError};
|
use error::{PayloadError};
|
||||||
use payload::PayloadHelper;
|
use payload::PayloadHelper;
|
||||||
|
|
||||||
use ws::WsError;
|
use ws::ProtocolError;
|
||||||
use ws::proto::{OpCode, CloseCode};
|
use ws::proto::{OpCode, CloseCode};
|
||||||
use ws::mask::apply_mask;
|
use ws::mask::apply_mask;
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ impl Frame {
|
|||||||
|
|
||||||
/// Parse the input stream into a frame.
|
/// Parse the input stream into a frame.
|
||||||
pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
|
pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
|
||||||
-> Poll<Option<Frame>, WsError>
|
-> Poll<Option<Frame>, ProtocolError>
|
||||||
where S: Stream<Item=Bytes, Error=PayloadError>
|
where S: Stream<Item=Bytes, Error=PayloadError>
|
||||||
{
|
{
|
||||||
let mut idx = 2;
|
let mut idx = 2;
|
||||||
@@ -69,9 +69,9 @@ impl Frame {
|
|||||||
// check masking
|
// check masking
|
||||||
let masked = second & 0x80 != 0;
|
let masked = second & 0x80 != 0;
|
||||||
if !masked && server {
|
if !masked && server {
|
||||||
return Err(WsError::UnmaskedFrame)
|
return Err(ProtocolError::UnmaskedFrame)
|
||||||
} else if masked && !server {
|
} else if masked && !server {
|
||||||
return Err(WsError::MaskedFrame)
|
return Err(ProtocolError::MaskedFrame)
|
||||||
}
|
}
|
||||||
|
|
||||||
let rsv1 = first & 0x40 != 0;
|
let rsv1 = first & 0x40 != 0;
|
||||||
@@ -104,7 +104,7 @@ impl Frame {
|
|||||||
|
|
||||||
// check for max allowed size
|
// check for max allowed size
|
||||||
if length > max_size {
|
if length > max_size {
|
||||||
return Err(WsError::Overflow)
|
return Err(ProtocolError::Overflow)
|
||||||
}
|
}
|
||||||
|
|
||||||
let mask = if server {
|
let mask = if server {
|
||||||
@@ -133,13 +133,13 @@ impl Frame {
|
|||||||
|
|
||||||
// Disallow bad opcode
|
// Disallow bad opcode
|
||||||
if let OpCode::Bad = opcode {
|
if let OpCode::Bad = opcode {
|
||||||
return Err(WsError::InvalidOpcode(first & 0x0F))
|
return Err(ProtocolError::InvalidOpcode(first & 0x0F))
|
||||||
}
|
}
|
||||||
|
|
||||||
// control frames must have length <= 125
|
// control frames must have length <= 125
|
||||||
match opcode {
|
match opcode {
|
||||||
OpCode::Ping | OpCode::Pong if length > 125 => {
|
OpCode::Ping | OpCode::Pong if length > 125 => {
|
||||||
return Err(WsError::InvalidLength(length))
|
return Err(ProtocolError::InvalidLength(length))
|
||||||
}
|
}
|
||||||
OpCode::Close if length > 125 => {
|
OpCode::Close if length > 125 => {
|
||||||
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
|
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
|
||||||
@@ -257,14 +257,14 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use futures::stream::once;
|
use futures::stream::once;
|
||||||
|
|
||||||
fn is_none(frm: Poll<Option<Frame>, WsError>) -> bool {
|
fn is_none(frm: Poll<Option<Frame>, ProtocolError>) -> bool {
|
||||||
match frm {
|
match frm {
|
||||||
Ok(Async::Ready(None)) => true,
|
Ok(Async::Ready(None)) => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract(frm: Poll<Option<Frame>, WsError>) -> Frame {
|
fn extract(frm: Poll<Option<Frame>, ProtocolError>) -> Frame {
|
||||||
match frm {
|
match frm {
|
||||||
Ok(Async::Ready(Some(frame))) => frame,
|
Ok(Async::Ready(Some(frame))) => frame,
|
||||||
_ => panic!("error"),
|
_ => panic!("error"),
|
||||||
@@ -370,7 +370,7 @@ mod tests {
|
|||||||
|
|
||||||
assert!(Frame::parse(&mut buf, true, 1).is_err());
|
assert!(Frame::parse(&mut buf, true, 1).is_err());
|
||||||
|
|
||||||
if let Err(WsError::Overflow) = Frame::parse(&mut buf, false, 0) {
|
if let Err(ProtocolError::Overflow) = Frame::parse(&mut buf, false, 0) {
|
||||||
} else {
|
} else {
|
||||||
panic!("error");
|
panic!("error");
|
||||||
}
|
}
|
||||||
|
@@ -67,13 +67,24 @@ use self::frame::Frame;
|
|||||||
use self::proto::{hash_key, OpCode};
|
use self::proto::{hash_key, OpCode};
|
||||||
pub use self::proto::CloseCode;
|
pub use self::proto::CloseCode;
|
||||||
pub use self::context::WebsocketContext;
|
pub use self::context::WebsocketContext;
|
||||||
|
pub use self::client::{Client, ClientError,
|
||||||
|
ClientReader, ClientWriter, ClientHandshake};
|
||||||
|
|
||||||
|
#[allow(deprecated)]
|
||||||
pub use self::client::{WsClient, WsClientError,
|
pub use self::client::{WsClient, WsClientError,
|
||||||
WsClientReader, WsClientWriter, WsClientHandshake};
|
WsClientReader, WsClientWriter, WsClientHandshake};
|
||||||
|
|
||||||
|
/// Backward compatibility
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since="0.4.2", note="please use `ws::ProtocolError` instead")]
|
||||||
|
pub type WsError = ProtocolError;
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since="0.4.2", note="please use `ws::HandshakeError` instead")]
|
||||||
|
pub type WsHandshakeError = HandshakeError;
|
||||||
|
|
||||||
/// Websocket errors
|
/// Websocket errors
|
||||||
#[derive(Fail, Debug)]
|
#[derive(Fail, Debug)]
|
||||||
pub enum WsError {
|
pub enum ProtocolError {
|
||||||
/// Received an unmasked frame from client
|
/// Received an unmasked frame from client
|
||||||
#[fail(display="Received an unmasked frame from client")]
|
#[fail(display="Received an unmasked frame from client")]
|
||||||
UnmaskedFrame,
|
UnmaskedFrame,
|
||||||
@@ -103,17 +114,17 @@ pub enum WsError {
|
|||||||
Payload(#[cause] PayloadError),
|
Payload(#[cause] PayloadError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseError for WsError {}
|
impl ResponseError for ProtocolError {}
|
||||||
|
|
||||||
impl From<PayloadError> for WsError {
|
impl From<PayloadError> for ProtocolError {
|
||||||
fn from(err: PayloadError) -> WsError {
|
fn from(err: PayloadError) -> ProtocolError {
|
||||||
WsError::Payload(err)
|
ProtocolError::Payload(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Websocket handshake errors
|
/// Websocket handshake errors
|
||||||
#[derive(Fail, PartialEq, Debug)]
|
#[derive(Fail, PartialEq, Debug)]
|
||||||
pub enum WsHandshakeError {
|
pub enum HandshakeError {
|
||||||
/// Only get method is allowed
|
/// Only get method is allowed
|
||||||
#[fail(display="Method not allowed")]
|
#[fail(display="Method not allowed")]
|
||||||
GetMethodRequired,
|
GetMethodRequired,
|
||||||
@@ -134,26 +145,26 @@ pub enum WsHandshakeError {
|
|||||||
BadWebsocketKey,
|
BadWebsocketKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseError for WsHandshakeError {
|
impl ResponseError for HandshakeError {
|
||||||
|
|
||||||
fn error_response(&self) -> HttpResponse {
|
fn error_response(&self) -> HttpResponse {
|
||||||
match *self {
|
match *self {
|
||||||
WsHandshakeError::GetMethodRequired => {
|
HandshakeError::GetMethodRequired => {
|
||||||
HttpMethodNotAllowed
|
HttpMethodNotAllowed
|
||||||
.build()
|
.build()
|
||||||
.header(header::ALLOW, "GET")
|
.header(header::ALLOW, "GET")
|
||||||
.finish()
|
.finish()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
WsHandshakeError::NoWebsocketUpgrade =>
|
HandshakeError::NoWebsocketUpgrade =>
|
||||||
HttpBadRequest.with_reason("No WebSocket UPGRADE header found"),
|
HttpBadRequest.with_reason("No WebSocket UPGRADE header found"),
|
||||||
WsHandshakeError::NoConnectionUpgrade =>
|
HandshakeError::NoConnectionUpgrade =>
|
||||||
HttpBadRequest.with_reason("No CONNECTION upgrade"),
|
HttpBadRequest.with_reason("No CONNECTION upgrade"),
|
||||||
WsHandshakeError::NoVersionHeader =>
|
HandshakeError::NoVersionHeader =>
|
||||||
HttpBadRequest.with_reason("Websocket version header is required"),
|
HttpBadRequest.with_reason("Websocket version header is required"),
|
||||||
WsHandshakeError::UnsupportedVersion =>
|
HandshakeError::UnsupportedVersion =>
|
||||||
HttpBadRequest.with_reason("Unsupported version"),
|
HttpBadRequest.with_reason("Unsupported version"),
|
||||||
WsHandshakeError::BadWebsocketKey =>
|
HandshakeError::BadWebsocketKey =>
|
||||||
HttpBadRequest.with_reason("Handshake error"),
|
HttpBadRequest.with_reason("Handshake error"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,7 +182,7 @@ pub enum Message {
|
|||||||
|
|
||||||
/// Do websocket handshake and start actor
|
/// Do websocket handshake and start actor
|
||||||
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
|
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
|
||||||
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, WsError>,
|
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, ProtocolError>,
|
||||||
S: 'static
|
S: 'static
|
||||||
{
|
{
|
||||||
let mut resp = handshake(&req)?;
|
let mut resp = handshake(&req)?;
|
||||||
@@ -191,10 +202,10 @@ pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
|
|||||||
// /// `protocols` is a sequence of known protocols. On successful handshake,
|
// /// `protocols` is a sequence of known protocols. On successful handshake,
|
||||||
// /// the returned response headers contain the first protocol in this list
|
// /// the returned response headers contain the first protocol in this list
|
||||||
// /// which the server also knows.
|
// /// which the server also knows.
|
||||||
pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHandshakeError> {
|
pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, HandshakeError> {
|
||||||
// WebSocket accepts only GET
|
// WebSocket accepts only GET
|
||||||
if *req.method() != Method::GET {
|
if *req.method() != Method::GET {
|
||||||
return Err(WsHandshakeError::GetMethodRequired)
|
return Err(HandshakeError::GetMethodRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for "UPGRADE" to websocket header
|
// Check for "UPGRADE" to websocket header
|
||||||
@@ -208,17 +219,17 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
|
|||||||
false
|
false
|
||||||
};
|
};
|
||||||
if !has_hdr {
|
if !has_hdr {
|
||||||
return Err(WsHandshakeError::NoWebsocketUpgrade)
|
return Err(HandshakeError::NoWebsocketUpgrade)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upgrade connection
|
// Upgrade connection
|
||||||
if !req.upgrade() {
|
if !req.upgrade() {
|
||||||
return Err(WsHandshakeError::NoConnectionUpgrade)
|
return Err(HandshakeError::NoConnectionUpgrade)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check supported version
|
// check supported version
|
||||||
if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
|
if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
|
||||||
return Err(WsHandshakeError::NoVersionHeader)
|
return Err(HandshakeError::NoVersionHeader)
|
||||||
}
|
}
|
||||||
let supported_ver = {
|
let supported_ver = {
|
||||||
if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
|
if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
|
||||||
@@ -228,12 +239,12 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
if !supported_ver {
|
if !supported_ver {
|
||||||
return Err(WsHandshakeError::UnsupportedVersion)
|
return Err(HandshakeError::UnsupportedVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check client handshake for validity
|
// check client handshake for validity
|
||||||
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
|
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
|
||||||
return Err(WsHandshakeError::BadWebsocketKey)
|
return Err(HandshakeError::BadWebsocketKey)
|
||||||
}
|
}
|
||||||
let key = {
|
let key = {
|
||||||
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
|
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
|
||||||
@@ -275,7 +286,7 @@ impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
|||||||
|
|
||||||
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
||||||
type Item = Message;
|
type Item = Message;
|
||||||
type Error = WsError;
|
type Error = ProtocolError;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||||
if self.closed {
|
if self.closed {
|
||||||
@@ -289,14 +300,14 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
|||||||
// continuation is not supported
|
// continuation is not supported
|
||||||
if !finished {
|
if !finished {
|
||||||
self.closed = true;
|
self.closed = true;
|
||||||
return Err(WsError::NoContinuation)
|
return Err(ProtocolError::NoContinuation)
|
||||||
}
|
}
|
||||||
|
|
||||||
match opcode {
|
match opcode {
|
||||||
OpCode::Continue => unimplemented!(),
|
OpCode::Continue => unimplemented!(),
|
||||||
OpCode::Bad => {
|
OpCode::Bad => {
|
||||||
self.closed = true;
|
self.closed = true;
|
||||||
Err(WsError::BadOpCode)
|
Err(ProtocolError::BadOpCode)
|
||||||
}
|
}
|
||||||
OpCode::Close => {
|
OpCode::Close => {
|
||||||
self.closed = true;
|
self.closed = true;
|
||||||
@@ -320,7 +331,7 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
|||||||
Ok(Async::Ready(Some(Message::Text(s)))),
|
Ok(Async::Ready(Some(Message::Text(s)))),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
self.closed = true;
|
self.closed = true;
|
||||||
Err(WsError::BadEncoding)
|
Err(ProtocolError::BadEncoding)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -346,25 +357,25 @@ mod tests {
|
|||||||
fn test_handshake() {
|
fn test_handshake() {
|
||||||
let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(),
|
let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(),
|
||||||
Version::HTTP_11, HeaderMap::new(), None);
|
Version::HTTP_11, HeaderMap::new(), None);
|
||||||
assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap());
|
assert_eq!(HandshakeError::GetMethodRequired, handshake(&req).err().unwrap());
|
||||||
|
|
||||||
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
||||||
Version::HTTP_11, HeaderMap::new(), None);
|
Version::HTTP_11, HeaderMap::new(), None);
|
||||||
assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
|
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(header::UPGRADE,
|
headers.insert(header::UPGRADE,
|
||||||
header::HeaderValue::from_static("test"));
|
header::HeaderValue::from_static("test"));
|
||||||
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
||||||
Version::HTTP_11, headers, None);
|
Version::HTTP_11, headers, None);
|
||||||
assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
|
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(header::UPGRADE,
|
headers.insert(header::UPGRADE,
|
||||||
header::HeaderValue::from_static("websocket"));
|
header::HeaderValue::from_static("websocket"));
|
||||||
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
||||||
Version::HTTP_11, headers, None);
|
Version::HTTP_11, headers, None);
|
||||||
assert_eq!(WsHandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap());
|
assert_eq!(HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap());
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(header::UPGRADE,
|
headers.insert(header::UPGRADE,
|
||||||
@@ -373,7 +384,7 @@ mod tests {
|
|||||||
header::HeaderValue::from_static("upgrade"));
|
header::HeaderValue::from_static("upgrade"));
|
||||||
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
||||||
Version::HTTP_11, headers, None);
|
Version::HTTP_11, headers, None);
|
||||||
assert_eq!(WsHandshakeError::NoVersionHeader, handshake(&req).err().unwrap());
|
assert_eq!(HandshakeError::NoVersionHeader, handshake(&req).err().unwrap());
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(header::UPGRADE,
|
headers.insert(header::UPGRADE,
|
||||||
@@ -384,7 +395,7 @@ mod tests {
|
|||||||
header::HeaderValue::from_static("5"));
|
header::HeaderValue::from_static("5"));
|
||||||
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
||||||
Version::HTTP_11, headers, None);
|
Version::HTTP_11, headers, None);
|
||||||
assert_eq!(WsHandshakeError::UnsupportedVersion, handshake(&req).err().unwrap());
|
assert_eq!(HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap());
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(header::UPGRADE,
|
headers.insert(header::UPGRADE,
|
||||||
@@ -395,7 +406,7 @@ mod tests {
|
|||||||
header::HeaderValue::from_static("13"));
|
header::HeaderValue::from_static("13"));
|
||||||
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
|
||||||
Version::HTTP_11, headers, None);
|
Version::HTTP_11, headers, None);
|
||||||
assert_eq!(WsHandshakeError::BadWebsocketKey, handshake(&req).err().unwrap());
|
assert_eq!(HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap());
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(header::UPGRADE,
|
headers.insert(header::UPGRADE,
|
||||||
@@ -414,17 +425,17 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_wserror_http_response() {
|
fn test_wserror_http_response() {
|
||||||
let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response();
|
let resp: HttpResponse = HandshakeError::GetMethodRequired.error_response();
|
||||||
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
|
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
|
||||||
let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response();
|
let resp: HttpResponse = HandshakeError::NoWebsocketUpgrade.error_response();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response();
|
let resp: HttpResponse = HandshakeError::NoConnectionUpgrade.error_response();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response();
|
let resp: HttpResponse = HandshakeError::NoVersionHeader.error_response();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response();
|
let resp: HttpResponse = HandshakeError::UnsupportedVersion.error_response();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response();
|
let resp: HttpResponse = HandshakeError::BadWebsocketKey.error_response();
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -121,7 +121,7 @@ fn test_shutdown() {
|
|||||||
assert!(response.status().is_success());
|
assert!(response.status().is_success());
|
||||||
}
|
}
|
||||||
|
|
||||||
thread::sleep(time::Duration::from_millis(100));
|
thread::sleep(time::Duration::from_millis(1000));
|
||||||
assert!(net::TcpStream::connect(addr).is_err());
|
assert!(net::TcpStream::connect(addr).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +163,7 @@ fn test_headers() {
|
|||||||
|
|
||||||
// read response
|
// read response
|
||||||
let bytes = srv.execute(response.body()).unwrap();
|
let bytes = srv.execute(response.body()).unwrap();
|
||||||
assert_eq!(Bytes::from(bytes), Bytes::from_static(STR.as_ref()));
|
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@@ -16,7 +16,7 @@ impl Actor for Ws {
|
|||||||
type Context = ws::WebsocketContext<Self>;
|
type Context = ws::WebsocketContext<Self>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamHandler<ws::Message, ws::WsError> for Ws {
|
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
|
||||||
|
|
||||||
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
||||||
match msg {
|
match msg {
|
||||||
|
Reference in New Issue
Block a user