use std::{
collections::VecDeque,
future::poll_fn,
io,
pin::Pin,
task::{Context, Poll},
};
use actix_codec::{Decoder, Encoder};
use actix_http::{
ws::{Codec, Frame, Message, ProtocolError},
Payload,
};
use actix_web::{
web::{Bytes, BytesMut},
Error,
};
use futures_core::stream::Stream;
use tokio::sync::mpsc::Receiver;
pub struct StreamingBody {
session_rx: Receiver<Message>,
messages: VecDeque<Message>,
buf: BytesMut,
codec: Codec,
closing: bool,
}
pub struct MessageStream {
payload: Payload,
messages: VecDeque<Message>,
buf: BytesMut,
codec: Codec,
closing: bool,
}
impl StreamingBody {
pub(super) fn new(session_rx: Receiver<Message>) -> Self {
StreamingBody {
session_rx,
messages: VecDeque::new(),
buf: BytesMut::new(),
codec: Codec::new(),
closing: false,
}
}
}
impl MessageStream {
pub(super) fn new(payload: Payload) -> Self {
MessageStream {
payload,
messages: VecDeque::new(),
buf: BytesMut::new(),
codec: Codec::new(),
closing: false,
}
}
pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
}
}
impl Stream for StreamingBody {
type Item = Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.closing {
return Poll::Ready(None);
}
loop {
match Pin::new(&mut this.session_rx).poll_recv(cx) {
Poll::Ready(Some(msg)) => {
this.messages.push_back(msg);
}
Poll::Ready(None) => {
this.closing = true;
break;
}
Poll::Pending => break,
}
}
while let Some(msg) = this.messages.pop_front() {
if let Err(e) = this.codec.encode(msg, &mut this.buf) {
return Poll::Ready(Some(Err(e.into())));
}
}
if !this.buf.is_empty() {
return Poll::Ready(Some(Ok(std::mem::take(&mut this.buf).freeze())));
}
Poll::Pending
}
}
impl Stream for MessageStream {
type Item = Result<Message, ProtocolError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(msg) = this.messages.pop_front() {
return Poll::Ready(Some(Ok(msg)));
}
if !this.closing {
loop {
match Pin::new(&mut this.payload).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.buf.extend_from_slice(&bytes);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new(
io::ErrorKind::Other,
e.to_string(),
)))));
}
Poll::Ready(None) => {
this.closing = true;
break;
}
Poll::Pending => break,
}
}
}
while let Some(frame) = this.codec.decode(&mut this.buf)? {
let message = match frame {
Frame::Text(bytes) => {
let s = std::str::from_utf8(&bytes)
.map_err(|e| {
ProtocolError::Io(io::Error::new(io::ErrorKind::Other, e.to_string()))
})?
.to_string();
Message::Text(s.into())
}
Frame::Binary(bytes) => Message::Binary(bytes),
Frame::Ping(bytes) => Message::Ping(bytes),
Frame::Pong(bytes) => Message::Pong(bytes),
Frame::Close(reason) => Message::Close(reason),
Frame::Continuation(item) => Message::Continuation(item),
};
this.messages.push_back(message);
}
if let Some(msg) = this.messages.pop_front() {
return Poll::Ready(Some(Ok(msg)));
}
if this.closing {
return Poll::Ready(None);
}
Poll::Pending
}
}