// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. use std::{collections::VecDeque, ops::Range}; use crate::{ bit_reader::BitReader, entropy_coding::decode::{Histograms, SymbolReader}, error::Result, frame::modular::{ ModularChannel, Predictor, Tree, decode::{ channel::ModularChannelDecoder, common::{make_pixel, precompute_references}, }, predict::{PredictionData, WeightedPredictorState, clamped_gradient}, tree::{ FlatTreeNode, NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode, predict_flat, }, }, headers::modular::GroupHeader, image::Image, }; pub struct NoWpTree { flat_nodes: Vec, references: Image, property_buffer: Vec, } impl NoWpTree { fn new( nodes: Vec, max_property_count: usize, channel: usize, stream: usize, xsize: usize, ) -> Result { let num_ref_props = max_property_count .saturating_sub(NUM_NONREF_PROPERTIES) .next_multiple_of(PROPERTIES_PER_PREVCHAN); let references = Image::::new((num_ref_props, xsize))?; let num_properties = NUM_NONREF_PROPERTIES + num_ref_props; let mut property_buffer: Vec = vec![0; num_properties]; property_buffer[0] = channel as i32; property_buffer[1] = stream as i32; let flat_nodes = Tree::build_flat_tree(&nodes)?; Ok(Self { flat_nodes, references, property_buffer, }) } } impl ModularChannelDecoder for NoWpTree { const NEEDS_TOP: bool = true; const NEEDS_TOPTOP: bool = true; fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { precompute_references(buffers, chan, y, &mut self.references); self.property_buffer[2..].fill(0); } fn decode_one( &mut self, prediction_data: PredictionData, pos: (usize, usize), xsize: usize, reader: &mut SymbolReader, br: &mut BitReader, histograms: &Histograms, ) -> i32 { let prediction_result = predict_flat( &self.flat_nodes, prediction_data, xsize, None, pos.0, pos.1, &self.references, &mut self.property_buffer, ); let dec = reader.read_signed(histograms, br, prediction_result.context as usize); make_pixel(dec, prediction_result.multiplier, prediction_result.guess) } } pub struct GeneralTree { no_wp_tree: NoWpTree, wp_state: WeightedPredictorState, } impl GeneralTree { fn new( nodes: Vec, max_property_count: usize, header: &GroupHeader, channel: usize, stream: usize, xsize: usize, ) -> Result { let wp_state = WeightedPredictorState::new(&header.wp_header, xsize); Ok(Self { no_wp_tree: NoWpTree::new(nodes, max_property_count, channel, stream, xsize)?, wp_state, }) } } impl ModularChannelDecoder for GeneralTree { const NEEDS_TOP: bool = true; const NEEDS_TOPTOP: bool = true; fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { self.no_wp_tree.init_row(buffers, chan, y); } fn decode_one( &mut self, prediction_data: PredictionData, pos: (usize, usize), xsize: usize, reader: &mut SymbolReader, br: &mut BitReader, histograms: &Histograms, ) -> i32 { let prediction_result = predict_flat( &self.no_wp_tree.flat_nodes, prediction_data, xsize, Some(&mut self.wp_state), pos.0, pos.1, &self.no_wp_tree.references, &mut self.no_wp_tree.property_buffer, ); let dec = reader.read_signed(histograms, br, prediction_result.context as usize); let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess); self.wp_state.update_errors(val, pos, xsize); val } } const LUT_MAX_SPLITVAL: i32 = 1023; const LUT_MIN_SPLITVAL: i32 = -1024; const LUT_TABLE_SIZE: usize = (LUT_MAX_SPLITVAL - LUT_MIN_SPLITVAL + 1) as usize; const _: () = assert!(LUT_TABLE_SIZE.is_power_of_two()); pub struct WpOnlyLookup { lut: [u8; LUT_TABLE_SIZE], // Lookup (wp value -> *clustered* context id) wp_state: WeightedPredictorState, } fn make_lut(tree: &[TreeNode], histograms: &Histograms) -> Option<[u8; LUT_TABLE_SIZE]> { struct RangeAndNode { range: Range, node: u32, } let mut stack = vec![RangeAndNode { range: LUT_MIN_SPLITVAL..LUT_MAX_SPLITVAL + 1, node: 0, }]; let mut ans = [0u8; LUT_TABLE_SIZE]; while let Some(RangeAndNode { range, node }) = stack.pop() { let v = tree[node as usize]; match v { TreeNode::Split { val, left, right, .. } => { let first_left = val + 1; if first_left >= range.end || first_left <= range.start { return None; } stack.push(RangeAndNode { range: first_left..range.end, node: left, }); stack.push(RangeAndNode { range: range.start..first_left, node: right, }); } TreeNode::Leaf { offset, multiplier, id, .. } => { if offset != 0 || multiplier != 1 { return None; } let start = range.start - LUT_MIN_SPLITVAL; let end = range.end - LUT_MIN_SPLITVAL; ans[start as usize..end as usize] .fill(histograms.map_context_to_cluster(id as usize) as u8); } } } Some(ans) } impl WpOnlyLookup { fn new( tree: &[TreeNode], histograms: &Histograms, header: &GroupHeader, xsize: usize, ) -> Option { let wp_state = WeightedPredictorState::new(&header.wp_header, xsize); let lut = make_lut(tree, histograms)?; Some(Self { lut, wp_state }) } } impl ModularChannelDecoder for WpOnlyLookup { const NEEDS_TOP: bool = true; const NEEDS_TOPTOP: bool = true; fn init_row(&mut self, _buffers: &mut [&mut ModularChannel], _chan: usize, _y: usize) { // nothing to do } #[inline(always)] fn decode_one( &mut self, prediction_data: PredictionData, pos: (usize, usize), xsize: usize, reader: &mut SymbolReader, br: &mut BitReader, histograms: &Histograms, ) -> i32 { let (wp_pred, property) = self .wp_state .predict_and_property(pos, xsize, &prediction_data); let ctx = self.lut[(property as i64 - LUT_MIN_SPLITVAL as i64) .clamp(0, LUT_TABLE_SIZE as i64 - 1) as usize]; let dec = reader.read_signed_clustered(histograms, br, ctx as usize); let val = dec.wrapping_add(wp_pred as i32); self.wp_state.update_errors(val, pos, xsize); val } } /// Fast path for trees that split only on property 9 (gradient: left + top - topleft) /// with Gradient predictor, offset=0, multiplier=1. /// Maps property 9 values directly to cluster IDs via a LUT. /// This targets libjxl effort 2 encoding. pub struct GradientLookup { lut: [u8; LUT_TABLE_SIZE], } /// Property 9 is the "gradient property": left + top - topleft const GRADIENT_PROPERTY: u8 = 9; fn make_gradient_lut(tree: &[TreeNode], histograms: &Histograms) -> Option { // Verify all splits are on property 9 and all leaves have Gradient predictor for node in tree { match node { TreeNode::Split { property, .. } => { if *property != GRADIENT_PROPERTY { return None; } } TreeNode::Leaf { predictor, .. } => { if *predictor != Predictor::Gradient { return None; } } } } // Use existing make_lut which handles offset=0, multiplier=1 checks let lut = make_lut(tree, histograms)?; Some(GradientLookup { lut }) } impl ModularChannelDecoder for GradientLookup { const NEEDS_TOP: bool = true; const NEEDS_TOPTOP: bool = false; fn init_row(&mut self, _: &mut [&mut ModularChannel], _: usize, _: usize) {} #[inline(always)] fn decode_one( &mut self, prediction_data: PredictionData, _: (usize, usize), _: usize, reader: &mut SymbolReader, br: &mut BitReader, histograms: &Histograms, ) -> i32 { let prop9 = prediction_data .left .wrapping_add(prediction_data.top) .wrapping_sub(prediction_data.topleft); let index = (prop9 as i64 - LUT_MIN_SPLITVAL as i64).clamp(0, LUT_TABLE_SIZE as i64 - 1) as usize; let cluster = self.lut[index]; let pred = clamped_gradient( prediction_data.left as i64, prediction_data.top as i64, prediction_data.topleft as i64, ); let dec = reader.read_signed_clustered(histograms, br, cluster as usize); dec.wrapping_add(pred as i32) } } pub struct SingleGradientOnly { ctx: usize, } impl ModularChannelDecoder for SingleGradientOnly { const NEEDS_TOP: bool = true; const NEEDS_TOPTOP: bool = false; fn init_row(&mut self, _: &mut [&mut ModularChannel], _: usize, _: usize) {} #[inline(always)] fn decode_one( &mut self, prediction_data: PredictionData, _: (usize, usize), _: usize, reader: &mut SymbolReader, br: &mut BitReader, histograms: &Histograms, ) -> i32 { let pred = Predictor::Gradient.predict_one(prediction_data, 0); let dec = reader.read_signed(histograms, br, self.ctx); make_pixel(dec, 1, pred) } } #[allow(clippy::large_enum_variant)] pub enum TreeSpecialCase { NoWp(NoWpTree), WpOnly(WpOnlyLookup), GradientLookup(GradientLookup), SingleGradientOnly(SingleGradientOnly), General(GeneralTree), } pub fn specialize_tree( tree: &Tree, channel: usize, stream: usize, xsize: usize, header: &GroupHeader, ) -> Result { // TODO(veluca): consider skipping the pruning if header.uses_global_tree is true. let mut pruned_tree = Vec::new(); let mut queue = VecDeque::new(); pruned_tree.try_reserve(tree.nodes.len())?; queue.try_reserve(tree.nodes.len())?; queue.push_front(0); let mut uses_wp = false; let mut uses_non_wp = false; // Obtain a pruned tree without nodes that are not relevant in the current channel and stream. // Proceed in BFS order, so that we know that the children of anode will be adjacent. while let Some(v) = queue.pop_front() { let node = tree.nodes[v as usize]; match node { TreeNode::Split { property, val, left, right, } if property < 2 => { // If the node splits on static properties, re-enqueue its correct child immediately. let vv = if property == 0 { channel } else { stream }; queue.push_front(if vv as i32 > val { left } else { right }); continue; } TreeNode::Split { property, val, left, right, } => { // WeightedPredictor property. uses_wp |= property == 15; uses_non_wp |= property != 15; let base = (queue.len() + pruned_tree.len() + 1) as u32; pruned_tree.push(TreeNode::Split { property, val, left: base, right: base + 1, }); queue.push_back(left); queue.push_back(right); } TreeNode::Leaf { predictor, .. } => { uses_wp |= predictor == Predictor::Weighted; uses_non_wp |= predictor != Predictor::Weighted; pruned_tree.push(node); } } } if let [ TreeNode::Leaf { predictor: Predictor::Gradient, multiplier: 1, offset: 0, id, }, ] = &*pruned_tree { return Ok(TreeSpecialCase::SingleGradientOnly(SingleGradientOnly { ctx: *id as usize, })); } if !uses_non_wp && let Some(wp) = WpOnlyLookup::new(&pruned_tree, &tree.histograms, header, xsize) { return Ok(TreeSpecialCase::WpOnly(wp)); } // Try gradient LUT for non-WP trees (targets effort 2 encoding) if !uses_wp { if let Some(gl) = make_gradient_lut(&pruned_tree, &tree.histograms) { return Ok(TreeSpecialCase::GradientLookup(gl)); } return Ok(TreeSpecialCase::NoWp(NoWpTree::new( pruned_tree, tree.max_property_count(), channel, stream, xsize, )?)); } Ok(TreeSpecialCase::General(GeneralTree::new( pruned_tree, tree.max_property_count(), header, channel, stream, xsize, )?)) }