use super::AbortOnDropHandle; use std::{ collections::VecDeque, future::Future, pin::Pin, task::{Context, Poll}, }; use tokio::{ runtime::Handle, task::{AbortHandle, Id, JoinError, JoinHandle}, }; /// A FIFO queue for of tasks spawned on a Tokio runtime. /// /// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO /// order. That is, if tasks are spawned in the order A, B, C, then /// awaiting the next completed task will always return A first, then B, /// then C, regardless of the order in which the tasks actually complete. /// /// All of the tasks must have the same return type `T`. /// /// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are /// immediately aborted. pub struct JoinQueue(VecDeque>); impl JoinQueue { /// Create a new empty [`JoinQueue`]. pub const fn new() -> Self { Self(VecDeque::new()) } /// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks. pub fn with_capacity(capacity: usize) -> Self { Self(VecDeque::with_capacity(capacity)) } /// Returns the number of tasks currently in the [`JoinQueue`]. /// /// This includes both tasks that are currently running and tasks that have /// completed but not yet been removed from the queue because outputting of /// them waits for FIFO order. pub fn len(&self) -> usize { self.0.len() } /// Returns whether the [`JoinQueue`] is empty. pub fn is_empty(&self) -> bool { self.0.is_empty() } /// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`] /// that can be used to remotely cancel the task. /// /// The provided future will start running in the background immediately /// when this method is called, even if you don't await anything on this /// [`JoinQueue`]. /// /// # Panics /// /// This method panics if called outside of a Tokio runtime. /// /// [`AbortHandle`]: tokio::task::AbortHandle #[track_caller] pub fn spawn(&mut self, task: F) -> AbortHandle where F: Future + Send + 'static, T: Send + 'static, { self.push_back(tokio::spawn(task)) } /// Spawn the provided task on the provided runtime and store it in this /// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely /// cancel the task. /// /// The provided future will start running in the background immediately /// when this method is called, even if you don't await anything on this /// [`JoinQueue`]. /// /// [`AbortHandle`]: tokio::task::AbortHandle #[track_caller] pub fn spawn_on(&mut self, task: F, handle: &Handle) -> AbortHandle where F: Future + Send + 'static, T: Send + 'static, { self.push_back(handle.spawn(task)) } /// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`] /// and store it in this [`JoinQueue`], returning an [`AbortHandle`] that /// can be used to remotely cancel the task. /// /// The provided future will start running in the background immediately /// when this method is called, even if you don't await anything on this /// [`JoinQueue`]. /// /// # Panics /// /// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`. /// /// [`LocalSet`]: tokio::task::LocalSet /// [`LocalRuntime`]: tokio::runtime::LocalRuntime /// [`AbortHandle`]: tokio::task::AbortHandle #[track_caller] pub fn spawn_local(&mut self, task: F) -> AbortHandle where F: Future + 'static, T: 'static, { self.push_back(tokio::task::spawn_local(task)) } /// Spawn the blocking code on the blocking threadpool and store /// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be /// used to remotely cancel the task. /// /// # Panics /// /// This method panics if called outside of a Tokio runtime. /// /// [`AbortHandle`]: tokio::task::AbortHandle #[track_caller] pub fn spawn_blocking(&mut self, f: F) -> AbortHandle where F: FnOnce() -> T + Send + 'static, T: Send + 'static, { self.push_back(tokio::task::spawn_blocking(f)) } /// Spawn the blocking code on the blocking threadpool of the /// provided runtime and store it in this [`JoinQueue`], returning an /// [`AbortHandle`] that can be used to remotely cancel the task. /// /// [`AbortHandle`]: tokio::task::AbortHandle #[track_caller] pub fn spawn_blocking_on(&mut self, f: F, handle: &Handle) -> AbortHandle where F: FnOnce() -> T + Send + 'static, T: Send + 'static, { self.push_back(handle.spawn_blocking(f)) } fn push_back(&mut self, jh: JoinHandle) -> AbortHandle { let jh = AbortOnDropHandle::new(jh); let abort_handle = jh.abort_handle(); self.0.push_back(jh); abort_handle } /// Waits until the next task in FIFO order completes and returns its output. /// /// Returns `None` if the queue is empty. /// /// # Cancel Safety /// /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!` /// statement and some other branch completes first, it is guaranteed that no tasks were /// removed from this [`JoinQueue`]. pub async fn join_next(&mut self) -> Option> { std::future::poll_fn(|cx| self.poll_join_next(cx)).await } /// Waits until the next task in FIFO order completes and returns its output, /// along with the [task ID] of the completed task. /// /// Returns `None` if the queue is empty. /// /// When this method returns an error, then the id of the task that failed can be accessed /// using the [`JoinError::id`] method. /// /// # Cancel Safety /// /// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!` /// statement and some other branch completes first, it is guaranteed that no tasks were /// removed from this [`JoinQueue`]. /// /// [task ID]: tokio::task::Id /// [`JoinError::id`]: fn@tokio::task::JoinError::id pub async fn join_next_with_id(&mut self) -> Option> { std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await } /// Tries to poll an `AbortOnDropHandle` without blocking or yielding. /// /// Note that on success the handle will panic on subsequent polls /// since it becomes consumed. fn try_poll_handle(jh: &mut AbortOnDropHandle) -> Option> { let waker = futures_util::task::noop_waker(); let mut cx = Context::from_waker(&waker); // Since this function is not async and cannot be forced to yield, we should // disable budgeting when we want to check for the `JoinHandle` readiness. let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh)); if let Poll::Ready(res) = jh.poll(&mut cx) { Some(res) } else { None } } /// Tries to join the next task in FIFO order if it has completed. /// /// Returns `None` if the queue is empty or if the next task is not yet ready. pub fn try_join_next(&mut self) -> Option> { let jh = self.0.front_mut()?; let res = Self::try_poll_handle(jh)?; // Use `detach` to avoid calling `abort` on a task that has already completed. // Dropping `AbortOnDropHandle` would abort the task, but since it is finished, // we only need to drop the `JoinHandle` for cleanup. drop(self.0.pop_front().unwrap().detach()); Some(res) } /// Tries to join the next task in FIFO order if it has completed and return its output, /// along with its [task ID]. /// /// Returns `None` if the queue is empty or if the next task is not yet ready. /// /// When this method returns an error, then the id of the task that failed can be accessed /// using the [`JoinError::id`] method. /// /// [task ID]: tokio::task::Id /// [`JoinError::id`]: fn@tokio::task::JoinError::id pub fn try_join_next_with_id(&mut self) -> Option> { let jh = self.0.front_mut()?; let res = Self::try_poll_handle(jh)?; // Use `detach` to avoid calling `abort` on a task that has already completed. // Dropping `AbortOnDropHandle` would abort the task, but since it is finished, // we only need to drop the `JoinHandle` for cleanup. let jh = self.0.pop_front().unwrap().detach(); let id = jh.id(); drop(jh); Some(res.map(|output| (id, output))) } /// Aborts all tasks and waits for them to finish shutting down. /// /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in /// a loop until it returns `None`. /// /// This method ignores any panics in the tasks shutting down. When this call returns, the /// [`JoinQueue`] will be empty. /// /// [`abort_all`]: fn@Self::abort_all /// [`join_next`]: fn@Self::join_next pub async fn shutdown(&mut self) { self.abort_all(); while self.join_next().await.is_some() {} } /// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results. /// /// The results will be stored in the order they were spawned, not the order they completed. /// This is a convenience method that is equivalent to calling [`join_next`] in /// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call /// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are /// cancelled. To handle errors in any other way, manually call [`join_next`] /// in a loop. /// /// # Cancel Safety /// /// This method is not cancel safe as it calls `join_next` in a loop. If you need /// cancel safety, manually call `join_next` in a loop with `Vec` accumulator. /// /// [`join_next`]: fn@Self::join_next /// [`JoinError::id`]: fn@tokio::task::JoinError::id pub async fn join_all(mut self) -> Vec { let mut output = Vec::with_capacity(self.len()); while let Some(res) = self.join_next().await { match res { Ok(t) => output.push(t), Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()), Err(err) => panic!("{err}"), } } output } /// Aborts all tasks on this [`JoinQueue`]. /// /// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete /// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty. pub fn abort_all(&mut self) { self.0.iter().for_each(|jh| jh.abort()); } /// Removes all tasks from this [`JoinQueue`] without aborting them. /// /// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`] /// is dropped. pub fn detach_all(&mut self) { self.0.drain(..).for_each(|jh| drop(jh.detach())); } /// Polls for the next task in [`JoinQueue`] to complete. /// /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue. /// /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. /// /// # Returns /// /// This function returns: /// /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is /// available right now. /// * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed. /// The `value` is the return value that task. /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been /// aborted. The `err` is the `JoinError` from the panicked/aborted task. /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty. pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll>> { let jh = match self.0.front_mut() { None => return Poll::Ready(None), Some(jh) => jh, }; if let Poll::Ready(res) = Pin::new(jh).poll(cx) { // Use `detach` to avoid calling `abort` on a task that has already completed. // Dropping `AbortOnDropHandle` would abort the task, but since it is finished, // we only need to drop the `JoinHandle` for cleanup. drop(self.0.pop_front().unwrap().detach()); Poll::Ready(Some(res)) } else { Poll::Pending } } /// Polls for the next task in [`JoinQueue`] to complete. /// /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue. /// /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. /// /// # Returns /// /// This function returns: /// /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is /// available right now. /// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed. /// The `value` is the return value that task, and `id` is its [task ID]. /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been /// aborted. The `err` is the `JoinError` from the panicked/aborted task. /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty. /// /// [task ID]: tokio::task::Id pub fn poll_join_next_with_id( &mut self, cx: &mut Context<'_>, ) -> Poll>> { let jh = match self.0.front_mut() { None => return Poll::Ready(None), Some(jh) => jh, }; if let Poll::Ready(res) = Pin::new(jh).poll(cx) { // Use `detach` to avoid calling `abort` on a task that has already completed. // Dropping `AbortOnDropHandle` would abort the task, but since it is finished, // we only need to drop the `JoinHandle` for cleanup. let jh = self.0.pop_front().unwrap().detach(); let id = jh.id(); drop(jh); // If the task succeeded, add the task ID to the output. Otherwise, the // `JoinError` will already have the task's ID. Poll::Ready(Some(res.map(|output| (id, output)))) } else { Poll::Pending } } } impl std::fmt::Debug for JoinQueue { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_list() .entries(self.0.iter().map(|jh| JoinHandle::id(jh.as_ref()))) .finish() } } impl Default for JoinQueue { fn default() -> Self { Self::new() } } /// Collect an iterator of futures into a [`JoinQueue`]. /// /// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator. impl std::iter::FromIterator for JoinQueue where F: Future + Send + 'static, T: Send + 'static, { fn from_iter>(iter: I) -> Self { let mut set = Self::new(); iter.into_iter().for_each(|task| { set.spawn(task); }); set } } #[cfg(test)] mod tests { use super::*; /// A simple type that does not implement [`std::fmt::Debug`]. struct NotDebug; fn is_debug() {} #[test] fn assert_debug() { is_debug::>(); } }