use std::{ cell::Cell, convert::Infallible, task::{Context, Poll}, }; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::{ body::{BodySize, BoxBody}, h1, ws::{self, CloseCode, Frame, Item, Message}, Error, HttpService, Request, Response, }; use actix_http_test::test_server; use actix_service::{fn_factory, Service}; use bytes::Bytes; use derive_more::{Display, Error, From}; use futures_core::future::LocalBoxFuture; use futures_util::{SinkExt as _, StreamExt as _}; #[derive(Clone)] struct WsService(Cell<bool>); impl WsService { fn new() -> Self { WsService(Cell::new(false)) } fn set_polled(&self) { self.0.set(true); } fn was_polled(&self) -> bool { self.0.get() } } #[derive(Debug, Display, Error, From)] enum WsServiceError { #[display(fmt = "HTTP error")] Http(actix_http::Error), #[display(fmt = "WS handshake error")] Ws(actix_http::ws::HandshakeError), #[display(fmt = "I/O error")] Io(std::io::Error), #[display(fmt = "dispatcher error")] Dispatcher, } impl From<WsServiceError> for Response<BoxBody> { fn from(err: WsServiceError) -> Self { match err { WsServiceError::Http(err) => err.into(), WsServiceError::Ws(err) => err.into(), WsServiceError::Io(_err) => unreachable!(), WsServiceError::Dispatcher => { Response::internal_server_error().set_body(BoxBody::new(format!("{}", err))) } } } } impl<T> Service<(Request, Framed<T, h1::Codec>)> for WsService where T: AsyncRead + AsyncWrite + Unpin + 'static, { type Response = (); type Error = WsServiceError; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { self.set_polled(); Poll::Ready(Ok(())) } fn call(&self, (req, mut framed): (Request, Framed<T, h1::Codec>)) -> Self::Future { assert!(self.was_polled()); Box::pin(async move { let res = ws::handshake(req.head())?.message_body(())?; framed.send((res, BodySize::None).into()).await?; let framed = framed.replace_codec(ws::Codec::new()); ws::Dispatcher::with(framed, service) .await .map_err(|_| WsServiceError::Dispatcher)?; Ok(()) }) } } async fn service(msg: Frame) -> Result<Message, Error> { let msg = match msg { Frame::Ping(msg) => Message::Pong(msg), Frame::Text(text) => Message::Text(String::from_utf8_lossy(&text).into_owned().into()), Frame::Binary(bin) => Message::Binary(bin), Frame::Continuation(item) => Message::Continuation(item), Frame::Close(reason) => Message::Close(reason), _ => return Err(ws::ProtocolError::BadOpCode.into()), }; Ok(msg) } #[actix_rt::test] async fn simple() { let mut srv = test_server(|| { HttpService::build() .upgrade(fn_factory(|| async { Ok::<_, Infallible>(WsService::new()) })) .finish(|_| async { Ok::<_, Infallible>(Response::not_found()) }) .tcp() }) .await; // client service let mut framed = srv.ws().await.unwrap(); framed.send(Message::Text("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, Frame::Text(Bytes::from_static(b"text"))); framed.send(Message::Binary("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, Frame::Binary(Bytes::from_static(&b"text"[..]))); framed.send(Message::Ping("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, Frame::Pong("text".to_string().into())); framed .send(Message::Continuation(Item::FirstText("text".into()))) .await .unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!( item, Frame::Continuation(Item::FirstText(Bytes::from_static(b"text"))) ); assert!(framed .send(Message::Continuation(Item::FirstText("text".into()))) .await .is_err()); assert!(framed .send(Message::Continuation(Item::FirstBinary("text".into()))) .await .is_err()); framed .send(Message::Continuation(Item::Continue("text".into()))) .await .unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!( item, Frame::Continuation(Item::Continue(Bytes::from_static(b"text"))) ); framed .send(Message::Continuation(Item::Last("text".into()))) .await .unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!( item, Frame::Continuation(Item::Last(Bytes::from_static(b"text"))) ); assert!(framed .send(Message::Continuation(Item::Continue("text".into()))) .await .is_err()); assert!(framed .send(Message::Continuation(Item::Last("text".into()))) .await .is_err()); framed .send(Message::Close(Some(CloseCode::Normal.into()))) .await .unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, Frame::Close(Some(CloseCode::Normal.into()))); }