mirror of
https://github.com/fafhrd91/actix-web
synced 2025-03-26 08:53:16 +01:00
554 lines
17 KiB
Rust
554 lines
17 KiB
Rust
use bitflags::bitflags;
|
|
use bytes::{Bytes, BytesMut};
|
|
use bytestring::ByteString;
|
|
use tokio_util::codec;
|
|
use tracing::error;
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
use super::deflate::{
|
|
DeflateCompressionContext, DeflateContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG,
|
|
};
|
|
use super::{
|
|
frame::Parser,
|
|
proto::{CloseReason, OpCode, RsvBits},
|
|
ProtocolError,
|
|
};
|
|
|
|
/// A WebSocket message.
|
|
#[derive(Debug, PartialEq, Eq)]
|
|
pub enum Message {
|
|
/// Text message.
|
|
Text(ByteString),
|
|
|
|
/// Binary message.
|
|
Binary(Bytes),
|
|
|
|
/// Continuation.
|
|
Continuation(Item),
|
|
|
|
/// Ping message.
|
|
Ping(Bytes),
|
|
|
|
/// Pong message.
|
|
Pong(Bytes),
|
|
|
|
/// Close message with optional reason.
|
|
Close(Option<CloseReason>),
|
|
|
|
/// No-op. Useful for low-level services.
|
|
Nop,
|
|
}
|
|
|
|
/// A WebSocket frame.
|
|
#[derive(Debug, PartialEq, Eq)]
|
|
pub enum Frame {
|
|
/// Text frame. Note that the codec does not validate UTF-8 encoding.
|
|
Text(Bytes),
|
|
|
|
/// Binary frame.
|
|
Binary(Bytes),
|
|
|
|
/// Continuation.
|
|
Continuation(Item),
|
|
|
|
/// Ping message.
|
|
Ping(Bytes),
|
|
|
|
/// Pong message.
|
|
Pong(Bytes),
|
|
|
|
/// Close message with optional reason.
|
|
Close(Option<CloseReason>),
|
|
}
|
|
|
|
/// A WebSocket continuation item.
|
|
#[derive(Debug, PartialEq, Eq)]
|
|
pub enum Item {
|
|
FirstText(Bytes),
|
|
FirstBinary(Bytes),
|
|
Continue(Bytes),
|
|
Last(Bytes),
|
|
}
|
|
|
|
bitflags! {
|
|
#[derive(Debug, Clone, Copy)]
|
|
struct Flags: u8 {
|
|
const SERVER = 0b0000_0001;
|
|
const CONTINUATION = 0b0000_0010;
|
|
const W_CONTINUATION = 0b0000_0100;
|
|
}
|
|
}
|
|
|
|
/// WebSocket message encoder.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Encoder {
|
|
flags: Flags,
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
deflate_compress: Option<DeflateCompressionContext>,
|
|
}
|
|
|
|
impl Encoder {
|
|
/// Create new WebSocket frames encoder.
|
|
pub const fn new() -> Encoder {
|
|
Encoder {
|
|
flags: Flags::SERVER,
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
deflate_compress: None,
|
|
}
|
|
}
|
|
|
|
/// Create new WebSocket frames encoder with `permessage-deflate` extension support.
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
pub fn new_deflate(compress: DeflateCompressionContext) -> Encoder {
|
|
Encoder {
|
|
flags: Flags::SERVER,
|
|
|
|
deflate_compress: Some(compress),
|
|
}
|
|
}
|
|
|
|
fn set_client_mode(mut self) -> Self {
|
|
self.flags = Flags::empty();
|
|
self
|
|
}
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
fn set_client_mode_deflate(
|
|
mut self,
|
|
remote_no_context_takeover: bool,
|
|
remote_max_window_bits: u8,
|
|
) -> Self {
|
|
self.deflate_compress = self
|
|
.deflate_compress
|
|
.map(|c| c.reset_with(remote_no_context_takeover, remote_max_window_bits));
|
|
self
|
|
}
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
fn process_payload(
|
|
&mut self,
|
|
fin: bool,
|
|
bytes: Bytes,
|
|
) -> Result<(Bytes, RsvBits), ProtocolError> {
|
|
if let Some(compress) = &mut self.deflate_compress {
|
|
Ok((compress.compress(fin, bytes)?, RSV_BIT_DEFLATE_FLAG))
|
|
} else {
|
|
Ok((bytes, RsvBits::empty()))
|
|
}
|
|
}
|
|
|
|
#[cfg(not(feature = "compress-ws-deflate"))]
|
|
fn process_payload(
|
|
&mut self,
|
|
_fin: bool,
|
|
bytes: Bytes,
|
|
) -> Result<(Bytes, RsvBits), ProtocolError> {
|
|
Ok((bytes, RsvBits::empty()))
|
|
}
|
|
}
|
|
|
|
impl Default for Encoder {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl codec::Encoder<Message> for Encoder {
|
|
type Error = ProtocolError;
|
|
|
|
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
|
match item {
|
|
Message::Text(txt) => {
|
|
let (bytes, rsv_bits) = self.process_payload(true, txt.into_bytes())?;
|
|
|
|
Parser::write_message(
|
|
dst,
|
|
bytes,
|
|
OpCode::Text,
|
|
rsv_bits,
|
|
true,
|
|
!self.flags.contains(Flags::SERVER),
|
|
)
|
|
}
|
|
Message::Binary(bin) => {
|
|
let (bin, rsv_bits) = self.process_payload(true, bin)?;
|
|
|
|
Parser::write_message(
|
|
dst,
|
|
bin,
|
|
OpCode::Binary,
|
|
rsv_bits,
|
|
true,
|
|
!self.flags.contains(Flags::SERVER),
|
|
)
|
|
}
|
|
Message::Ping(txt) => Parser::write_message(
|
|
dst,
|
|
txt,
|
|
OpCode::Ping,
|
|
RsvBits::empty(),
|
|
true,
|
|
!self.flags.contains(Flags::SERVER),
|
|
),
|
|
Message::Pong(txt) => Parser::write_message(
|
|
dst,
|
|
txt,
|
|
OpCode::Pong,
|
|
RsvBits::empty(),
|
|
true,
|
|
!self.flags.contains(Flags::SERVER),
|
|
),
|
|
Message::Close(reason) => Parser::write_close(
|
|
dst,
|
|
reason,
|
|
RsvBits::empty(),
|
|
!self.flags.contains(Flags::SERVER),
|
|
),
|
|
Message::Continuation(cont) => match cont {
|
|
Item::FirstText(data) => {
|
|
if self.flags.contains(Flags::W_CONTINUATION) {
|
|
return Err(ProtocolError::ContinuationStarted);
|
|
} else {
|
|
let (data, rsv_bits) = self.process_payload(false, data)?;
|
|
|
|
self.flags.insert(Flags::W_CONTINUATION);
|
|
Parser::write_message(
|
|
dst,
|
|
data,
|
|
OpCode::Text,
|
|
rsv_bits,
|
|
false,
|
|
!self.flags.contains(Flags::SERVER),
|
|
)
|
|
}
|
|
}
|
|
Item::FirstBinary(data) => {
|
|
if self.flags.contains(Flags::W_CONTINUATION) {
|
|
return Err(ProtocolError::ContinuationStarted);
|
|
} else {
|
|
let (data, rsv_bits) = self.process_payload(false, data)?;
|
|
|
|
self.flags.insert(Flags::W_CONTINUATION);
|
|
Parser::write_message(
|
|
dst,
|
|
data,
|
|
OpCode::Binary,
|
|
rsv_bits,
|
|
false,
|
|
!self.flags.contains(Flags::SERVER),
|
|
)
|
|
}
|
|
}
|
|
Item::Continue(data) => {
|
|
if self.flags.contains(Flags::W_CONTINUATION) {
|
|
let (data, rsv_bits) = self.process_payload(false, data)?;
|
|
|
|
Parser::write_message(
|
|
dst,
|
|
data,
|
|
OpCode::Continue,
|
|
rsv_bits,
|
|
false,
|
|
!self.flags.contains(Flags::SERVER),
|
|
)
|
|
} else {
|
|
return Err(ProtocolError::ContinuationNotStarted);
|
|
}
|
|
}
|
|
Item::Last(data) => {
|
|
if self.flags.contains(Flags::W_CONTINUATION) {
|
|
self.flags.remove(Flags::W_CONTINUATION);
|
|
|
|
let (data, rsv_bits) = self.process_payload(true, data)?;
|
|
|
|
Parser::write_message(
|
|
dst,
|
|
data,
|
|
OpCode::Continue,
|
|
rsv_bits,
|
|
true,
|
|
!self.flags.contains(Flags::SERVER),
|
|
)
|
|
} else {
|
|
return Err(ProtocolError::ContinuationNotStarted);
|
|
}
|
|
}
|
|
},
|
|
Message::Nop => {}
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// WebSocket message decoder.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Decoder {
|
|
flags: Flags,
|
|
max_size: usize,
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
deflate_decompress: Option<DeflateDecompressionContext>,
|
|
}
|
|
|
|
impl Decoder {
|
|
/// Create new WebSocket frames decoder.
|
|
pub const fn new() -> Decoder {
|
|
Decoder {
|
|
flags: Flags::SERVER,
|
|
max_size: 65_536,
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
deflate_decompress: None,
|
|
}
|
|
}
|
|
|
|
/// Create new WebSocket frames decoder with `permessage-deflate` extension support.
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
pub fn new_deflate(decompress: DeflateDecompressionContext) -> Decoder {
|
|
Decoder {
|
|
flags: Flags::SERVER,
|
|
max_size: 65_536,
|
|
|
|
deflate_decompress: Some(decompress),
|
|
}
|
|
}
|
|
|
|
fn set_client_mode(mut self) -> Self {
|
|
self.flags = Flags::empty();
|
|
self
|
|
}
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
fn set_client_mode_deflate(
|
|
mut self,
|
|
local_no_context_takeover: bool,
|
|
local_max_window_bits: u8,
|
|
) -> Self {
|
|
if let Some(decompress) = &mut self.deflate_decompress {
|
|
decompress.reset_with(local_no_context_takeover, local_max_window_bits);
|
|
}
|
|
|
|
self
|
|
}
|
|
|
|
fn set_max_size(mut self, size: usize) -> Self {
|
|
self.max_size = size;
|
|
self
|
|
}
|
|
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
fn process_payload(
|
|
&mut self,
|
|
fin: bool,
|
|
opcode: OpCode,
|
|
rsv_bits: RsvBits,
|
|
bytes: Option<Bytes>,
|
|
) -> Result<Option<Bytes>, ProtocolError> {
|
|
if let Some(bytes) = bytes {
|
|
if let Some(decompress) = &mut self.deflate_decompress {
|
|
Ok(Some(decompress.decompress(fin, opcode, rsv_bits, bytes)?))
|
|
} else {
|
|
Ok(Some(bytes))
|
|
}
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
#[cfg(not(feature = "compress-ws-deflate"))]
|
|
fn process_payload(
|
|
&mut self,
|
|
_fin: bool,
|
|
_opcode: OpCode,
|
|
_rsv_bits: RsvBits,
|
|
bytes: Option<Bytes>,
|
|
) -> Result<Option<Bytes>, ProtocolError> {
|
|
Ok(bytes)
|
|
}
|
|
}
|
|
|
|
impl Default for Decoder {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl codec::Decoder for Decoder {
|
|
type Item = Frame;
|
|
type Error = ProtocolError;
|
|
|
|
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
|
match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
|
|
Ok(Some((finished, opcode, rsv_bits, payload))) => {
|
|
let payload = self.process_payload(
|
|
finished,
|
|
opcode,
|
|
rsv_bits,
|
|
payload.map(BytesMut::freeze),
|
|
)?;
|
|
|
|
// continuation is not supported
|
|
if !finished {
|
|
return match opcode {
|
|
OpCode::Continue => {
|
|
if self.flags.contains(Flags::CONTINUATION) {
|
|
Ok(Some(Frame::Continuation(Item::Continue(
|
|
payload.unwrap_or_else(Bytes::new),
|
|
))))
|
|
} else {
|
|
Err(ProtocolError::ContinuationNotStarted)
|
|
}
|
|
}
|
|
OpCode::Binary => {
|
|
if !self.flags.contains(Flags::CONTINUATION) {
|
|
self.flags.insert(Flags::CONTINUATION);
|
|
Ok(Some(Frame::Continuation(Item::FirstBinary(
|
|
payload.unwrap_or_else(Bytes::new),
|
|
))))
|
|
} else {
|
|
Err(ProtocolError::ContinuationStarted)
|
|
}
|
|
}
|
|
OpCode::Text => {
|
|
if !self.flags.contains(Flags::CONTINUATION) {
|
|
self.flags.insert(Flags::CONTINUATION);
|
|
Ok(Some(Frame::Continuation(Item::FirstText(
|
|
payload.unwrap_or_else(Bytes::new),
|
|
))))
|
|
} else {
|
|
Err(ProtocolError::ContinuationStarted)
|
|
}
|
|
}
|
|
_ => {
|
|
error!("Unfinished fragment {:?}", opcode);
|
|
Err(ProtocolError::ContinuationFragment(opcode))
|
|
}
|
|
};
|
|
}
|
|
|
|
match opcode {
|
|
OpCode::Continue => {
|
|
if self.flags.contains(Flags::CONTINUATION) {
|
|
self.flags.remove(Flags::CONTINUATION);
|
|
Ok(Some(Frame::Continuation(Item::Last(
|
|
payload.unwrap_or_else(Bytes::new),
|
|
))))
|
|
} else {
|
|
Err(ProtocolError::ContinuationNotStarted)
|
|
}
|
|
}
|
|
OpCode::Bad => Err(ProtocolError::BadOpCode),
|
|
OpCode::Close => {
|
|
if let Some(ref pl) = payload {
|
|
let close_reason = Parser::parse_close_payload(pl);
|
|
Ok(Some(Frame::Close(close_reason)))
|
|
} else {
|
|
Ok(Some(Frame::Close(None)))
|
|
}
|
|
}
|
|
OpCode::Ping => Ok(Some(Frame::Ping(payload.unwrap_or_else(Bytes::new)))),
|
|
OpCode::Pong => Ok(Some(Frame::Pong(payload.unwrap_or_else(Bytes::new)))),
|
|
OpCode::Binary => Ok(Some(Frame::Binary(payload.unwrap_or_else(Bytes::new)))),
|
|
OpCode::Text => Ok(Some(Frame::Text(payload.unwrap_or_else(Bytes::new)))),
|
|
}
|
|
}
|
|
Ok(None) => Ok(None),
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// WebSocket protocol codec.
|
|
#[derive(Debug, Default, Clone)]
|
|
pub struct Codec {
|
|
encoder: Encoder,
|
|
decoder: Decoder,
|
|
}
|
|
|
|
impl Codec {
|
|
/// Create new WebSocket frames codec.
|
|
pub fn new() -> Codec {
|
|
Codec {
|
|
encoder: Encoder::new(),
|
|
decoder: Decoder::new(),
|
|
}
|
|
}
|
|
|
|
/// Create new WebSocket frames codec with DEFLATE compression.
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
pub fn new_deflate(context: DeflateContext) -> Codec {
|
|
let DeflateContext {
|
|
compress,
|
|
decompress,
|
|
} = context;
|
|
|
|
Codec {
|
|
encoder: Encoder::new_deflate(compress),
|
|
decoder: Decoder::new_deflate(decompress),
|
|
}
|
|
}
|
|
|
|
/// Set max frame size.
|
|
///
|
|
/// By default max size is set to 64KiB.
|
|
#[must_use = "This returns the a new Codec, without modifying the original."]
|
|
pub fn max_size(self, size: usize) -> Self {
|
|
let Self { encoder, decoder } = self;
|
|
|
|
Codec {
|
|
encoder,
|
|
decoder: decoder.set_max_size(size),
|
|
}
|
|
}
|
|
|
|
/// Set decoder to client mode.
|
|
///
|
|
/// By default decoder works in server mode.
|
|
#[must_use = "This returns the a new Codec, without modifying the original."]
|
|
pub fn client_mode(self) -> Self {
|
|
let Self {
|
|
mut encoder,
|
|
mut decoder,
|
|
} = self;
|
|
|
|
encoder = encoder.set_client_mode();
|
|
decoder = decoder.set_client_mode();
|
|
#[cfg(feature = "compress-ws-deflate")]
|
|
{
|
|
if let Some(decoder) = &decoder.deflate_decompress {
|
|
encoder = encoder.set_client_mode_deflate(
|
|
decoder.local_no_context_takeover,
|
|
decoder.local_max_window_bits,
|
|
);
|
|
}
|
|
if let Some(encoder) = &encoder.deflate_compress {
|
|
decoder = decoder.set_client_mode_deflate(
|
|
encoder.remote_no_context_takeover,
|
|
encoder.remote_max_window_bits,
|
|
);
|
|
}
|
|
}
|
|
|
|
Self { encoder, decoder }
|
|
}
|
|
}
|
|
|
|
impl codec::Decoder for Codec {
|
|
type Item = Frame;
|
|
type Error = ProtocolError;
|
|
|
|
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
|
self.decoder.decode(src)
|
|
}
|
|
}
|
|
|
|
impl codec::Encoder<Message> for Codec {
|
|
type Error = ProtocolError;
|
|
|
|
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
|
self.encoder.encode(item, dst)
|
|
}
|
|
}
|