mirror of
https://github.com/actix/actix-extras.git
synced 2025-06-29 11:14:58 +02:00
Merge actix-http project
This commit is contained in:
109
actix-http/src/ws/client/connect.rs
Normal file
109
actix-http/src/ws/client/connect.rs
Normal file
@ -0,0 +1,109 @@
|
||||
//! Http client request
|
||||
use std::str;
|
||||
|
||||
#[cfg(feature = "cookies")]
|
||||
use cookie::Cookie;
|
||||
use http::header::{HeaderName, HeaderValue};
|
||||
use http::{Error as HttpError, HttpTryFrom, Uri};
|
||||
|
||||
use super::ClientError;
|
||||
use crate::header::IntoHeaderValue;
|
||||
use crate::message::RequestHead;
|
||||
|
||||
/// `WebSocket` connection
|
||||
pub struct Connect {
|
||||
pub(super) head: RequestHead,
|
||||
pub(super) err: Option<ClientError>,
|
||||
pub(super) http_err: Option<HttpError>,
|
||||
pub(super) origin: Option<HeaderValue>,
|
||||
pub(super) protocols: Option<String>,
|
||||
pub(super) max_size: usize,
|
||||
pub(super) server_mode: bool,
|
||||
}
|
||||
|
||||
impl Connect {
|
||||
/// Create new websocket connection
|
||||
pub fn new<S: AsRef<str>>(uri: S) -> Connect {
|
||||
let mut cl = Connect {
|
||||
head: RequestHead::default(),
|
||||
err: None,
|
||||
http_err: None,
|
||||
origin: None,
|
||||
protocols: None,
|
||||
max_size: 65_536,
|
||||
server_mode: false,
|
||||
};
|
||||
|
||||
match Uri::try_from(uri.as_ref()) {
|
||||
Ok(uri) => cl.head.uri = uri,
|
||||
Err(e) => cl.http_err = Some(e.into()),
|
||||
}
|
||||
|
||||
cl
|
||||
}
|
||||
|
||||
/// Set supported websocket protocols
|
||||
pub fn protocols<U, V>(mut self, protos: U) -> Self
|
||||
where
|
||||
U: IntoIterator<Item = V> + 'static,
|
||||
V: AsRef<str>,
|
||||
{
|
||||
let mut protos = protos
|
||||
.into_iter()
|
||||
.fold(String::new(), |acc, s| acc + s.as_ref() + ",");
|
||||
protos.pop();
|
||||
self.protocols = Some(protos);
|
||||
self
|
||||
}
|
||||
|
||||
// #[cfg(feature = "cookies")]
|
||||
// /// Set cookie for handshake request
|
||||
// pub fn cookie(mut self, cookie: Cookie) -> Self {
|
||||
// self.request.cookie(cookie);
|
||||
// self
|
||||
// }
|
||||
|
||||
/// Set request Origin
|
||||
pub fn origin<V>(mut self, origin: V) -> Self
|
||||
where
|
||||
HeaderValue: HttpTryFrom<V>,
|
||||
{
|
||||
match HeaderValue::try_from(origin) {
|
||||
Ok(value) => self.origin = Some(value),
|
||||
Err(e) => self.http_err = Some(e.into()),
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max frame size
|
||||
///
|
||||
/// By default max size is set to 64kb
|
||||
pub fn max_frame_size(mut self, size: usize) -> Self {
|
||||
self.max_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Disable payload masking. By default ws client masks frame payload.
|
||||
pub fn server_mode(mut self) -> Self {
|
||||
self.server_mode = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set request header
|
||||
pub fn header<K, V>(mut self, key: K, value: V) -> Self
|
||||
where
|
||||
HeaderName: HttpTryFrom<K>,
|
||||
V: IntoHeaderValue,
|
||||
{
|
||||
match HeaderName::try_from(key) {
|
||||
Ok(key) => match value.try_into() {
|
||||
Ok(value) => {
|
||||
self.head.headers.append(key, value);
|
||||
}
|
||||
Err(e) => self.http_err = Some(e.into()),
|
||||
},
|
||||
Err(e) => self.http_err = Some(e.into()),
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
53
actix-http/src/ws/client/error.rs
Normal file
53
actix-http/src/ws/client/error.rs
Normal file
@ -0,0 +1,53 @@
|
||||
//! Http client request
|
||||
use std::io;
|
||||
|
||||
use actix_connect::ConnectError;
|
||||
use derive_more::{Display, From};
|
||||
use http::{header::HeaderValue, Error as HttpError, StatusCode};
|
||||
|
||||
use crate::error::ParseError;
|
||||
use crate::ws::ProtocolError;
|
||||
|
||||
/// Websocket client error
|
||||
#[derive(Debug, Display, From)]
|
||||
pub enum ClientError {
|
||||
/// Invalid url
|
||||
#[display(fmt = "Invalid url")]
|
||||
InvalidUrl,
|
||||
/// Invalid response status
|
||||
#[display(fmt = "Invalid response status")]
|
||||
InvalidResponseStatus(StatusCode),
|
||||
/// Invalid upgrade header
|
||||
#[display(fmt = "Invalid upgrade header")]
|
||||
InvalidUpgradeHeader,
|
||||
/// Invalid connection header
|
||||
#[display(fmt = "Invalid connection header")]
|
||||
InvalidConnectionHeader(HeaderValue),
|
||||
/// Missing CONNECTION header
|
||||
#[display(fmt = "Missing CONNECTION header")]
|
||||
MissingConnectionHeader,
|
||||
/// Missing SEC-WEBSOCKET-ACCEPT header
|
||||
#[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")]
|
||||
MissingWebSocketAcceptHeader,
|
||||
/// Invalid challenge response
|
||||
#[display(fmt = "Invalid challenge response")]
|
||||
InvalidChallengeResponse(String, HeaderValue),
|
||||
/// Http parsing error
|
||||
#[display(fmt = "Http parsing error")]
|
||||
Http(HttpError),
|
||||
/// Response parsing error
|
||||
#[display(fmt = "Response parsing error: {}", _0)]
|
||||
ParseError(ParseError),
|
||||
/// Protocol error
|
||||
#[display(fmt = "{}", _0)]
|
||||
Protocol(ProtocolError),
|
||||
/// Connect error
|
||||
#[display(fmt = "Connector error: {:?}", _0)]
|
||||
Connect(ConnectError),
|
||||
/// IO Error
|
||||
#[display(fmt = "{}", _0)]
|
||||
Io(io::Error),
|
||||
/// "Disconnected"
|
||||
#[display(fmt = "Disconnected")]
|
||||
Disconnected,
|
||||
}
|
48
actix-http/src/ws/client/mod.rs
Normal file
48
actix-http/src/ws/client/mod.rs
Normal file
@ -0,0 +1,48 @@
|
||||
mod connect;
|
||||
mod error;
|
||||
mod service;
|
||||
|
||||
pub use self::connect::Connect;
|
||||
pub use self::error::ClientError;
|
||||
pub use self::service::Client;
|
||||
|
||||
#[derive(PartialEq, Hash, Debug, Clone, Copy)]
|
||||
pub(crate) enum Protocol {
|
||||
Http,
|
||||
Https,
|
||||
Ws,
|
||||
Wss,
|
||||
}
|
||||
|
||||
impl Protocol {
|
||||
fn from(s: &str) -> Option<Protocol> {
|
||||
match s {
|
||||
"http" => Some(Protocol::Http),
|
||||
"https" => Some(Protocol::Https),
|
||||
"ws" => Some(Protocol::Ws),
|
||||
"wss" => Some(Protocol::Wss),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// fn is_http(self) -> bool {
|
||||
// match self {
|
||||
// Protocol::Https | Protocol::Http => true,
|
||||
// _ => false,
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn is_secure(self) -> bool {
|
||||
// match self {
|
||||
// Protocol::Https | Protocol::Wss => true,
|
||||
// _ => false,
|
||||
// }
|
||||
// }
|
||||
|
||||
fn port(self) -> u16 {
|
||||
match self {
|
||||
Protocol::Http | Protocol::Ws => 80,
|
||||
Protocol::Https | Protocol::Wss => 443,
|
||||
}
|
||||
}
|
||||
}
|
272
actix-http/src/ws/client/service.rs
Normal file
272
actix-http/src/ws/client/service.rs
Normal file
@ -0,0 +1,272 @@
|
||||
//! websockets client
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use actix_codec::{AsyncRead, AsyncWrite, Framed};
|
||||
use actix_connect::{default_connector, Connect as TcpConnect, ConnectError};
|
||||
use actix_service::{apply_fn, Service};
|
||||
use base64;
|
||||
use futures::future::{err, Either, FutureResult};
|
||||
use futures::{try_ready, Async, Future, Poll, Sink, Stream};
|
||||
use http::header::{self, HeaderValue};
|
||||
use http::{HttpTryFrom, StatusCode};
|
||||
use log::trace;
|
||||
use rand;
|
||||
use sha1::Sha1;
|
||||
|
||||
use crate::body::BodyLength;
|
||||
use crate::h1;
|
||||
use crate::message::{ConnectionType, Head, ResponseHead};
|
||||
use crate::ws::Codec;
|
||||
|
||||
use super::{ClientError, Connect, Protocol};
|
||||
|
||||
/// WebSocket's client
|
||||
pub struct Client<T> {
|
||||
connector: T,
|
||||
}
|
||||
|
||||
impl Client<()> {
|
||||
/// Create client with default connector.
|
||||
pub fn default() -> Client<
|
||||
impl Service<
|
||||
Request = TcpConnect<String>,
|
||||
Response = impl AsyncRead + AsyncWrite,
|
||||
Error = ConnectError,
|
||||
> + Clone,
|
||||
> {
|
||||
Client::new(apply_fn(default_connector(), |msg: TcpConnect<_>, srv| {
|
||||
srv.call(msg).map(|stream| stream.into_parts().0)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Client<T>
|
||||
where
|
||||
T: Service<Request = TcpConnect<String>, Error = ConnectError>,
|
||||
T::Response: AsyncRead + AsyncWrite,
|
||||
{
|
||||
/// Create new websocket's client factory
|
||||
pub fn new(connector: T) -> Self {
|
||||
Client { connector }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for Client<T>
|
||||
where
|
||||
T: Service<Request = TcpConnect<String>, Error = ConnectError> + Clone,
|
||||
T::Response: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Client {
|
||||
connector: self.connector.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Service for Client<T>
|
||||
where
|
||||
T: Service<Request = TcpConnect<String>, Error = ConnectError>,
|
||||
T::Response: AsyncRead + AsyncWrite + 'static,
|
||||
T::Future: 'static,
|
||||
{
|
||||
type Request = Connect;
|
||||
type Response = Framed<T::Response, Codec>;
|
||||
type Error = ClientError;
|
||||
type Future = Either<
|
||||
FutureResult<Self::Response, Self::Error>,
|
||||
ClientResponseFut<T::Response>,
|
||||
>;
|
||||
|
||||
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
|
||||
self.connector.poll_ready().map_err(ClientError::from)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Connect) -> Self::Future {
|
||||
if let Some(e) = req.err.take() {
|
||||
Either::A(err(e))
|
||||
} else if let Some(e) = req.http_err.take() {
|
||||
Either::A(err(e.into()))
|
||||
} else {
|
||||
// origin
|
||||
if let Some(origin) = req.origin.take() {
|
||||
req.head.headers.insert(header::ORIGIN, origin);
|
||||
}
|
||||
|
||||
req.head.set_connection_type(ConnectionType::Upgrade);
|
||||
req.head
|
||||
.headers
|
||||
.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
|
||||
req.head.headers.insert(
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
HeaderValue::from_static("13"),
|
||||
);
|
||||
|
||||
if let Some(protocols) = req.protocols.take() {
|
||||
req.head.headers.insert(
|
||||
header::SEC_WEBSOCKET_PROTOCOL,
|
||||
HeaderValue::try_from(protocols.as_str()).unwrap(),
|
||||
);
|
||||
}
|
||||
if let Some(e) = req.http_err {
|
||||
return Either::A(err(e.into()));
|
||||
};
|
||||
|
||||
let mut request = req.head;
|
||||
if request.uri.host().is_none() {
|
||||
return Either::A(err(ClientError::InvalidUrl));
|
||||
}
|
||||
|
||||
// supported protocols
|
||||
let proto = if let Some(scheme) = request.uri.scheme_part() {
|
||||
match Protocol::from(scheme.as_str()) {
|
||||
Some(proto) => proto,
|
||||
None => return Either::A(err(ClientError::InvalidUrl)),
|
||||
}
|
||||
} else {
|
||||
return Either::A(err(ClientError::InvalidUrl));
|
||||
};
|
||||
|
||||
// Generate a random key for the `Sec-WebSocket-Key` header.
|
||||
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
||||
// when decoded, is 16 bytes in length (RFC 6455)
|
||||
let sec_key: [u8; 16] = rand::random();
|
||||
let key = base64::encode(&sec_key);
|
||||
|
||||
request.headers.insert(
|
||||
header::SEC_WEBSOCKET_KEY,
|
||||
HeaderValue::try_from(key.as_str()).unwrap(),
|
||||
);
|
||||
|
||||
// prep connection
|
||||
let connect = TcpConnect::new(request.uri.host().unwrap().to_string())
|
||||
.set_port(request.uri.port_u16().unwrap_or_else(|| proto.port()));
|
||||
|
||||
let fut = Box::new(
|
||||
self.connector
|
||||
.call(connect)
|
||||
.map_err(ClientError::from)
|
||||
.and_then(move |io| {
|
||||
// h1 protocol
|
||||
let framed = Framed::new(io, h1::ClientCodec::default());
|
||||
framed
|
||||
.send((request, BodyLength::None).into())
|
||||
.map_err(ClientError::from)
|
||||
.and_then(|framed| {
|
||||
framed
|
||||
.into_future()
|
||||
.map_err(|(e, _)| ClientError::from(e))
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
// start handshake
|
||||
Either::B(ClientResponseFut {
|
||||
key,
|
||||
fut,
|
||||
max_size: req.max_size,
|
||||
server_mode: req.server_mode,
|
||||
_t: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Future that implementes client websocket handshake process.
|
||||
///
|
||||
/// It resolves to a `Framed<T, ws::Codec>` instance.
|
||||
pub struct ClientResponseFut<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fut: Box<
|
||||
Future<
|
||||
Item = (Option<ResponseHead>, Framed<T, h1::ClientCodec>),
|
||||
Error = ClientError,
|
||||
>,
|
||||
>,
|
||||
key: String,
|
||||
max_size: usize,
|
||||
server_mode: bool,
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Future for ClientResponseFut<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = Framed<T, Codec>;
|
||||
type Error = ClientError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
let (item, framed) = try_ready!(self.fut.poll());
|
||||
|
||||
let res = match item {
|
||||
Some(res) => res,
|
||||
None => return Err(ClientError::Disconnected),
|
||||
};
|
||||
|
||||
// verify response
|
||||
if res.status != StatusCode::SWITCHING_PROTOCOLS {
|
||||
return Err(ClientError::InvalidResponseStatus(res.status));
|
||||
}
|
||||
// Check for "UPGRADE" to websocket header
|
||||
let has_hdr = if let Some(hdr) = res.headers.get(header::UPGRADE) {
|
||||
if let Ok(s) = hdr.to_str() {
|
||||
s.to_lowercase().contains("websocket")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if !has_hdr {
|
||||
trace!("Invalid upgrade header");
|
||||
return Err(ClientError::InvalidUpgradeHeader);
|
||||
}
|
||||
// Check for "CONNECTION" header
|
||||
if let Some(conn) = res.headers.get(header::CONNECTION) {
|
||||
if let Ok(s) = conn.to_str() {
|
||||
if !s.to_lowercase().contains("upgrade") {
|
||||
trace!("Invalid connection header: {}", s);
|
||||
return Err(ClientError::InvalidConnectionHeader(conn.clone()));
|
||||
}
|
||||
} else {
|
||||
trace!("Invalid connection header: {:?}", conn);
|
||||
return Err(ClientError::InvalidConnectionHeader(conn.clone()));
|
||||
}
|
||||
} else {
|
||||
trace!("Missing connection header");
|
||||
return Err(ClientError::MissingConnectionHeader);
|
||||
}
|
||||
|
||||
if let Some(key) = res.headers.get(header::SEC_WEBSOCKET_ACCEPT) {
|
||||
// field is constructed by concatenating /key/
|
||||
// with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
|
||||
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(self.key.as_ref());
|
||||
sha1.update(WS_GUID);
|
||||
let encoded = base64::encode(&sha1.digest().bytes());
|
||||
if key.as_bytes() != encoded.as_bytes() {
|
||||
trace!(
|
||||
"Invalid challenge response: expected: {} received: {:?}",
|
||||
encoded,
|
||||
key
|
||||
);
|
||||
return Err(ClientError::InvalidChallengeResponse(encoded, key.clone()));
|
||||
}
|
||||
} else {
|
||||
trace!("Missing SEC-WEBSOCKET-ACCEPT header");
|
||||
return Err(ClientError::MissingWebSocketAcceptHeader);
|
||||
};
|
||||
|
||||
// websockets codec
|
||||
let codec = if self.server_mode {
|
||||
Codec::new().max_size(self.max_size)
|
||||
} else {
|
||||
Codec::new().max_size(self.max_size).client_mode()
|
||||
};
|
||||
|
||||
Ok(Async::Ready(framed.into_framed(codec)))
|
||||
}
|
||||
}
|
147
actix-http/src/ws/codec.rs
Normal file
147
actix-http/src/ws/codec.rs
Normal file
@ -0,0 +1,147 @@
|
||||
use actix_codec::{Decoder, Encoder};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
|
||||
use super::frame::Parser;
|
||||
use super::proto::{CloseReason, OpCode};
|
||||
use super::ProtocolError;
|
||||
|
||||
/// `WebSocket` Message
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Message {
|
||||
/// Text message
|
||||
Text(String),
|
||||
/// Binary message
|
||||
Binary(Bytes),
|
||||
/// Ping message
|
||||
Ping(String),
|
||||
/// Pong message
|
||||
Pong(String),
|
||||
/// Close message with optional reason
|
||||
Close(Option<CloseReason>),
|
||||
}
|
||||
|
||||
/// `WebSocket` frame
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Frame {
|
||||
/// Text frame, codec does not verify utf8 encoding
|
||||
Text(Option<BytesMut>),
|
||||
/// Binary frame
|
||||
Binary(Option<BytesMut>),
|
||||
/// Ping message
|
||||
Ping(String),
|
||||
/// Pong message
|
||||
Pong(String),
|
||||
/// Close message with optional reason
|
||||
Close(Option<CloseReason>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// WebSockets protocol codec
|
||||
pub struct Codec {
|
||||
max_size: usize,
|
||||
server: bool,
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
/// Create new websocket frames decoder
|
||||
pub fn new() -> Codec {
|
||||
Codec {
|
||||
max_size: 65_536,
|
||||
server: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set max frame size
|
||||
///
|
||||
/// By default max size is set to 64kb
|
||||
pub fn max_size(mut self, size: usize) -> Self {
|
||||
self.max_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set decoder to client mode.
|
||||
///
|
||||
/// By default decoder works in server mode.
|
||||
pub fn client_mode(mut self) -> Self {
|
||||
self.server = false;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for Codec {
|
||||
type Item = Message;
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
Message::Text(txt) => {
|
||||
Parser::write_message(dst, txt, OpCode::Text, true, !self.server)
|
||||
}
|
||||
Message::Binary(bin) => {
|
||||
Parser::write_message(dst, bin, OpCode::Binary, true, !self.server)
|
||||
}
|
||||
Message::Ping(txt) => {
|
||||
Parser::write_message(dst, txt, OpCode::Ping, true, !self.server)
|
||||
}
|
||||
Message::Pong(txt) => {
|
||||
Parser::write_message(dst, txt, OpCode::Pong, true, !self.server)
|
||||
}
|
||||
Message::Close(reason) => Parser::write_close(dst, reason, !self.server),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for Codec {
|
||||
type Item = Frame;
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
match Parser::parse(src, self.server, self.max_size) {
|
||||
Ok(Some((finished, opcode, payload))) => {
|
||||
// continuation is not supported
|
||||
if !finished {
|
||||
return Err(ProtocolError::NoContinuation);
|
||||
}
|
||||
|
||||
match opcode {
|
||||
OpCode::Continue => Err(ProtocolError::NoContinuation),
|
||||
OpCode::Bad => Err(ProtocolError::BadOpCode),
|
||||
OpCode::Close => {
|
||||
if let Some(ref pl) = payload {
|
||||
let close_reason = Parser::parse_close_payload(pl);
|
||||
Ok(Some(Frame::Close(close_reason)))
|
||||
} else {
|
||||
Ok(Some(Frame::Close(None)))
|
||||
}
|
||||
}
|
||||
OpCode::Ping => {
|
||||
if let Some(ref pl) = payload {
|
||||
Ok(Some(Frame::Ping(String::from_utf8_lossy(pl).into())))
|
||||
} else {
|
||||
Ok(Some(Frame::Ping(String::new())))
|
||||
}
|
||||
}
|
||||
OpCode::Pong => {
|
||||
if let Some(ref pl) = payload {
|
||||
Ok(Some(Frame::Pong(String::from_utf8_lossy(pl).into())))
|
||||
} else {
|
||||
Ok(Some(Frame::Pong(String::new())))
|
||||
}
|
||||
}
|
||||
OpCode::Binary => Ok(Some(Frame::Binary(payload))),
|
||||
OpCode::Text => {
|
||||
Ok(Some(Frame::Text(payload)))
|
||||
//let tmp = Vec::from(payload.as_ref());
|
||||
//match String::from_utf8(tmp) {
|
||||
// Ok(s) => Ok(Some(Message::Text(s))),
|
||||
// Err(_) => Err(ProtocolError::BadEncoding),
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
383
actix-http/src/ws/frame.rs
Normal file
383
actix-http/src/ws/frame.rs
Normal file
@ -0,0 +1,383 @@
|
||||
use byteorder::{ByteOrder, LittleEndian, NetworkEndian};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use log::debug;
|
||||
use rand;
|
||||
|
||||
use crate::ws::mask::apply_mask;
|
||||
use crate::ws::proto::{CloseCode, CloseReason, OpCode};
|
||||
use crate::ws::ProtocolError;
|
||||
|
||||
/// A struct representing a `WebSocket` frame.
|
||||
#[derive(Debug)]
|
||||
pub struct Parser;
|
||||
|
||||
impl Parser {
|
||||
fn parse_metadata(
|
||||
src: &[u8],
|
||||
server: bool,
|
||||
max_size: usize,
|
||||
) -> Result<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> {
|
||||
let chunk_len = src.len();
|
||||
|
||||
let mut idx = 2;
|
||||
if chunk_len < 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let first = src[0];
|
||||
let second = src[1];
|
||||
let finished = first & 0x80 != 0;
|
||||
|
||||
// check masking
|
||||
let masked = second & 0x80 != 0;
|
||||
if !masked && server {
|
||||
return Err(ProtocolError::UnmaskedFrame);
|
||||
} else if masked && !server {
|
||||
return Err(ProtocolError::MaskedFrame);
|
||||
}
|
||||
|
||||
// Op code
|
||||
let opcode = OpCode::from(first & 0x0F);
|
||||
|
||||
if let OpCode::Bad = opcode {
|
||||
return Err(ProtocolError::InvalidOpcode(first & 0x0F));
|
||||
}
|
||||
|
||||
let len = second & 0x7F;
|
||||
let length = if len == 126 {
|
||||
if chunk_len < 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
let len = NetworkEndian::read_uint(&src[idx..], 2) as usize;
|
||||
idx += 2;
|
||||
len
|
||||
} else if len == 127 {
|
||||
if chunk_len < 10 {
|
||||
return Ok(None);
|
||||
}
|
||||
let len = NetworkEndian::read_uint(&src[idx..], 8);
|
||||
if len > max_size as u64 {
|
||||
return Err(ProtocolError::Overflow);
|
||||
}
|
||||
idx += 8;
|
||||
len as usize
|
||||
} else {
|
||||
len as usize
|
||||
};
|
||||
|
||||
// check for max allowed size
|
||||
if length > max_size {
|
||||
return Err(ProtocolError::Overflow);
|
||||
}
|
||||
|
||||
let mask = if server {
|
||||
if chunk_len < idx + 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mask: &[u8] = &src[idx..idx + 4];
|
||||
let mask_u32 = LittleEndian::read_u32(mask);
|
||||
idx += 4;
|
||||
Some(mask_u32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Some((idx, finished, opcode, length, mask)))
|
||||
}
|
||||
|
||||
/// Parse the input stream into a frame.
|
||||
pub fn parse(
|
||||
src: &mut BytesMut,
|
||||
server: bool,
|
||||
max_size: usize,
|
||||
) -> Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError> {
|
||||
// try to parse ws frame metadata
|
||||
let (idx, finished, opcode, length, mask) =
|
||||
match Parser::parse_metadata(src, server, max_size)? {
|
||||
None => return Ok(None),
|
||||
Some(res) => res,
|
||||
};
|
||||
|
||||
// not enough data
|
||||
if src.len() < idx + length {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// remove prefix
|
||||
src.split_to(idx);
|
||||
|
||||
// no need for body
|
||||
if length == 0 {
|
||||
return Ok(Some((finished, opcode, None)));
|
||||
}
|
||||
|
||||
let mut data = src.split_to(length);
|
||||
|
||||
// control frames must have length <= 125
|
||||
match opcode {
|
||||
OpCode::Ping | OpCode::Pong if length > 125 => {
|
||||
return Err(ProtocolError::InvalidLength(length));
|
||||
}
|
||||
OpCode::Close if length > 125 => {
|
||||
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
|
||||
return Ok(Some((true, OpCode::Close, None)));
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
// unmask
|
||||
if let Some(mask) = mask {
|
||||
apply_mask(&mut data, mask);
|
||||
}
|
||||
|
||||
Ok(Some((finished, opcode, Some(data))))
|
||||
}
|
||||
|
||||
/// Parse the payload of a close frame.
|
||||
pub fn parse_close_payload(payload: &[u8]) -> Option<CloseReason> {
|
||||
if payload.len() >= 2 {
|
||||
let raw_code = NetworkEndian::read_u16(payload);
|
||||
let code = CloseCode::from(raw_code);
|
||||
let description = if payload.len() > 2 {
|
||||
Some(String::from_utf8_lossy(&payload[2..]).into())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Some(CloseReason { code, description })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate binary representation
|
||||
pub fn write_message<B: Into<Bytes>>(
|
||||
dst: &mut BytesMut,
|
||||
pl: B,
|
||||
op: OpCode,
|
||||
fin: bool,
|
||||
mask: bool,
|
||||
) {
|
||||
let payload = pl.into();
|
||||
let one: u8 = if fin {
|
||||
0x80 | Into::<u8>::into(op)
|
||||
} else {
|
||||
op.into()
|
||||
};
|
||||
let payload_len = payload.len();
|
||||
let (two, p_len) = if mask {
|
||||
(0x80, payload_len + 4)
|
||||
} else {
|
||||
(0, payload_len)
|
||||
};
|
||||
|
||||
if payload_len < 126 {
|
||||
dst.put_slice(&[one, two | payload_len as u8]);
|
||||
} else if payload_len <= 65_535 {
|
||||
dst.reserve(p_len + 4);
|
||||
dst.put_slice(&[one, two | 126]);
|
||||
dst.put_u16_be(payload_len as u16);
|
||||
} else {
|
||||
dst.reserve(p_len + 10);
|
||||
dst.put_slice(&[one, two | 127]);
|
||||
dst.put_u64_be(payload_len as u64);
|
||||
};
|
||||
|
||||
if mask {
|
||||
let mask = rand::random::<u32>();
|
||||
dst.put_u32_le(mask);
|
||||
dst.extend_from_slice(payload.as_ref());
|
||||
let pos = dst.len() - payload_len;
|
||||
apply_mask(&mut dst[pos..], mask);
|
||||
} else {
|
||||
dst.put_slice(payload.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Close control frame.
|
||||
#[inline]
|
||||
pub fn write_close(dst: &mut BytesMut, reason: Option<CloseReason>, mask: bool) {
|
||||
let payload = match reason {
|
||||
None => Vec::new(),
|
||||
Some(reason) => {
|
||||
let mut code_bytes = [0; 2];
|
||||
NetworkEndian::write_u16(&mut code_bytes, reason.code.into());
|
||||
|
||||
let mut payload = Vec::from(&code_bytes[..]);
|
||||
if let Some(description) = reason.description {
|
||||
payload.extend(description.as_bytes());
|
||||
}
|
||||
payload
|
||||
}
|
||||
};
|
||||
|
||||
Parser::write_message(dst, payload, OpCode::Close, true, mask)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use bytes::Bytes;
|
||||
|
||||
struct F {
|
||||
finished: bool,
|
||||
opcode: OpCode,
|
||||
payload: Bytes,
|
||||
}
|
||||
|
||||
fn is_none(
|
||||
frm: &Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
|
||||
) -> bool {
|
||||
match *frm {
|
||||
Ok(None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract(
|
||||
frm: Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
|
||||
) -> F {
|
||||
match frm {
|
||||
Ok(Some((finished, opcode, payload))) => F {
|
||||
finished,
|
||||
opcode,
|
||||
payload: payload
|
||||
.map(|b| b.freeze())
|
||||
.unwrap_or_else(|| Bytes::from("")),
|
||||
},
|
||||
_ => unreachable!("error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse() {
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
|
||||
assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
|
||||
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
|
||||
buf.extend(b"1");
|
||||
|
||||
let frame = extract(Parser::parse(&mut buf, false, 1024));
|
||||
assert!(!frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert_eq!(frame.payload.as_ref(), &b"1"[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_length0() {
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
|
||||
let frame = extract(Parser::parse(&mut buf, false, 1024));
|
||||
assert!(!frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert!(frame.payload.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_length2() {
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
|
||||
assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
|
||||
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
|
||||
buf.extend(&[0u8, 4u8][..]);
|
||||
buf.extend(b"1234");
|
||||
|
||||
let frame = extract(Parser::parse(&mut buf, false, 1024));
|
||||
assert!(!frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_length4() {
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
|
||||
assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
|
||||
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
|
||||
buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
|
||||
buf.extend(b"1234");
|
||||
|
||||
let frame = extract(Parser::parse(&mut buf, false, 1024));
|
||||
assert!(!frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frame_mask() {
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]);
|
||||
buf.extend(b"0001");
|
||||
buf.extend(b"1");
|
||||
|
||||
assert!(Parser::parse(&mut buf, false, 1024).is_err());
|
||||
|
||||
let frame = extract(Parser::parse(&mut buf, true, 1024));
|
||||
assert!(!frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert_eq!(frame.payload, Bytes::from(vec![1u8]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frame_no_mask() {
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
|
||||
buf.extend(&[1u8]);
|
||||
|
||||
assert!(Parser::parse(&mut buf, true, 1024).is_err());
|
||||
|
||||
let frame = extract(Parser::parse(&mut buf, false, 1024));
|
||||
assert!(!frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert_eq!(frame.payload, Bytes::from(vec![1u8]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_frame_max_size() {
|
||||
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
|
||||
buf.extend(&[1u8, 1u8]);
|
||||
|
||||
assert!(Parser::parse(&mut buf, true, 1).is_err());
|
||||
|
||||
if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) {
|
||||
} else {
|
||||
unreachable!("error");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ping_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
|
||||
|
||||
let mut v = vec![137u8, 4u8];
|
||||
v.extend(b"data");
|
||||
assert_eq!(&buf[..], &v[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pong_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
|
||||
|
||||
let mut v = vec![138u8, 4u8];
|
||||
v.extend(b"data");
|
||||
assert_eq!(&buf[..], &v[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_close_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
let reason = (CloseCode::Normal, "data");
|
||||
Parser::write_close(&mut buf, Some(reason.into()), false);
|
||||
|
||||
let mut v = vec![136u8, 6u8, 3u8, 232u8];
|
||||
v.extend(b"data");
|
||||
assert_eq!(&buf[..], &v[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_close_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
Parser::write_close(&mut buf, None, false);
|
||||
assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
|
||||
}
|
||||
}
|
149
actix-http/src/ws/mask.rs
Normal file
149
actix-http/src/ws/mask.rs
Normal file
@ -0,0 +1,149 @@
|
||||
//! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs)
|
||||
#![allow(clippy::cast_ptr_alignment)]
|
||||
use std::ptr::copy_nonoverlapping;
|
||||
use std::slice;
|
||||
|
||||
// Holds a slice guaranteed to be shorter than 8 bytes
|
||||
struct ShortSlice<'a>(&'a mut [u8]);
|
||||
|
||||
impl<'a> ShortSlice<'a> {
|
||||
unsafe fn new(slice: &'a mut [u8]) -> Self {
|
||||
// Sanity check for debug builds
|
||||
debug_assert!(slice.len() < 8);
|
||||
ShortSlice(slice)
|
||||
}
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Faster version of `apply_mask()` which operates on 8-byte blocks.
|
||||
#[inline]
|
||||
#[allow(clippy::cast_lossless)]
|
||||
pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) {
|
||||
// Extend the mask to 64 bits
|
||||
let mut mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64);
|
||||
// Split the buffer into three segments
|
||||
let (head, mid, tail) = align_buf(buf);
|
||||
|
||||
// Initial unaligned segment
|
||||
let head_len = head.len();
|
||||
if head_len > 0 {
|
||||
xor_short(head, mask_u64);
|
||||
if cfg!(target_endian = "big") {
|
||||
mask_u64 = mask_u64.rotate_left(8 * head_len as u32);
|
||||
} else {
|
||||
mask_u64 = mask_u64.rotate_right(8 * head_len as u32);
|
||||
}
|
||||
}
|
||||
// Aligned segment
|
||||
for v in mid {
|
||||
*v ^= mask_u64;
|
||||
}
|
||||
// Final unaligned segment
|
||||
if tail.len() > 0 {
|
||||
xor_short(tail, mask_u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so
|
||||
// inefficient, it could be done better. The compiler does not understand that
|
||||
// a `ShortSlice` must be smaller than a u64.
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn xor_short(buf: ShortSlice, mask: u64) {
|
||||
// Unsafe: we know that a `ShortSlice` fits in a u64
|
||||
unsafe {
|
||||
let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len());
|
||||
let mut b: u64 = 0;
|
||||
#[allow(trivial_casts)]
|
||||
copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len);
|
||||
b ^= mask;
|
||||
#[allow(trivial_casts)]
|
||||
copy_nonoverlapping(&b as *const _ as *const u8, ptr, len);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
// Unsafe: caller must ensure the buffer has the correct size and alignment
|
||||
unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] {
|
||||
// Assert correct size and alignment in debug builds
|
||||
debug_assert!(buf.len().trailing_zeros() >= 3);
|
||||
debug_assert!((buf.as_ptr() as usize).trailing_zeros() >= 3);
|
||||
|
||||
slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u64, buf.len() >> 3)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
// Splits a slice into three parts: an unaligned short head and tail, plus an aligned
|
||||
// u64 mid section.
|
||||
fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) {
|
||||
let start_ptr = buf.as_ptr() as usize;
|
||||
let end_ptr = start_ptr + buf.len();
|
||||
|
||||
// Round *up* to next aligned boundary for start
|
||||
let start_aligned = (start_ptr + 7) & !0x7;
|
||||
// Round *down* to last aligned boundary for end
|
||||
let end_aligned = end_ptr & !0x7;
|
||||
|
||||
if end_aligned >= start_aligned {
|
||||
// We have our three segments (head, mid, tail)
|
||||
let (tmp, tail) = buf.split_at_mut(end_aligned - start_ptr);
|
||||
let (head, mid) = tmp.split_at_mut(start_aligned - start_ptr);
|
||||
|
||||
// Unsafe: we know the middle section is correctly aligned, and the outer
|
||||
// sections are smaller than 8 bytes
|
||||
unsafe { (ShortSlice::new(head), cast_slice(mid), ShortSlice(tail)) }
|
||||
} else {
|
||||
// We didn't cross even one aligned boundary!
|
||||
|
||||
// Unsafe: The outer sections are smaller than 8 bytes
|
||||
unsafe { (ShortSlice::new(buf), &mut [], ShortSlice::new(&mut [])) }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::apply_mask;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
/// A safe unoptimized mask application.
|
||||
fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) {
|
||||
for (i, byte) in buf.iter_mut().enumerate() {
|
||||
*byte ^= mask[i & 3];
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_mask() {
|
||||
let mask = [0x6d, 0xb6, 0xb2, 0x80];
|
||||
let mask_u32: u32 = LittleEndian::read_u32(&mask);
|
||||
|
||||
let unmasked = vec![
|
||||
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17,
|
||||
0x74, 0xf9, 0x12, 0x03,
|
||||
];
|
||||
|
||||
// Check masking with proper alignment.
|
||||
{
|
||||
let mut masked = unmasked.clone();
|
||||
apply_mask_fallback(&mut masked, &mask);
|
||||
|
||||
let mut masked_fast = unmasked.clone();
|
||||
apply_mask(&mut masked_fast, mask_u32);
|
||||
|
||||
assert_eq!(masked, masked_fast);
|
||||
}
|
||||
|
||||
// Check masking without alignment.
|
||||
{
|
||||
let mut masked = unmasked.clone();
|
||||
apply_mask_fallback(&mut masked[1..], &mask);
|
||||
|
||||
let mut masked_fast = unmasked.clone();
|
||||
apply_mask(&mut masked_fast[1..], mask_u32);
|
||||
|
||||
assert_eq!(masked, masked_fast);
|
||||
}
|
||||
}
|
||||
}
|
320
actix-http/src/ws/mod.rs
Normal file
320
actix-http/src/ws/mod.rs
Normal file
@ -0,0 +1,320 @@
|
||||
//! WebSocket protocol support.
|
||||
//!
|
||||
//! To setup a `WebSocket`, first do web socket handshake then on success
|
||||
//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
|
||||
//! communicate with the peer.
|
||||
use std::io;
|
||||
|
||||
use derive_more::{Display, From};
|
||||
use http::{header, Method, StatusCode};
|
||||
|
||||
use crate::error::ResponseError;
|
||||
use crate::httpmessage::HttpMessage;
|
||||
use crate::request::Request;
|
||||
use crate::response::{Response, ResponseBuilder};
|
||||
|
||||
mod client;
|
||||
mod codec;
|
||||
mod frame;
|
||||
mod mask;
|
||||
mod proto;
|
||||
mod service;
|
||||
mod transport;
|
||||
|
||||
pub use self::client::{Client, ClientError, Connect};
|
||||
pub use self::codec::{Codec, Frame, Message};
|
||||
pub use self::frame::Parser;
|
||||
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
|
||||
pub use self::service::VerifyWebSockets;
|
||||
pub use self::transport::Transport;
|
||||
|
||||
/// Websocket protocol errors
|
||||
#[derive(Debug, Display, From)]
|
||||
pub enum ProtocolError {
|
||||
/// Received an unmasked frame from client
|
||||
#[display(fmt = "Received an unmasked frame from client")]
|
||||
UnmaskedFrame,
|
||||
/// Received a masked frame from server
|
||||
#[display(fmt = "Received a masked frame from server")]
|
||||
MaskedFrame,
|
||||
/// Encountered invalid opcode
|
||||
#[display(fmt = "Invalid opcode: {}", _0)]
|
||||
InvalidOpcode(u8),
|
||||
/// Invalid control frame length
|
||||
#[display(fmt = "Invalid control frame length: {}", _0)]
|
||||
InvalidLength(usize),
|
||||
/// Bad web socket op code
|
||||
#[display(fmt = "Bad web socket op code")]
|
||||
BadOpCode,
|
||||
/// A payload reached size limit.
|
||||
#[display(fmt = "A payload reached size limit.")]
|
||||
Overflow,
|
||||
/// Continuation is not supported
|
||||
#[display(fmt = "Continuation is not supported.")]
|
||||
NoContinuation,
|
||||
/// Bad utf-8 encoding
|
||||
#[display(fmt = "Bad utf-8 encoding.")]
|
||||
BadEncoding,
|
||||
/// Io error
|
||||
#[display(fmt = "io error: {}", _0)]
|
||||
Io(io::Error),
|
||||
}
|
||||
|
||||
impl ResponseError for ProtocolError {}
|
||||
|
||||
/// Websocket handshake errors
|
||||
#[derive(PartialEq, Debug, Display)]
|
||||
pub enum HandshakeError {
|
||||
/// Only get method is allowed
|
||||
#[display(fmt = "Method not allowed")]
|
||||
GetMethodRequired,
|
||||
/// Upgrade header if not set to websocket
|
||||
#[display(fmt = "Websocket upgrade is expected")]
|
||||
NoWebsocketUpgrade,
|
||||
/// Connection header is not set to upgrade
|
||||
#[display(fmt = "Connection upgrade is expected")]
|
||||
NoConnectionUpgrade,
|
||||
/// Websocket version header is not set
|
||||
#[display(fmt = "Websocket version header is required")]
|
||||
NoVersionHeader,
|
||||
/// Unsupported websocket version
|
||||
#[display(fmt = "Unsupported version")]
|
||||
UnsupportedVersion,
|
||||
/// Websocket key is not set or wrong
|
||||
#[display(fmt = "Unknown websocket key")]
|
||||
BadWebsocketKey,
|
||||
}
|
||||
|
||||
impl ResponseError for HandshakeError {
|
||||
fn error_response(&self) -> Response {
|
||||
match *self {
|
||||
HandshakeError::GetMethodRequired => Response::MethodNotAllowed()
|
||||
.header(header::ALLOW, "GET")
|
||||
.finish(),
|
||||
HandshakeError::NoWebsocketUpgrade => Response::BadRequest()
|
||||
.reason("No WebSocket UPGRADE header found")
|
||||
.finish(),
|
||||
HandshakeError::NoConnectionUpgrade => Response::BadRequest()
|
||||
.reason("No CONNECTION upgrade")
|
||||
.finish(),
|
||||
HandshakeError::NoVersionHeader => Response::BadRequest()
|
||||
.reason("Websocket version header is required")
|
||||
.finish(),
|
||||
HandshakeError::UnsupportedVersion => Response::BadRequest()
|
||||
.reason("Unsupported version")
|
||||
.finish(),
|
||||
HandshakeError::BadWebsocketKey => {
|
||||
Response::BadRequest().reason("Handshake error").finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify `WebSocket` handshake request and create handshake reponse.
|
||||
// /// `protocols` is a sequence of known protocols. On successful handshake,
|
||||
// /// the returned response headers contain the first protocol in this list
|
||||
// /// which the server also knows.
|
||||
pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
|
||||
verify_handshake(req)?;
|
||||
Ok(handshake_response(req))
|
||||
}
|
||||
|
||||
/// Verify `WebSocket` handshake request.
|
||||
// /// `protocols` is a sequence of known protocols. On successful handshake,
|
||||
// /// the returned response headers contain the first protocol in this list
|
||||
// /// which the server also knows.
|
||||
pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> {
|
||||
// WebSocket accepts only GET
|
||||
if *req.method() != Method::GET {
|
||||
return Err(HandshakeError::GetMethodRequired);
|
||||
}
|
||||
|
||||
// Check for "UPGRADE" to websocket header
|
||||
let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) {
|
||||
if let Ok(s) = hdr.to_str() {
|
||||
s.to_ascii_lowercase().contains("websocket")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if !has_hdr {
|
||||
return Err(HandshakeError::NoWebsocketUpgrade);
|
||||
}
|
||||
|
||||
// Upgrade connection
|
||||
if !req.upgrade() {
|
||||
return Err(HandshakeError::NoConnectionUpgrade);
|
||||
}
|
||||
|
||||
// check supported version
|
||||
if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
|
||||
return Err(HandshakeError::NoVersionHeader);
|
||||
}
|
||||
let supported_ver = {
|
||||
if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
|
||||
hdr == "13" || hdr == "8" || hdr == "7"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
if !supported_ver {
|
||||
return Err(HandshakeError::UnsupportedVersion);
|
||||
}
|
||||
|
||||
// check client handshake for validity
|
||||
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
|
||||
return Err(HandshakeError::BadWebsocketKey);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create websocket's handshake response
|
||||
///
|
||||
/// This function returns handshake `Response`, ready to send to peer.
|
||||
pub fn handshake_response(req: &Request) -> ResponseBuilder {
|
||||
let key = {
|
||||
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
|
||||
proto::hash_key(key.as_ref())
|
||||
};
|
||||
|
||||
Response::build(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.upgrade("websocket")
|
||||
.header(header::TRANSFER_ENCODING, "chunked")
|
||||
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
|
||||
.take()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::test::TestRequest;
|
||||
use http::{header, Method};
|
||||
|
||||
#[test]
|
||||
fn test_handshake() {
|
||||
let req = TestRequest::default().method(Method::POST).finish();
|
||||
assert_eq!(
|
||||
HandshakeError::GetMethodRequired,
|
||||
verify_handshake(&req).err().unwrap()
|
||||
);
|
||||
|
||||
let req = TestRequest::default().finish();
|
||||
assert_eq!(
|
||||
HandshakeError::NoWebsocketUpgrade,
|
||||
verify_handshake(&req).err().unwrap()
|
||||
);
|
||||
|
||||
let req = TestRequest::default()
|
||||
.header(header::UPGRADE, header::HeaderValue::from_static("test"))
|
||||
.finish();
|
||||
assert_eq!(
|
||||
HandshakeError::NoWebsocketUpgrade,
|
||||
verify_handshake(&req).err().unwrap()
|
||||
);
|
||||
|
||||
let req = TestRequest::default()
|
||||
.header(
|
||||
header::UPGRADE,
|
||||
header::HeaderValue::from_static("websocket"),
|
||||
)
|
||||
.finish();
|
||||
assert_eq!(
|
||||
HandshakeError::NoConnectionUpgrade,
|
||||
verify_handshake(&req).err().unwrap()
|
||||
);
|
||||
|
||||
let req = TestRequest::default()
|
||||
.header(
|
||||
header::UPGRADE,
|
||||
header::HeaderValue::from_static("websocket"),
|
||||
)
|
||||
.header(
|
||||
header::CONNECTION,
|
||||
header::HeaderValue::from_static("upgrade"),
|
||||
)
|
||||
.finish();
|
||||
assert_eq!(
|
||||
HandshakeError::NoVersionHeader,
|
||||
verify_handshake(&req).err().unwrap()
|
||||
);
|
||||
|
||||
let req = TestRequest::default()
|
||||
.header(
|
||||
header::UPGRADE,
|
||||
header::HeaderValue::from_static("websocket"),
|
||||
)
|
||||
.header(
|
||||
header::CONNECTION,
|
||||
header::HeaderValue::from_static("upgrade"),
|
||||
)
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
header::HeaderValue::from_static("5"),
|
||||
)
|
||||
.finish();
|
||||
assert_eq!(
|
||||
HandshakeError::UnsupportedVersion,
|
||||
verify_handshake(&req).err().unwrap()
|
||||
);
|
||||
|
||||
let req = TestRequest::default()
|
||||
.header(
|
||||
header::UPGRADE,
|
||||
header::HeaderValue::from_static("websocket"),
|
||||
)
|
||||
.header(
|
||||
header::CONNECTION,
|
||||
header::HeaderValue::from_static("upgrade"),
|
||||
)
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
header::HeaderValue::from_static("13"),
|
||||
)
|
||||
.finish();
|
||||
assert_eq!(
|
||||
HandshakeError::BadWebsocketKey,
|
||||
verify_handshake(&req).err().unwrap()
|
||||
);
|
||||
|
||||
let req = TestRequest::default()
|
||||
.header(
|
||||
header::UPGRADE,
|
||||
header::HeaderValue::from_static("websocket"),
|
||||
)
|
||||
.header(
|
||||
header::CONNECTION,
|
||||
header::HeaderValue::from_static("upgrade"),
|
||||
)
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
header::HeaderValue::from_static("13"),
|
||||
)
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_KEY,
|
||||
header::HeaderValue::from_static("13"),
|
||||
)
|
||||
.finish();
|
||||
assert_eq!(
|
||||
StatusCode::SWITCHING_PROTOCOLS,
|
||||
handshake_response(&req).finish().status()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wserror_http_response() {
|
||||
let resp: Response = HandshakeError::GetMethodRequired.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
|
||||
let resp: Response = HandshakeError::NoWebsocketUpgrade.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: Response = HandshakeError::NoConnectionUpgrade.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: Response = HandshakeError::NoVersionHeader.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: Response = HandshakeError::UnsupportedVersion.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
let resp: Response = HandshakeError::BadWebsocketKey.error_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
}
|
318
actix-http/src/ws/proto.rs
Normal file
318
actix-http/src/ws/proto.rs
Normal file
@ -0,0 +1,318 @@
|
||||
use base64;
|
||||
use sha1;
|
||||
use std::convert::{From, Into};
|
||||
use std::fmt;
|
||||
|
||||
use self::OpCode::*;
|
||||
/// Operation codes as part of rfc6455.
|
||||
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
|
||||
pub enum OpCode {
|
||||
/// Indicates a continuation frame of a fragmented message.
|
||||
Continue,
|
||||
/// Indicates a text data frame.
|
||||
Text,
|
||||
/// Indicates a binary data frame.
|
||||
Binary,
|
||||
/// Indicates a close control frame.
|
||||
Close,
|
||||
/// Indicates a ping control frame.
|
||||
Ping,
|
||||
/// Indicates a pong control frame.
|
||||
Pong,
|
||||
/// Indicates an invalid opcode was received.
|
||||
Bad,
|
||||
}
|
||||
|
||||
impl fmt::Display for OpCode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
Continue => write!(f, "CONTINUE"),
|
||||
Text => write!(f, "TEXT"),
|
||||
Binary => write!(f, "BINARY"),
|
||||
Close => write!(f, "CLOSE"),
|
||||
Ping => write!(f, "PING"),
|
||||
Pong => write!(f, "PONG"),
|
||||
Bad => write!(f, "BAD"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<u8> for OpCode {
|
||||
fn into(self) -> u8 {
|
||||
match self {
|
||||
Continue => 0,
|
||||
Text => 1,
|
||||
Binary => 2,
|
||||
Close => 8,
|
||||
Ping => 9,
|
||||
Pong => 10,
|
||||
Bad => {
|
||||
debug_assert!(
|
||||
false,
|
||||
"Attempted to convert invalid opcode to u8. This is a bug."
|
||||
);
|
||||
8 // if this somehow happens, a close frame will help us tear down quickly
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u8> for OpCode {
|
||||
fn from(byte: u8) -> OpCode {
|
||||
match byte {
|
||||
0 => Continue,
|
||||
1 => Text,
|
||||
2 => Binary,
|
||||
8 => Close,
|
||||
9 => Ping,
|
||||
10 => Pong,
|
||||
_ => Bad,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use self::CloseCode::*;
|
||||
/// Status code used to indicate why an endpoint is closing the `WebSocket`
|
||||
/// connection.
|
||||
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
|
||||
pub enum CloseCode {
|
||||
/// Indicates a normal closure, meaning that the purpose for
|
||||
/// which the connection was established has been fulfilled.
|
||||
Normal,
|
||||
/// Indicates that an endpoint is "going away", such as a server
|
||||
/// going down or a browser having navigated away from a page.
|
||||
Away,
|
||||
/// Indicates that an endpoint is terminating the connection due
|
||||
/// to a protocol error.
|
||||
Protocol,
|
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received a type of data it cannot accept (e.g., an
|
||||
/// endpoint that understands only text data MAY send this if it
|
||||
/// receives a binary message).
|
||||
Unsupported,
|
||||
/// Indicates an abnormal closure. If the abnormal closure was due to an
|
||||
/// error, this close code will not be used. Instead, the `on_error` method
|
||||
/// of the handler will be called with the error. However, if the connection
|
||||
/// is simply dropped, without an error, this close code will be sent to the
|
||||
/// handler.
|
||||
Abnormal,
|
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received data within a message that was not
|
||||
/// consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
|
||||
/// data within a text message).
|
||||
Invalid,
|
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received a message that violates its policy. This
|
||||
/// is a generic status code that can be returned when there is no
|
||||
/// other more suitable status code (e.g., Unsupported or Size) or if there
|
||||
/// is a need to hide specific details about the policy.
|
||||
Policy,
|
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received a message that is too big for it to
|
||||
/// process.
|
||||
Size,
|
||||
/// Indicates that an endpoint (client) is terminating the
|
||||
/// connection because it has expected the server to negotiate one or
|
||||
/// more extension, but the server didn't return them in the response
|
||||
/// message of the WebSocket handshake. The list of extensions that
|
||||
/// are needed should be given as the reason for closing.
|
||||
/// Note that this status code is not used by the server, because it
|
||||
/// can fail the WebSocket handshake instead.
|
||||
Extension,
|
||||
/// Indicates that a server is terminating the connection because
|
||||
/// it encountered an unexpected condition that prevented it from
|
||||
/// fulfilling the request.
|
||||
Error,
|
||||
/// Indicates that the server is restarting. A client may choose to
|
||||
/// reconnect, and if it does, it should use a randomized delay of 5-30
|
||||
/// seconds between attempts.
|
||||
Restart,
|
||||
/// Indicates that the server is overloaded and the client should either
|
||||
/// connect to a different IP (when multiple targets exist), or
|
||||
/// reconnect to the same IP when a user has performed an action.
|
||||
Again,
|
||||
#[doc(hidden)]
|
||||
Tls,
|
||||
#[doc(hidden)]
|
||||
Other(u16),
|
||||
}
|
||||
|
||||
impl Into<u16> for CloseCode {
|
||||
fn into(self) -> u16 {
|
||||
match self {
|
||||
Normal => 1000,
|
||||
Away => 1001,
|
||||
Protocol => 1002,
|
||||
Unsupported => 1003,
|
||||
Abnormal => 1006,
|
||||
Invalid => 1007,
|
||||
Policy => 1008,
|
||||
Size => 1009,
|
||||
Extension => 1010,
|
||||
Error => 1011,
|
||||
Restart => 1012,
|
||||
Again => 1013,
|
||||
Tls => 1015,
|
||||
Other(code) => code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u16> for CloseCode {
|
||||
fn from(code: u16) -> CloseCode {
|
||||
match code {
|
||||
1000 => Normal,
|
||||
1001 => Away,
|
||||
1002 => Protocol,
|
||||
1003 => Unsupported,
|
||||
1006 => Abnormal,
|
||||
1007 => Invalid,
|
||||
1008 => Policy,
|
||||
1009 => Size,
|
||||
1010 => Extension,
|
||||
1011 => Error,
|
||||
1012 => Restart,
|
||||
1013 => Again,
|
||||
1015 => Tls,
|
||||
_ => Other(code),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
/// Reason for closing the connection
|
||||
pub struct CloseReason {
|
||||
/// Exit code
|
||||
pub code: CloseCode,
|
||||
/// Optional description of the exit code
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
impl From<CloseCode> for CloseReason {
|
||||
fn from(code: CloseCode) -> Self {
|
||||
CloseReason {
|
||||
code,
|
||||
description: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Into<String>> From<(CloseCode, T)> for CloseReason {
|
||||
fn from(info: (CloseCode, T)) -> Self {
|
||||
CloseReason {
|
||||
code: info.0,
|
||||
description: Some(info.1.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
|
||||
// TODO: hash is always same size, we dont need String
|
||||
pub fn hash_key(key: &[u8]) -> String {
|
||||
let mut hasher = sha1::Sha1::new();
|
||||
|
||||
hasher.update(key);
|
||||
hasher.update(WS_GUID.as_bytes());
|
||||
|
||||
base64::encode(&hasher.digest().bytes())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#![allow(unused_imports, unused_variables, dead_code)]
|
||||
use super::*;
|
||||
|
||||
macro_rules! opcode_into {
|
||||
($from:expr => $opcode:pat) => {
|
||||
match OpCode::from($from) {
|
||||
e @ $opcode => (),
|
||||
e => unreachable!("{:?}", e),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! opcode_from {
|
||||
($from:expr => $opcode:pat) => {
|
||||
let res: u8 = $from.into();
|
||||
match res {
|
||||
e @ $opcode => (),
|
||||
e => unreachable!("{:?}", e),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_opcode() {
|
||||
opcode_into!(0 => OpCode::Continue);
|
||||
opcode_into!(1 => OpCode::Text);
|
||||
opcode_into!(2 => OpCode::Binary);
|
||||
opcode_into!(8 => OpCode::Close);
|
||||
opcode_into!(9 => OpCode::Ping);
|
||||
opcode_into!(10 => OpCode::Pong);
|
||||
opcode_into!(99 => OpCode::Bad);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_opcode() {
|
||||
opcode_from!(OpCode::Continue => 0);
|
||||
opcode_from!(OpCode::Text => 1);
|
||||
opcode_from!(OpCode::Binary => 2);
|
||||
opcode_from!(OpCode::Close => 8);
|
||||
opcode_from!(OpCode::Ping => 9);
|
||||
opcode_from!(OpCode::Pong => 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_from_opcode_debug() {
|
||||
opcode_from!(OpCode::Bad => 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_opcode_display() {
|
||||
assert_eq!(format!("{}", OpCode::Continue), "CONTINUE");
|
||||
assert_eq!(format!("{}", OpCode::Text), "TEXT");
|
||||
assert_eq!(format!("{}", OpCode::Binary), "BINARY");
|
||||
assert_eq!(format!("{}", OpCode::Close), "CLOSE");
|
||||
assert_eq!(format!("{}", OpCode::Ping), "PING");
|
||||
assert_eq!(format!("{}", OpCode::Pong), "PONG");
|
||||
assert_eq!(format!("{}", OpCode::Bad), "BAD");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn closecode_from_u16() {
|
||||
assert_eq!(CloseCode::from(1000u16), CloseCode::Normal);
|
||||
assert_eq!(CloseCode::from(1001u16), CloseCode::Away);
|
||||
assert_eq!(CloseCode::from(1002u16), CloseCode::Protocol);
|
||||
assert_eq!(CloseCode::from(1003u16), CloseCode::Unsupported);
|
||||
assert_eq!(CloseCode::from(1006u16), CloseCode::Abnormal);
|
||||
assert_eq!(CloseCode::from(1007u16), CloseCode::Invalid);
|
||||
assert_eq!(CloseCode::from(1008u16), CloseCode::Policy);
|
||||
assert_eq!(CloseCode::from(1009u16), CloseCode::Size);
|
||||
assert_eq!(CloseCode::from(1010u16), CloseCode::Extension);
|
||||
assert_eq!(CloseCode::from(1011u16), CloseCode::Error);
|
||||
assert_eq!(CloseCode::from(1012u16), CloseCode::Restart);
|
||||
assert_eq!(CloseCode::from(1013u16), CloseCode::Again);
|
||||
assert_eq!(CloseCode::from(1015u16), CloseCode::Tls);
|
||||
assert_eq!(CloseCode::from(2000u16), CloseCode::Other(2000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn closecode_into_u16() {
|
||||
assert_eq!(1000u16, Into::<u16>::into(CloseCode::Normal));
|
||||
assert_eq!(1001u16, Into::<u16>::into(CloseCode::Away));
|
||||
assert_eq!(1002u16, Into::<u16>::into(CloseCode::Protocol));
|
||||
assert_eq!(1003u16, Into::<u16>::into(CloseCode::Unsupported));
|
||||
assert_eq!(1006u16, Into::<u16>::into(CloseCode::Abnormal));
|
||||
assert_eq!(1007u16, Into::<u16>::into(CloseCode::Invalid));
|
||||
assert_eq!(1008u16, Into::<u16>::into(CloseCode::Policy));
|
||||
assert_eq!(1009u16, Into::<u16>::into(CloseCode::Size));
|
||||
assert_eq!(1010u16, Into::<u16>::into(CloseCode::Extension));
|
||||
assert_eq!(1011u16, Into::<u16>::into(CloseCode::Error));
|
||||
assert_eq!(1012u16, Into::<u16>::into(CloseCode::Restart));
|
||||
assert_eq!(1013u16, Into::<u16>::into(CloseCode::Again));
|
||||
assert_eq!(1015u16, Into::<u16>::into(CloseCode::Tls));
|
||||
assert_eq!(2000u16, Into::<u16>::into(CloseCode::Other(2000)));
|
||||
}
|
||||
}
|
52
actix-http/src/ws/service.rs
Normal file
52
actix-http/src/ws/service.rs
Normal file
@ -0,0 +1,52 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use actix_codec::Framed;
|
||||
use actix_service::{NewService, Service};
|
||||
use futures::future::{ok, FutureResult};
|
||||
use futures::{Async, IntoFuture, Poll};
|
||||
|
||||
use crate::h1::Codec;
|
||||
use crate::request::Request;
|
||||
|
||||
use super::{verify_handshake, HandshakeError};
|
||||
|
||||
pub struct VerifyWebSockets<T> {
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Default for VerifyWebSockets<T> {
|
||||
fn default() -> Self {
|
||||
VerifyWebSockets { _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> NewService for VerifyWebSockets<T> {
|
||||
type Request = (Request, Framed<T, Codec>);
|
||||
type Response = (Request, Framed<T, Codec>);
|
||||
type Error = (HandshakeError, Framed<T, Codec>);
|
||||
type InitError = ();
|
||||
type Service = VerifyWebSockets<T>;
|
||||
type Future = FutureResult<Self::Service, Self::InitError>;
|
||||
|
||||
fn new_service(&self, _: &()) -> Self::Future {
|
||||
ok(VerifyWebSockets { _t: PhantomData })
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Service for VerifyWebSockets<T> {
|
||||
type Request = (Request, Framed<T, Codec>);
|
||||
type Response = (Request, Framed<T, Codec>);
|
||||
type Error = (HandshakeError, Framed<T, Codec>);
|
||||
type Future = FutureResult<Self::Response, Self::Error>;
|
||||
|
||||
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
|
||||
match verify_handshake(&req) {
|
||||
Err(e) => Err((e, framed)).into_future(),
|
||||
Ok(_) => Ok((req, framed)).into_future(),
|
||||
}
|
||||
}
|
||||
}
|
49
actix-http/src/ws/transport.rs
Normal file
49
actix-http/src/ws/transport.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use actix_codec::{AsyncRead, AsyncWrite, Framed};
|
||||
use actix_service::{IntoService, Service};
|
||||
use actix_utils::framed::{FramedTransport, FramedTransportError};
|
||||
use futures::{Future, Poll};
|
||||
|
||||
use super::{Codec, Frame, Message};
|
||||
|
||||
pub struct Transport<S, T>
|
||||
where
|
||||
S: Service<Request = Frame, Response = Message> + 'static,
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
inner: FramedTransport<S, T, Codec>,
|
||||
}
|
||||
|
||||
impl<S, T> Transport<S, T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
S: Service<Request = Frame, Response = Message>,
|
||||
S::Future: 'static,
|
||||
S::Error: 'static,
|
||||
{
|
||||
pub fn new<F: IntoService<S>>(io: T, service: F) -> Self {
|
||||
Transport {
|
||||
inner: FramedTransport::new(Framed::new(io, Codec::new()), service),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with<F: IntoService<S>>(framed: Framed<T, Codec>, service: F) -> Self {
|
||||
Transport {
|
||||
inner: FramedTransport::new(framed, service),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Future for Transport<S, T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
S: Service<Request = Frame, Response = Message>,
|
||||
S::Future: 'static,
|
||||
S::Error: 'static,
|
||||
{
|
||||
type Item = ();
|
||||
type Error = FramedTransportError<S::Error, Codec>;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
self.inner.poll()
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user