use std::{ future::Future, mem, pin::{pin, Pin}, task::{Context, Poll}, }; use futures::{io::AsyncRead, FutureExt}; use super::int::{read_varint, ReadVarint}; use crate::{Error, Res}; /// A reader for a varint-length-prefixed buffer. #[pin_project::pin_project(project = ReadVecProj)] #[allow(clippy::module_name_repetitions)] pub enum ReadVec { // Invariant: This Option is always Some. ReadLen { src: Option>, cap: u64, }, ReadBody { src: S, buf: Vec, remaining: usize, }, } impl ReadVec { /// # Panics /// If `limit` is more than `usize::MAX` or /// if this is called after the length is read. pub fn limit(&mut self, limit: u64) { usize::try_from(limit).expect("cannot set a limit larger than usize::MAX"); if let Self::ReadLen { ref mut cap, .. } = self { *cap = limit; } else { panic!("cannot set a limit once the size has been read"); } } pub fn stream(self) -> S { match self { Self::ReadLen { mut src, .. } => src.take().unwrap().stream(), Self::ReadBody { src, .. } => src, } } } impl Future for ReadVec { type Output = Res>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut(); if let Self::ReadLen { src, cap } = this.get_mut() { match src.as_mut().unwrap().poll_unpin(cx) { Poll::Ready(Ok(None)) => return Poll::Ready(Ok(None)), Poll::Ready(Ok(Some(0))) => return Poll::Ready(Ok(Some(Vec::new()))), Poll::Ready(Ok(Some(sz))) => { if sz > *cap { return Poll::Ready(Err(Error::LimitExceeded)); } // `cap` cannot exceed min(usize::MAX, u64::MAX). let sz = usize::try_from(sz).unwrap(); let body = Self::ReadBody { src: src.take().unwrap().stream(), buf: vec![0; sz], remaining: sz, }; self.set(body); } Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } } let ReadVecProj::ReadBody { src, buf, remaining, } = self.project() else { return Poll::Pending; }; let offset = buf.len() - *remaining; match pin!(src).poll_read(cx, &mut buf[offset..]) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), Poll::Ready(Ok(0)) => Poll::Ready(Err(Error::Truncated)), Poll::Ready(Ok(c)) => { *remaining -= c; if *remaining > 0 { Poll::Pending } else { Poll::Ready(Ok(Some(mem::take(buf)))) } } } } } #[allow(clippy::module_name_repetitions)] pub fn read_vec(src: S) -> ReadVec { ReadVec::ReadLen { src: Some(read_varint(src)), cap: u64::try_from(usize::MAX).unwrap_or(u64::MAX), } } #[cfg(test)] mod test { use std::{ cmp, fmt::Debug, io::Result, pin::Pin, task::{Context, Poll}, }; use futures::AsyncRead; use sync_async::SyncResolve; use crate::{rw::write_varint as sync_write_varint, stream::vec::read_vec, Error}; const FILL_VALUE: u8 = 90; fn fill(len: T) -> Vec where u64: TryFrom, >::Error: Debug, usize: TryFrom, >::Error: Debug, T: Debug + Copy, { let mut buf = Vec::new(); sync_write_varint(u64::try_from(len).unwrap(), &mut buf).unwrap(); buf.resize(buf.len() + usize::try_from(len).unwrap(), FILL_VALUE); buf } #[test] fn read_vecs() { for len in [0, 1, 2, 3, 64] { let buf = fill(len); let mut buf_ref = &buf[..]; let mut fut = read_vec(&mut buf_ref); if let Ok(Some(out)) = fut.sync_resolve() { assert_eq!(len, out.len()); assert!(out.iter().all(|&v| v == FILL_VALUE)); assert!(fut.stream().is_empty()); } } } #[test] fn exceed_cap() { const LEN: u64 = 20; let buf = fill(LEN); let mut buf_ref = &buf[..]; let mut fut = read_vec(&mut buf_ref); fut.limit(LEN - 1); assert!(matches!(fut.sync_resolve(), Err(Error::LimitExceeded))); } /// This class implements `AsyncRead`, but /// always blocks after returning a fixed value. #[derive(Default)] struct IncompleteRead<'a> { data: &'a [u8], consumed: usize, } impl<'a> IncompleteRead<'a> { fn new(data: &'a [u8]) -> Self { Self { data, consumed: 0 } } } impl AsyncRead for IncompleteRead<'_> { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { let remaining = &self.data[self.consumed..]; if remaining.is_empty() { Poll::Pending } else { let copied = cmp::min(buf.len(), remaining.len()); buf[..copied].copy_from_slice(&remaining[..copied]); self.as_mut().consumed += copied; Poll::Ready(std::io::Result::Ok(copied)) } } } #[test] #[should_panic(expected = "cannot set a limit once the size has been read")] fn late_cap() { let mut buf = IncompleteRead::new(&[2, 1]); _ = read_vec(&mut buf).sync_resolve_with(|mut f| { println!("pending"); f.as_mut().limit(100); }); } #[test] #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] #[should_panic(expected = "cannot set a limit larger than usize::MAX")] fn too_large_cap() { const LEN: u64 = 20; let buf = fill(LEN); let mut buf_ref = &buf[..]; let mut fut = read_vec(&mut buf_ref); fut.limit(u64::try_from(usize::MAX).unwrap() + 1); } }