//! `WebSocket` support. //! //! To setup a `WebSocket`, first do web socket handshake then on success //! convert `Payload` into a `WsStream` stream and then use `WsWriter` to //! communicate with the peer. //! ``` use bytes::Bytes; use futures::{Async, Poll, Stream}; use http::{header, Method, StatusCode}; use body::Binary; use error::{PayloadError, ResponseError}; use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; use payload::PayloadBuffer; use request::Request; mod frame; mod mask; mod proto; pub use self::frame::{Frame, FramedMessage}; pub use self::proto::{CloseCode, CloseReason, OpCode}; /// Websocket protocol errors #[derive(Fail, Debug)] pub enum ProtocolError { /// Received an unmasked frame from client #[fail(display = "Received an unmasked frame from client")] UnmaskedFrame, /// Received a masked frame from server #[fail(display = "Received a masked frame from server")] MaskedFrame, /// Encountered invalid opcode #[fail(display = "Invalid opcode: {}", _0)] InvalidOpcode(u8), /// Invalid control frame length #[fail(display = "Invalid control frame length: {}", _0)] InvalidLength(usize), /// Bad web socket op code #[fail(display = "Bad web socket op code")] BadOpCode, /// A payload reached size limit. #[fail(display = "A payload reached size limit.")] Overflow, /// Continuation is not supported #[fail(display = "Continuation is not supported.")] NoContinuation, /// Bad utf-8 encoding #[fail(display = "Bad utf-8 encoding.")] BadEncoding, /// Payload error #[fail(display = "Payload error: {}", _0)] Payload(#[cause] PayloadError), } impl ResponseError for ProtocolError {} impl From for ProtocolError { fn from(err: PayloadError) -> ProtocolError { ProtocolError::Payload(err) } } /// Websocket handshake errors #[derive(Fail, PartialEq, Debug)] pub enum HandshakeError { /// Only get method is allowed #[fail(display = "Method not allowed")] GetMethodRequired, /// Upgrade header if not set to websocket #[fail(display = "Websocket upgrade is expected")] NoWebsocketUpgrade, /// Connection header is not set to upgrade #[fail(display = "Connection upgrade is expected")] NoConnectionUpgrade, /// Websocket version header is not set #[fail(display = "Websocket version header is required")] NoVersionHeader, /// Unsupported websocket version #[fail(display = "Unsupported version")] UnsupportedVersion, /// Websocket key is not set or wrong #[fail(display = "Unknown websocket key")] BadWebsocketKey, } impl ResponseError for HandshakeError { fn error_response(&self) -> HttpResponse { match *self { HandshakeError::GetMethodRequired => HttpResponse::MethodNotAllowed() .header(header::ALLOW, "GET") .finish(), HandshakeError::NoWebsocketUpgrade => HttpResponse::BadRequest() .reason("No WebSocket UPGRADE header found") .finish(), HandshakeError::NoConnectionUpgrade => HttpResponse::BadRequest() .reason("No CONNECTION upgrade") .finish(), HandshakeError::NoVersionHeader => HttpResponse::BadRequest() .reason("Websocket version header is required") .finish(), HandshakeError::UnsupportedVersion => HttpResponse::BadRequest() .reason("Unsupported version") .finish(), HandshakeError::BadWebsocketKey => HttpResponse::BadRequest() .reason("Handshake error") .finish(), } } } /// `WebSocket` Message #[derive(Debug, PartialEq)] pub enum Message { /// Text message Text(String), /// Binary message Binary(Binary), /// Ping message Ping(String), /// Pong message Pong(String), /// Close message with optional reason Close(Option), } /// Prepare `WebSocket` handshake response. /// /// This function returns handshake `HttpResponse`, ready to send to peer. /// It does not perform any IO. /// // /// `protocols` is a sequence of known protocols. On successful handshake, // /// the returned response headers contain the first protocol in this list // /// which the server also knows. pub fn handshake(req: &Request) -> Result { // WebSocket accepts only GET if *req.method() != Method::GET { return Err(HandshakeError::GetMethodRequired); } // Check for "UPGRADE" to websocket header let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { if let Ok(s) = hdr.to_str() { s.to_lowercase().contains("websocket") } else { false } } else { false }; if !has_hdr { return Err(HandshakeError::NoWebsocketUpgrade); } // Upgrade connection if !req.upgrade() { return Err(HandshakeError::NoConnectionUpgrade); } // check supported version if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) { return Err(HandshakeError::NoVersionHeader); } let supported_ver = { if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) { hdr == "13" || hdr == "8" || hdr == "7" } else { false } }; if !supported_ver { return Err(HandshakeError::UnsupportedVersion); } // check client handshake for validity if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { return Err(HandshakeError::BadWebsocketKey); } let key = { let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); proto::hash_key(key.as_ref()) }; Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) .connection_type(ConnectionType::Upgrade) .header(header::UPGRADE, "websocket") .header(header::TRANSFER_ENCODING, "chunked") .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) .take()) } /// Maps `Payload` stream into stream of `ws::Message` items pub struct WsStream { rx: PayloadBuffer, closed: bool, max_size: usize, } impl WsStream where S: Stream, { /// Create new websocket frames stream pub fn new(stream: S) -> WsStream { WsStream { rx: PayloadBuffer::new(stream), closed: false, max_size: 65_536, } } /// Set max frame size /// /// By default max size is set to 64kb pub fn max_size(mut self, size: usize) -> Self { self.max_size = size; self } } impl Stream for WsStream where S: Stream, { type Item = Message; type Error = ProtocolError; fn poll(&mut self) -> Poll, Self::Error> { if self.closed { return Ok(Async::Ready(None)); } match Frame::parse(&mut self.rx, true, self.max_size) { Ok(Async::Ready(Some(frame))) => { let (finished, opcode, payload) = frame.unpack(); // continuation is not supported if !finished { self.closed = true; return Err(ProtocolError::NoContinuation); } match opcode { OpCode::Continue => Err(ProtocolError::NoContinuation), OpCode::Bad => { self.closed = true; Err(ProtocolError::BadOpCode) } OpCode::Close => { self.closed = true; let close_reason = Frame::parse_close_payload(&payload); Ok(Async::Ready(Some(Message::Close(close_reason)))) } OpCode::Ping => Ok(Async::Ready(Some(Message::Ping( String::from_utf8_lossy(payload.as_ref()).into(), )))), OpCode::Pong => Ok(Async::Ready(Some(Message::Pong( String::from_utf8_lossy(payload.as_ref()).into(), )))), OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))), OpCode::Text => { let tmp = Vec::from(payload.as_ref()); match String::from_utf8(tmp) { Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), Err(_) => { self.closed = true; Err(ProtocolError::BadEncoding) } } } } } Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Ok(Async::NotReady) => Ok(Async::NotReady), Err(e) => { self.closed = true; Err(e) } } } } /// Common writing methods for a websocket. pub trait WsWriter { /// Send a text fn send_text>(&mut self, text: T); /// Send a binary fn send_binary>(&mut self, data: B); /// Send a ping message fn send_ping(&mut self, message: &str); /// Send a pong message fn send_pong(&mut self, message: &str); /// Close the connection fn send_close(&mut self, reason: Option); } #[cfg(test)] mod tests { use super::*; use http::{header, Method}; use test::TestRequest; #[test] fn test_handshake() { let req = TestRequest::default().method(Method::POST).finish(); assert_eq!( HandshakeError::GetMethodRequired, handshake(&req).err().unwrap() ); let req = TestRequest::default().finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header(header::UPGRADE, header::HeaderValue::from_static("test")) .finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ).finish(); assert_eq!( HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ).finish(); assert_eq!( HandshakeError::NoVersionHeader, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ).header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("5"), ).finish(); assert_eq!( HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ).header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), ).finish(); assert_eq!( HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ).header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), ).header( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), ).finish(); assert_eq!( StatusCode::SWITCHING_PROTOCOLS, handshake(&req).unwrap().finish().status() ); } #[test] fn test_wserror_http_response() { let resp: HttpResponse = HandshakeError::GetMethodRequired.error_response(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); let resp: HttpResponse = HandshakeError::NoWebsocketUpgrade.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let resp: HttpResponse = HandshakeError::NoConnectionUpgrade.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let resp: HttpResponse = HandshakeError::NoVersionHeader.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let resp: HttpResponse = HandshakeError::UnsupportedVersion.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let resp: HttpResponse = HandshakeError::BadWebsocketKey.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } }