mirror of
https://github.com/fafhrd91/actix-web
synced 2025-08-30 16:40:21 +02:00
refactor websockets handling
This commit is contained in:
165
src/ws/mod.rs
165
src/ws/mod.rs
@@ -23,9 +23,8 @@
|
||||
//! type Context = ws::WebsocketContext<Self>;
|
||||
//! }
|
||||
//!
|
||||
//! // Define Handler for ws::Message message
|
||||
//! impl Handler<ws::Message> for Ws {
|
||||
//! type Result = ();
|
||||
//! // Handler for ws::Message messages
|
||||
//! impl StreamHandler<ws::Message, ws::WsError> for Ws {
|
||||
//!
|
||||
//! fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
|
||||
//! match msg {
|
||||
@@ -48,13 +47,14 @@ use http::{Method, StatusCode, header};
|
||||
use futures::{Async, Poll, Stream};
|
||||
use byteorder::{ByteOrder, NetworkEndian};
|
||||
|
||||
use actix::{Actor, AsyncContext, Handler};
|
||||
use actix::{Actor, AsyncContext, StreamHandler};
|
||||
|
||||
use body::Binary;
|
||||
use payload::PayloadHelper;
|
||||
use error::{Error, WsHandshakeError, PayloadError};
|
||||
use error::{Error, PayloadError, ResponseError};
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder};
|
||||
use httpcodes::{HTTPBadRequest, HTTPMethodNotAllowed};
|
||||
|
||||
mod frame;
|
||||
mod proto;
|
||||
@@ -66,7 +66,8 @@ use self::frame::Frame;
|
||||
use self::proto::{hash_key, OpCode};
|
||||
pub use self::proto::CloseCode;
|
||||
pub use self::context::WebsocketContext;
|
||||
pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake};
|
||||
pub use self::client::{WsClient, WsClientError,
|
||||
WsClientReader, WsClientWriter, WsClientHandshake};
|
||||
|
||||
const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
|
||||
const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";
|
||||
@@ -74,6 +75,94 @@ const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION";
|
||||
// const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL";
|
||||
|
||||
|
||||
/// Websocket errors
|
||||
#[derive(Fail, Debug)]
|
||||
pub enum WsError {
|
||||
/// Received an unmasked frame from client
|
||||
#[fail(display="Received an unmasked frame from client")]
|
||||
UnmaskedFrame,
|
||||
/// Received a masked frame from server
|
||||
#[fail(display="Received a masked frame from server")]
|
||||
MaskedFrame,
|
||||
/// Encountered invalid opcode
|
||||
#[fail(display="Invalid opcode: {}", _0)]
|
||||
InvalidOpcode(u8),
|
||||
/// Invalid control frame length
|
||||
#[fail(display="Invalid control frame length: {}", _0)]
|
||||
InvalidLength(usize),
|
||||
/// Bad web socket op code
|
||||
#[fail(display="Bad web socket op code")]
|
||||
BadOpCode,
|
||||
/// A payload reached size limit.
|
||||
#[fail(display="A payload reached size limit.")]
|
||||
Overflow,
|
||||
/// Continuation is not supproted
|
||||
#[fail(display="Continuation is not supproted.")]
|
||||
NoContinuation,
|
||||
/// Bad utf-8 encoding
|
||||
#[fail(display="Bad utf-8 encoding.")]
|
||||
BadEncoding,
|
||||
/// Payload error
|
||||
#[fail(display="Payload error: {}", _0)]
|
||||
Payload(#[cause] PayloadError),
|
||||
}
|
||||
|
||||
impl ResponseError for WsError {}
|
||||
|
||||
impl From<PayloadError> for WsError {
|
||||
fn from(err: PayloadError) -> WsError {
|
||||
WsError::Payload(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Websocket handshake errors
|
||||
#[derive(Fail, PartialEq, Debug)]
|
||||
pub enum WsHandshakeError {
|
||||
/// Only get method is allowed
|
||||
#[fail(display="Method not allowed")]
|
||||
GetMethodRequired,
|
||||
/// Upgrade header if not set to websocket
|
||||
#[fail(display="Websocket upgrade is expected")]
|
||||
NoWebsocketUpgrade,
|
||||
/// Connection header is not set to upgrade
|
||||
#[fail(display="Connection upgrade is expected")]
|
||||
NoConnectionUpgrade,
|
||||
/// Websocket version header is not set
|
||||
#[fail(display="Websocket version header is required")]
|
||||
NoVersionHeader,
|
||||
/// Unsupported websocket version
|
||||
#[fail(display="Unsupported version")]
|
||||
UnsupportedVersion,
|
||||
/// Websocket key is not set or wrong
|
||||
#[fail(display="Unknown websocket key")]
|
||||
BadWebsocketKey,
|
||||
}
|
||||
|
||||
impl ResponseError for WsHandshakeError {
|
||||
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
match *self {
|
||||
WsHandshakeError::GetMethodRequired => {
|
||||
HTTPMethodNotAllowed
|
||||
.build()
|
||||
.header(header::ALLOW, "GET")
|
||||
.finish()
|
||||
.unwrap()
|
||||
}
|
||||
WsHandshakeError::NoWebsocketUpgrade =>
|
||||
HTTPBadRequest.with_reason("No WebSocket UPGRADE header found"),
|
||||
WsHandshakeError::NoConnectionUpgrade =>
|
||||
HTTPBadRequest.with_reason("No CONNECTION upgrade"),
|
||||
WsHandshakeError::NoVersionHeader =>
|
||||
HTTPBadRequest.with_reason("Websocket version header is required"),
|
||||
WsHandshakeError::UnsupportedVersion =>
|
||||
HTTPBadRequest.with_reason("Unsupported version"),
|
||||
WsHandshakeError::BadWebsocketKey =>
|
||||
HTTPBadRequest.with_reason("Handshake error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `WebSocket` Message
|
||||
#[derive(Debug, PartialEq, Message)]
|
||||
pub enum Message {
|
||||
@@ -82,19 +171,18 @@ pub enum Message {
|
||||
Ping(String),
|
||||
Pong(String),
|
||||
Close(CloseCode),
|
||||
Error
|
||||
}
|
||||
|
||||
/// Do websocket handshake and start actor
|
||||
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
|
||||
where A: Actor<Context=WebsocketContext<A, S>> + Handler<Message>,
|
||||
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, WsError>,
|
||||
S: 'static
|
||||
{
|
||||
let mut resp = handshake(&req)?;
|
||||
let stream = WsStream::new(req.clone());
|
||||
|
||||
let mut ctx = WebsocketContext::new(req, actor);
|
||||
ctx.add_message_stream(stream);
|
||||
ctx.add_stream(stream);
|
||||
|
||||
Ok(resp.body(ctx)?)
|
||||
}
|
||||
@@ -168,33 +256,52 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
|
||||
pub struct WsStream<S> {
|
||||
rx: PayloadHelper<S>,
|
||||
closed: bool,
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
||||
/// Create new websocket frames stream
|
||||
pub fn new(stream: S) -> WsStream<S> {
|
||||
WsStream { rx: PayloadHelper::new(stream),
|
||||
closed: false }
|
||||
closed: false,
|
||||
max_size: 65_536,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set max frame size
|
||||
///
|
||||
/// By default max size is set to 64kb
|
||||
pub fn max_size(mut self, size: usize) -> Self {
|
||||
self.max_size = size;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
||||
type Item = Message;
|
||||
type Error = ();
|
||||
type Error = WsError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
if self.closed {
|
||||
return Ok(Async::Ready(None))
|
||||
}
|
||||
|
||||
match Frame::parse(&mut self.rx, true) {
|
||||
match Frame::parse(&mut self.rx, true, self.max_size) {
|
||||
Ok(Async::Ready(Some(frame))) => {
|
||||
// trace!("WsFrame {}", frame);
|
||||
let (_finished, opcode, payload) = frame.unpack();
|
||||
let (finished, opcode, payload) = frame.unpack();
|
||||
|
||||
// continuation is not supported
|
||||
if !finished {
|
||||
self.closed = true;
|
||||
return Err(WsError::NoContinuation)
|
||||
}
|
||||
|
||||
match opcode {
|
||||
OpCode::Continue => unimplemented!(),
|
||||
OpCode::Bad =>
|
||||
Ok(Async::Ready(Some(Message::Error))),
|
||||
OpCode::Bad => {
|
||||
self.closed = true;
|
||||
Err(WsError::BadOpCode)
|
||||
}
|
||||
OpCode::Close => {
|
||||
self.closed = true;
|
||||
let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16;
|
||||
@@ -215,17 +322,19 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
|
||||
match String::from_utf8(tmp) {
|
||||
Ok(s) =>
|
||||
Ok(Async::Ready(Some(Message::Text(s)))),
|
||||
Err(_) =>
|
||||
Ok(Async::Ready(Some(Message::Error))),
|
||||
Err(_) => {
|
||||
self.closed = true;
|
||||
Err(WsError::BadEncoding)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(_) => {
|
||||
Err(e) => {
|
||||
self.closed = true;
|
||||
Ok(Async::Ready(Some(Message::Error)))
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -306,4 +415,20 @@ mod tests {
|
||||
assert_eq!(StatusCode::SWITCHING_PROTOCOLS,
|
||||
handshake(&req).unwrap().finish().unwrap().status());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wserror_http_response() {
|
||||
let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
|
||||
let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user