//! Payload stream use std::{ cell::RefCell, collections::VecDeque, pin::Pin, rc::{Rc, Weak}, task::{Context, Poll, Waker}, }; use bytes::Bytes; use futures_core::Stream; use crate::error::PayloadError; /// max buffer size 32k pub(crate) const MAX_BUFFER_SIZE: usize = 32_768; #[derive(Debug, PartialEq, Eq)] pub enum PayloadStatus { Read, Pause, Dropped, } /// Buffered stream of bytes chunks /// /// Payload stores chunks in a vector. First chunk can be received with `poll_next`. Payload does /// not notify current task when new data is available. /// /// Payload can be used as `Response` body stream. #[derive(Debug)] pub struct Payload { inner: Rc>, } impl Payload { /// Creates a payload stream. /// /// This method construct two objects responsible for bytes stream generation: /// - `PayloadSender` - *Sender* side of the stream /// - `Payload` - *Receiver* side of the stream pub fn create(eof: bool) -> (PayloadSender, Payload) { let shared = Rc::new(RefCell::new(Inner::new(eof))); ( PayloadSender::new(Rc::downgrade(&shared)), Payload { inner: shared }, ) } /// Creates an empty payload. pub(crate) fn empty() -> Payload { Payload { inner: Rc::new(RefCell::new(Inner::new(true))), } } /// Length of the data in this payload #[cfg(test)] pub fn len(&self) -> usize { self.inner.borrow().len() } /// Is payload empty #[cfg(test)] pub fn is_empty(&self) -> bool { self.inner.borrow().len() == 0 } /// Put unused data back to payload #[inline] pub fn unread_data(&mut self, data: Bytes) { self.inner.borrow_mut().unread_data(data); } } impl Stream for Payload { type Item = Result; fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx) } } /// Sender part of the payload stream pub struct PayloadSender { inner: Weak>, } impl PayloadSender { fn new(inner: Weak>) -> Self { Self { inner } } #[inline] pub fn set_error(&mut self, err: PayloadError) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().set_error(err) } } #[inline] pub fn feed_eof(&mut self) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().feed_eof() } } #[inline] pub fn feed_data(&mut self, data: Bytes) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().feed_data(data) } } #[allow(clippy::needless_pass_by_ref_mut)] #[inline] pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus { // we check need_read only if Payload (other side) is alive, // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { if shared.borrow().need_read { PayloadStatus::Read } else { shared.borrow_mut().register_io(cx); PayloadStatus::Pause } } else { PayloadStatus::Dropped } } } #[derive(Debug)] struct Inner { len: usize, eof: bool, err: Option, need_read: bool, items: VecDeque, task: Option, io_task: Option, } impl Inner { fn new(eof: bool) -> Self { Inner { eof, len: 0, err: None, items: VecDeque::new(), need_read: true, task: None, io_task: None, } } /// Wake up future waiting for payload data to be available. fn wake(&mut self) { if let Some(waker) = self.task.take() { waker.wake(); } } /// Wake up future feeding data to Payload. fn wake_io(&mut self) { if let Some(waker) = self.io_task.take() { waker.wake(); } } /// Register future waiting data from payload. /// Waker would be used in `Inner::wake` fn register(&mut self, cx: &Context<'_>) { if self .task .as_ref() .map_or(true, |w| !cx.waker().will_wake(w)) { self.task = Some(cx.waker().clone()); } } // Register future feeding data to payload. /// Waker would be used in `Inner::wake_io` fn register_io(&mut self, cx: &Context<'_>) { if self .io_task .as_ref() .map_or(true, |w| !cx.waker().will_wake(w)) { self.io_task = Some(cx.waker().clone()); } } #[inline] fn set_error(&mut self, err: PayloadError) { self.err = Some(err); self.wake(); } #[inline] fn feed_eof(&mut self) { self.eof = true; self.wake(); } #[inline] fn feed_data(&mut self, data: Bytes) { self.len += data.len(); self.items.push_back(data); self.need_read = self.len < MAX_BUFFER_SIZE; self.wake(); } #[cfg(test)] fn len(&self) -> usize { self.len } fn poll_next( mut self: Pin<&mut Self>, cx: &Context<'_>, ) -> Poll>> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); self.need_read = self.len < MAX_BUFFER_SIZE; if self.need_read && !self.eof { self.register(cx); } self.wake_io(); Poll::Ready(Some(Ok(data))) } else if let Some(err) = self.err.take() { Poll::Ready(Some(Err(err))) } else if self.eof { Poll::Ready(None) } else { self.need_read = true; self.register(cx); self.wake_io(); Poll::Pending } } fn unread_data(&mut self, data: Bytes) { self.len += data.len(); self.items.push_front(data); } } #[cfg(test)] mod tests { use std::{task::Poll, time::Duration}; use actix_rt::time::timeout; use actix_utils::future::poll_fn; use futures_util::{FutureExt, StreamExt}; use static_assertions::{assert_impl_all, assert_not_impl_any}; use tokio::sync::oneshot; use super::*; assert_impl_all!(Payload: Unpin); assert_not_impl_any!(Payload: Send, Sync); assert_impl_all!(Inner: Unpin, Send, Sync); const WAKE_TIMEOUT: Duration = Duration::from_secs(2); fn prepare_waking_test( mut payload: Payload, expected: Option>, ) -> (oneshot::Receiver<()>, actix_rt::task::JoinHandle<()>) { let (tx, rx) = oneshot::channel(); let handle = actix_rt::spawn(async move { // Make sure to poll once to set the waker poll_fn(|cx| { assert!(payload.poll_next_unpin(cx).is_pending()); Poll::Ready(()) }) .await; tx.send(()).unwrap(); // actix-rt is single-threaded, so this won't race with `rx.await` let mut pend_once = false; poll_fn(|_| { if pend_once { Poll::Ready(()) } else { // Return pending without storing wakers, we already did on the previous // `poll_fn`, now this task will only continue if the `sender` wakes us pend_once = true; Poll::Pending } }) .await; let got = payload.next().now_or_never().unwrap(); match expected { Some(Ok(_)) => assert!(got.unwrap().is_ok()), Some(Err(_)) => assert!(got.unwrap().is_err()), None => assert!(got.is_none()), } }); (rx, handle) } #[actix_rt::test] async fn wake_on_error() { let (mut sender, payload) = Payload::create(false); let (rx, handle) = prepare_waking_test(payload, Some(Err(()))); rx.await.unwrap(); sender.set_error(PayloadError::Incomplete(None)); timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap(); } #[actix_rt::test] async fn wake_on_eof() { let (mut sender, payload) = Payload::create(false); let (rx, handle) = prepare_waking_test(payload, None); rx.await.unwrap(); sender.feed_eof(); timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap(); } #[actix_rt::test] async fn test_unread_data() { let (_, mut payload) = Payload::create(false); payload.unread_data(Bytes::from("data")); assert!(!payload.is_empty()); assert_eq!(payload.len(), 4); assert_eq!( Bytes::from("data"), poll_fn(|cx| Pin::new(&mut payload).poll_next(cx)) .await .unwrap() .unwrap() ); } }