//! Websocket integration use std::collections::VecDeque; use std::io; use actix::dev::{ AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, ToEnvelope, }; use actix::fut::ActorFuture; use actix::{ Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message as ActixMessage, SpawnHandle, }; use actix_codec::{Decoder, Encoder}; use actix_http::ws::{hash_key, Codec}; pub use actix_http::ws::{ CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError, }; use actix_web::dev::HttpResponseBuilder; use actix_web::error::{Error, ErrorInternalServerError, PayloadError}; use actix_web::http::{header, Method, StatusCode}; use actix_web::{HttpRequest, HttpResponse}; use bytes::{Bytes, BytesMut}; use futures::sync::oneshot::Sender; use futures::{Async, Future, Poll, Stream}; /// Do websocket handshake and start ws actor. pub fn start(actor: A, req: &HttpRequest, stream: T) -> Result where A: Actor> + StreamHandler, T: Stream + 'static, { let mut res = handshake(req)?; Ok(res.streaming(WebsocketContext::create(actor, stream))) } /// Prepare `WebSocket` handshake response. /// /// This function returns handshake `HttpResponse`, ready to send to peer. /// It does not perform any IO. /// // /// `protocols` is a sequence of known protocols. On successful handshake, // /// the returned response headers contain the first protocol in this list // /// which the server also knows. pub fn handshake(req: &HttpRequest) -> Result { // WebSocket accepts only GET if *req.method() != Method::GET { return Err(HandshakeError::GetMethodRequired); } // Check for "UPGRADE" to websocket header let has_hdr = if let Some(hdr) = req.headers().get(&header::UPGRADE) { if let Ok(s) = hdr.to_str() { s.to_ascii_lowercase().contains("websocket") } else { false } } else { false }; if !has_hdr { return Err(HandshakeError::NoWebsocketUpgrade); } // Upgrade connection if !req.head().upgrade() { return Err(HandshakeError::NoConnectionUpgrade); } // check supported version if !req.headers().contains_key(&header::SEC_WEBSOCKET_VERSION) { return Err(HandshakeError::NoVersionHeader); } let supported_ver = { if let Some(hdr) = req.headers().get(&header::SEC_WEBSOCKET_VERSION) { hdr == "13" || hdr == "8" || hdr == "7" } else { false } }; if !supported_ver { return Err(HandshakeError::UnsupportedVersion); } // check client handshake for validity if !req.headers().contains_key(&header::SEC_WEBSOCKET_KEY) { return Err(HandshakeError::BadWebsocketKey); } let key = { let key = req.headers().get(&header::SEC_WEBSOCKET_KEY).unwrap(); hash_key(key.as_ref()) }; Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) .upgrade("websocket") .header(header::TRANSFER_ENCODING, "chunked") .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) .take()) } /// Execution context for `WebSockets` actors pub struct WebsocketContext where A: Actor>, { inner: ContextParts, messages: VecDeque>, } impl ActorContext for WebsocketContext where A: Actor, { fn stop(&mut self) { self.inner.stop(); } fn terminate(&mut self) { self.inner.terminate() } fn state(&self) -> ActorState { self.inner.state() } } impl AsyncContext for WebsocketContext where A: Actor, { fn spawn(&mut self, fut: F) -> SpawnHandle where F: ActorFuture + 'static, { self.inner.spawn(fut) } fn wait(&mut self, fut: F) where F: ActorFuture + 'static, { self.inner.wait(fut) } #[doc(hidden)] #[inline] fn waiting(&self) -> bool { self.inner.waiting() || self.inner.state() == ActorState::Stopping || self.inner.state() == ActorState::Stopped } fn cancel_future(&mut self, handle: SpawnHandle) -> bool { self.inner.cancel_future(handle) } #[inline] fn address(&self) -> Addr { self.inner.address() } } impl WebsocketContext where A: Actor, { #[inline] /// Create a new Websocket context from a request and an actor pub fn create(actor: A, stream: S) -> impl Stream where A: StreamHandler, S: Stream + 'static, { let mb = Mailbox::default(); let mut ctx = WebsocketContext { inner: ContextParts::new(mb.sender_producer()), messages: VecDeque::new(), }; ctx.add_stream(WsStream::new(stream, Codec::new())); WebsocketContextFut::new(ctx, actor, mb, Codec::new()) } #[inline] /// Create a new Websocket context from a request, an actor, and a codec pub fn with_codec( actor: A, stream: S, codec: Codec, ) -> impl Stream where A: StreamHandler, S: Stream + 'static, { let mb = Mailbox::default(); let mut ctx = WebsocketContext { inner: ContextParts::new(mb.sender_producer()), messages: VecDeque::new(), }; ctx.add_stream(WsStream::new(stream, codec)); WebsocketContextFut::new(ctx, actor, mb, codec) } /// Create a new Websocket context pub fn with_factory( stream: S, f: F, ) -> impl Stream where F: FnOnce(&mut Self) -> A + 'static, A: StreamHandler, S: Stream + 'static, { let mb = Mailbox::default(); let mut ctx = WebsocketContext { inner: ContextParts::new(mb.sender_producer()), messages: VecDeque::new(), }; ctx.add_stream(WsStream::new(stream, Codec::new())); let act = f(&mut ctx); WebsocketContextFut::new(ctx, act, mb, Codec::new()) } } impl WebsocketContext where A: Actor, { /// Write payload /// /// This is a low-level function that accepts framed messages that should /// be created using `Frame::message()`. If you want to send text or binary /// data you should prefer the `text()` or `binary()` convenience functions /// that handle the framing for you. #[inline] pub fn write_raw(&mut self, msg: Message) { self.messages.push_back(Some(msg)); } /// Send text frame #[inline] pub fn text>(&mut self, text: T) { self.write_raw(Message::Text(text.into())); } /// Send binary frame #[inline] pub fn binary>(&mut self, data: B) { self.write_raw(Message::Binary(data.into())); } /// Send ping frame #[inline] pub fn ping(&mut self, message: &str) { self.write_raw(Message::Ping(message.to_string())); } /// Send pong frame #[inline] pub fn pong(&mut self, message: &str) { self.write_raw(Message::Pong(message.to_string())); } /// Send close frame #[inline] pub fn close(&mut self, reason: Option) { self.write_raw(Message::Close(reason)); } /// Handle of the running future /// /// SpawnHandle is the handle returned by `AsyncContext::spawn()` method. pub fn handle(&self) -> SpawnHandle { self.inner.curr_handle() } /// Set mailbox capacity /// /// By default mailbox capacity is 16 messages. pub fn set_mailbox_capacity(&mut self, cap: usize) { self.inner.set_mailbox_capacity(cap) } } impl AsyncContextParts for WebsocketContext where A: Actor, { fn parts(&mut self) -> &mut ContextParts { &mut self.inner } } struct WebsocketContextFut where A: Actor>, { fut: ContextFut>, encoder: Codec, buf: BytesMut, closed: bool, } impl WebsocketContextFut where A: Actor>, { fn new(ctx: WebsocketContext, act: A, mailbox: Mailbox, codec: Codec) -> Self { let fut = ContextFut::new(ctx, act, mailbox); WebsocketContextFut { fut, encoder: codec, buf: BytesMut::new(), closed: false, } } } impl Stream for WebsocketContextFut where A: Actor>, { type Item = Bytes; type Error = Error; fn poll(&mut self) -> Poll, Error> { if self.fut.alive() && self.fut.poll().is_err() { return Err(ErrorInternalServerError("error")); } // encode messages while let Some(item) = self.fut.ctx().messages.pop_front() { if let Some(msg) = item { self.encoder.encode(msg, &mut self.buf)?; } else { self.closed = true; break; } } if !self.buf.is_empty() { Ok(Async::Ready(Some(self.buf.take().freeze()))) } else if self.fut.alive() && !self.closed { Ok(Async::NotReady) } else { Ok(Async::Ready(None)) } } } impl ToEnvelope for WebsocketContext where A: Actor> + Handler, M: ActixMessage + Send + 'static, M::Result: Send, { fn pack(msg: M, tx: Option>) -> Envelope { Envelope::new(msg, tx) } } struct WsStream { stream: S, decoder: Codec, buf: BytesMut, closed: bool, } impl WsStream where S: Stream, { fn new(stream: S, codec: Codec) -> Self { Self { stream, decoder: codec, buf: BytesMut::new(), closed: false, } } } impl Stream for WsStream where S: Stream, { type Item = Message; type Error = ProtocolError; fn poll(&mut self) -> Poll, Self::Error> { if !self.closed { loop { match self.stream.poll() { Ok(Async::Ready(Some(chunk))) => { self.buf.extend_from_slice(&chunk[..]); } Ok(Async::Ready(None)) => { self.closed = true; break; } Ok(Async::NotReady) => break, Err(e) => { return Err(ProtocolError::Io(io::Error::new( io::ErrorKind::Other, format!("{}", e), ))); } } } } match self.decoder.decode(&mut self.buf)? { None => { if self.closed { Ok(Async::Ready(None)) } else { Ok(Async::NotReady) } } Some(frm) => { let msg = match frm { Frame::Text(data) => { if let Some(data) = data { Message::Text( std::str::from_utf8(&data) .map_err(|_| ProtocolError::BadEncoding)? .to_string(), ) } else { Message::Text(String::new()) } } Frame::Binary(data) => Message::Binary( data.map(|b| b.freeze()).unwrap_or_else(|| Bytes::new()), ), Frame::Ping(s) => Message::Ping(s), Frame::Pong(s) => Message::Pong(s), Frame::Close(reason) => Message::Close(reason), }; Ok(Async::Ready(Some(msg))) } } } } #[cfg(test)] mod tests { use super::*; use actix_web::http::{header, Method}; use actix_web::test::TestRequest; #[test] fn test_handshake() { let req = TestRequest::default() .method(Method::POST) .to_http_request(); assert_eq!( HandshakeError::GetMethodRequired, handshake(&req).err().unwrap() ); let req = TestRequest::default().to_http_request(); assert_eq!( HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header(header::UPGRADE, header::HeaderValue::from_static("test")) .to_http_request(); assert_eq!( HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ) .to_http_request(); assert_eq!( HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ) .header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ) .to_http_request(); assert_eq!( HandshakeError::NoVersionHeader, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ) .header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ) .header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("5"), ) .to_http_request(); assert_eq!( HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ) .header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ) .header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), ) .to_http_request(); assert_eq!( HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap() ); let req = TestRequest::default() .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), ) .header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), ) .header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), ) .header( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), ) .to_http_request(); assert_eq!( StatusCode::SWITCHING_PROTOCOLS, handshake(&req).unwrap().finish().status() ); } }