// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. use num_traits::Float; use jxl_transforms::{transform::*, transform_map::*}; use crate::{ BLOCK_DIM, BLOCK_SIZE, GROUP_DIM, bit_reader::BitReader, entropy_coding::decode::SymbolReader, error::{Error, Result}, frame::{ HfGlobalState, HfMetadata, LfGlobalState, block_context_map::*, color_correlation_map::COLOR_TILE_DIM_IN_BLOCKS, quant_weights::DequantMatrices, }, headers::frame_header::FrameHeader, image::{Image, ImageRect, Rect}, util::{CeilLog2, ShiftRightCeil, tracing_wrappers::*}, }; use jxl_simd::{F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask, simd_function}; const LF_BUFFER_SIZE: usize = 32 * 32; /// Reusable buffers for VarDCT group decoding to avoid repeated allocations. pub struct VarDctBuffers { pub scratch: Vec, pub transform_buffer: [Vec; 3], /// Coefficient storage for single-pass decoding (when hf_coefficients is None) pub coeffs_storage: Vec, } impl VarDctBuffers { pub fn new() -> Self { Self { scratch: vec![0.0; LF_BUFFER_SIZE], transform_buffer: [ vec![0.0; MAX_COEFF_AREA], vec![0.0; MAX_COEFF_AREA], vec![0.0; MAX_COEFF_AREA], ], coeffs_storage: vec![0; 3 * GROUP_DIM * GROUP_DIM], } } /// Reset buffers to zero for reuse. pub fn reset(&mut self) { self.scratch.fill(0.0); for buf in &mut self.transform_buffer { buf.fill(0.0); } self.coeffs_storage.fill(0); } } impl Default for VarDctBuffers { fn default() -> Self { Self::new() } } #[inline] fn predict_num_nonzeros(nzeros_map: &Image, bx: usize, by: usize) -> usize { if bx == 0 { if by == 0 { 32 } else { nzeros_map.row(by - 1)[0] as usize } } else if by == 0 { nzeros_map.row(by)[bx - 1] as usize } else { (nzeros_map.row(by - 1)[bx] + nzeros_map.row(by)[bx - 1]).div_ceil(2) as usize } } #[inline(always)] fn adjust_quant_bias( d: D, c: usize, quant_i: D::I32Vec, biases: &[f32; 4], ) -> D::F32Vec { let quant = quant_i.as_f32(); let adjusted = quant - D::F32Vec::splat(d, biases[3]) / quant; D::I32Vec::splat(d, 2) .gt(quant_i.abs()) .if_then_else_f32(quant_i.as_f32() * D::F32Vec::splat(d, biases[c]), adjusted) } #[allow(clippy::too_many_arguments)] #[inline(always)] fn dequant_lane( d: D, scaled_dequant_x: f32, scaled_dequant_y: f32, scaled_dequant_b: f32, dequant_matrices: &[f32], size: usize, k: usize, x_cc_mul: f32, b_cc_mul: f32, biases: &[f32; 4], qblock: &[&[i32]; 3], block: &mut [Vec; 3], ) { let x_mul = D::F32Vec::load(d, &dequant_matrices[k..]) * D::F32Vec::splat(d, scaled_dequant_x); let y_mul = D::F32Vec::load(d, &dequant_matrices[size + k..]) * D::F32Vec::splat(d, scaled_dequant_y); let b_mul = D::F32Vec::load(d, &dequant_matrices[2 * size + k..]) * D::F32Vec::splat(d, scaled_dequant_b); let quantized_x = D::I32Vec::load(d, &qblock[0][k..]); let quantized_y = D::I32Vec::load(d, &qblock[1][k..]); let quantized_b = D::I32Vec::load(d, &qblock[2][k..]); let dequant_x_cc = adjust_quant_bias(d, 0, quantized_x, biases) * x_mul; let dequant_y = adjust_quant_bias(d, 1, quantized_y, biases) * y_mul; let dequant_b_cc = adjust_quant_bias(d, 2, quantized_b, biases) * b_mul; let dequant_x = D::F32Vec::splat(d, x_cc_mul).mul_add(dequant_y, dequant_x_cc); let dequant_b = D::F32Vec::splat(d, b_cc_mul).mul_add(dequant_y, dequant_b_cc); dequant_x.store(&mut block[0][k..]); dequant_y.store(&mut block[1][k..]); dequant_b.store(&mut block[2][k..]); } #[allow(clippy::too_many_arguments)] #[inline(always)] fn dequant_block( d: D, hf_type: HfTransformType, inv_global_scale: f32, quant: u32, x_dm_multiplier: f32, b_dm_multiplier: f32, x_cc_mul: f32, b_cc_mul: f32, size: usize, dequant_matrices: &DequantMatrices, covered_blocks: usize, biases: &[f32; 4], qblock: &[&[i32]; 3], block: &mut [Vec; 3], ) { let scaled_dequant_y = inv_global_scale / (quant as f32); let scaled_dequant_x = scaled_dequant_y * x_dm_multiplier; let scaled_dequant_b = scaled_dequant_y * b_dm_multiplier; let matrices = dequant_matrices.matrix(hf_type, 0); assert!(BLOCK_SIZE.is_multiple_of(D::F32Vec::LEN)); for k in (0..covered_blocks * BLOCK_SIZE).step_by(D::F32Vec::LEN) { dequant_lane( d, scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, matrices, size, k, x_cc_mul, b_cc_mul, biases, qblock, block, ); } } #[allow(clippy::too_many_arguments)] #[inline(always)] fn dequant_and_transform_to_pixels( d: D, quant_biases: &[f32; 4], x_dm_multiplier: f32, b_dm_multiplier: f32, pixels: &mut [Image; 3], scratch: &mut [f32], inv_global_scale: f32, transform_buffer: &mut [Vec; 3], hshift: [usize; 3], vshift: [usize; 3], by: usize, sby: [usize; 3], bx: usize, sbx: [usize; 3], x_cc_mul: f32, b_cc_mul: f32, raw_quant: u32, lf_rects: &Option<[ImageRect; 3]>, transform_type: HfTransformType, block_rect: Rect, num_blocks: usize, num_coeffs: usize, qblock: &[&[i32]; 3], dequant_matrices: &DequantMatrices, ) -> Result<(), Error> { dequant_block::( d, transform_type, inv_global_scale, raw_quant, x_dm_multiplier, b_dm_multiplier, x_cc_mul, b_cc_mul, num_coeffs, dequant_matrices, num_blocks, quant_biases, qblock, transform_buffer, ); for c in [1, 0, 2] { if (sbx[c] << hshift[c]) != bx || (sby[c] << vshift[c] != by) { continue; } let lf = &mut scratch[..]; { let xs = covered_blocks_x(transform_type) as usize; let ys = covered_blocks_y(transform_type) as usize; let rect = lf_rects.as_ref().unwrap()[c]; for (y, lf) in lf.chunks_exact_mut(xs).enumerate().take(ys) { lf.copy_from_slice(&rect.row(y)[0..xs]); } } transform_to_pixels(transform_type, lf, &mut transform_buffer[c]); let downsampled_rect = Rect { origin: ( block_rect.origin.0 >> hshift[c], block_rect.origin.1 >> vshift[c], ), size: block_rect.size, }; let mut output_rect = pixels[c].get_rect_mut(downsampled_rect); for i in 0..downsampled_rect.size.1 { let offset = i * downsampled_rect.size.0; output_rect .row(i) .copy_from_slice(&transform_buffer[c][offset..offset + downsampled_rect.size.0]); } } Ok(()) } simd_function!( dequant_and_transform_to_pixels_dispatch, d: D, #[allow(clippy::too_many_arguments)] pub fn dequant_and_transform_to_pixels_fwd( quant_biases: &[f32; 4], x_dm_multiplier: f32, b_dm_multiplier: f32, pixels: &mut [Image; 3], scratch: &mut [f32], inv_global_scale: f32, transform_buffer: &mut [Vec; 3], hshift: [usize; 3], vshift: [usize; 3], by: usize, sby: [usize; 3], bx: usize, sbx: [usize; 3], x_cc_mul: f32, b_cc_mul: f32, raw_quant: u32, lf_rects: &Option<[ImageRect; 3]>, transform_type: HfTransformType, block_rect: Rect, num_blocks: usize, num_coeffs: usize, qblock: &[&[i32]; 3], dequant_matrices: &DequantMatrices, ) -> Result<(), Error> { dequant_and_transform_to_pixels( d, quant_biases, x_dm_multiplier, b_dm_multiplier, pixels, scratch, inv_global_scale, transform_buffer, hshift, vshift, by, sby, bx, sbx, x_cc_mul, b_cc_mul, raw_quant, lf_rects, transform_type, block_rect, num_blocks, num_coeffs, qblock, dequant_matrices, ) } ); #[allow(clippy::too_many_arguments)] #[allow(clippy::type_complexity)] pub fn decode_vardct_group( group: usize, pass: usize, frame_header: &FrameHeader, lf_global: &mut LfGlobalState, hf_global: &mut HfGlobalState, hf_meta: &HfMetadata, lf_image: &Option<[Image; 3]>, quant_lf: &Image, quant_biases: &[f32; 4], pixels: &mut [Image; 3], br: &mut BitReader, buffers: &mut VarDctBuffers, ) -> Result<(), Error> { let x_dm_multiplier = (1.0 / (1.25)).powf(frame_header.x_qm_scale as f32 - 2.0); let b_dm_multiplier = (1.0 / (1.25)).powf(frame_header.b_qm_scale as f32 - 2.0); let num_histo_bits = hf_global.num_histograms.ceil_log2(); let histogram_index: usize = br.read(num_histo_bits as usize)? as usize; debug!(?histogram_index); let mut reader = SymbolReader::new(&hf_global.passes[pass].histograms, br, None)?; let block_group_rect = frame_header.block_group_rect(group); debug!(?block_group_rect); // Reset and use pooled buffers buffers.reset(); let scratch = &mut buffers.scratch; let color_correlation_params = lf_global.color_correlation_params.as_ref().unwrap(); let cmap_rect = Rect { origin: ( block_group_rect.origin.0 / COLOR_TILE_DIM_IN_BLOCKS, block_group_rect.origin.1 / COLOR_TILE_DIM_IN_BLOCKS, ), size: ( block_group_rect.size.0.div_ceil(COLOR_TILE_DIM_IN_BLOCKS), block_group_rect.size.1.div_ceil(COLOR_TILE_DIM_IN_BLOCKS), ), }; let quant_params = lf_global.quant_params.as_ref().unwrap(); let inv_global_scale = quant_params.inv_global_scale(); let ytox_map = hf_meta.ytox_map.get_rect(cmap_rect); let ytob_map = hf_meta.ytob_map.get_rect(cmap_rect); let transform_map = hf_meta.transform_map.get_rect(block_group_rect); let raw_quant_map = hf_meta.raw_quant_map.get_rect(block_group_rect); let mut num_nzeros: [Image; 3] = [ Image::new(( block_group_rect.size.0 >> frame_header.hshift(0), block_group_rect.size.1 >> frame_header.vshift(0), ))?, Image::new(( block_group_rect.size.0 >> frame_header.hshift(1), block_group_rect.size.1 >> frame_header.vshift(1), ))?, Image::new(( block_group_rect.size.0 >> frame_header.hshift(2), block_group_rect.size.1 >> frame_header.vshift(2), ))?, ]; let quant_lf_rect = quant_lf.get_rect(block_group_rect); let block_context_map = lf_global.block_context_map.as_mut().unwrap(); let context_offset = histogram_index * block_context_map.num_ac_contexts(); let coeffs = match hf_global.hf_coefficients.as_mut() { Some(hf_coefficients) => [ hf_coefficients.0.row_mut(group), hf_coefficients.1.row_mut(group), hf_coefficients.2.row_mut(group), ], None => { // Use pooled buffer (already reset to zero in buffers.reset() above) let (coeffs_x, coeffs_y_b) = buffers.coeffs_storage.split_at_mut(GROUP_DIM * GROUP_DIM); let (coeffs_y, coeffs_b) = coeffs_y_b.split_at_mut(GROUP_DIM * GROUP_DIM); [coeffs_x, coeffs_y, coeffs_b] } }; let shift_for_pass = if pass < frame_header.passes.shift.len() { frame_header.passes.shift[pass] } else { 0 }; let mut coeffs_offset = 0; let transform_buffer = &mut buffers.transform_buffer; let hshift = [ frame_header.hshift(0), frame_header.hshift(1), frame_header.hshift(2), ]; let vshift = [ frame_header.vshift(0), frame_header.vshift(1), frame_header.vshift(2), ]; let lf = match lf_image.as_ref() { None => None, Some(lf_planes) => { let r: [Rect; 3] = core::array::from_fn(|i| Rect { origin: ( block_group_rect.origin.0 >> hshift[i], block_group_rect.origin.1 >> vshift[i], ), size: ( block_group_rect.size.0 >> hshift[i], block_group_rect.size.1 >> vshift[i], ), }); let [lf_x, lf_y, lf_b] = lf_planes.each_ref(); Some([ lf_x.get_rect(r[0]), lf_y.get_rect(r[1]), lf_b.get_rect(r[2]), ]) } }; for by in 0..block_group_rect.size.1 { let sby = [by >> vshift[0], by >> vshift[1], by >> vshift[2]]; let ty = by / COLOR_TILE_DIM_IN_BLOCKS; let row_cmap_x = ytox_map.row(ty); let row_cmap_b = ytob_map.row(ty); for bx in 0..block_group_rect.size.0 { let sbx = [bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]]; let tx = bx / COLOR_TILE_DIM_IN_BLOCKS; let x_cc_mul = color_correlation_params.y_to_x(row_cmap_x[tx] as i32); let b_cc_mul = color_correlation_params.y_to_b(row_cmap_b[tx] as i32); let raw_quant = raw_quant_map.row(by)[bx] as u32; let quant_lf = quant_lf_rect.row(by)[bx] as usize; let raw_transform_id = transform_map.row(by)[bx]; let transform_id = raw_transform_id & 127; let is_first_block = raw_transform_id >= 128; if !is_first_block { continue; } let lf_rects = match lf.as_ref() { None => None, Some(lf) => { let [lf_x, lf_y, lf_b] = lf.each_ref(); Some([ lf_x.rect(Rect { origin: (sbx[0], sby[0]), size: (lf_x.size().0 - sbx[0], lf_x.size().1 - sby[0]), }), lf_y.rect(Rect { origin: (sbx[1], sby[1]), size: (lf_y.size().0 - sbx[1], lf_y.size().1 - sby[1]), }), lf_b.rect(Rect { origin: (sbx[2], sby[2]), size: (lf_b.size().0 - sbx[2], lf_b.size().1 - sby[2]), }), ]) } }; let transform_type = HfTransformType::from_usize(transform_id as usize) .ok_or(Error::InvalidVarDCTTransform(transform_id as usize))?; let cx = covered_blocks_x(transform_type) as usize; let cy = covered_blocks_y(transform_type) as usize; let shape_id = block_shape_id(transform_type) as usize; let block_size = (cx * BLOCK_DIM, cy * BLOCK_DIM); let block_rect = Rect { origin: (bx * BLOCK_DIM, by * BLOCK_DIM), size: block_size, }; let num_blocks = cx * cy; let num_coeffs = num_blocks * BLOCK_SIZE; let log_num_blocks = num_blocks.ilog2() as usize; let pass_info = &hf_global.passes[pass]; for c in [1, 0, 2] { if (sbx[c] << hshift[c]) != bx || (sby[c] << vshift[c] != by) { continue; } trace!( "Decoding block ({},{}) channel {} with {}x{} block transform {} (shape id {})", sbx[c], sby[c], c, cx, cy, transform_id, shape_id ); let predicted_nzeros = predict_num_nonzeros(&num_nzeros[c], sbx[c], sby[c]); let block_context = block_context_map.block_context(quant_lf, raw_quant, shape_id, c); let nonzero_context = block_context_map .nonzero_context(predicted_nzeros, block_context) + context_offset; let mut nonzeros = reader.read_unsigned(&pass_info.histograms, br, nonzero_context) as usize; trace!( "block ({},{},{c}) predicted_nzeros: {predicted_nzeros} \ nzero_ctx: {nonzero_context} (offset: {context_offset}) \ nzeros: {nonzeros}", sbx[c], sby[c] ); if nonzeros + num_blocks > num_coeffs { return Err(Error::InvalidNumNonZeros(nonzeros, num_blocks)); } for iy in 0..cy { let nzrow = num_nzeros[c].row_mut(sby[c] + iy); for ix in 0..cx { nzrow[sbx[c] + ix] = nonzeros.shrc(log_num_blocks) as u32; } } let histo_offset = block_context_map.zero_density_context_offset(block_context) + context_offset; let mut prev = if nonzeros > num_coeffs / 16 { 0 } else { 1 }; let permutation = &pass_info.coeff_orders[shape_id * 3 + c]; let current_coeffs = &mut coeffs[c][coeffs_offset..coeffs_offset + num_coeffs]; for k in num_blocks..num_coeffs { if nonzeros == 0 { break; } let ctx = histo_offset + zero_density_context(nonzeros, k, log_num_blocks, prev); let coeff = reader.read_signed(&pass_info.histograms, br, ctx) << shift_for_pass; prev = if coeff != 0 { 1 } else { 0 }; nonzeros -= prev; let coeff_index = permutation[k] as usize; current_coeffs[coeff_index] += coeff; } if nonzeros != 0 { return Err(Error::EndOfBlockResidualNonZeros(nonzeros)); } } let qblock = [ &coeffs[0][coeffs_offset..], &coeffs[1][coeffs_offset..], &coeffs[2][coeffs_offset..], ]; let dequant_matrices = &hf_global.dequant_matrices; dequant_and_transform_to_pixels_dispatch( quant_biases, x_dm_multiplier, b_dm_multiplier, pixels, scratch, inv_global_scale, transform_buffer, hshift, vshift, by, sby, bx, sbx, x_cc_mul, b_cc_mul, raw_quant, &lf_rects, transform_type, block_rect, num_blocks, num_coeffs, &qblock, dequant_matrices, )?; coeffs_offset += num_coeffs; } } reader.check_final_state(&hf_global.passes[pass].histograms, br)?; Ok(()) }