1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 01:32:57 +01:00

Merge branch 'master' into csrf-middleware

This commit is contained in:
Nikolay Kim 2018-03-02 12:25:05 -08:00 committed by GitHub
commit e60acb7607
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 169 additions and 135 deletions

View File

@ -1,5 +1,10 @@
# Changes
## 0.4.2 (2018-03-xx)
* Better naming for websockets implementation
## 0.4.1 (2018-03-01)
* Rename `Route::p()` to `Route::filter()`

View File

@ -36,7 +36,7 @@ impl Actor for MyWebSocket {
type Context = ws::WebsocketContext<Self, AppState>;
}
impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket {
impl StreamHandler<ws::Message, ws::ProtocolError> for MyWebSocket {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
self.counter += 1;

View File

@ -92,7 +92,7 @@ impl Handler<session::Message> for WsChatSession {
}
/// WebSocket message handler
impl StreamHandler<ws::Message, ws::WsError> for WsChatSession {
impl StreamHandler<ws::Message, ws::ProtocolError> for WsChatSession {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
println!("WEBSOCKET MESSAGE: {:?}", msg);

View File

@ -12,7 +12,7 @@ use std::time::Duration;
use actix::*;
use futures::Future;
use actix_web::ws::{Message, WsError, WsClient, WsClientWriter};
use actix_web::ws::{Message, ProtocolError, Client, ClientWriter};
fn main() {
@ -21,7 +21,7 @@ fn main() {
let sys = actix::System::new("ws-example");
Arbiter::handle().spawn(
WsClient::new("http://127.0.0.1:8080/ws/")
Client::new("http://127.0.0.1:8080/ws/")
.connect()
.map_err(|e| {
println!("Error: {}", e);
@ -53,7 +53,7 @@ fn main() {
}
struct ChatClient(WsClientWriter);
struct ChatClient(ClientWriter);
#[derive(Message)]
struct ClientCommand(String);
@ -93,7 +93,7 @@ impl Handler<ClientCommand> for ChatClient {
}
/// Handle server websocket messages
impl StreamHandler<Message, WsError> for ChatClient {
impl StreamHandler<Message, ProtocolError> for ChatClient {
fn handle(&mut self, msg: Message, ctx: &mut Context<Self>) {
match msg {

View File

@ -25,7 +25,7 @@ impl Actor for MyWebSocket {
}
/// Handler for `ws::Message`
impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket {
impl StreamHandler<ws::Message, ws::ProtocolError> for MyWebSocket {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
// process websocket messages

View File

@ -22,7 +22,7 @@ impl Actor for Ws {
}
/// Handler for ws::Message message
impl StreamHandler<ws::Message, ws::WsError> for Ws {
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg {

View File

@ -14,6 +14,7 @@ use tokio_core::net::TcpListener;
use tokio_core::reactor::Core;
use net2::TcpBuilder;
use ws;
use body::Binary;
use error::Error;
use handler::{Handler, Responder, ReplyItem};
@ -25,7 +26,6 @@ use payload::Payload;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use server::{HttpServer, IntoHttpHandler, ServerSettings};
use ws::{WsClient, WsClientError, WsClientReader, WsClientWriter};
use client::{ClientRequest, ClientRequestBuilder};
/// The `TestServer` type.
@ -180,9 +180,9 @@ impl TestServer {
}
/// Connect to websocket server
pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> {
pub fn ws(&mut self) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> {
let url = self.url("/");
self.system.run_until_complete(WsClient::new(url).connect())
self.system.run_until_complete(ws::Client::new(url).connect())
}
/// Create `GET` request

View File

@ -25,14 +25,32 @@ use client::{ClientRequest, ClientRequestBuilder, ClientResponse,
ClientConnector, SendRequest, SendRequestError,
HttpResponseParserError};
use super::{Message, WsError};
use super::{Message, ProtocolError};
use super::frame::Frame;
use super::proto::{CloseCode, OpCode};
/// Backward compatibility
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::Client` instead")]
pub type WsClient = Client;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientError` instead")]
pub type WsClientError = ClientError;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientReader` instead")]
pub type WsClientReader = ClientReader;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientWriter` instead")]
pub type WsClientWriter = ClientWriter;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ClientHandshake` instead")]
pub type WsClientHandshake = ClientHandshake;
/// Websocket client error
#[derive(Fail, Debug)]
pub enum WsClientError {
pub enum ClientError {
#[fail(display="Invalid url")]
InvalidUrl,
#[fail(display="Invalid response status")]
@ -56,46 +74,46 @@ pub enum WsClientError {
#[fail(display="{}", _0)]
SendRequest(SendRequestError),
#[fail(display="{}", _0)]
Protocol(#[cause] WsError),
Protocol(#[cause] ProtocolError),
#[fail(display="{}", _0)]
Io(io::Error),
#[fail(display="Disconnected")]
Disconnected,
}
impl From<HttpError> for WsClientError {
fn from(err: HttpError) -> WsClientError {
WsClientError::Http(err)
impl From<HttpError> for ClientError {
fn from(err: HttpError) -> ClientError {
ClientError::Http(err)
}
}
impl From<UrlParseError> for WsClientError {
fn from(err: UrlParseError) -> WsClientError {
WsClientError::Url(err)
impl From<UrlParseError> for ClientError {
fn from(err: UrlParseError) -> ClientError {
ClientError::Url(err)
}
}
impl From<SendRequestError> for WsClientError {
fn from(err: SendRequestError) -> WsClientError {
WsClientError::SendRequest(err)
impl From<SendRequestError> for ClientError {
fn from(err: SendRequestError) -> ClientError {
ClientError::SendRequest(err)
}
}
impl From<WsError> for WsClientError {
fn from(err: WsError) -> WsClientError {
WsClientError::Protocol(err)
impl From<ProtocolError> for ClientError {
fn from(err: ProtocolError) -> ClientError {
ClientError::Protocol(err)
}
}
impl From<io::Error> for WsClientError {
fn from(err: io::Error) -> WsClientError {
WsClientError::Io(err)
impl From<io::Error> for ClientError {
fn from(err: io::Error) -> ClientError {
ClientError::Io(err)
}
}
impl From<HttpResponseParserError> for WsClientError {
fn from(err: HttpResponseParserError) -> WsClientError {
WsClientError::ResponseParseError(err)
impl From<HttpResponseParserError> for ClientError {
fn from(err: HttpResponseParserError) -> ClientError {
ClientError::ResponseParseError(err)
}
}
@ -104,9 +122,9 @@ impl From<HttpResponseParserError> for WsClientError {
/// Example of `WebSocket` client usage is available in
/// [websocket example](
/// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24)
pub struct WsClient {
pub struct Client {
request: ClientRequestBuilder,
err: Option<WsClientError>,
err: Option<ClientError>,
http_err: Option<HttpError>,
origin: Option<HeaderValue>,
protocols: Option<String>,
@ -114,16 +132,16 @@ pub struct WsClient {
max_size: usize,
}
impl WsClient {
impl Client {
/// Create new websocket connection
pub fn new<S: AsRef<str>>(uri: S) -> WsClient {
WsClient::with_connector(uri, ClientConnector::from_registry())
pub fn new<S: AsRef<str>>(uri: S) -> Client {
Client::with_connector(uri, ClientConnector::from_registry())
}
/// Create new websocket connection with custom `ClientConnector`
pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> WsClient {
let mut cl = WsClient {
pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> Client {
let mut cl = Client {
request: ClientRequest::build(),
err: None,
http_err: None,
@ -182,12 +200,12 @@ impl WsClient {
}
/// Connect to websocket server and do ws handshake
pub fn connect(&mut self) -> WsClientHandshake {
pub fn connect(&mut self) -> ClientHandshake {
if let Some(e) = self.err.take() {
WsClientHandshake::error(e)
ClientHandshake::error(e)
}
else if let Some(e) = self.http_err.take() {
WsClientHandshake::error(e.into())
ClientHandshake::error(e.into())
} else {
// origin
if let Some(origin) = self.origin.take() {
@ -205,42 +223,42 @@ impl WsClient {
}
let request = match self.request.finish() {
Ok(req) => req,
Err(err) => return WsClientHandshake::error(err.into()),
Err(err) => return ClientHandshake::error(err.into()),
};
if request.uri().host().is_none() {
return WsClientHandshake::error(WsClientError::InvalidUrl)
return ClientHandshake::error(ClientError::InvalidUrl)
}
if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return WsClientHandshake::error(WsClientError::InvalidUrl)
return ClientHandshake::error(ClientError::InvalidUrl)
}
} else {
return WsClientHandshake::error(WsClientError::InvalidUrl)
return ClientHandshake::error(ClientError::InvalidUrl)
}
// start handshake
WsClientHandshake::new(request, self.max_size)
ClientHandshake::new(request, self.max_size)
}
}
}
struct WsInner {
struct Inner {
tx: UnboundedSender<Bytes>,
rx: PayloadHelper<ClientResponse>,
closed: bool,
}
pub struct WsClientHandshake {
pub struct ClientHandshake {
request: Option<SendRequest>,
tx: Option<UnboundedSender<Bytes>>,
key: String,
error: Option<WsClientError>,
error: Option<ClientError>,
max_size: usize,
}
impl WsClientHandshake {
fn new(mut request: ClientRequest, max_size: usize) -> WsClientHandshake
impl ClientHandshake {
fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake
{
// Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that,
@ -257,7 +275,7 @@ impl WsClientHandshake {
Box::new(rx.map_err(|_| io::Error::new(
io::ErrorKind::Other, "disconnected").into()))));
WsClientHandshake {
ClientHandshake {
key,
max_size,
request: Some(request.send()),
@ -266,8 +284,8 @@ impl WsClientHandshake {
}
}
fn error(err: WsClientError) -> WsClientHandshake {
WsClientHandshake {
fn error(err: ClientError) -> ClientHandshake {
ClientHandshake {
key: String::new(),
request: None,
tx: None,
@ -277,9 +295,9 @@ impl WsClientHandshake {
}
}
impl Future for WsClientHandshake {
type Item = (WsClientReader, WsClientWriter);
type Error = WsClientError;
impl Future for ClientHandshake {
type Item = (ClientReader, ClientWriter);
type Error = ClientError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(err) = self.error.take() {
@ -296,7 +314,7 @@ impl Future for WsClientHandshake {
// verify response
if resp.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(resp.status()))
return Err(ClientError::InvalidResponseStatus(resp.status()))
}
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) {
@ -310,22 +328,22 @@ impl Future for WsClientHandshake {
};
if !has_hdr {
trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader)
return Err(ClientError::InvalidUpgradeHeader)
}
// Check for "CONNECTION" header
if let Some(conn) = resp.headers().get(header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_lowercase().contains("upgrade") {
trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()))
return Err(ClientError::InvalidConnectionHeader(conn.clone()))
}
} else {
trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()))
return Err(ClientError::InvalidConnectionHeader(conn.clone()))
}
} else {
trace!("Missing connection header");
return Err(WsClientError::MissingConnectionHeader)
return Err(ClientError::MissingConnectionHeader)
}
if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT)
@ -341,14 +359,14 @@ impl Future for WsClientHandshake {
trace!(
"Invalid challenge response: expected: {} received: {:?}",
encoded, key);
return Err(WsClientError::InvalidChallengeResponse(encoded, key.clone()));
return Err(ClientError::InvalidChallengeResponse(encoded, key.clone()));
}
} else {
trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(WsClientError::MissingWebSocketAcceptHeader)
return Err(ClientError::MissingWebSocketAcceptHeader)
};
let inner = WsInner {
let inner = Inner {
tx: self.tx.take().unwrap(),
rx: PayloadHelper::new(resp),
closed: false,
@ -356,33 +374,33 @@ impl Future for WsClientHandshake {
let inner = Rc::new(UnsafeCell::new(inner));
Ok(Async::Ready(
(WsClientReader{inner: Rc::clone(&inner), max_size: self.max_size},
WsClientWriter{inner})))
(ClientReader{inner: Rc::clone(&inner), max_size: self.max_size},
ClientWriter{inner})))
}
}
pub struct WsClientReader {
inner: Rc<UnsafeCell<WsInner>>,
pub struct ClientReader {
inner: Rc<UnsafeCell<Inner>>,
max_size: usize,
}
impl fmt::Debug for WsClientReader {
impl fmt::Debug for ClientReader {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WsClientReader()")
write!(f, "ws::ClientReader()")
}
}
impl WsClientReader {
impl ClientReader {
#[inline]
fn as_mut(&mut self) -> &mut WsInner {
fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() }
}
}
impl Stream for WsClientReader {
impl Stream for ClientReader {
type Item = Message;
type Error = WsError;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let max_size = self.max_size;
@ -399,14 +417,14 @@ impl Stream for WsClientReader {
// continuation is not supported
if !finished {
inner.closed = true;
return Err(WsError::NoContinuation)
return Err(ProtocolError::NoContinuation)
}
match opcode {
OpCode::Continue => unimplemented!(),
OpCode::Bad => {
inner.closed = true;
Err(WsError::BadOpCode)
Err(ProtocolError::BadOpCode)
},
OpCode::Close => {
inner.closed = true;
@ -430,7 +448,7 @@ impl Stream for WsClientReader {
Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => {
inner.closed = true;
Err(WsError::BadEncoding)
Err(ProtocolError::BadEncoding)
}
}
}
@ -446,18 +464,18 @@ impl Stream for WsClientReader {
}
}
pub struct WsClientWriter {
inner: Rc<UnsafeCell<WsInner>>
pub struct ClientWriter {
inner: Rc<UnsafeCell<Inner>>
}
impl WsClientWriter {
impl ClientWriter {
#[inline]
fn as_mut(&mut self) -> &mut WsInner {
fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() }
}
}
impl WsClientWriter {
impl ClientWriter {
/// Write payload
#[inline]

View File

@ -9,7 +9,7 @@ use body::Binary;
use error::{PayloadError};
use payload::PayloadHelper;
use ws::WsError;
use ws::ProtocolError;
use ws::proto::{OpCode, CloseCode};
use ws::mask::apply_mask;
@ -53,7 +53,7 @@ impl Frame {
/// Parse the input stream into a frame.
pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
-> Poll<Option<Frame>, WsError>
-> Poll<Option<Frame>, ProtocolError>
where S: Stream<Item=Bytes, Error=PayloadError>
{
let mut idx = 2;
@ -69,9 +69,9 @@ impl Frame {
// check masking
let masked = second & 0x80 != 0;
if !masked && server {
return Err(WsError::UnmaskedFrame)
return Err(ProtocolError::UnmaskedFrame)
} else if masked && !server {
return Err(WsError::MaskedFrame)
return Err(ProtocolError::MaskedFrame)
}
let rsv1 = first & 0x40 != 0;
@ -104,7 +104,7 @@ impl Frame {
// check for max allowed size
if length > max_size {
return Err(WsError::Overflow)
return Err(ProtocolError::Overflow)
}
let mask = if server {
@ -133,13 +133,13 @@ impl Frame {
// Disallow bad opcode
if let OpCode::Bad = opcode {
return Err(WsError::InvalidOpcode(first & 0x0F))
return Err(ProtocolError::InvalidOpcode(first & 0x0F))
}
// control frames must have length <= 125
match opcode {
OpCode::Ping | OpCode::Pong if length > 125 => {
return Err(WsError::InvalidLength(length))
return Err(ProtocolError::InvalidLength(length))
}
OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
@ -257,14 +257,14 @@ mod tests {
use super::*;
use futures::stream::once;
fn is_none(frm: Poll<Option<Frame>, WsError>) -> bool {
fn is_none(frm: Poll<Option<Frame>, ProtocolError>) -> bool {
match frm {
Ok(Async::Ready(None)) => true,
_ => false,
}
}
fn extract(frm: Poll<Option<Frame>, WsError>) -> Frame {
fn extract(frm: Poll<Option<Frame>, ProtocolError>) -> Frame {
match frm {
Ok(Async::Ready(Some(frame))) => frame,
_ => panic!("error"),
@ -370,7 +370,7 @@ mod tests {
assert!(Frame::parse(&mut buf, true, 1).is_err());
if let Err(WsError::Overflow) = Frame::parse(&mut buf, false, 0) {
if let Err(ProtocolError::Overflow) = Frame::parse(&mut buf, false, 0) {
} else {
panic!("error");
}

View File

@ -67,13 +67,24 @@ use self::frame::Frame;
use self::proto::{hash_key, OpCode};
pub use self::proto::CloseCode;
pub use self::context::WebsocketContext;
pub use self::client::{Client, ClientError,
ClientReader, ClientWriter, ClientHandshake};
#[allow(deprecated)]
pub use self::client::{WsClient, WsClientError,
WsClientReader, WsClientWriter, WsClientHandshake};
/// Backward compatibility
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::ProtocolError` instead")]
pub type WsError = ProtocolError;
#[doc(hidden)]
#[deprecated(since="0.4.2", note="please use `ws::HandshakeError` instead")]
pub type WsHandshakeError = HandshakeError;
/// Websocket errors
#[derive(Fail, Debug)]
pub enum WsError {
pub enum ProtocolError {
/// Received an unmasked frame from client
#[fail(display="Received an unmasked frame from client")]
UnmaskedFrame,
@ -103,17 +114,17 @@ pub enum WsError {
Payload(#[cause] PayloadError),
}
impl ResponseError for WsError {}
impl ResponseError for ProtocolError {}
impl From<PayloadError> for WsError {
fn from(err: PayloadError) -> WsError {
WsError::Payload(err)
impl From<PayloadError> for ProtocolError {
fn from(err: PayloadError) -> ProtocolError {
ProtocolError::Payload(err)
}
}
/// Websocket handshake errors
#[derive(Fail, PartialEq, Debug)]
pub enum WsHandshakeError {
pub enum HandshakeError {
/// Only get method is allowed
#[fail(display="Method not allowed")]
GetMethodRequired,
@ -134,26 +145,26 @@ pub enum WsHandshakeError {
BadWebsocketKey,
}
impl ResponseError for WsHandshakeError {
impl ResponseError for HandshakeError {
fn error_response(&self) -> HttpResponse {
match *self {
WsHandshakeError::GetMethodRequired => {
HandshakeError::GetMethodRequired => {
HttpMethodNotAllowed
.build()
.header(header::ALLOW, "GET")
.finish()
.unwrap()
}
WsHandshakeError::NoWebsocketUpgrade =>
HandshakeError::NoWebsocketUpgrade =>
HttpBadRequest.with_reason("No WebSocket UPGRADE header found"),
WsHandshakeError::NoConnectionUpgrade =>
HandshakeError::NoConnectionUpgrade =>
HttpBadRequest.with_reason("No CONNECTION upgrade"),
WsHandshakeError::NoVersionHeader =>
HandshakeError::NoVersionHeader =>
HttpBadRequest.with_reason("Websocket version header is required"),
WsHandshakeError::UnsupportedVersion =>
HandshakeError::UnsupportedVersion =>
HttpBadRequest.with_reason("Unsupported version"),
WsHandshakeError::BadWebsocketKey =>
HandshakeError::BadWebsocketKey =>
HttpBadRequest.with_reason("Handshake error"),
}
}
@ -171,7 +182,7 @@ pub enum Message {
/// Do websocket handshake and start actor
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, WsError>,
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, ProtocolError>,
S: 'static
{
let mut resp = handshake(&req)?;
@ -191,10 +202,10 @@ pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
// /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list
// /// which the server also knows.
pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHandshakeError> {
pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, HandshakeError> {
// WebSocket accepts only GET
if *req.method() != Method::GET {
return Err(WsHandshakeError::GetMethodRequired)
return Err(HandshakeError::GetMethodRequired)
}
// Check for "UPGRADE" to websocket header
@ -208,17 +219,17 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
false
};
if !has_hdr {
return Err(WsHandshakeError::NoWebsocketUpgrade)
return Err(HandshakeError::NoWebsocketUpgrade)
}
// Upgrade connection
if !req.upgrade() {
return Err(WsHandshakeError::NoConnectionUpgrade)
return Err(HandshakeError::NoConnectionUpgrade)
}
// check supported version
if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
return Err(WsHandshakeError::NoVersionHeader)
return Err(HandshakeError::NoVersionHeader)
}
let supported_ver = {
if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
@ -228,12 +239,12 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
}
};
if !supported_ver {
return Err(WsHandshakeError::UnsupportedVersion)
return Err(HandshakeError::UnsupportedVersion)
}
// check client handshake for validity
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
return Err(WsHandshakeError::BadWebsocketKey)
return Err(HandshakeError::BadWebsocketKey)
}
let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
@ -275,7 +286,7 @@ impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
type Item = Message;
type Error = WsError;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.closed {
@ -289,14 +300,14 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// continuation is not supported
if !finished {
self.closed = true;
return Err(WsError::NoContinuation)
return Err(ProtocolError::NoContinuation)
}
match opcode {
OpCode::Continue => unimplemented!(),
OpCode::Bad => {
self.closed = true;
Err(WsError::BadOpCode)
Err(ProtocolError::BadOpCode)
}
OpCode::Close => {
self.closed = true;
@ -320,7 +331,7 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => {
self.closed = true;
Err(WsError::BadEncoding)
Err(ProtocolError::BadEncoding)
}
}
}
@ -346,25 +357,25 @@ mod tests {
fn test_handshake() {
let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap());
assert_eq!(HandshakeError::GetMethodRequired, handshake(&req).err().unwrap());
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("test"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("websocket"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap());
assert_eq!(HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
@ -373,7 +384,7 @@ mod tests {
header::HeaderValue::from_static("upgrade"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::NoVersionHeader, handshake(&req).err().unwrap());
assert_eq!(HandshakeError::NoVersionHeader, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
@ -384,7 +395,7 @@ mod tests {
header::HeaderValue::from_static("5"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::UnsupportedVersion, handshake(&req).err().unwrap());
assert_eq!(HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
@ -395,7 +406,7 @@ mod tests {
header::HeaderValue::from_static("13"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(WsHandshakeError::BadWebsocketKey, handshake(&req).err().unwrap());
assert_eq!(HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap());
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
@ -414,17 +425,17 @@ mod tests {
#[test]
fn test_wserror_http_response() {
let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response();
let resp: HttpResponse = HandshakeError::GetMethodRequired.error_response();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response();
let resp: HttpResponse = HandshakeError::NoWebsocketUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response();
let resp: HttpResponse = HandshakeError::NoConnectionUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response();
let resp: HttpResponse = HandshakeError::NoVersionHeader.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response();
let resp: HttpResponse = HandshakeError::UnsupportedVersion.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response();
let resp: HttpResponse = HandshakeError::BadWebsocketKey.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
}

View File

@ -121,7 +121,7 @@ fn test_shutdown() {
assert!(response.status().is_success());
}
thread::sleep(time::Duration::from_millis(100));
thread::sleep(time::Duration::from_millis(1000));
assert!(net::TcpStream::connect(addr).is_err());
}
@ -163,7 +163,7 @@ fn test_headers() {
// read response
let bytes = srv.execute(response.body()).unwrap();
assert_eq!(Bytes::from(bytes), Bytes::from_static(STR.as_ref()));
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]

View File

@ -16,7 +16,7 @@ impl Actor for Ws {
type Context = ws::WebsocketContext<Self>;
}
impl StreamHandler<ws::Message, ws::WsError> for Ws {
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg {