// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. use super::*; use jxl_simd::{test_all_instruction_sets, ScalarDescriptor, SimdDescriptor}; use rand::Rng; use rand::SeedableRng; use rand_chacha::ChaCha12Rng; use test_log::test; use std::f64::consts::FRAC_1_SQRT_2; use std::f64::consts::PI; use std::f64::consts::SQRT_2; #[inline(always)] fn alpha(u: usize) -> f64 { if u == 0 { FRAC_1_SQRT_2 } else { 1.0 } } pub fn dct1d(input_matrix: &[Vec]) -> Vec> { let num_rows = input_matrix.len(); if num_rows == 0 { return Vec::new(); } let num_cols = input_matrix[0].len(); let mut output_matrix = vec![vec![0.0f64; num_cols]; num_rows]; let scale: f64 = SQRT_2; // Precompute the DCT matrix (size: n_rows x n_rows) let mut dct_coeff_matrix = vec![vec![0.0f64; num_rows]; num_rows]; for (u_freq, row) in dct_coeff_matrix.iter_mut().enumerate() { let alpha_u_val = alpha(u_freq); for (y_spatial, coeff) in row.iter_mut().enumerate() { *coeff = alpha_u_val * ((y_spatial as f64 + 0.5) * u_freq as f64 * PI / num_rows as f64).cos() * scale; } } // Perform the DCT calculation column by column for x_col_idx in 0..num_cols { for u_freq_idx in 0..num_rows { let mut sum = 0.0; for (y_spatial_idx, col) in input_matrix.iter().enumerate() { // This access `input_matrix[y_spatial_idx][x_col_idx]` assumes the input_matrix // is rectangular. If not, it might panic here. sum += dct_coeff_matrix[u_freq_idx][y_spatial_idx] * col[x_col_idx]; } output_matrix[u_freq_idx][x_col_idx] = sum; } } output_matrix } pub fn idct1d(input_matrix: &[Vec]) -> Vec> { let num_rows = input_matrix.len(); if num_rows == 0 { return Vec::new(); } let num_cols = input_matrix[0].len(); let mut output_matrix = vec![vec![0.0f64; num_cols]; num_rows]; let scale: f64 = SQRT_2; // Precompute the DCT matrix (size: num_rows x num_rows) let mut dct_coeff_matrix = vec![vec![0.0f64; num_rows]; num_rows]; for (u_freq, row) in dct_coeff_matrix.iter_mut().enumerate() { let alpha_u_val = alpha(u_freq); for (y_def_idx, coeff) in row.iter_mut().enumerate() { *coeff = alpha_u_val * ((y_def_idx as f64 + 0.5) * u_freq as f64 * PI / num_rows as f64).cos() * scale; } } // Perform the IDCT calculation column by column for x_col_idx in 0..num_cols { for (y_row_idx, row) in output_matrix.iter_mut().enumerate() { let mut sum = 0.0; for (u_freq_idx, col) in input_matrix.iter().enumerate() { // This access input_coeffs_matrix[u_freq_idx][x_col_idx] assumes input_coeffs_matrix // is rectangular. If not, it might panic here. sum += dct_coeff_matrix[u_freq_idx][y_row_idx] * col[x_col_idx]; } row[x_col_idx] = sum; } } output_matrix } fn transpose_f64(matrix: &[Vec]) -> Vec> { if matrix.is_empty() { return Vec::new(); } let num_rows = matrix.len(); let num_cols = matrix[0].len(); let mut transposed = vec![vec![0.0; num_rows]; num_cols]; for i in 0..num_rows { for j in 0..num_cols { transposed[j][i] = matrix[i][j]; } } transposed } pub fn slow_idct2d(input: &[Vec]) -> Vec> { let rows = input.len(); let cols = input[0].len(); let idct1 = if rows < cols { let transposed = transpose_f64(input); idct1d(&transposed) } else { let input: Vec<_> = input.iter().flat_map(|x| x.iter()).copied().collect(); let input: Vec<_> = input.chunks_exact(rows).map(|x| x.to_vec()).collect(); idct1d(&input) }; let transposed1 = transpose_f64(&idct1); idct1d(&transposed1) } pub fn scales(n: usize) -> Vec { (0..n) .map(|i| { (i as f64 / (16 * n) as f64 * PI).cos() * (i as f64 / (8 * n) as f64 * PI).cos() * (i as f64 / (4 * n) as f64 * PI).cos() * n as f64 }) .collect() } pub fn slow_reinterpreting_dct2d(input: &[Vec]) -> Vec> { let rows = input.len(); let cols = input[0].len(); let dct1 = dct1d(input); let tdct1 = transpose_f64(&dct1); let dct2 = dct1d(&tdct1); let mut res = if rows < cols { transpose_f64(&dct2) } else { dct2 }; let row_scales = scales(rows); let col_scales = scales(cols); if rows < cols { for y in 0..rows { for x in 0..cols { res[y][x] /= row_scales[y] * col_scales[x]; } } } else { for y in 0..cols { for x in 0..rows { res[y][x] /= row_scales[x] * col_scales[y]; } } } res } #[track_caller] fn check_close(a: f64, b: f64, max_err: f64) { let abs = (a - b).abs(); let rel = abs / a.abs().max(b.abs()); assert!( abs < max_err || rel < max_err, "a: {a} b: {b} abs diff: {abs:?} rel diff: {rel:?}" ); } macro_rules! test_reinterpreting_dct1d_eq_slow_n { ($test_name:ident, $n_val:expr, $do_idct_fun:path, $tolerance:expr) => { #[test] fn $test_name() { const N: usize = $n_val; let input_matrix_for_ref = random_matrix(N, 1); let output_matrix_slow: Vec> = dct1d(&input_matrix_for_ref); let mut output: Vec<_> = input_matrix_for_ref.iter().map(|x| x[0] as f32).collect(); let d = ScalarDescriptor {}; let (output_chunks, remainder) = output.as_chunks_mut::<1>(); assert!(remainder.is_empty()); $do_idct_fun(d, output_chunks, 1); let scales = scales(N); for i in 0..N { check_close( output[i] as f64, output_matrix_slow[i][0] / scales[i], $tolerance, ); } } }; } test_reinterpreting_dct1d_eq_slow_n!( test_reinterpreting_dct1d_2_eq_slow, 2, do_reinterpreting_dct_2, 1e-6 ); test_reinterpreting_dct1d_eq_slow_n!( test_reinterpreting_dct1d_4_eq_slow, 4, do_reinterpreting_dct_4, 1e-6 ); test_reinterpreting_dct1d_eq_slow_n!( test_reinterpreting_dct1d_8_eq_slow, 8, do_reinterpreting_dct_8, 1e-6 ); test_reinterpreting_dct1d_eq_slow_n!( test_reinterpreting_dct1d_16_eq_slow, 16, do_reinterpreting_dct_16, 5e-6 ); test_reinterpreting_dct1d_eq_slow_n!( test_reinterpreting_dct1d_32_eq_slow, 32, do_reinterpreting_dct_32, 5e-6 ); fn random_matrix(n: usize, m: usize) -> Vec> { let mut rng = ChaCha12Rng::seed_from_u64(0); let mut data = vec![vec![0.0; m]; n]; data.iter_mut() .flat_map(|x| x.iter_mut()) .for_each(|x| *x = rng.random_range(-1.0..1.0)); data } macro_rules! test_idct1d_eq_slow_n { ($test_name:ident, $n_val:expr, $do_idct_fun:path, $tolerance:expr) => { #[test] fn $test_name() { const N: usize = $n_val; let input_matrix_for_ref = random_matrix(N, 1); let output_matrix_slow: Vec> = idct1d(&input_matrix_for_ref); let mut output: Vec<_> = input_matrix_for_ref.iter().map(|x| x[0] as f32).collect(); let d = ScalarDescriptor {}; let (output_chunks, remainder) = output.as_chunks_mut::<1>(); assert!(remainder.is_empty()); $do_idct_fun(d, output_chunks, 1); for i in 0..N { check_close(output[i] as f64, output_matrix_slow[i][0], $tolerance); } } }; } test_idct1d_eq_slow_n!(test_idct1d_2_eq_slow, 2, do_idct_2, 1e-6); test_idct1d_eq_slow_n!(test_idct1d_4_eq_slow, 4, do_idct_4, 1e-6); test_idct1d_eq_slow_n!(test_idct1d_8_eq_slow, 8, do_idct_8, 1e-6); test_idct1d_eq_slow_n!(test_idct1d_16_eq_slow, 16, do_idct_16, 1e-6); test_idct1d_eq_slow_n!(test_idct1d_32_eq_slow, 32, do_idct_32, 5e-6); test_idct1d_eq_slow_n!(test_idct1d_64_eq_slow, 64, do_idct_64, 5e-6); test_idct1d_eq_slow_n!(test_idct1d_128_eq_slow, 128, do_idct_128, 5e-5); test_idct1d_eq_slow_n!(test_idct1d_256_eq_slow, 256, do_idct_256, 5e-5); macro_rules! test_idct2d_eq_slow { ($test_name:ident, $rows:expr, $cols:expr, $fast_idct:ident, $tol:expr) => { fn $test_name(d: D) { const N: usize = $rows; const M: usize = $cols; let slow_input = random_matrix(N, M); let slow_output = slow_idct2d(&slow_input); let mut fast_input: Vec<_> = slow_input .iter() .flat_map(|x| x.iter()) .map(|x| *x as f32) .collect(); $fast_idct(d, &mut fast_input); for r in 0..N { for c in 0..M { check_close(fast_input[r * M + c] as f64, slow_output[r][c], $tol); } } } test_all_instruction_sets!($test_name); }; } test_idct2d_eq_slow!(test_idct2d_2_2_eq_slow, 2, 2, idct2d_2_2, 1e-6); test_idct2d_eq_slow!(test_idct2d_4_4_eq_slow, 4, 4, idct2d_4_4, 1e-6); test_idct2d_eq_slow!(test_idct2d_4_8_eq_slow, 4, 8, idct2d_4_8, 1e-6); test_idct2d_eq_slow!(test_idct2d_8_4_eq_slow, 8, 4, idct2d_8_4, 1e-6); test_idct2d_eq_slow!(test_idct2d_8_8_eq_slow, 8, 8, idct2d_8_8, 5e-6); test_idct2d_eq_slow!(test_idct2d_16_8_eq_slow, 16, 8, idct2d_16_8, 5e-6); test_idct2d_eq_slow!(test_idct2d_8_16_eq_slow, 8, 16, idct2d_8_16, 5e-6); test_idct2d_eq_slow!(test_idct2d_16_16_eq_slow, 16, 16, idct2d_16_16, 1e-5); test_idct2d_eq_slow!(test_idct2d_32_8_eq_slow, 32, 8, idct2d_32_8, 5e-6); test_idct2d_eq_slow!(test_idct2d_8_32_eq_slow, 8, 32, idct2d_8_32, 5e-6); test_idct2d_eq_slow!(test_idct2d_32_16_eq_slow, 32, 16, idct2d_32_16, 1e-5); test_idct2d_eq_slow!(test_idct2d_16_32_eq_slow, 16, 32, idct2d_16_32, 1e-5); test_idct2d_eq_slow!(test_idct2d_32_32_eq_slow, 32, 32, idct2d_32_32, 5e-5); test_idct2d_eq_slow!(test_idct2d_64_32_eq_slow, 64, 32, idct2d_64_32, 1e-4); test_idct2d_eq_slow!(test_idct2d_32_64_eq_slow, 32, 64, idct2d_32_64, 1e-4); test_idct2d_eq_slow!(test_idct2d_64_64_eq_slow, 64, 64, idct2d_64_64, 1e-4); test_idct2d_eq_slow!(test_idct2d_128_64_eq_slow, 128, 64, idct2d_128_64, 5e-4); test_idct2d_eq_slow!(test_idct2d_64_128_eq_slow, 64, 128, idct2d_64_128, 5e-4); test_idct2d_eq_slow!(test_idct2d_128_128_eq_slow, 128, 128, idct2d_128_128, 5e-4); test_idct2d_eq_slow!(test_idct2d_256_128_eq_slow, 256, 128, idct2d_256_128, 1e-3); test_idct2d_eq_slow!(test_idct2d_128_256_eq_slow, 128, 256, idct2d_128_256, 1e-3); test_idct2d_eq_slow!(test_idct2d_256_256_eq_slow, 256, 256, idct2d_256_256, 5e-3); macro_rules! test_reinterpreting_dct_eq_slow { ($test_name:ident, $fun: ident, $rows:expr, $cols:expr, $tol:expr) => { fn $test_name(d: D) { const N: usize = $rows; const M: usize = $cols; let slow_input = random_matrix(N, M); let slow_output = slow_reinterpreting_dct2d(&slow_input); let mut fast_input: Vec<_> = slow_input .iter() .flat_map(|x| x.iter()) .map(|x| *x as f32) .collect(); let mut output = [0.0; $rows * $cols * 64]; $fun(d, &mut fast_input, &mut output); let on = slow_output.len(); let om = slow_output[0].len(); for r in 0..on { for c in 0..om { check_close(output[r * om * 8 + c] as f64, slow_output[r][c], $tol); } } } test_all_instruction_sets!($test_name); }; } test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_1x2_eq_slow, reinterpreting_dct2d_1_2, 1, 2, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_2x1_eq_slow, reinterpreting_dct2d_2_1, 2, 1, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_2x2_eq_slow, reinterpreting_dct2d_2_2, 2, 2, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_1x4_eq_slow, reinterpreting_dct2d_1_4, 1, 4, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_4x1_eq_slow, reinterpreting_dct2d_4_1, 4, 1, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_2x4_eq_slow, reinterpreting_dct2d_2_4, 2, 4, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_4x2_eq_slow, reinterpreting_dct2d_4_2, 4, 2, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_4x4_eq_slow, reinterpreting_dct2d_4_4, 4, 4, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_8x4_eq_slow, reinterpreting_dct2d_8_4, 8, 4, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_4x8_eq_slow, reinterpreting_dct2d_4_8, 4, 8, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_8x8_eq_slow, reinterpreting_dct2d_8_8, 8, 8, 1e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_8x16_eq_slow, reinterpreting_dct2d_8_16, 8, 16, 5e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_16x8_eq_slow, reinterpreting_dct2d_16_8, 16, 8, 5e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_16x16_eq_slow, reinterpreting_dct2d_16_16, 16, 16, 5e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_32x16_eq_slow, reinterpreting_dct2d_32_16, 32, 16, 5e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_16x32_eq_slow, reinterpreting_dct2d_16_32, 16, 32, 5e-6 ); test_reinterpreting_dct_eq_slow!( test_reinterpreting_dct_32x32_eq_slow, reinterpreting_dct2d_32_32, 32, 32, 5e-6 );