From fd3e351c318c31e98dbd0354afb0cba654b0cdbf Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 17 Mar 2019 22:02:03 -0700 Subject: [PATCH] add websockets context --- Cargo.toml | 4 +- actix-web-actors/Cargo.toml | 7 +- actix-web-actors/src/context.rs | 3 +- actix-web-actors/src/lib.rs | 6 + actix-web-actors/src/ws.rs | 408 ++++++++++++++++++++++++++++++ actix-web-actors/tests/test_ws.rs | 67 +++++ src/lib.rs | 2 +- 7 files changed, 490 insertions(+), 7 deletions(-) create mode 100644 actix-web-actors/src/ws.rs create mode 100644 actix-web-actors/tests/test_ws.rs diff --git a/Cargo.toml b/Cargo.toml index df06f3a1..d98c926b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,14 +61,14 @@ ssl = ["openssl", "actix-server/ssl"] # rust-tls = ["rustls", "actix-server/rustls"] [dependencies] -actix-codec = "0.1.0" +actix-codec = "0.1.1" actix-service = "0.3.4" actix-utils = "0.3.4" actix-router = "0.1.0" actix-rt = "0.2.1" actix-web-codegen = { path="actix-web-codegen" } actix-http = { git = "https://github.com/actix/actix-http.git" } -actix-server = "0.4.0" +actix-server = "0.4.1" actix-server-config = "0.1.0" bytes = "0.4" diff --git a/actix-web-actors/Cargo.toml b/actix-web-actors/Cargo.toml index 698acc1c..db42a1a2 100644 --- a/actix-web-actors/Cargo.toml +++ b/actix-web-actors/Cargo.toml @@ -20,9 +20,12 @@ path = "src/lib.rs" [dependencies] actix-web = { path=".." } actix = { git = "https://github.com/actix/actix.git" } - +actix-http = { git = "https://github.com/actix/actix-http.git" } +actix-codec = "0.1.1" bytes = "0.4" futures = "0.1" [dev-dependencies] -actix-rt = "0.2.0" +env_logger = "0.6" +actix-http = { git = "https://github.com/actix/actix-http.git", features=["ssl"] } +actix-http-test = { git = "https://github.com/actix/actix-http.git", features=["ssl"] } diff --git a/actix-web-actors/src/context.rs b/actix-web-actors/src/context.rs index 646698ef..da473ff3 100644 --- a/actix-web-actors/src/context.rs +++ b/actix-web-actors/src/context.rs @@ -198,10 +198,9 @@ mod tests { use std::time::Duration; use actix::Actor; - use actix_web::dev::HttpMessageBody; use actix_web::http::StatusCode; use actix_web::test::{block_on, call_success, init_service, TestRequest}; - use actix_web::{web, App, HttpRequest, HttpResponse}; + use actix_web::{web, App, HttpResponse}; use bytes::{Bytes, BytesMut}; use super::*; diff --git a/actix-web-actors/src/lib.rs b/actix-web-actors/src/lib.rs index 47adf5f0..0d447865 100644 --- a/actix-web-actors/src/lib.rs +++ b/actix-web-actors/src/lib.rs @@ -1,4 +1,10 @@ //! Actix actors integration for Actix web framework mod context; +mod ws; pub use self::context::HttpContext; +pub use self::ws::{ws_handshake, ws_start, WebsocketContext}; + +pub use actix_http::ws::CloseCode as WsCloseCode; +pub use actix_http::ws::ProtocolError as WsProtocolError; +pub use actix_http::ws::{Frame as WsFrame, Message as WsMessage}; diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs new file mode 100644 index 00000000..dca0f46d --- /dev/null +++ b/actix-web-actors/src/ws.rs @@ -0,0 +1,408 @@ +use std::collections::VecDeque; +use std::io; + +use actix::dev::{ + AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, + ToEnvelope, +}; +use actix::fut::ActorFuture; +use actix::{ + Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, + Message as ActixMessage, SpawnHandle, +}; +use actix_codec::{Decoder, Encoder}; +use actix_http::ws::{ + hash_key, CloseReason, Codec, Frame, HandshakeError, Message, ProtocolError, +}; +use actix_web::dev::{Head, HttpResponseBuilder}; +use actix_web::error::{Error, ErrorInternalServerError, PayloadError}; +use actix_web::http::{header, Method, StatusCode}; +use actix_web::{HttpMessage, HttpRequest, HttpResponse}; +use bytes::{Bytes, BytesMut}; +use futures::sync::oneshot::Sender; +use futures::{Async, Future, Poll, Stream}; + +/// Do websocket handshake and start ws actor. +pub fn ws_start( + actor: A, + req: &HttpRequest, + stream: T, +) -> Result +where + A: Actor> + StreamHandler, + T: Stream + 'static, +{ + let mut res = ws_handshake(req)?; + Ok(res.streaming(WebsocketContext::create(actor, stream))) +} + +/// Prepare `WebSocket` handshake response. +/// +/// This function returns handshake `HttpResponse`, ready to send to peer. +/// It does not perform any IO. +/// +// /// `protocols` is a sequence of known protocols. On successful handshake, +// /// the returned response headers contain the first protocol in this list +// /// which the server also knows. +pub fn ws_handshake(req: &HttpRequest) -> Result { + // WebSocket accepts only GET + if *req.method() != Method::GET { + return Err(HandshakeError::GetMethodRequired); + } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + return Err(HandshakeError::NoWebsocketUpgrade); + } + + // Upgrade connection + if !req.upgrade() { + return Err(HandshakeError::NoConnectionUpgrade); + } + + // check supported version + if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) { + return Err(HandshakeError::NoVersionHeader); + } + let supported_ver = { + if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) { + hdr == "13" || hdr == "8" || hdr == "7" + } else { + false + } + }; + if !supported_ver { + return Err(HandshakeError::UnsupportedVersion); + } + + // check client handshake for validity + if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { + return Err(HandshakeError::BadWebsocketKey); + } + let key = { + let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); + hash_key(key.as_ref()) + }; + + Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, "upgrade") + .header(header::UPGRADE, "websocket") + .header(header::TRANSFER_ENCODING, "chunked") + .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) + .take()) +} + +/// Execution context for `WebSockets` actors +pub struct WebsocketContext +where + A: Actor>, +{ + inner: ContextParts, + messages: VecDeque>, +} + +impl ActorContext for WebsocketContext +where + A: Actor, +{ + fn stop(&mut self) { + self.inner.stop(); + } + + fn terminate(&mut self) { + self.inner.terminate() + } + + fn state(&self) -> ActorState { + self.inner.state() + } +} + +impl AsyncContext for WebsocketContext +where + A: Actor, +{ + fn spawn(&mut self, fut: F) -> SpawnHandle + where + F: ActorFuture + 'static, + { + self.inner.spawn(fut) + } + + fn wait(&mut self, fut: F) + where + F: ActorFuture + 'static, + { + self.inner.wait(fut) + } + + #[doc(hidden)] + #[inline] + fn waiting(&self) -> bool { + self.inner.waiting() + || self.inner.state() == ActorState::Stopping + || self.inner.state() == ActorState::Stopped + } + + fn cancel_future(&mut self, handle: SpawnHandle) -> bool { + self.inner.cancel_future(handle) + } + + #[inline] + fn address(&self) -> Addr { + self.inner.address() + } +} + +impl WebsocketContext +where + A: Actor, +{ + #[inline] + /// Create a new Websocket context from a request and an actor + pub fn create(actor: A, stream: S) -> impl Stream + where + A: StreamHandler, + S: Stream + 'static, + { + let mb = Mailbox::default(); + let mut ctx = WebsocketContext { + inner: ContextParts::new(mb.sender_producer()), + messages: VecDeque::new(), + }; + ctx.add_stream(WsStream::new(stream)); + + WebsocketContextFut::new(ctx, actor, mb) + } + + /// Create a new Websocket context + pub fn with_factory( + stream: S, + f: F, + ) -> impl Stream + where + F: FnOnce(&mut Self) -> A + 'static, + A: StreamHandler, + S: Stream + 'static, + { + let mb = Mailbox::default(); + let mut ctx = WebsocketContext { + inner: ContextParts::new(mb.sender_producer()), + messages: VecDeque::new(), + }; + ctx.add_stream(WsStream::new(stream)); + + let act = f(&mut ctx); + + WebsocketContextFut::new(ctx, act, mb) + } +} + +impl WebsocketContext +where + A: Actor, +{ + /// Write payload + /// + /// This is a low-level function that accepts framed messages that should + /// be created using `Frame::message()`. If you want to send text or binary + /// data you should prefer the `text()` or `binary()` convenience functions + /// that handle the framing for you. + #[inline] + pub fn write_raw(&mut self, msg: Message) { + self.messages.push_back(Some(msg)); + } + + /// Send text frame + #[inline] + pub fn text>(&mut self, text: T) { + self.write_raw(Message::Text(text.into())); + } + + /// Send binary frame + #[inline] + pub fn binary>(&mut self, data: B) { + self.write_raw(Message::Binary(data.into())); + } + + /// Send ping frame + #[inline] + pub fn ping(&mut self, message: &str) { + self.write_raw(Message::Ping(message.to_string())); + } + + /// Send pong frame + #[inline] + pub fn pong(&mut self, message: &str) { + self.write_raw(Message::Pong(message.to_string())); + } + + /// Send close frame + #[inline] + pub fn close(&mut self, reason: Option) { + self.write_raw(Message::Close(reason)); + } + + /// Handle of the running future + /// + /// SpawnHandle is the handle returned by `AsyncContext::spawn()` method. + pub fn handle(&self) -> SpawnHandle { + self.inner.curr_handle() + } + + /// Set mailbox capacity + /// + /// By default mailbox capacity is 16 messages. + pub fn set_mailbox_capacity(&mut self, cap: usize) { + self.inner.set_mailbox_capacity(cap) + } +} + +impl AsyncContextParts for WebsocketContext +where + A: Actor, +{ + fn parts(&mut self) -> &mut ContextParts { + &mut self.inner + } +} + +struct WebsocketContextFut +where + A: Actor>, +{ + fut: ContextFut>, + encoder: Codec, + buf: BytesMut, + closed: bool, +} + +impl WebsocketContextFut +where + A: Actor>, +{ + fn new(ctx: WebsocketContext, act: A, mailbox: Mailbox) -> Self { + let fut = ContextFut::new(ctx, act, mailbox); + WebsocketContextFut { + fut, + encoder: Codec::new(), + buf: BytesMut::new(), + closed: false, + } + } +} + +impl Stream for WebsocketContextFut +where + A: Actor>, +{ + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + if self.fut.alive() && self.fut.poll().is_err() { + return Err(ErrorInternalServerError("error")); + } + + // encode messages + while let Some(item) = self.fut.ctx().messages.pop_front() { + if let Some(msg) = item { + self.encoder.encode(msg, &mut self.buf)?; + } else { + self.closed = true; + break; + } + } + + if !self.buf.is_empty() { + Ok(Async::Ready(Some(self.buf.take().freeze()))) + } else if self.fut.alive() && !self.closed { + Ok(Async::NotReady) + } else { + Ok(Async::Ready(None)) + } + } +} + +impl ToEnvelope for WebsocketContext +where + A: Actor> + Handler, + M: ActixMessage + Send + 'static, + M::Result: Send, +{ + fn pack(msg: M, tx: Option>) -> Envelope { + Envelope::new(msg, tx) + } +} + +struct WsStream { + stream: S, + decoder: Codec, + buf: BytesMut, + closed: bool, +} + +impl WsStream +where + S: Stream, +{ + fn new(stream: S) -> Self { + Self { + stream, + decoder: Codec::new(), + buf: BytesMut::new(), + closed: false, + } + } +} + +impl Stream for WsStream +where + S: Stream, +{ + type Item = Frame; + type Error = ProtocolError; + + fn poll(&mut self) -> Poll, Self::Error> { + if !self.closed { + loop { + match self.stream.poll() { + Ok(Async::Ready(Some(chunk))) => { + self.buf.extend_from_slice(&chunk[..]); + } + Ok(Async::Ready(None)) => { + self.closed = true; + break; + } + Ok(Async::NotReady) => break, + Err(e) => { + return Err(ProtocolError::Io(io::Error::new( + io::ErrorKind::Other, + format!("{}", e), + ))); + } + } + } + } + + match self.decoder.decode(&mut self.buf)? { + None => { + if self.closed { + Ok(Async::Ready(None)) + } else { + Ok(Async::NotReady) + } + } + Some(frm) => Ok(Async::Ready(Some(frm))), + } + } +} diff --git a/actix-web-actors/tests/test_ws.rs b/actix-web-actors/tests/test_ws.rs new file mode 100644 index 00000000..0a214644 --- /dev/null +++ b/actix-web-actors/tests/test_ws.rs @@ -0,0 +1,67 @@ +use actix::prelude::*; +use actix_http::HttpService; +use actix_http_test::TestServer; +use actix_web::{web, App, HttpRequest}; +use actix_web_actors::*; +use bytes::{Bytes, BytesMut}; +use futures::{Sink, Stream}; + +struct Ws; + +impl Actor for Ws { + type Context = WebsocketContext; +} + +impl StreamHandler for Ws { + fn handle(&mut self, msg: WsFrame, ctx: &mut Self::Context) { + match msg { + WsFrame::Ping(msg) => ctx.pong(&msg), + WsFrame::Text(text) => { + ctx.text(String::from_utf8_lossy(&text.unwrap())).to_owned() + } + WsFrame::Binary(bin) => ctx.binary(bin.unwrap()), + WsFrame::Close(reason) => ctx.close(reason), + _ => (), + } + } +} + +#[test] +fn test_simple() { + let mut srv = + TestServer::new(|| { + HttpService::new(App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload<_>| ws_start(Ws, &req, stream), + ))) + }); + + // client service + let framed = srv.ws().unwrap(); + let framed = srv + .block_on(framed.send(WsMessage::Text("text".to_string()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(WsFrame::Text(Some(BytesMut::from("text"))))); + + let framed = srv + .block_on(framed.send(WsMessage::Binary("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(WsFrame::Binary(Some(Bytes::from_static(b"text").into()))) + ); + + let framed = srv + .block_on(framed.send(WsMessage::Ping("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(WsFrame::Pong("text".to_string().into()))); + + let framed = srv + .block_on(framed.send(WsMessage::Close(Some(WsCloseCode::Normal.into())))) + .unwrap(); + + let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(WsFrame::Close(Some(WsCloseCode::Normal.into())))); +} diff --git a/src/lib.rs b/src/lib.rs index dbf8f743..f59cbc26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,7 +66,7 @@ pub mod dev { pub use actix_http::body::{Body, BodyLength, MessageBody, ResponseBody}; pub use actix_http::ResponseBuilder as HttpResponseBuilder; pub use actix_http::{ - Extensions, Payload, PayloadStream, RequestHead, ResponseHead, + Extensions, Head, Payload, PayloadStream, RequestHead, ResponseHead, }; pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; pub use actix_server::Server;