1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-18 05:41:50 +01:00

create custom WebsocketContext for websocket connection

This commit is contained in:
Nikolay Kim 2018-01-10 10:12:34 -08:00
parent 8aae2daafa
commit 4b72a1b325
8 changed files with 326 additions and 136 deletions

View File

@ -6,10 +6,11 @@ a [*ws::WsStream*](../actix_web/ws/struct.WsStream.html) and then use stream
combinators to handle actual messages. But it is simplier to handle websocket communications combinators to handle actual messages. But it is simplier to handle websocket communications
with http actor. with http actor.
```rust This is example of simple websocket echo server:
extern crate actix;
extern crate actix_web;
```rust
# extern crate actix;
# extern crate actix_web;
use actix::*; use actix::*;
use actix_web::*; use actix_web::*;
@ -17,18 +18,18 @@ use actix_web::*;
struct Ws; struct Ws;
impl Actor for Ws { impl Actor for Ws {
type Context = HttpContext<Self>; type Context = ws::WebsocketContext<Self>;
} }
/// Define Handler for ws::Message message /// Define Handler for ws::Message message
impl Handler<ws::Message> for Ws { impl Handler<ws::Message> for Ws {
type Result=(); type Result=();
fn handle(&mut self, msg: ws::Message, ctx: &mut HttpContext<Self>) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {
ws::Message::Ping(msg) => ws::WsWriter::pong(ctx, &msg), ws::Message::Ping(msg) => ctx.pong(&msg),
ws::Message::Text(text) => ws::WsWriter::text(ctx, &text), ws::Message::Text(text) => ctx.text(&text),
ws::Message::Binary(bin) => ws::WsWriter::binary(ctx, bin), ws::Message::Binary(bin) => ctx.binary(bin),
_ => (), _ => (),
} }
} }

View File

@ -38,17 +38,12 @@ pub struct HttpContext<A, S=()> where A: Actor<Context=HttpContext<A, S>>,
impl<A, S> ActorContext for HttpContext<A, S> where A: Actor<Context=Self> impl<A, S> ActorContext for HttpContext<A, S> where A: Actor<Context=Self>
{ {
/// Stop actor execution
fn stop(&mut self) { fn stop(&mut self) {
self.inner.stop(); self.inner.stop();
} }
/// Terminate actor execution
fn terminate(&mut self) { fn terminate(&mut self) {
self.inner.terminate() self.inner.terminate()
} }
/// Actor execution state
fn state(&self) -> ActorState { fn state(&self) -> ActorState {
self.inner.state() self.inner.state()
} }
@ -61,13 +56,11 @@ impl<A, S> AsyncContext<A> for HttpContext<A, S> where A: Actor<Context=Self>
{ {
self.inner.spawn(fut) self.inner.spawn(fut)
} }
fn wait<F>(&mut self, fut: F) fn wait<F>(&mut self, fut: F)
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static
{ {
self.inner.wait(fut) self.inner.wait(fut)
} }
fn cancel_future(&mut self, handle: SpawnHandle) -> bool { fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
self.inner.cancel_future(handle) self.inner.cancel_future(handle)
} }
@ -79,12 +72,10 @@ impl<A, S> AsyncContextApi<A> for HttpContext<A, S> where A: Actor<Context=Self>
fn unsync_sender(&mut self) -> queue::unsync::UnboundedSender<ContextProtocol<A>> { fn unsync_sender(&mut self) -> queue::unsync::UnboundedSender<ContextProtocol<A>> {
self.inner.unsync_sender() self.inner.unsync_sender()
} }
#[inline] #[inline]
fn unsync_address(&mut self) -> Address<A> { fn unsync_address(&mut self) -> Address<A> {
self.inner.unsync_address() self.inner.unsync_address()
} }
#[inline] #[inline]
fn sync_address(&mut self) -> SyncAddress<A> { fn sync_address(&mut self) -> SyncAddress<A> {
self.inner.sync_address() self.inner.sync_address()
@ -97,7 +88,6 @@ impl<A, S: 'static> HttpContext<A, S> where A: Actor<Context=Self> {
pub fn new(req: HttpRequest<S>, actor: A) -> HttpContext<A, S> { pub fn new(req: HttpRequest<S>, actor: A) -> HttpContext<A, S> {
HttpContext::from_request(req).actor(actor) HttpContext::from_request(req).actor(actor)
} }
pub fn from_request(req: HttpRequest<S>) -> HttpContext<A, S> { pub fn from_request(req: HttpRequest<S>) -> HttpContext<A, S> {
HttpContext { HttpContext {
inner: ContextImpl::new(None), inner: ContextImpl::new(None),
@ -106,7 +96,6 @@ impl<A, S: 'static> HttpContext<A, S> where A: Actor<Context=Self> {
disconnected: false, disconnected: false,
} }
} }
#[inline] #[inline]
pub fn actor(mut self, actor: A) -> HttpContext<A, S> { pub fn actor(mut self, actor: A) -> HttpContext<A, S> {
self.inner.set_actor(actor); self.inner.set_actor(actor);
@ -217,9 +206,7 @@ impl<A, S> ToEnvelope<A> for HttpContext<A, S>
fn pack<M>(msg: M, tx: Option<Sender<Result<M::Item, M::Error>>>, fn pack<M>(msg: M, tx: Option<Sender<Result<M::Item, M::Error>>>,
channel_on_drop: bool) -> Envelope<A> channel_on_drop: bool) -> Envelope<A>
where A: Handler<M>, where A: Handler<M>,
M: ResponseType + Send + 'static, M: ResponseType + Send + 'static, M::Item: Send, M::Error: Send
M::Item: Send,
M::Error: Send
{ {
RemoteEnvelope::new(msg, tx, channel_on_drop).into() RemoteEnvelope::new(msg, tx, channel_on_drop).into()
} }
@ -240,7 +227,7 @@ pub struct Drain<A> {
} }
impl<A> Drain<A> { impl<A> Drain<A> {
fn new(fut: oneshot::Receiver<()>) -> Self { pub fn new(fut: oneshot::Receiver<()>) -> Self {
Drain { Drain {
fut: fut, fut: fut,
_a: PhantomData _a: PhantomData

View File

@ -109,6 +109,7 @@ mod worker;
mod channel; mod channel;
mod wsframe; mod wsframe;
mod wsproto; mod wsproto;
mod wscontext;
mod h1; mod h1;
mod h2; mod h2;
mod h1writer; mod h1writer;

View File

@ -34,7 +34,7 @@
//! .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) //! .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
//! .allowed_header(header::CONTENT_TYPE) //! .allowed_header(header::CONTENT_TYPE)
//! .max_age(3600) //! .max_age(3600)
//! .finish().expect("Can not create CORS middleware")) //! .finish().expect("Can not create CORS middleware"));
//! r.method(Method::GET).f(|_| httpcodes::HTTPOk); //! r.method(Method::GET).f(|_| httpcodes::HTTPOk);
//! r.method(Method::HEAD).f(|_| httpcodes::HTTPMethodNotAllowed); //! r.method(Method::HEAD).f(|_| httpcodes::HTTPMethodNotAllowed);
//! }) //! })
@ -96,10 +96,7 @@ pub enum Error {
impl ResponseError for Error { impl ResponseError for Error {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
match *self { HTTPBadRequest.into()
Error::BadOrigin => HTTPBadRequest.into(),
_ => HTTPBadRequest.into()
}
} }
} }
@ -355,7 +352,7 @@ impl CorsBuilder {
{ {
self.methods = true; self.methods = true;
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
for m in methods.into_iter() { for m in methods {
match Method::try_from(m) { match Method::try_from(m) {
Ok(method) => { Ok(method) => {
cors.methods.insert(method); cors.methods.insert(method);
@ -404,7 +401,7 @@ impl CorsBuilder {
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H> where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H>
{ {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
for h in headers.into_iter() { for h in headers {
match HeaderName::try_from(h) { match HeaderName::try_from(h) {
Ok(method) => { Ok(method) => {
if cors.headers.is_all() { if cors.headers.is_all() {

View File

@ -103,7 +103,7 @@ impl<S: 'static> Route<S> {
} }
} }
/// RouteHandler wrapper. This struct is required because it needs to be shared /// `RouteHandler` wrapper. This struct is required because it needs to be shared
/// for resource level middlewares. /// for resource level middlewares.
struct InnerHandler<S>(Rc<Box<RouteHandler<S>>>); struct InnerHandler<S>(Rc<Box<RouteHandler<S>>>);

View File

@ -8,8 +8,9 @@
//! ```rust //! ```rust
//! # extern crate actix; //! # extern crate actix;
//! # extern crate actix_web; //! # extern crate actix_web;
//! use actix::*; //! # use actix::*;
//! use actix_web::*; //! # use actix_web::*;
//! use actix_web::ws;
//! //!
//! // do websocket handshake and start actor //! // do websocket handshake and start actor
//! fn ws_index(req: HttpRequest) -> Result<HttpResponse> { //! fn ws_index(req: HttpRequest) -> Result<HttpResponse> {
@ -19,18 +20,18 @@
//! struct Ws; //! struct Ws;
//! //!
//! impl Actor for Ws { //! impl Actor for Ws {
//! type Context = HttpContext<Self>; //! type Context = ws::WebsocketContext<Self>;
//! } //! }
//! //!
//! // Define Handler for ws::Message message //! // Define Handler for ws::Message message
//! impl Handler<ws::Message> for Ws { //! impl Handler<ws::Message> for Ws {
//! type Result = (); //! type Result = ();
//! //!
//! fn handle(&mut self, msg: ws::Message, ctx: &mut HttpContext<Self>) { //! fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
//! match msg { //! match msg {
//! ws::Message::Ping(msg) => ws::WsWriter::pong(ctx, &msg), //! ws::Message::Ping(msg) => ctx.pong(&msg),
//! ws::Message::Text(text) => ws::WsWriter::text(ctx, &text), //! ws::Message::Text(text) => ctx.text(&text),
//! ws::Message::Binary(bin) => ws::WsWriter::binary(ctx, bin), //! ws::Message::Binary(bin) => ctx.binary(bin),
//! _ => (), //! _ => (),
//! } //! }
//! } //! }
@ -42,22 +43,22 @@
//! # .finish(); //! # .finish();
//! # } //! # }
//! ``` //! ```
use std::vec::Vec;
use http::{Method, StatusCode, header};
use bytes::BytesMut; use bytes::BytesMut;
use http::{Method, StatusCode, header};
use futures::{Async, Poll, Stream}; use futures::{Async, Poll, Stream};
use actix::{Actor, AsyncContext, ResponseType, Handler}; use actix::{Actor, AsyncContext, ResponseType, Handler};
use body::Binary;
use payload::ReadAny; use payload::ReadAny;
use error::{Error, WsHandshakeError}; use error::{Error, WsHandshakeError};
use context::HttpContext;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder};
use wsframe; use wsframe;
use wsproto::*; use wsproto::*;
pub use wsproto::CloseCode; pub use wsproto::CloseCode;
pub use wscontext::WebsocketContext;
const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";
@ -69,7 +70,7 @@ const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION";
#[derive(Debug)] #[derive(Debug)]
pub enum Message { pub enum Message {
Text(String), Text(String),
Binary(Vec<u8>), Binary(Binary),
Ping(String), Ping(String),
Pong(String), Pong(String),
Close, Close,
@ -84,13 +85,13 @@ impl ResponseType for Message {
/// Do websocket handshake and start actor /// Do websocket handshake and start actor
pub fn start<A, S>(mut req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error> pub fn start<A, S>(mut req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
where A: Actor<Context=HttpContext<A, S>> + Handler<Message>, where A: Actor<Context=WebsocketContext<A, S>> + Handler<Message>,
S: 'static S: 'static
{ {
let mut resp = handshake(&req)?; let mut resp = handshake(&req)?;
let stream = WsStream::new(req.payload_mut().readany()); let stream = WsStream::new(req.payload_mut().readany());
let mut ctx = HttpContext::new(req, actor); let mut ctx = WebsocketContext::new(req, actor);
ctx.add_message_stream(stream); ctx.add_message_stream(stream);
Ok(resp.body(ctx)?) Ok(resp.body(ctx)?)
@ -222,14 +223,17 @@ impl Stream for WsStream {
}, },
OpCode::Ping => OpCode::Ping =>
return Ok(Async::Ready(Some( return Ok(Async::Ready(Some(
Message::Ping(String::from_utf8_lossy(&payload).into())))), Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into())))),
OpCode::Pong => OpCode::Pong =>
return Ok(Async::Ready(Some( return Ok(Async::Ready(Some(
Message::Pong(String::from_utf8_lossy(&payload).into())))), Message::Pong(
String::from_utf8_lossy(payload.as_ref()).into())))),
OpCode::Binary => OpCode::Binary =>
return Ok(Async::Ready(Some(Message::Binary(payload)))), return Ok(Async::Ready(Some(Message::Binary(payload)))),
OpCode::Text => { OpCode::Text => {
match String::from_utf8(payload) { let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) {
Ok(s) => Ok(s) =>
return Ok(Async::Ready(Some(Message::Text(s)))), return Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => Err(_) =>
@ -262,67 +266,6 @@ impl Stream for WsStream {
} }
} }
/// `WebSocket` writer
pub struct WsWriter;
impl WsWriter {
/// Send text frame
pub fn text<A, S>(ctx: &mut HttpContext<A, S>, text: &str)
where A: Actor<Context=HttpContext<A, S>>
{
let mut frame = wsframe::Frame::message(Vec::from(text), OpCode::Text, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
ctx.write(buf);
}
/// Send binary frame
pub fn binary<A, S>(ctx: &mut HttpContext<A, S>, data: Vec<u8>)
where A: Actor<Context=HttpContext<A, S>>
{
let mut frame = wsframe::Frame::message(data, OpCode::Binary, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
ctx.write(buf);
}
/// Send ping frame
pub fn ping<A, S>(ctx: &mut HttpContext<A, S>, message: &str)
where A: Actor<Context=HttpContext<A, S>>
{
let mut frame = wsframe::Frame::message(Vec::from(message), OpCode::Ping, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
ctx.write(buf);
}
/// Send pong frame
pub fn pong<A, S>(ctx: &mut HttpContext<A, S>, message: &str)
where A: Actor<Context=HttpContext<A, S>>
{
let mut frame = wsframe::Frame::message(Vec::from(message), OpCode::Pong, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
ctx.write(buf);
}
/// Send close frame
pub fn close<A, S>(ctx: &mut HttpContext<A, S>, code: CloseCode, reason: &str)
where A: Actor<Context=HttpContext<A, S>>
{
let mut frame = wsframe::Frame::close(code, reason);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
ctx.write(buf);
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

257
src/wscontext.rs Normal file
View File

@ -0,0 +1,257 @@
use std::mem;
use std::collections::VecDeque;
use futures::{Async, Poll};
use futures::sync::oneshot::Sender;
use futures::unsync::oneshot;
use actix::{Actor, ActorState, ActorContext, AsyncContext,
Address, SyncAddress, Handler, Subscriber, ResponseType, SpawnHandle};
use actix::fut::ActorFuture;
use actix::dev::{queue, AsyncContextApi,
ContextImpl, ContextProtocol, Envelope, ToEnvelope, RemoteEnvelope};
use body::{Body, Binary};
use error::{Error, Result, ErrorInternalServerError};
use httprequest::HttpRequest;
use context::{Frame, ActorHttpContext, Drain};
use wsframe;
use wsproto::*;
pub use wsproto::CloseCode;
/// Http actor execution context
pub struct WebsocketContext<A, S=()> where A: Actor<Context=WebsocketContext<A, S>>,
{
inner: ContextImpl<A>,
stream: VecDeque<Frame>,
request: HttpRequest<S>,
disconnected: bool,
}
impl<A, S> ActorContext for WebsocketContext<A, S> where A: Actor<Context=Self>
{
fn stop(&mut self) {
self.inner.stop();
}
fn terminate(&mut self) {
self.inner.terminate()
}
fn state(&self) -> ActorState {
self.inner.state()
}
}
impl<A, S> AsyncContext<A> for WebsocketContext<A, S> where A: Actor<Context=Self>
{
fn spawn<F>(&mut self, fut: F) -> SpawnHandle
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static
{
self.inner.spawn(fut)
}
fn wait<F>(&mut self, fut: F)
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static
{
self.inner.wait(fut)
}
fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
self.inner.cancel_future(handle)
}
}
#[doc(hidden)]
impl<A, S> AsyncContextApi<A> for WebsocketContext<A, S> where A: Actor<Context=Self> {
#[inline]
fn unsync_sender(&mut self) -> queue::unsync::UnboundedSender<ContextProtocol<A>> {
self.inner.unsync_sender()
}
#[inline]
fn unsync_address(&mut self) -> Address<A> {
self.inner.unsync_address()
}
#[inline]
fn sync_address(&mut self) -> SyncAddress<A> {
self.inner.sync_address()
}
}
impl<A, S: 'static> WebsocketContext<A, S> where A: Actor<Context=Self> {
#[inline]
pub fn new(req: HttpRequest<S>, actor: A) -> WebsocketContext<A, S> {
WebsocketContext::from_request(req).actor(actor)
}
pub fn from_request(req: HttpRequest<S>) -> WebsocketContext<A, S> {
WebsocketContext {
inner: ContextImpl::new(None),
stream: VecDeque::new(),
request: req,
disconnected: false,
}
}
#[inline]
pub fn actor(mut self, actor: A) -> WebsocketContext<A, S> {
self.inner.set_actor(actor);
self
}
}
impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
/// Write payload
#[inline]
fn write<B: Into<Binary>>(&mut self, data: B) {
if !self.disconnected {
self.stream.push_back(Frame::Payload(Some(data.into())));
} else {
warn!("Trying to write to disconnected response");
}
}
/// Shared application state
#[inline]
pub fn state(&self) -> &S {
self.request.state()
}
/// Incoming request
#[inline]
pub fn request(&mut self) -> &mut HttpRequest<S> {
&mut self.request
}
/// Send text frame
pub fn text(&mut self, text: &str) {
let mut frame = wsframe::Frame::message(Vec::from(text), OpCode::Text, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
self.write(buf);
}
/// Send binary frame
pub fn binary<B: Into<Binary>>(&mut self, data: B) {
let mut frame = wsframe::Frame::message(data, OpCode::Binary, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
self.write(buf);
}
/// Send ping frame
pub fn ping(&mut self, message: &str) {
let mut frame = wsframe::Frame::message(Vec::from(message), OpCode::Ping, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
self.write(buf);
}
/// Send pong frame
pub fn pong(&mut self, message: &str) {
let mut frame = wsframe::Frame::message(Vec::from(message), OpCode::Pong, true);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
self.write(buf);
}
/// Send close frame
pub fn close(&mut self, code: CloseCode, reason: &str) {
let mut frame = wsframe::Frame::close(code, reason);
let mut buf = Vec::new();
frame.format(&mut buf).unwrap();
self.write(buf);
}
/// Returns drain future
pub fn drain(&mut self) -> Drain<A> {
let (tx, rx) = oneshot::channel();
self.inner.modify();
self.stream.push_back(Frame::Drain(tx));
Drain::new(rx)
}
/// Check if connection still open
#[inline]
pub fn connected(&self) -> bool {
!self.disconnected
}
}
impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
#[inline]
#[doc(hidden)]
pub fn subscriber<M>(&mut self) -> Box<Subscriber<M>>
where A: Handler<M>, M: ResponseType + 'static
{
self.inner.subscriber()
}
#[inline]
#[doc(hidden)]
pub fn sync_subscriber<M>(&mut self) -> Box<Subscriber<M> + Send>
where A: Handler<M>,
M: ResponseType + Send + 'static, M::Item: Send, M::Error: Send,
{
self.inner.sync_subscriber()
}
}
impl<A, S> ActorHttpContext for WebsocketContext<A, S> where A: Actor<Context=Self>, S: 'static {
#[inline]
fn disconnected(&mut self) {
self.disconnected = true;
self.stop();
}
fn poll(&mut self) -> Poll<Option<Frame>, Error> {
let ctx: &mut WebsocketContext<A, S> = unsafe {
mem::transmute(self as &mut WebsocketContext<A, S>)
};
if self.inner.alive() {
match self.inner.poll(ctx) {
Ok(Async::NotReady) | Ok(Async::Ready(())) => (),
Err(_) => return Err(ErrorInternalServerError("error").into()),
}
}
// frames
if let Some(frame) = self.stream.pop_front() {
Ok(Async::Ready(Some(frame)))
} else if self.inner.alive() {
Ok(Async::NotReady)
} else {
Ok(Async::Ready(None))
}
}
}
impl<A, S> ToEnvelope<A> for WebsocketContext<A, S>
where A: Actor<Context=WebsocketContext<A, S>>,
{
#[inline]
fn pack<M>(msg: M, tx: Option<Sender<Result<M::Item, M::Error>>>,
channel_on_drop: bool) -> Envelope<A>
where A: Handler<M>,
M: ResponseType + Send + 'static, M::Item: Send, M::Error: Send {
RemoteEnvelope::new(msg, tx, channel_on_drop).into()
}
}
impl<A, S> From<WebsocketContext<A, S>> for Body
where A: Actor<Context=WebsocketContext<A, S>>, S: 'static
{
fn from(ctx: WebsocketContext<A, S>) -> Body {
Body::Actor(Box::new(ctx))
}
}

View File

@ -3,6 +3,7 @@ use std::io::{Write, Error, ErrorKind};
use std::iter::FromIterator; use std::iter::FromIterator;
use bytes::BytesMut; use bytes::BytesMut;
use body::Binary;
use wsproto::{OpCode, CloseCode}; use wsproto::{OpCode, CloseCode};
@ -14,7 +15,7 @@ fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
} }
/// A struct representing a `WebSocket` frame. /// A struct representing a `WebSocket` frame.
#[derive(Debug, Clone)] #[derive(Debug)]
pub(crate) struct Frame { pub(crate) struct Frame {
finished: bool, finished: bool,
rsv1: bool, rsv1: bool,
@ -22,13 +23,13 @@ pub(crate) struct Frame {
rsv3: bool, rsv3: bool,
opcode: OpCode, opcode: OpCode,
mask: Option<[u8; 4]>, mask: Option<[u8; 4]>,
payload: Vec<u8>, payload: Binary,
} }
impl Frame { impl Frame {
/// Desctructe frame /// Desctructe frame
pub fn unpack(self) -> (bool, OpCode, Vec<u8>) { pub fn unpack(self) -> (bool, OpCode, Binary) {
(self.finished, self.opcode, self.payload) (self.finished, self.opcode, self.payload)
} }
@ -55,11 +56,11 @@ impl Frame {
/// Create a new data frame. /// Create a new data frame.
#[inline] #[inline]
pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame { pub fn message<B: Into<Binary>>(data: B, code: OpCode, finished: bool) -> Frame {
Frame { Frame {
finished: finished, finished: finished,
opcode: code, opcode: code,
payload: data, payload: data.into(),
.. Frame::default() .. Frame::default()
} }
} }
@ -82,7 +83,7 @@ impl Frame {
}; };
Frame { Frame {
payload: payload, payload: payload.into(),
.. Frame::default() .. Frame::default()
} }
} }
@ -212,7 +213,7 @@ impl Frame {
rsv3: rsv3, rsv3: rsv3,
opcode: opcode, opcode: opcode,
mask: mask, mask: mask,
payload: data, payload: data.into(),
}; };
(frame, header_length + length) (frame, header_length + length)
@ -251,7 +252,7 @@ impl Frame {
if self.payload.len() < 126 { if self.payload.len() < 126 {
two |= self.payload.len() as u8; two |= self.payload.len() as u8;
let headers = [one, two]; let headers = [one, two];
try!(w.write_all(&headers)); w.write_all(&headers)?;
} else if self.payload.len() <= 65_535 { } else if self.payload.len() <= 65_535 {
two |= 126; two |= 126;
let length_bytes: [u8; 2] = unsafe { let length_bytes: [u8; 2] = unsafe {
@ -259,7 +260,7 @@ impl Frame {
mem::transmute(short.to_be()) mem::transmute(short.to_be())
}; };
let headers = [one, two, length_bytes[0], length_bytes[1]]; let headers = [one, two, length_bytes[0], length_bytes[1]];
try!(w.write_all(&headers)); w.write_all(&headers)?;
} else { } else {
two |= 127; two |= 127;
let length_bytes: [u8; 8] = unsafe { let length_bytes: [u8; 8] = unsafe {
@ -278,16 +279,18 @@ impl Frame {
length_bytes[6], length_bytes[6],
length_bytes[7], length_bytes[7],
]; ];
try!(w.write_all(&headers)); w.write_all(&headers)?;
} }
if self.mask.is_some() { if self.mask.is_some() {
let mask = self.mask.take().unwrap(); let mask = self.mask.take().unwrap();
apply_mask(&mut self.payload, &mask); let mut payload = Vec::from(self.payload.as_ref());
try!(w.write_all(&mask)); apply_mask(&mut payload, &mask);
w.write_all(&mask)?;
w.write_all(payload.as_ref())?;
} else {
w.write_all(self.payload.as_ref())?;
} }
try!(w.write_all(&self.payload));
Ok(()) Ok(())
} }
} }
@ -301,7 +304,7 @@ impl Default for Frame {
rsv3: false, rsv3: false,
opcode: OpCode::Close, opcode: OpCode::Close,
mask: None, mask: None,
payload: Vec::new(), payload: Binary::from(&b""[..]),
} }
} }
} }
@ -326,7 +329,8 @@ impl fmt::Display for Frame {
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(), self.len(),
self.payload.len(), self.payload.len(),
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>()) self.payload.as_ref().iter().map(
|byte| format!("{:x}", byte)).collect::<String>())
} }
} }
@ -343,7 +347,7 @@ mod tests {
println!("FRAME: {}", frame); println!("FRAME: {}", frame);
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, &b"1"[..]); assert_eq!(frame.payload.as_ref(), &b"1"[..]);
} }
#[test] #[test]
@ -365,7 +369,7 @@ mod tests {
let frame = Frame::parse(&mut buf).unwrap().unwrap(); let frame = Frame::parse(&mut buf).unwrap().unwrap();
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, &b"1234"[..]); assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
} }
#[test] #[test]
@ -378,7 +382,7 @@ mod tests {
let frame = Frame::parse(&mut buf).unwrap().unwrap(); let frame = Frame::parse(&mut buf).unwrap().unwrap();
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, &b"1234"[..]); assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
} }
#[test] #[test]
@ -390,7 +394,7 @@ mod tests {
let frame = Frame::parse(&mut buf).unwrap().unwrap(); let frame = Frame::parse(&mut buf).unwrap().unwrap();
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, vec![1u8]); assert_eq!(frame.payload, vec![1u8].into());
} }
#[test] #[test]