1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 16:02:59 +01:00

refactor ws to a websocket codec

This commit is contained in:
Nikolay Kim 2018-10-05 12:47:22 -07:00
parent 8c2244dd88
commit 5c0a2066cc
4 changed files with 238 additions and 398 deletions

View File

@ -17,8 +17,6 @@ pub enum Body {
/// Unspecified streaming response. Developer is responsible for setting /// Unspecified streaming response. Developer is responsible for setting
/// right `Content-Length` or `Transfer-Encoding` headers. /// right `Content-Length` or `Transfer-Encoding` headers.
Streaming(BodyStream), Streaming(BodyStream),
// /// Special body type for actor response.
// Actor(Box<ActorHttpContext>),
} }
/// Represents various types of binary body. /// Represents various types of binary body.

119
src/ws/codec.rs Normal file
View File

@ -0,0 +1,119 @@
use bytes::BytesMut;
use tokio_codec::{Decoder, Encoder};
use super::frame::Frame;
use super::proto::{CloseReason, OpCode};
use super::ProtocolError;
use body::Binary;
/// `WebSocket` Message
#[derive(Debug, PartialEq)]
pub enum Message {
/// Text message
Text(String),
/// Binary message
Binary(Binary),
/// Ping message
Ping(String),
/// Pong message
Pong(String),
/// Close message with optional reason
Close(Option<CloseReason>),
}
/// WebSockets protocol codec
pub struct Codec {
max_size: usize,
server: bool,
}
impl Codec {
/// Create new websocket frames decoder
pub fn new() -> Codec {
Codec {
max_size: 65_536,
server: true,
}
}
/// Set max frame size
///
/// By default max size is set to 64kb
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
/// Set decoder to client mode.
///
/// By default decoder works in server mode.
pub fn client_mode(mut self) -> Self {
self.server = false;
self
}
}
impl Encoder for Codec {
type Item = Message;
type Error = ProtocolError;
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Text(txt) => {
Frame::write_message(dst, txt, OpCode::Text, true, !self.server)
}
Message::Binary(bin) => {
Frame::write_message(dst, bin, OpCode::Binary, true, !self.server)
}
Message::Ping(txt) => {
Frame::write_message(dst, txt, OpCode::Ping, true, !self.server)
}
Message::Pong(txt) => {
Frame::write_message(dst, txt, OpCode::Pong, true, !self.server)
}
Message::Close(reason) => Frame::write_close(dst, reason, !self.server),
}
Ok(())
}
}
impl Decoder for Codec {
type Item = Message;
type Error = ProtocolError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match Frame::parse(src, self.server, self.max_size) {
Ok(Some((finished, opcode, payload))) => {
// continuation is not supported
if !finished {
return Err(ProtocolError::NoContinuation);
}
match opcode {
OpCode::Continue => Err(ProtocolError::NoContinuation),
OpCode::Bad => Err(ProtocolError::BadOpCode),
OpCode::Close => {
let close_reason = Frame::parse_close_payload(&payload);
Ok(Some(Message::Close(close_reason)))
}
OpCode::Ping => Ok(Some(Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into(),
))),
OpCode::Pong => Ok(Some(Message::Pong(
String::from_utf8_lossy(payload.as_ref()).into(),
))),
OpCode::Binary => Ok(Some(Message::Binary(payload))),
OpCode::Text => {
let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) {
Ok(s) => Ok(Some(Message::Text(s))),
Err(_) => Err(ProtocolError::BadEncoding),
}
}
}
}
Ok(None) => Ok(None),
Err(e) => Err(e),
}
}
}

View File

@ -1,144 +1,29 @@
use byteorder::{ByteOrder, LittleEndian, NetworkEndian}; use byteorder::{ByteOrder, LittleEndian, NetworkEndian};
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, BytesMut};
use futures::{Async, Poll, Stream};
use rand; use rand;
use std::fmt;
use body::Binary; use body::Binary;
use error::PayloadError;
use payload::PayloadBuffer;
use ws::mask::apply_mask; use ws::mask::apply_mask;
use ws::proto::{CloseCode, CloseReason, OpCode}; use ws::proto::{CloseCode, CloseReason, OpCode};
use ws::ProtocolError; use ws::ProtocolError;
/// A struct representing a `WebSocket` frame. /// A struct representing a `WebSocket` frame.
#[derive(Debug)] #[derive(Debug)]
pub struct Frame { pub struct Frame;
finished: bool,
opcode: OpCode,
payload: Binary,
}
impl Frame { impl Frame {
/// Destruct frame fn parse_metadata(
pub fn unpack(self) -> (bool, OpCode, Binary) { src: &[u8], server: bool, max_size: usize,
(self.finished, self.opcode, self.payload) ) -> Result<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> {
} let chunk_len = src.len();
/// Create a new Close control frame.
#[inline]
pub fn close(reason: Option<CloseReason>, genmask: bool) -> FramedMessage {
let payload = match reason {
None => Vec::new(),
Some(reason) => {
let mut code_bytes = [0; 2];
NetworkEndian::write_u16(&mut code_bytes, reason.code.into());
let mut payload = Vec::from(&code_bytes[..]);
if let Some(description) = reason.description {
payload.extend(description.as_bytes());
}
payload
}
};
Frame::message(payload, OpCode::Close, true, genmask)
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
fn read_copy_md<S>(
pl: &mut PayloadBuffer<S>, server: bool, max_size: usize,
) -> Poll<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
let mut idx = 2;
let buf = match pl.copy(2)? {
Async::Ready(Some(buf)) => buf,
Async::Ready(None) => return Ok(Async::Ready(None)),
Async::NotReady => return Ok(Async::NotReady),
};
let first = buf[0];
let second = buf[1];
let finished = first & 0x80 != 0;
// check masking
let masked = second & 0x80 != 0;
if !masked && server {
return Err(ProtocolError::UnmaskedFrame);
} else if masked && !server {
return Err(ProtocolError::MaskedFrame);
}
// Op code
let opcode = OpCode::from(first & 0x0F);
if let OpCode::Bad = opcode {
return Err(ProtocolError::InvalidOpcode(first & 0x0F));
}
let len = second & 0x7F;
let length = if len == 126 {
let buf = match pl.copy(4)? {
Async::Ready(Some(buf)) => buf,
Async::Ready(None) => return Ok(Async::Ready(None)),
Async::NotReady => return Ok(Async::NotReady),
};
let len = NetworkEndian::read_uint(&buf[idx..], 2) as usize;
idx += 2;
len
} else if len == 127 {
let buf = match pl.copy(10)? {
Async::Ready(Some(buf)) => buf,
Async::Ready(None) => return Ok(Async::Ready(None)),
Async::NotReady => return Ok(Async::NotReady),
};
let len = NetworkEndian::read_uint(&buf[idx..], 8);
if len > max_size as u64 {
return Err(ProtocolError::Overflow);
}
idx += 8;
len as usize
} else {
len as usize
};
// check for max allowed size
if length > max_size {
return Err(ProtocolError::Overflow);
}
let mask = if server {
let buf = match pl.copy(idx + 4)? {
Async::Ready(Some(buf)) => buf,
Async::Ready(None) => return Ok(Async::Ready(None)),
Async::NotReady => return Ok(Async::NotReady),
};
let mask: &[u8] = &buf[idx..idx + 4];
let mask_u32 = LittleEndian::read_u32(mask);
idx += 4;
Some(mask_u32)
} else {
None
};
Ok(Async::Ready(Some((idx, finished, opcode, length, mask))))
}
fn read_chunk_md(
chunk: &[u8], server: bool, max_size: usize,
) -> Poll<(usize, bool, OpCode, usize, Option<u32>), ProtocolError> {
let chunk_len = chunk.len();
let mut idx = 2; let mut idx = 2;
if chunk_len < 2 { if chunk_len < 2 {
return Ok(Async::NotReady); return Ok(None);
} }
let first = chunk[0]; let first = src[0];
let second = chunk[1]; let second = src[1];
let finished = first & 0x80 != 0; let finished = first & 0x80 != 0;
// check masking // check masking
@ -159,16 +44,16 @@ impl Frame {
let len = second & 0x7F; let len = second & 0x7F;
let length = if len == 126 { let length = if len == 126 {
if chunk_len < 4 { if chunk_len < 4 {
return Ok(Async::NotReady); return Ok(None);
} }
let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize; let len = NetworkEndian::read_uint(&src[idx..], 2) as usize;
idx += 2; idx += 2;
len len
} else if len == 127 { } else if len == 127 {
if chunk_len < 10 { if chunk_len < 10 {
return Ok(Async::NotReady); return Ok(None);
} }
let len = NetworkEndian::read_uint(&chunk[idx..], 8); let len = NetworkEndian::read_uint(&src[idx..], 8);
if len > max_size as u64 { if len > max_size as u64 {
return Err(ProtocolError::Overflow); return Err(ProtocolError::Overflow);
} }
@ -185,10 +70,10 @@ impl Frame {
let mask = if server { let mask = if server {
if chunk_len < idx + 4 { if chunk_len < idx + 4 {
return Ok(Async::NotReady); return Ok(None);
} }
let mask: &[u8] = &chunk[idx..idx + 4]; let mask: &[u8] = &src[idx..idx + 4];
let mask_u32 = LittleEndian::read_u32(mask); let mask_u32 = LittleEndian::read_u32(mask);
idx += 4; idx += 4;
Some(mask_u32) Some(mask_u32)
@ -196,56 +81,34 @@ impl Frame {
None None
}; };
Ok(Async::Ready((idx, finished, opcode, length, mask))) Ok(Some((idx, finished, opcode, length, mask)))
} }
/// Parse the input stream into a frame. /// Parse the input stream into a frame.
pub fn parse<S>( pub fn parse(
pl: &mut PayloadBuffer<S>, server: bool, max_size: usize, src: &mut BytesMut, server: bool, max_size: usize,
) -> Poll<Option<Frame>, ProtocolError> ) -> Result<Option<(bool, OpCode, Binary)>, ProtocolError> {
where // try to parse ws frame metadata
S: Stream<Item = Bytes, Error = PayloadError>, let (idx, finished, opcode, length, mask) =
{ match Frame::parse_metadata(src, server, max_size)? {
// try to parse ws frame md from one chunk None => return Ok(None),
let result = match pl.get_chunk()? { Some(res) => res,
Async::NotReady => return Ok(Async::NotReady), };
Async::Ready(None) => return Ok(Async::Ready(None)),
Async::Ready(Some(chunk)) => Frame::read_chunk_md(chunk, server, max_size)?,
};
let (idx, finished, opcode, length, mask) = match result { // not enough data
// we may need to join several chunks if src.len() < idx + length {
Async::NotReady => match Frame::read_copy_md(pl, server, max_size)? { return Ok(None);
Async::Ready(Some(item)) => item,
Async::NotReady => return Ok(Async::NotReady),
Async::Ready(None) => return Ok(Async::Ready(None)),
},
Async::Ready(item) => item,
};
match pl.can_read(idx + length)? {
Async::Ready(Some(true)) => (),
Async::Ready(None) => return Ok(Async::Ready(None)),
Async::Ready(Some(false)) | Async::NotReady => return Ok(Async::NotReady),
} }
// remove prefix // remove prefix
pl.drop_bytes(idx); src.split_to(idx);
// no need for body // no need for body
if length == 0 { if length == 0 {
return Ok(Async::Ready(Some(Frame { return Ok(Some((finished, opcode, Binary::from(""))));
finished,
opcode,
payload: Binary::from(""),
})));
} }
let data = match pl.read_exact(length)? { let mut data = src.split_to(length);
Async::Ready(Some(buf)) => buf,
Async::Ready(None) => return Ok(Async::Ready(None)),
Async::NotReady => panic!(),
};
// control frames must have length <= 125 // control frames must have length <= 125
match opcode { match opcode {
@ -254,26 +117,17 @@ impl Frame {
} }
OpCode::Close if length > 125 => { OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
return Ok(Async::Ready(Some(Frame::default()))); return Ok(Some((true, OpCode::Close, Binary::from(""))));
} }
_ => (), _ => (),
} }
// unmask // unmask
let data = if let Some(mask) = mask { if let Some(mask) = mask {
let mut buf = BytesMut::new(); apply_mask(&mut data, mask);
buf.extend_from_slice(&data); }
apply_mask(&mut buf, mask);
buf.freeze()
} else {
data
};
Ok(Async::Ready(Some(Frame { Ok(Some((finished, opcode, data.into())))
finished,
opcode,
payload: data.into(),
})))
} }
/// Parse the payload of a close frame. /// Parse the payload of a close frame.
@ -293,120 +147,101 @@ impl Frame {
} }
/// Generate binary representation /// Generate binary representation
pub fn message<B: Into<Binary>>( pub fn write_message<B: Into<Binary>>(
data: B, code: OpCode, finished: bool, genmask: bool, dst: &mut BytesMut, pl: B, op: OpCode, fin: bool, mask: bool,
) -> FramedMessage { ) {
let payload = data.into(); let payload = pl.into();
let one: u8 = if finished { let one: u8 = if fin {
0x80 | Into::<u8>::into(code) 0x80 | Into::<u8>::into(op)
} else { } else {
code.into() op.into()
}; };
let payload_len = payload.len(); let payload_len = payload.len();
let (two, p_len) = if genmask { let (two, p_len) = if mask {
(0x80, payload_len + 4) (0x80, payload_len + 4)
} else { } else {
(0, payload_len) (0, payload_len)
}; };
let mut buf = if payload_len < 126 { if payload_len < 126 {
let mut buf = BytesMut::with_capacity(p_len + 2); dst.put_slice(&[one, two | payload_len as u8]);
buf.put_slice(&[one, two | payload_len as u8]);
buf
} else if payload_len <= 65_535 { } else if payload_len <= 65_535 {
let mut buf = BytesMut::with_capacity(p_len + 4); dst.reserve(p_len + 4);
buf.put_slice(&[one, two | 126]); dst.put_slice(&[one, two | 126]);
buf.put_u16_be(payload_len as u16); dst.put_u16_be(payload_len as u16);
buf
} else { } else {
let mut buf = BytesMut::with_capacity(p_len + 10); dst.reserve(p_len + 10);
buf.put_slice(&[one, two | 127]); dst.put_slice(&[one, two | 127]);
buf.put_u64_be(payload_len as u64); dst.put_u64_be(payload_len as u64);
buf
}; };
let binary = if genmask { if mask {
let mask = rand::random::<u32>(); let mask = rand::random::<u32>();
buf.put_u32_le(mask); dst.put_u32_le(mask);
buf.extend_from_slice(payload.as_ref()); dst.extend_from_slice(payload.as_ref());
let pos = buf.len() - payload_len; let pos = dst.len() - payload_len;
apply_mask(&mut buf[pos..], mask); apply_mask(&mut dst[pos..], mask);
buf.into()
} else { } else {
buf.put_slice(payload.as_ref()); dst.put_slice(payload.as_ref());
buf.into()
};
FramedMessage(binary)
}
}
impl Default for Frame {
fn default() -> Frame {
Frame {
finished: true,
opcode: OpCode::Close,
payload: Binary::from(&b""[..]),
} }
} }
}
impl fmt::Display for Frame { /// Create a new Close control frame.
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { #[inline]
write!( pub fn write_close(dst: &mut BytesMut, reason: Option<CloseReason>, mask: bool) {
f, let payload = match reason {
" None => Vec::new(),
<FRAME> Some(reason) => {
final: {} let mut code_bytes = [0; 2];
opcode: {} NetworkEndian::write_u16(&mut code_bytes, reason.code.into());
payload length: {}
payload: 0x{} let mut payload = Vec::from(&code_bytes[..]);
</FRAME>", if let Some(description) = reason.description {
self.finished, payload.extend(description.as_bytes());
self.opcode, }
self.payload.len(), payload
self.payload }
.as_ref() };
.iter()
.map(|byte| format!("{:x}", byte)) Frame::write_message(dst, payload, OpCode::Close, true, mask)
.collect::<String>()
)
} }
} }
/// `WebSocket` message with framing.
#[derive(Debug)]
pub struct FramedMessage(pub(crate) Binary);
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures::stream::once;
fn is_none(frm: &Poll<Option<Frame>, ProtocolError>) -> bool { struct F {
finished: bool,
opcode: OpCode,
payload: Binary,
}
fn is_none(frm: &Result<Option<(bool, OpCode, Binary)>, ProtocolError>) -> bool {
match *frm { match *frm {
Ok(Async::Ready(None)) => true, Ok(None) => true,
_ => false, _ => false,
} }
} }
fn extract(frm: Poll<Option<Frame>, ProtocolError>) -> Frame { fn extract(frm: Result<Option<(bool, OpCode, Binary)>, ProtocolError>) -> F {
match frm { match frm {
Ok(Async::Ready(Some(frame))) => frame, Ok(Some((finished, opcode, payload))) => F {
finished,
opcode,
payload,
},
_ => unreachable!("error"), _ => unreachable!("error"),
} }
} }
#[test] #[test]
fn test_parse() { fn test_parse() {
let mut buf = PayloadBuffer::new(once(Ok(BytesMut::from( let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
&[0b0000_0001u8, 0b0000_0001u8][..],
).freeze())));
assert!(is_none(&Frame::parse(&mut buf, false, 1024))); assert!(is_none(&Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
buf.extend(b"1"); buf.extend(b"1");
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
@ -416,9 +251,7 @@ mod tests {
#[test] #[test]
fn test_parse_length0() { fn test_parse_length0() {
let buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
@ -427,14 +260,12 @@ mod tests {
#[test] #[test]
fn test_parse_length2() { fn test_parse_length2() {
let buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
assert!(is_none(&Frame::parse(&mut buf, false, 1024))); assert!(is_none(&Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
buf.extend(&[0u8, 4u8][..]); buf.extend(&[0u8, 4u8][..]);
buf.extend(b"1234"); buf.extend(b"1234");
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
@ -444,14 +275,12 @@ mod tests {
#[test] #[test]
fn test_parse_length4() { fn test_parse_length4() {
let buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
assert!(is_none(&Frame::parse(&mut buf, false, 1024))); assert!(is_none(&Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
buf.extend(b"1234"); buf.extend(b"1234");
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
@ -464,7 +293,6 @@ mod tests {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]);
buf.extend(b"0001"); buf.extend(b"0001");
buf.extend(b"1"); buf.extend(b"1");
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
assert!(Frame::parse(&mut buf, false, 1024).is_err()); assert!(Frame::parse(&mut buf, false, 1024).is_err());
@ -478,7 +306,6 @@ mod tests {
fn test_parse_frame_no_mask() { fn test_parse_frame_no_mask() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
buf.extend(&[1u8]); buf.extend(&[1u8]);
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
assert!(Frame::parse(&mut buf, true, 1024).is_err()); assert!(Frame::parse(&mut buf, true, 1024).is_err());
@ -492,7 +319,6 @@ mod tests {
fn test_parse_frame_max_size() { fn test_parse_frame_max_size() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
buf.extend(&[1u8, 1u8]); buf.extend(&[1u8, 1u8]);
let mut buf = PayloadBuffer::new(once(Ok(buf.freeze())));
assert!(Frame::parse(&mut buf, true, 1).is_err()); assert!(Frame::parse(&mut buf, true, 1).is_err());
@ -504,35 +330,39 @@ mod tests {
#[test] #[test]
fn test_ping_frame() { fn test_ping_frame() {
let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false); let mut buf = BytesMut::new();
Frame::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
let mut v = vec![137u8, 4u8]; let mut v = vec![137u8, 4u8];
v.extend(b"data"); v.extend(b"data");
assert_eq!(frame.0, v.into()); assert_eq!(&buf[..], &v[..]);
} }
#[test] #[test]
fn test_pong_frame() { fn test_pong_frame() {
let frame = Frame::message(Vec::from("data"), OpCode::Pong, true, false); let mut buf = BytesMut::new();
Frame::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
let mut v = vec![138u8, 4u8]; let mut v = vec![138u8, 4u8];
v.extend(b"data"); v.extend(b"data");
assert_eq!(frame.0, v.into()); assert_eq!(&buf[..], &v[..]);
} }
#[test] #[test]
fn test_close_frame() { fn test_close_frame() {
let mut buf = BytesMut::new();
let reason = (CloseCode::Normal, "data"); let reason = (CloseCode::Normal, "data");
let frame = Frame::close(Some(reason.into()), false); Frame::write_close(&mut buf, Some(reason.into()), false);
let mut v = vec![136u8, 6u8, 3u8, 232u8]; let mut v = vec![136u8, 6u8, 3u8, 232u8];
v.extend(b"data"); v.extend(b"data");
assert_eq!(frame.0, v.into()); assert_eq!(&buf[..], &v[..]);
} }
#[test] #[test]
fn test_empty_close_frame() { fn test_empty_close_frame() {
let frame = Frame::close(None, false); let mut buf = BytesMut::new();
assert_eq!(frame.0, vec![0x88, 0x00].into()); Frame::write_close(&mut buf, None, false);
assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
} }
} }

View File

@ -1,24 +1,23 @@
//! `WebSocket` support. //! WebSocket protocol support.
//! //!
//! To setup a `WebSocket`, first do web socket handshake then on success //! To setup a `WebSocket`, first do web socket handshake then on success
//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to //! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
//! communicate with the peer. //! communicate with the peer.
//! ``` //! ```
use bytes::Bytes; use std::io;
use futures::{Async, Poll, Stream};
use http::{header, Method, StatusCode};
use body::Binary; use error::ResponseError;
use error::{PayloadError, ResponseError}; use http::{header, Method, StatusCode};
use payload::PayloadBuffer;
use request::Request; use request::Request;
use response::{ConnectionType, Response, ResponseBuilder}; use response::{ConnectionType, Response, ResponseBuilder};
mod codec;
mod frame; mod frame;
mod mask; mod mask;
mod proto; mod proto;
pub use self::frame::{Frame, FramedMessage}; pub use self::codec::Message;
pub use self::frame::Frame;
pub use self::proto::{CloseCode, CloseReason, OpCode}; pub use self::proto::{CloseCode, CloseReason, OpCode};
/// Websocket protocol errors /// Websocket protocol errors
@ -48,16 +47,16 @@ pub enum ProtocolError {
/// Bad utf-8 encoding /// Bad utf-8 encoding
#[fail(display = "Bad utf-8 encoding.")] #[fail(display = "Bad utf-8 encoding.")]
BadEncoding, BadEncoding,
/// Payload error /// Io error
#[fail(display = "Payload error: {}", _0)] #[fail(display = "io error: {}", _0)]
Payload(#[cause] PayloadError), Io(#[cause] io::Error),
} }
impl ResponseError for ProtocolError {} impl ResponseError for ProtocolError {}
impl From<PayloadError> for ProtocolError { impl From<io::Error> for ProtocolError {
fn from(err: PayloadError) -> ProtocolError { fn from(err: io::Error) -> ProtocolError {
ProtocolError::Payload(err) ProtocolError::Io(err)
} }
} }
@ -109,21 +108,6 @@ impl ResponseError for HandshakeError {
} }
} }
/// `WebSocket` Message
#[derive(Debug, PartialEq)]
pub enum Message {
/// Text message
Text(String),
/// Binary message
Binary(Binary),
/// Ping message
Ping(String),
/// Pong message
Pong(String),
/// Close message with optional reason
Close(Option<CloseReason>),
}
/// Prepare `WebSocket` handshake response. /// Prepare `WebSocket` handshake response.
/// ///
/// This function returns handshake `Response`, ready to send to peer. /// This function returns handshake `Response`, ready to send to peer.
@ -189,97 +173,6 @@ pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
.take()) .take())
} }
/// Maps `Payload` stream into stream of `ws::Message` items
pub struct WsStream<S> {
rx: PayloadBuffer<S>,
closed: bool,
max_size: usize,
}
impl<S> WsStream<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create new websocket frames stream
pub fn new(stream: S) -> WsStream<S> {
WsStream {
rx: PayloadBuffer::new(stream),
closed: false,
max_size: 65_536,
}
}
/// Set max frame size
///
/// By default max size is set to 64kb
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
}
impl<S> Stream for WsStream<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = Message;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.closed {
return Ok(Async::Ready(None));
}
match Frame::parse(&mut self.rx, true, self.max_size) {
Ok(Async::Ready(Some(frame))) => {
let (finished, opcode, payload) = frame.unpack();
// continuation is not supported
if !finished {
self.closed = true;
return Err(ProtocolError::NoContinuation);
}
match opcode {
OpCode::Continue => Err(ProtocolError::NoContinuation),
OpCode::Bad => {
self.closed = true;
Err(ProtocolError::BadOpCode)
}
OpCode::Close => {
self.closed = true;
let close_reason = Frame::parse_close_payload(&payload);
Ok(Async::Ready(Some(Message::Close(close_reason))))
}
OpCode::Ping => Ok(Async::Ready(Some(Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into(),
)))),
OpCode::Pong => Ok(Async::Ready(Some(Message::Pong(
String::from_utf8_lossy(payload.as_ref()).into(),
)))),
OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))),
OpCode::Text => {
let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) {
Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => {
self.closed = true;
Err(ProtocolError::BadEncoding)
}
}
}
}
}
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
self.closed = true;
Err(e)
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;