use core::cmp::min; use core::num::NonZeroU64; use ::bytes::{Buf, BufMut}; use crate::DecodeError; /// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. /// The buffer must have enough remaining space (maximum 10 bytes). #[inline] pub fn encode_varint(mut value: u64, buf: &mut impl BufMut) { // Varints are never more than 10 bytes for _ in 0..10 { if value < 0x80 { buf.put_u8(value as u8); break; } else { buf.put_u8(((value & 0x7F) | 0x80) as u8); value >>= 7; } } } /// Returns the encoded length of the value in LEB128 variable length format. /// The returned value will be between 1 and 10, inclusive. #[inline] pub const fn encoded_len_varint(value: u64) -> usize { // Based on [VarintSize64][1]. // [1]: https://github.com/protocolbuffers/protobuf/blob/v28.3/src/google/protobuf/io/coded_stream.h#L1744-L1756 // Safety: value | 1 is non-zero. let log2value = unsafe { NonZeroU64::new_unchecked(value | 1) }.ilog2(); ((log2value * 9 + (64 + 9)) / 64) as usize } /// Decodes a LEB128-encoded variable length integer from the buffer. #[inline] pub fn decode_varint(buf: &mut impl Buf) -> Result { let bytes = buf.chunk(); let len = bytes.len(); if len == 0 { return Err(DecodeError::new("invalid varint")); } let byte = bytes[0]; if byte < 0x80 { buf.advance(1); Ok(u64::from(byte)) } else if len > 10 || bytes[len - 1] < 0x80 { let (value, advance) = decode_varint_slice(bytes)?; buf.advance(advance); Ok(value) } else { decode_varint_slow(buf) } } /// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the /// number of bytes read. /// /// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from /// [`ConsumeVarint`][2]. /// /// ## Safety /// /// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last /// element in bytes is < `0x80`. /// /// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406 /// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 #[inline] fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> { // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance. // Use assertions to ensure memory safety, but it should always be optimized after inline. assert!(!bytes.is_empty()); assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80); let mut b: u8 = unsafe { *bytes.get_unchecked(0) }; let mut part0: u32 = u32::from(b); if b < 0x80 { return Ok((u64::from(part0), 1)); }; part0 -= 0x80; b = unsafe { *bytes.get_unchecked(1) }; part0 += u32::from(b) << 7; if b < 0x80 { return Ok((u64::from(part0), 2)); }; part0 -= 0x80 << 7; b = unsafe { *bytes.get_unchecked(2) }; part0 += u32::from(b) << 14; if b < 0x80 { return Ok((u64::from(part0), 3)); }; part0 -= 0x80 << 14; b = unsafe { *bytes.get_unchecked(3) }; part0 += u32::from(b) << 21; if b < 0x80 { return Ok((u64::from(part0), 4)); }; part0 -= 0x80 << 21; let value = u64::from(part0); b = unsafe { *bytes.get_unchecked(4) }; let mut part1: u32 = u32::from(b); if b < 0x80 { return Ok((value + (u64::from(part1) << 28), 5)); }; part1 -= 0x80; b = unsafe { *bytes.get_unchecked(5) }; part1 += u32::from(b) << 7; if b < 0x80 { return Ok((value + (u64::from(part1) << 28), 6)); }; part1 -= 0x80 << 7; b = unsafe { *bytes.get_unchecked(6) }; part1 += u32::from(b) << 14; if b < 0x80 { return Ok((value + (u64::from(part1) << 28), 7)); }; part1 -= 0x80 << 14; b = unsafe { *bytes.get_unchecked(7) }; part1 += u32::from(b) << 21; if b < 0x80 { return Ok((value + (u64::from(part1) << 28), 8)); }; part1 -= 0x80 << 21; let value = value + ((u64::from(part1)) << 28); b = unsafe { *bytes.get_unchecked(8) }; let mut part2: u32 = u32::from(b); if b < 0x80 { return Ok((value + (u64::from(part2) << 56), 9)); }; part2 -= 0x80; b = unsafe { *bytes.get_unchecked(9) }; part2 += u32::from(b) << 7; // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details. // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 if b < 0x02 { return Ok((value + (u64::from(part2) << 56), 10)); }; // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow. // Assume the data is corrupt. Err(DecodeError::new("invalid varint")) } /// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as /// necessary. /// /// Contains a varint overflow check from [`ConsumeVarint`][1]. /// /// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 #[inline(never)] #[cold] fn decode_varint_slow(buf: &mut impl Buf) -> Result { let mut value = 0; for count in 0..min(10, buf.remaining()) { let byte = buf.get_u8(); value |= u64::from(byte & 0x7F) << (count * 7); if byte <= 0x7F { // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details. // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 if count == 9 && byte >= 0x02 { return Err(DecodeError::new("invalid varint")); } else { return Ok(value); } } } Err(DecodeError::new("invalid varint")) } #[cfg(test)] mod test { use super::*; #[test] fn varint() { fn check(value: u64, encoded: &[u8]) { // Small buffer. let mut buf = Vec::with_capacity(1); encode_varint(value, &mut buf); assert_eq!(buf, encoded); // Large buffer. let mut buf = Vec::with_capacity(100); encode_varint(value, &mut buf); assert_eq!(buf, encoded); assert_eq!(encoded_len_varint(value), encoded.len()); // See: https://github.com/tokio-rs/prost/pull/1008 for copying reasoning. let mut encoded_copy = encoded; let roundtrip_value = decode_varint(&mut encoded_copy).expect("decoding failed"); assert_eq!(value, roundtrip_value); let mut encoded_copy = encoded; let roundtrip_value = decode_varint_slow(&mut encoded_copy).expect("slow decoding failed"); assert_eq!(value, roundtrip_value); } check(2u64.pow(0) - 1, &[0x00]); check(2u64.pow(0), &[0x01]); check(2u64.pow(7) - 1, &[0x7F]); check(2u64.pow(7), &[0x80, 0x01]); check(300, &[0xAC, 0x02]); check(2u64.pow(14) - 1, &[0xFF, 0x7F]); check(2u64.pow(14), &[0x80, 0x80, 0x01]); check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]); check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]); check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]); check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]); check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); check( 2u64.pow(49) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], ); check( 2u64.pow(49), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], ); check( 2u64.pow(56) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], ); check( 2u64.pow(56), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], ); check( 2u64.pow(63) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], ); check( 2u64.pow(63), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], ); check( u64::MAX, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01], ); } const U64_MAX_PLUS_ONE: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02]; #[test] fn varint_overflow() { let mut copy = U64_MAX_PLUS_ONE; decode_varint(&mut copy).expect_err("decoding u64::MAX + 1 succeeded"); } #[test] fn variant_slow_overflow() { let mut copy = U64_MAX_PLUS_ONE; decode_varint_slow(&mut copy).expect_err("slow decoding u64::MAX + 1 succeeded"); } }