#![cfg(loom)] use concurrent_queue::{ConcurrentQueue, ForcePushError, PopError, PushError}; use loom::sync::atomic::{AtomicUsize, Ordering}; use loom::sync::{Arc, Condvar, Mutex}; use loom::thread; #[cfg(target_family = "wasm")] use wasm_bindgen_test::wasm_bindgen_test as test; /// A basic MPMC channel based on a ConcurrentQueue and loom primitives. struct Channel { /// The queue used to contain items. queue: ConcurrentQueue, /// The number of senders. senders: AtomicUsize, /// The number of receivers. receivers: AtomicUsize, /// The event that is signaled when a new item is pushed. push_event: Event, /// The event that is signaled when a new item is popped. pop_event: Event, } /// The sending side of a channel. struct Sender { /// The channel. channel: Arc>, } /// The receiving side of a channel. struct Receiver { /// The channel. channel: Arc>, } /// Create a new pair of senders/receivers based on a queue. fn pair(queue: ConcurrentQueue) -> (Sender, Receiver) { let channel = Arc::new(Channel { queue, senders: AtomicUsize::new(1), receivers: AtomicUsize::new(1), push_event: Event::new(), pop_event: Event::new(), }); ( Sender { channel: channel.clone(), }, Receiver { channel }, ) } impl Clone for Sender { fn clone(&self) -> Self { self.channel.senders.fetch_add(1, Ordering::SeqCst); Sender { channel: self.channel.clone(), } } } impl Drop for Sender { fn drop(&mut self) { if self.channel.senders.fetch_sub(1, Ordering::SeqCst) == 1 { // Close the channel and notify the receivers. self.channel.queue.close(); self.channel.push_event.signal_all(); } } } impl Clone for Receiver { fn clone(&self) -> Self { self.channel.receivers.fetch_add(1, Ordering::SeqCst); Receiver { channel: self.channel.clone(), } } } impl Drop for Receiver { fn drop(&mut self) { if self.channel.receivers.fetch_sub(1, Ordering::SeqCst) == 1 { // Close the channel and notify the senders. self.channel.queue.close(); self.channel.pop_event.signal_all(); } } } impl Sender { /// Send a value. /// /// Returns an error with the value if the channel is closed. fn send(&self, mut value: T) -> Result<(), T> { loop { match self.channel.queue.push(value) { Ok(()) => { // Notify a single receiver. self.channel.push_event.signal(); return Ok(()); } Err(PushError::Closed(val)) => return Err(val), Err(PushError::Full(val)) => { // Wait for a receiver to pop an item. value = val; self.channel.pop_event.wait(); } } } } /// Send a value forcefully. fn force_send(&self, value: T) -> Result, T> { match self.channel.queue.force_push(value) { Ok(bumped) => { self.channel.push_event.signal(); Ok(bumped) } Err(ForcePushError(val)) => Err(val), } } } impl Receiver { /// Channel capacity. fn capacity(&self) -> Option { self.channel.queue.capacity() } /// Receive a value. /// /// Returns an error if the channel is closed. fn recv(&self) -> Result { loop { match self.channel.queue.pop() { Ok(value) => { // Notify a single sender. self.channel.pop_event.signal(); return Ok(value); } Err(PopError::Closed) => return Err(()), Err(PopError::Empty) => { // Wait for a sender to push an item. self.channel.push_event.wait(); } } } } } /// An event that can be waited on and then signaled. struct Event { /// The condition variable used to wait on the event. condvar: Condvar, /// The mutex used to protect the event. /// /// Inside is the event's state. The first bit is used to indicate if the /// notify_one method was called. The second bit is used to indicate if the /// notify_all method was called. mutex: Mutex, } impl Event { /// Create a new event. fn new() -> Self { Self { condvar: Condvar::new(), mutex: Mutex::new(0), } } /// Wait for the event to be signaled. fn wait(&self) { let mut state = self.mutex.lock().unwrap(); loop { if *state & 0b11 != 0 { // The event was signaled. *state &= !0b01; return; } // Wait for the event to be signaled. state = self.condvar.wait(state).unwrap(); } } /// Signal the event. fn signal(&self) { let mut state = self.mutex.lock().unwrap(); *state |= 1; drop(state); self.condvar.notify_one(); } /// Signal the event, but notify all waiters. fn signal_all(&self) { let mut state = self.mutex.lock().unwrap(); *state |= 3; drop(state); self.condvar.notify_all(); } } /// Wrapper to run tests on all three queues. fn run_test, usize) + Send + Sync + Clone + 'static>(f: F) { // The length of a loom test seems to increase exponentially the higher this number is. const LIMIT: usize = 4; let fc = f.clone(); loom::model(move || { fc(ConcurrentQueue::bounded(1), LIMIT); }); let fc = f.clone(); loom::model(move || { fc(ConcurrentQueue::bounded(LIMIT / 2), LIMIT); }); loom::model(move || { f(ConcurrentQueue::unbounded(), LIMIT); }); } #[test] fn spsc() { run_test(|q, limit| { // Create a new pair of senders/receivers. let (tx, rx) = pair(q); // Push each onto a thread and run them. let handle = thread::spawn(move || { for i in 0..limit { if tx.send(i).is_err() { break; } } }); let mut recv_values = vec![]; loop { match rx.recv() { Ok(value) => recv_values.push(value), Err(()) => break, } } // Values may not be in order. recv_values.sort_unstable(); assert_eq!(recv_values, (0..limit).collect::>()); // Join the handle before we exit. handle.join().unwrap(); }); } #[test] fn spsc_force() { run_test(|q, limit| { // Create a new pair of senders/receivers. let (tx, rx) = pair(q); // Push each onto a thread and run them. let handle = thread::spawn(move || { for i in 0..limit { if tx.force_send(i).is_err() { break; } } }); let mut recv_values = vec![]; loop { match rx.recv() { Ok(value) => recv_values.push(value), Err(()) => break, } } // Values may not be in order. recv_values.sort_unstable(); let cap = rx.capacity().unwrap_or(usize::MAX); for (left, right) in (0..limit) .rev() .take(cap) .zip(recv_values.into_iter().rev()) { assert_eq!(left, right); } // Join the handle before we exit. handle.join().unwrap(); }); }