diff --git a/actix-utils/CHANGES.md b/actix-utils/CHANGES.md index 69f9e514..3d397e9d 100644 --- a/actix-utils/CHANGES.md +++ b/actix-utils/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [1.0.5] - 2020-01-08 + +* Add `Condition` type. + +* Add `Pool` of one-shot's. + ## [1.0.4] - 2019-12-20 * Add methods to check `LocalWaker` registration state. diff --git a/actix-utils/Cargo.toml b/actix-utils/Cargo.toml index 9b942062..5470ef01 100644 --- a/actix-utils/Cargo.toml +++ b/actix-utils/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-utils" -version = "1.0.4" +version = "1.0.5" authors = ["Nikolay Kim "] description = "Actix utils - various actix net related services" keywords = ["network", "framework", "async", "futures"] @@ -16,11 +16,13 @@ name = "actix_utils" path = "src/lib.rs" [dependencies] -actix-service = "1.0.0" +actix-service = "1.0.1" actix-rt = "1.0.0" actix-codec = "0.2.0" +bitflags = "1.2" bytes = "0.5.3" either = "1.5.3" futures = "0.3.1" pin-project = "0.4.6" log = "0.4" +slab = "0.4" diff --git a/actix-utils/src/condition.rs b/actix-utils/src/condition.rs new file mode 100644 index 00000000..4d32f010 --- /dev/null +++ b/actix-utils/src/condition.rs @@ -0,0 +1,104 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use slab::Slab; + +use crate::cell::Cell; +use crate::task::LocalWaker; + +/// Condition allows to notify multiple receivers at the same time +pub struct Condition(Cell); + +struct Inner { + data: Slab>, +} + +impl Condition { + pub fn new() -> Condition { + Condition(Cell::new(Inner { data: Slab::new() })) + } + + /// Get condition waiter + pub fn wait(&mut self) -> Waiter { + let token = self.0.get_mut().data.insert(None); + Waiter { + token, + inner: self.0.clone(), + } + } + + /// Notify all waiters + pub fn notify(&self) { + let inner = self.0.get_ref(); + for item in inner.data.iter() { + if let Some(waker) = item.1 { + waker.wake(); + } + } + } +} + +impl Drop for Condition { + fn drop(&mut self) { + self.notify() + } +} + +#[must_use = "Waiter do nothing unless polled"] +pub struct Waiter { + token: usize, + inner: Cell, +} + +impl Future for Waiter { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let inner = unsafe { this.inner.get_mut().data.get_unchecked_mut(this.token) }; + if inner.is_none() { + let waker = LocalWaker::default(); + waker.register(cx.waker()); + *inner = Some(waker); + Poll::Pending + } else if inner.as_mut().unwrap().register(cx.waker()) { + Poll::Pending + } else { + Poll::Ready(()) + } + } +} + +impl Drop for Waiter { + fn drop(&mut self) { + self.inner.get_mut().data.remove(self.token); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::future::lazy; + + #[actix_rt::test] + async fn test_condition() { + let mut cond = Condition::new(); + let mut waiter = cond.wait(); + assert_eq!( + lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, + Poll::Pending + ); + cond.notify(); + assert_eq!(waiter.await, ()); + + let mut waiter = cond.wait(); + assert_eq!( + lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, + Poll::Pending + ); + drop(cond); + assert_eq!(waiter.await, ()); + } +} diff --git a/actix-utils/src/lib.rs b/actix-utils/src/lib.rs index 78467e9c..c4d56c56 100644 --- a/actix-utils/src/lib.rs +++ b/actix-utils/src/lib.rs @@ -3,6 +3,7 @@ #![allow(clippy::type_complexity)] mod cell; +pub mod condition; pub mod counter; pub mod either; pub mod framed; diff --git a/actix-utils/src/oneshot.rs b/actix-utils/src/oneshot.rs index 9ec5f218..533167c9 100644 --- a/actix-utils/src/oneshot.rs +++ b/actix-utils/src/oneshot.rs @@ -4,6 +4,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; pub use futures::channel::oneshot::Canceled; +use slab::Slab; use crate::cell::Cell; use crate::task::LocalWaker; @@ -21,6 +22,11 @@ pub fn channel() -> (Sender, Receiver) { (tx, rx) } +/// Creates a new futures-aware, pool of one-shot's. +pub fn pool() -> Pool { + Pool(Cell::new(Slab::new())) +} + /// Represents the completion half of a oneshot through which the result of a /// computation is signaled. #[derive(Debug)] @@ -77,9 +83,7 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { - if self.inner.strong_count() == 2 { - self.inner.get_ref().rx_task.wake(); - }; + self.inner.get_ref().rx_task.wake(); } } @@ -104,6 +108,148 @@ impl Future for Receiver { } } +/// Futures-aware, pool of one-shot's. +pub struct Pool(Cell>>); + +bitflags::bitflags! { + pub struct Flags: u8 { + const SENDER = 0b0000_0001; + const RECEIVER = 0b0000_0010; + } +} + +#[derive(Debug)] +struct PoolInner { + flags: Flags, + value: Option, + waker: LocalWaker, +} + +impl Pool { + pub fn channel(&mut self) -> (PSender, PReceiver) { + let token = self.0.get_mut().insert(PoolInner { + flags: Flags::all(), + value: None, + waker: LocalWaker::default(), + }); + + ( + PSender { + token, + inner: self.0.clone(), + }, + PReceiver { + token, + inner: self.0.clone(), + }, + ) + } +} + +impl Clone for Pool { + fn clone(&self) -> Self { + Pool(self.0.clone()) + } +} + +/// Represents the completion half of a oneshot through which the result of a +/// computation is signaled. +#[derive(Debug)] +pub struct PSender { + token: usize, + inner: Cell>>, +} + +/// A future representing the completion of a computation happening elsewhere in +/// memory. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct PReceiver { + token: usize, + inner: Cell>>, +} + +// The oneshots do not ever project Pin to the inner T +impl Unpin for PReceiver {} +impl Unpin for PSender {} + +impl PSender { + /// Completes this oneshot with a successful result. + /// + /// This function will consume `self` and indicate to the other end, the + /// `Receiver`, that the error provided is the result of the computation this + /// represents. + /// + /// If the value is successfully enqueued for the remote end to receive, + /// then `Ok(())` is returned. If the receiving end was dropped before + /// this function was called, however, then `Err` is returned with the value + /// provided. + pub fn send(mut self, val: T) -> Result<(), T> { + let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) }; + + if inner.flags.contains(Flags::RECEIVER) { + inner.value = Some(val); + inner.waker.wake(); + Ok(()) + } else { + Err(val) + } + } + + /// Tests to see whether this `Sender`'s corresponding `Receiver` + /// has gone away. + pub fn is_canceled(&self) -> bool { + !unsafe { self.inner.get_ref().get_unchecked(self.token) } + .flags + .contains(Flags::RECEIVER) + } +} + +impl Drop for PSender { + fn drop(&mut self) { + let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) }; + if inner.flags.contains(Flags::RECEIVER) { + inner.waker.wake(); + inner.flags.remove(Flags::SENDER); + } else { + self.inner.get_mut().remove(self.token); + } + } +} + +impl Drop for PReceiver { + fn drop(&mut self) { + let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) }; + if inner.flags.contains(Flags::SENDER) { + inner.flags.remove(Flags::RECEIVER); + } else { + self.inner.get_mut().remove(self.token); + } + } +} + +impl Future for PReceiver { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let inner = unsafe { this.inner.get_mut().get_unchecked_mut(this.token) }; + + // If we've got a value, then skip the logic below as we're done. + if let Some(val) = inner.value.take() { + return Poll::Ready(Ok(val)); + } + + // Check if sender is dropped and return error if it is. + if !inner.flags.contains(Flags::SENDER) { + Poll::Ready(Err(Canceled)) + } else { + inner.waker.register(cx.waker()); + Poll::Pending + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -135,4 +281,31 @@ mod tests { drop(tx); assert!(rx.await.is_err()); } + + #[actix_rt::test] + async fn test_pool() { + let (tx, rx) = pool().channel(); + tx.send("test").unwrap(); + assert_eq!(rx.await.unwrap(), "test"); + + let (tx, rx) = pool().channel(); + assert!(!tx.is_canceled()); + drop(rx); + assert!(tx.is_canceled()); + assert!(tx.send("test").is_err()); + + let (tx, rx) = pool::<&'static str>().channel(); + drop(tx); + assert!(rx.await.is_err()); + + let (tx, mut rx) = pool::<&'static str>().channel(); + assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending); + tx.send("test").unwrap(); + assert_eq!(rx.await.unwrap(), "test"); + + let (tx, mut rx) = pool::<&'static str>().channel(); + assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending); + drop(tx); + assert!(rx.await.is_err()); + } }