diff --git a/actix-multipart/src/lib.rs b/actix-multipart/src/lib.rs index 853648beb..63c2e890f 100644 --- a/actix-multipart/src/lib.rs +++ b/actix-multipart/src/lib.rs @@ -60,6 +60,7 @@ extern crate self as actix_multipart; mod error; mod extractor; pub mod form; +pub(crate) mod safety; mod server; pub mod test; diff --git a/actix-multipart/src/safety.rs b/actix-multipart/src/safety.rs new file mode 100644 index 000000000..db6b3b18b --- /dev/null +++ b/actix-multipart/src/safety.rs @@ -0,0 +1,60 @@ +use std::{cell::Cell, marker::PhantomData, rc::Rc, task}; + +use local_waker::LocalWaker; + +/// Counter. It tracks of number of clones of payloads and give access to payload only to top most. +/// +/// - When dropped, parent task is awakened. This is to support the case where `Field` is dropped in +/// a separate task than `Multipart`. +/// - Assumes that parent owners don't move to different tasks; only the top-most is allowed to. +/// - If dropped and is not top most owner, is_clean flag is set to false. +#[derive(Debug)] +pub(crate) struct Safety { + task: LocalWaker, + level: usize, + payload: Rc>, + clean: Rc>, +} + +impl Safety { + pub(crate) fn new() -> Safety { + let payload = Rc::new(PhantomData); + Safety { + task: LocalWaker::new(), + level: Rc::strong_count(&payload), + clean: Rc::new(Cell::new(true)), + payload, + } + } + + pub(crate) fn current(&self) -> bool { + Rc::strong_count(&self.payload) == self.level && self.clean.get() + } + + pub(crate) fn is_clean(&self) -> bool { + self.clean.get() + } + + pub(crate) fn clone(&self, cx: &task::Context<'_>) -> Safety { + let payload = Rc::clone(&self.payload); + let s = Safety { + task: LocalWaker::new(), + level: Rc::strong_count(&payload), + clean: self.clean.clone(), + payload, + }; + s.task.register(cx.waker()); + s + } +} + +impl Drop for Safety { + fn drop(&mut self) { + if Rc::strong_count(&self.payload) != self.level { + // Multipart dropped leaving a Field + self.clean.set(false); + } + + self.task.wake(); + } +} diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index 76eae11ff..1a173f88a 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -1,9 +1,8 @@ //! Multipart response payload support. use std::{ - cell::{Cell, RefCell, RefMut}, + cell::{RefCell, RefMut}, cmp, fmt, - marker::PhantomData, pin::Pin, rc::Rc, task::{Context, Poll}, @@ -17,10 +16,9 @@ use actix_web::{ }; use bytes::{Bytes, BytesMut}; use futures_core::stream::{LocalBoxStream, Stream}; -use local_waker::LocalWaker; use mime::Mime; -use crate::error::MultipartError; +use crate::{error::MultipartError, safety::Safety}; const MAX_HEADERS: usize = 32; @@ -797,63 +795,6 @@ impl Clone for PayloadRef { } } -/// Counter. It tracks of number of clones of payloads and give access to payload only to top most. -/// -/// - When dropped, parent task is awakened. This is to support the case where `Field` is dropped in -/// a separate task than `Multipart`. -/// - Assumes that parent owners don't move to different tasks; only the top-most is allowed to. -/// - If dropped and is not top most owner, is_clean flag is set to false. -#[derive(Debug)] -struct Safety { - task: LocalWaker, - level: usize, - payload: Rc>, - clean: Rc>, -} - -impl Safety { - fn new() -> Safety { - let payload = Rc::new(PhantomData); - Safety { - task: LocalWaker::new(), - level: Rc::strong_count(&payload), - clean: Rc::new(Cell::new(true)), - payload, - } - } - - fn current(&self) -> bool { - Rc::strong_count(&self.payload) == self.level && self.clean.get() - } - - fn is_clean(&self) -> bool { - self.clean.get() - } - - fn clone(&self, cx: &Context<'_>) -> Safety { - let payload = Rc::clone(&self.payload); - let s = Safety { - task: LocalWaker::new(), - level: Rc::strong_count(&payload), - clean: self.clean.clone(), - payload, - }; - s.task.register(cx.waker()); - s - } -} - -impl Drop for Safety { - fn drop(&mut self) { - if Rc::strong_count(&self.payload) != self.level { - // Multipart dropped leaving a Field - self.clean.set(false); - } - - self.task.wake(); - } -} - /// Payload buffer. struct PayloadBuffer { eof: bool,