// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. use std::borrow::Cow; use crate::bit_reader::BitReader; use crate::entropy_coding::decode::Histograms; use crate::entropy_coding::decode::SymbolReader; use crate::error::{Error, Result}; use crate::util::{CeilLog2, NewWithCapacity, tracing_wrappers::instrument, value_of_lowest_1_bit}; #[derive(Debug, PartialEq, Default, Clone)] pub struct Permutation(pub Cow<'static, [u32]>); impl std::ops::Deref for Permutation { type Target = [u32]; fn deref(&self) -> &[u32] { &self.0 } } impl Permutation { /// Decode a permutation from entropy-coded stream. pub fn decode( size: u32, skip: u32, histograms: &Histograms, br: &mut BitReader, entropy_reader: &mut SymbolReader, ) -> Result { let end = entropy_reader.read_unsigned(histograms, br, get_context(size)); Self::decode_inner(size, skip, end, |ctx| -> Result { let r = entropy_reader.read_unsigned(histograms, br, ctx); br.check_for_error()?; Ok(r) }) } fn decode_inner( size: u32, skip: u32, end: u32, mut read: impl FnMut(usize) -> Result, ) -> Result { if end > size - skip { return Err(Error::InvalidPermutationSize { size, skip, end }); } let mut lehmer = Vec::new_with_capacity(end as usize)?; let mut prev_val = 0u32; for idx in skip..(skip + end) { let val = match read(get_context(prev_val)) { Ok(val) => val, Err(Error::OutOfBounds(_)) => { // Estimate 1.5 bits for each remaining code let bits = (((skip + end) - idx) as usize).saturating_mul(3) / 2; return Err(Error::OutOfBounds(bits)); } Err(e) => return Err(e), }; if val >= size - idx { return Err(Error::InvalidPermutationLehmerCode { size, idx, lehmer: val, }); } lehmer.push(val); prev_val = val; } // Initialize the full permutation vector with skipped elements intact let mut permutation = Vec::new_with_capacity((size - skip) as usize)?; permutation.extend(0..size); // Decode the Lehmer code into the slice starting at `skip` let permuted_slice = decode_lehmer_code(&lehmer, &permutation[skip as usize..])?; // Replace the target slice in `permutation` permutation[skip as usize..].copy_from_slice(&permuted_slice); // Ensure the permutation has the correct size assert_eq!(permutation.len(), size as usize); Ok(Self(Cow::Owned(permutation))) } pub fn compose(&mut self, other: &Permutation) { assert_eq!(self.0.len(), other.0.len()); let mut tmp: Vec = vec![0; self.0.len()]; for (i, val) in tmp.iter_mut().enumerate().take(self.0.len()) { *val = self.0[other.0[i] as usize] } self.0.to_mut().copy_from_slice(&tmp[..]); } } // Decodes the Lehmer code in `code` and returns the permuted slice. #[instrument(level = "debug", ret, err)] fn decode_lehmer_code(code: &[u32], permutation_slice: &[u32]) -> Result> { let n = permutation_slice.len(); if n == 0 { return Err(Error::InvalidPermutationLehmerCode { size: 0, idx: 0, lehmer: 0, }); } let mut permuted = Vec::new_with_capacity(n)?; permuted.extend_from_slice(permutation_slice); let padded_n = (n as u32).next_power_of_two() as usize; // Allocate temp array inside the function let mut temp = Vec::new_with_capacity(padded_n)?; temp.extend((0..padded_n as u32).map(|x| value_of_lowest_1_bit(x + 1))); for (i, permuted_item) in permuted.iter_mut().enumerate() { let code_i = *code.get(i).unwrap_or(&0); // Adjust the maximum allowed value for code_i if code_i as usize > n - i - 1 { return Err(Error::InvalidPermutationLehmerCode { size: n as u32, idx: i as u32, lehmer: code_i, }); } let mut rank = code_i + 1; // Extract i-th unused element via implicit order-statistics tree. let mut bit = padded_n; let mut next = 0usize; while bit != 0 { let cand = next + bit; if cand == 0 || cand > padded_n { return Err(Error::InvalidPermutationLehmerCode { size: n as u32, idx: i as u32, lehmer: code_i, }); } bit >>= 1; if temp[cand - 1] < rank { next = cand; rank -= temp[cand - 1]; } } *permuted_item = permutation_slice[next]; next += 1; while next <= padded_n { temp[next - 1] -= 1; next += value_of_lowest_1_bit(next as u32) as usize; } } Ok(permuted) } // Decodes the Lehmer code in `code` and returns the permuted vector. #[cfg(test)] fn decode_lehmer_code_naive(code: &[u32], permutation_slice: &[u32]) -> Result> { let n = code.len(); if n == 0 { return Err(Error::InvalidPermutationLehmerCode { size: 0, idx: 0, lehmer: 0, }); } // Ensure permutation_slice has sufficient length if permutation_slice.len() < n { return Err(Error::InvalidPermutationLehmerCode { size: n as u32, idx: 0, lehmer: 0, }); } // Create temp array with values from permutation_slice let mut temp = permutation_slice.to_vec(); let mut permuted = Vec::new_with_capacity(n)?; // Iterate over the Lehmer code for (i, &idx) in code.iter().enumerate() { if idx as usize >= temp.len() { return Err(Error::InvalidPermutationLehmerCode { size: n as u32, idx: i as u32, lehmer: idx, }); } // Assign temp[idx] to permuted vector permuted.push(temp.remove(idx as usize)); } // Append any remaining elements from temp to permuted permuted.extend(temp); Ok(permuted) } fn get_context(x: u32) -> usize { (x + 1).ceil_log2().min(7) as usize } #[cfg(test)] mod test { use super::*; use arbtest::arbitrary::{self, Arbitrary, Unstructured}; use core::assert_eq; use test_log::test; #[test] fn generate_permutation_arbtest() { arbtest::arbtest(|u| { let input = PermutationInput::arbitrary(u)?; let permutation_slice = input.permutation.as_slice(); let perm1 = decode_lehmer_code(&input.code, permutation_slice); let perm2 = decode_lehmer_code_naive(&input.code, permutation_slice); assert_eq!( perm1.map_err(|x| x.to_string()), perm2.map_err(|x| x.to_string()) ); Ok(()) }); } #[derive(Debug)] struct PermutationInput { code: Vec, permutation: Vec, } impl<'a> Arbitrary<'a> for PermutationInput { fn arbitrary(u: &mut Unstructured<'a>) -> Result { // Generate a reasonable size to prevent tests from taking too long let size_lehmer = u.int_in_range(1..=1000)?; let mut lehmer: Vec = Vec::with_capacity(size_lehmer as usize); for i in 0..size_lehmer { let max_val = size_lehmer - i - 1; let val = if max_val > 0 { u.int_in_range(0..=max_val)? } else { 0 }; lehmer.push(val); } let mut permutation = Vec::new(); let size_permutation = u.int_in_range(size_lehmer..=1000)?; permutation.extend(0..size_permutation); let num_of_swaps = u.int_in_range(0..=100)?; for _ in 0..num_of_swaps { // Randomly swap two positions let pos1 = u.int_in_range(0..=size_permutation - 1)?; let pos2 = u.int_in_range(0..=size_permutation - 1)?; permutation.swap(pos1 as usize, pos2 as usize); } Ok(PermutationInput { code: lehmer, permutation, }) } } #[test] fn simple() { // Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1] let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1]; let skip = 4; let size = 16; let permutation_slice: Vec = (skip..size).collect(); let permuted = decode_lehmer_code(&code, &permutation_slice).unwrap(); let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice).unwrap(); let mut permutation = Vec::with_capacity(size as usize); permutation.extend(0..skip); // Add skipped elements permutation.extend(permuted.iter()); let expected_permutation = vec![0, 1, 2, 3, 5, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14]; assert_eq!(permutation, expected_permutation); assert_eq!(permuted, permuted_naive); } #[test] fn decode_lehmer_compare_different_length() -> Result<(), Box> { // Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1] let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1]; let skip = 4; let size = 16; let permutation_slice: Vec = (skip..size).collect(); let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?; let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?; let expected_permuted = vec![5u32, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14]; assert_eq!(permuted_optimized, expected_permuted); assert_eq!(permuted_naive, expected_permuted); assert_eq!(permuted_optimized, permuted_naive); Ok(()) } #[test] fn decode_lehmer_compare_same_length() -> Result<(), Box> { // Lehmer code: [2, 3, 0, 0, 0] let code = vec![2u32, 3, 0, 0, 0]; let n = code.len(); let permutation_slice: Vec = (0..n as u32).collect(); let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?; let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?; let expected_permutation = vec![2u32, 4, 0, 1, 3]; assert_eq!(permuted_optimized, expected_permutation); assert_eq!(permuted_naive, expected_permutation); assert_eq!(permuted_optimized, permuted_naive); Ok(()) } #[test] fn lehmer_out_of_bounds() { let code = vec![4]; let permutation_slice: Vec = (4..8).collect(); let result = decode_lehmer_code(&code, &permutation_slice); assert!(result.is_err()); } }