//! Unidirectional byte-oriented channel. use crate::util::poll_proceed; use bytes::Buf; use bytes::BytesMut; use futures_core::ready; use std::io::Error as IoError; use std::io::ErrorKind as IoErrorKind; use std::io::IoSlice; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll, Waker}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; type IoResult = Result; const CLOSED_ERROR_MSG: &str = "simplex has been closed"; #[derive(Debug)] struct Inner { /// `poll_write` will return [`Poll::Pending`] if the backpressure boundary is reached backpressure_boundary: usize, /// either [`Sender`] or [`Receiver`] is closed is_closed: bool, /// Waker used to wake the [`Receiver`] receiver_waker: Option, /// Waker used to wake the [`Sender`] sender_waker: Option, /// Buffer used to read and write data buf: BytesMut, } impl Inner { fn with_capacity(capacity: usize) -> Self { Self { backpressure_boundary: capacity, is_closed: false, receiver_waker: None, sender_waker: None, buf: BytesMut::with_capacity(capacity), } } fn register_receiver_waker(&mut self, waker: &Waker) -> Option { match self.receiver_waker.as_mut() { Some(old) if old.will_wake(waker) => None, _ => self.receiver_waker.replace(waker.clone()), } } fn register_sender_waker(&mut self, waker: &Waker) -> Option { match self.sender_waker.as_mut() { Some(old) if old.will_wake(waker) => None, _ => self.sender_waker.replace(waker.clone()), } } fn take_receiver_waker(&mut self) -> Option { self.receiver_waker.take() } fn take_sender_waker(&mut self) -> Option { self.sender_waker.take() } fn is_closed(&self) -> bool { self.is_closed } fn close_receiver(&mut self) -> Option { self.is_closed = true; self.take_sender_waker() } fn close_sender(&mut self) -> Option { self.is_closed = true; self.take_receiver_waker() } } /// Receiver of the simplex channel. /// /// # Cancellation safety /// /// The `Receiver` is cancel safe. If it is used as the event in a /// [`tokio::select!`](macro@tokio::select) statement and some other branch /// completes first, it is guaranteed that no bytes were received on this /// channel. /// /// You can still read the remaining data from the buffer /// even if the write half has been dropped. /// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details. #[derive(Debug)] pub struct Receiver { inner: Arc>, } impl Drop for Receiver { /// This also wakes up the [`Sender`]. fn drop(&mut self) { let maybe_waker = { let mut inner = self.inner.lock().unwrap(); inner.close_receiver() }; if let Some(waker) = maybe_waker { waker.wake(); } } } impl AsyncRead for Receiver { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let coop = ready!(poll_proceed(cx)); let mut inner = self.inner.lock().unwrap(); let to_read = buf.remaining().min(inner.buf.remaining()); if to_read == 0 { if inner.is_closed() || buf.remaining() == 0 { return Poll::Ready(Ok(())); } let old_waker = inner.register_receiver_waker(cx.waker()); let maybe_waker = inner.take_sender_waker(); // unlock before waking up and dropping old waker drop(inner); drop(old_waker); if let Some(waker) = maybe_waker { waker.wake(); } return Poll::Pending; } // this is to avoid starving other tasks coop.made_progress(); buf.put_slice(&inner.buf[..to_read]); inner.buf.advance(to_read); let waker = inner.take_sender_waker(); drop(inner); // unlock before waking up if let Some(waker) = waker { waker.wake(); } Poll::Ready(Ok(())) } } /// Sender of the simplex channel. /// /// # Cancellation safety /// /// The `Sender` is cancel safe. If it is used as the event in a /// [`tokio::select!`](macro@tokio::select) statement and some other branch /// completes first, it is guaranteed that no bytes were sent on this /// channel. /// /// # Shutdown /// /// See [`Sender::poll_shutdown`]. #[derive(Debug)] pub struct Sender { inner: Arc>, } impl Drop for Sender { /// This also wakes up the [`Receiver`]. fn drop(&mut self) { let maybe_waker = { let mut inner = self.inner.lock().unwrap(); inner.close_sender() }; if let Some(waker) = maybe_waker { waker.wake(); } } } impl AsyncWrite for Sender { /// # Errors /// /// This method will return [`IoErrorKind::BrokenPipe`] /// if the channel has been closed. fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let coop = ready!(poll_proceed(cx)); let mut inner = self.inner.lock().unwrap(); if inner.is_closed() { return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG))); } let free = inner .backpressure_boundary .checked_sub(inner.buf.len()) .expect("backpressure boundary overflow"); let to_write = buf.len().min(free); if to_write == 0 { if buf.is_empty() { return Poll::Ready(Ok(0)); } let old_waker = inner.register_sender_waker(cx.waker()); let waker = inner.take_receiver_waker(); // unlock before waking up and dropping old waker drop(inner); drop(old_waker); if let Some(waker) = waker { waker.wake(); } return Poll::Pending; } // this is to avoid starving other tasks coop.made_progress(); inner.buf.extend_from_slice(&buf[..to_write]); let waker = inner.take_receiver_waker(); drop(inner); // unlock before waking up if let Some(waker) = waker { waker.wake(); } Poll::Ready(Ok(to_write)) } /// # Errors /// /// This method will return [`IoErrorKind::BrokenPipe`] /// if the channel has been closed. fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { let inner = self.inner.lock().unwrap(); if inner.is_closed() { Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG))) } else { Poll::Ready(Ok(())) } } /// After returns [`Poll::Ready`], all the following call to /// [`Sender::poll_write`] and [`Sender::poll_flush`] /// will return error. /// /// The [`Receiver`] can still be used to read remaining data /// until all bytes have been consumed. fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { let maybe_waker = { let mut inner = self.inner.lock().unwrap(); inner.close_sender() }; if let Some(waker) = maybe_waker { waker.wake(); } Poll::Ready(Ok(())) } fn is_write_vectored(&self) -> bool { true } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { let coop = ready!(poll_proceed(cx)); let mut inner = self.inner.lock().unwrap(); if inner.is_closed() { return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG))); } let free = inner .backpressure_boundary .checked_sub(inner.buf.len()) .expect("backpressure boundary overflow"); if free == 0 { let old_waker = inner.register_sender_waker(cx.waker()); let maybe_waker = inner.take_receiver_waker(); // unlock before waking up and dropping old waker drop(inner); drop(old_waker); if let Some(waker) = maybe_waker { waker.wake(); } return Poll::Pending; } // this is to avoid starving other tasks coop.made_progress(); let mut rem = free; for buf in bufs { if rem == 0 { break; } let to_write = buf.len().min(rem); if to_write == 0 { assert_ne!(rem, 0); assert_eq!(buf.len(), 0); continue; } inner.buf.extend_from_slice(&buf[..to_write]); rem -= to_write; } let waker = inner.take_receiver_waker(); drop(inner); // unlock before waking up if let Some(waker) = waker { waker.wake(); } Poll::Ready(Ok(free - rem)) } } /// Create a simplex channel. /// /// The `capacity` parameter specifies the maximum number of bytes that can be /// stored in the channel without making the [`Sender::poll_write`] /// return [`Poll::Pending`]. /// /// # Panics /// /// This function will panic if `capacity` is zero. pub fn new(capacity: usize) -> (Sender, Receiver) { assert_ne!(capacity, 0, "capacity must be greater than zero"); let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity))); let tx = Sender { inner: Arc::clone(&inner), }; let rx = Receiver { inner }; (tx, rx) }