// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. use std::sync::Arc; use super::render::pipeline; use super::{ block_context_map::BlockContextMap, coeff_order::decode_coeff_orders, color_correlation_map::ColorCorrelationParams, group::{VarDctBuffers, decode_vardct_group}, modular::{FullModularImage, ModularStreamId, Tree, decode_hf_metadata, decode_vardct_lf}, quant_weights::DequantMatrices, quantizer::{LfQuantFactors, QuantizerParams}, }; use crate::error::Error; #[cfg(test)] use crate::render::SimpleRenderPipeline; use crate::render::buffer_splitter::BufferSplitter; use crate::{ GROUP_DIM, bit_reader::BitReader, entropy_coding::decode::Histograms, error::Result, features::{noise::Noise, patches::PatchesDictionary, spline::Splines}, frame::{ DecoderState, Frame, HfGlobalState, HfMetadata, LfGlobalState, PassState, coeff_order, }, headers::{ color_encoding::ColorSpace, frame_header::{Encoding, FrameHeader}, toc::Toc, }, image::Image, render::RenderPipeline, util::{CeilLog2, Xorshift128Plus, tracing_wrappers::*}, }; use jxl_transforms::transform_map::*; impl Frame { pub fn from_header_and_toc( frame_header: FrameHeader, toc: Toc, mut decoder_state: DecoderState, ) -> Result { if frame_header.is_visible() { decoder_state.visible_frame_index += 1; decoder_state.nonvisible_frame_index = 0; } else { decoder_state.nonvisible_frame_index += 1; } let image_metadata = &decoder_state.file_header.image_metadata; let is_gray = !frame_header.do_ycbcr && !image_metadata.xyb_encoded && image_metadata.color_encoding.color_space == ColorSpace::Gray; let color_channels = if is_gray { 1 } else { 3 }; let size_blocks = frame_header.size_blocks(); let lf_image = if frame_header.encoding == Encoding::VarDCT { if frame_header.has_lf_frame() { decoder_state.lf_frames[frame_header.lf_level as usize] .as_ref() .map(|[a, b, c]| { Ok::<_, Error>([a.try_clone()?, b.try_clone()?, c.try_clone()?]) }) .transpose()? } else { Some([ Image::new(size_blocks)?, Image::new(size_blocks)?, Image::new(size_blocks)?, ]) } } else { None }; let quant_lf = Image::new(size_blocks)?; let size_color_tiles = (size_blocks.0.div_ceil(8), size_blocks.1.div_ceil(8)); let hf_meta = if frame_header.encoding == Encoding::VarDCT { Some(HfMetadata { ytox_map: Image::new(size_color_tiles)?, ytob_map: Image::new(size_color_tiles)?, raw_quant_map: Image::new(size_blocks)?, transform_map: Image::new_with_value( size_blocks, HfTransformType::INVALID_TRANSFORM, )?, epf_map: Image::new(size_blocks)?, used_hf_types: 0, }) } else { None }; let reference_frame_data = if frame_header.can_be_referenced { let image_size = &decoder_state.file_header.size; let image_size = (image_size.xsize() as usize, image_size.ysize() as usize); let sz = if frame_header.save_before_ct { frame_header.size_upsampled() } else { image_size }; let num_ref_channels = 3 + image_metadata.extra_channel_info.len(); Some( (0..num_ref_channels) .map(|_| Image::new(sz)) .collect::>>()?, ) } else { None }; let lf_frame_data = if frame_header.lf_level != 0 { Some( (0..3) .map(|_| Image::new(frame_header.size_upsampled())) .collect::, _>>()? .try_into() .unwrap(), ) } else { None }; Ok(Self { #[cfg(test)] use_simple_pipeline: decoder_state.use_simple_pipeline, header: frame_header, color_channels, toc, lf_global: None, hf_global: None, lf_image, quant_lf, hf_meta, decoder_state, render_pipeline: None, reference_frame_data, lf_frame_data, lf_global_was_rendered: false, vardct_buffers: None, }) } /// Given a bit reader pointing at the end of the TOC, returns a vector of `BitReader`s, each /// of which reads a specific section. pub fn sections<'a>(&self, br: &'a mut BitReader) -> Result>> { debug!(toc = ?self.toc); let ret = self .toc .entries .iter() .scan(br, |br, count| Some(br.split_at(*count as usize))) .collect::>>()?; if !self.toc.permuted { return Ok(ret); } let mut inv_perm = vec![0; ret.len()]; for (i, pos) in self.toc.permutation.iter().enumerate() { inv_perm[*pos as usize] = i; } let mut shuffled_ret = ret.clone(); for (br, pos) in ret.into_iter().zip(inv_perm.into_iter()) { shuffled_ret[pos] = br; } Ok(shuffled_ret) } #[instrument(level = "debug", skip_all)] pub fn decode_lf_global(&mut self, br: &mut BitReader) -> Result<()> { debug!(section_size = br.total_bits_available()); assert!(self.lf_global.is_none()); trace!(pos = br.total_bits_read()); let patches = if self.header.has_patches() { info!("decoding patches"); Some(PatchesDictionary::read( br, self.header.size_padded().0, self.header.size_padded().1, self.decoder_state.extra_channel_info().len(), &self.decoder_state.reference_frames[..], )?) } else { None }; let splines = if self.header.has_splines() { info!("decoding splines"); Some(Splines::read(br, self.header.width * self.header.height)?) } else { None }; let noise = if self.header.has_noise() { info!("decoding noise"); Some(Noise::read(br)?) } else { None }; let lf_quant = LfQuantFactors::new(br)?; debug!(?lf_quant); let quant_params = if self.header.encoding == Encoding::VarDCT { info!("decoding VarDCT quantizer params"); Some(QuantizerParams::read(br)?) } else { None }; debug!(?quant_params); let block_context_map = if self.header.encoding == Encoding::VarDCT { info!("decoding block context map"); Some(BlockContextMap::read(br)?) } else { None }; debug!(?block_context_map); let color_correlation_params = if self.header.encoding == Encoding::VarDCT { info!("decoding color correlation params"); Some(ColorCorrelationParams::read(br)?) } else { None }; debug!(?color_correlation_params); let tree = if br.read(1)? == 1 { let size_limit = (1024 + self.header.width as usize * self.header.height as usize * (self.color_channels + self.decoder_state.extra_channel_info().len()) / 16) .min(1 << 22); Some(Tree::read(br, size_limit)?) } else { None }; let modular_global = FullModularImage::read( &self.header, &self.decoder_state.file_header.image_metadata, self.modular_color_channels(), &tree, br, )?; self.lf_global = Some(LfGlobalState { patches: patches.map(Arc::new), splines, noise, lf_quant, quant_params, block_context_map, color_correlation_params, tree, modular_global, }); Ok(()) } #[instrument(level = "debug", skip(self, br))] pub fn decode_lf_group(&mut self, group: usize, br: &mut BitReader) -> Result<()> { debug!(section_size = br.total_bits_available()); let lf_global = self.lf_global.as_mut().unwrap(); if self.header.encoding == Encoding::VarDCT && !self.header.has_lf_frame() { info!("decoding VarDCT LF with group id {}", group); decode_vardct_lf( group, &self.header, &self.decoder_state.file_header.image_metadata, &lf_global.tree, lf_global.color_correlation_params.as_ref().unwrap(), lf_global.quant_params.as_ref().unwrap(), &lf_global.lf_quant, lf_global.block_context_map.as_ref().unwrap(), self.lf_image.as_mut().unwrap(), &mut self.quant_lf, br, )?; } lf_global.modular_global.read_stream( ModularStreamId::ModularLF(group), &self.header, &lf_global.tree, br, )?; if self.header.encoding == Encoding::VarDCT { info!("decoding HF metadata with group id {}", group); let hf_meta = self.hf_meta.as_mut().unwrap(); decode_hf_metadata( group, &self.header, &self.decoder_state.file_header.image_metadata, &lf_global.tree, hf_meta, br, )?; } Ok(()) } #[instrument(level = "debug", skip_all)] pub fn decode_hf_global(&mut self, br: &mut BitReader) -> Result<()> { debug!(section_size = br.total_bits_available()); if self.header.encoding == Encoding::Modular { return Ok(()); } let lf_global = self.lf_global.as_mut().unwrap(); let dequant_matrices = DequantMatrices::decode(&self.header, lf_global, br)?; let block_context_map = lf_global.block_context_map.as_mut().unwrap(); let num_histo_bits = self.header.num_groups().ceil_log2(); let num_histograms: u32 = br.read(num_histo_bits)? as u32 + 1; info!( "Processing HFGlobal section with {} passes and {} histograms", self.header.passes.num_passes, num_histograms ); let mut passes: Vec = vec![]; #[allow(unused_variables)] for i in 0..self.header.passes.num_passes as usize { let used_orders = match br.read(2)? { 0 => 0x5f, 1 => 0x13, 2 => 0, _ => br.read(coeff_order::NUM_ORDERS)?, } as u32; debug!(used_orders); let coeff_orders = decode_coeff_orders(used_orders, br)?; assert_eq!(coeff_orders.len(), 3 * coeff_order::NUM_ORDERS); let num_contexts = num_histograms as usize * block_context_map.num_ac_contexts(); info!( "Deconding histograms for pass {} with {} contexts", i, num_contexts ); let histograms = Histograms::decode(num_contexts, br, true)?; debug!("Found {} histograms", histograms.num_histograms()); passes.push(PassState { coeff_orders, histograms, }); } let hf_coefficients = if passes.len() <= 1 { None } else { let xs = GROUP_DIM * GROUP_DIM; let ys = self.header.num_groups(); Some(( Image::new((xs, ys))?, Image::new((xs, ys))?, Image::new((xs, ys))?, )) }; self.hf_global = Some(HfGlobalState { num_histograms, passes, dequant_matrices, hf_coefficients, }); Ok(()) } #[instrument(level = "debug", skip(self, br, buffer_splitter))] pub fn decode_hf_group( &mut self, group: usize, pass: usize, mut br: BitReader, buffer_splitter: &mut BufferSplitter, ) -> Result<()> { debug!(section_size = br.total_bits_available()); if self.header.has_noise() { // TODO(sboukortt): consider making this a dedicated stage let num_channels = self.header.num_extra_channels as usize + 3; let group_dim = self.header.group_dim() as u32; let xsize_groups = self.header.size_groups().0; let gx = (group % xsize_groups) as u32; let gy = (group / xsize_groups) as u32; // TODO(sboukortt): test upsampling+noise let upsampling = self.header.upsampling; let x0 = gx * upsampling * group_dim; let y0 = gy * upsampling * group_dim; let x1 = ((x0 + upsampling * group_dim) as usize).min(self.header.size_upsampled().0); let y1 = ((y0 + upsampling * group_dim) as usize).min(self.header.size_upsampled().1); let xsize = x1 - x0 as usize; let ysize = y1 - y0 as usize; let mut rng = Xorshift128Plus::new_with_seeds( self.decoder_state.visible_frame_index as u32, self.decoder_state.nonvisible_frame_index as u32, x0, y0, ); let bits_to_float = |bits: u32| f32::from_bits((bits >> 9) | 0x3F800000); for i in 0..3 { let mut buf = pipeline!(self, p, p.get_buffer(num_channels + i)?); const FLOATS_PER_BATCH: usize = Xorshift128Plus::N * std::mem::size_of::() / std::mem::size_of::(); let mut batch = [0u64; Xorshift128Plus::N]; for y in 0..ysize { let row = buf.row_mut(y); for batch_index in 0..xsize.div_ceil(FLOATS_PER_BATCH) { rng.fill(&mut batch); let batch_size = (xsize - batch_index * FLOATS_PER_BATCH).min(FLOATS_PER_BATCH); for i in 0..batch_size { let x = FLOATS_PER_BATCH * batch_index + i; let k = i / 2; let high_bytes = i % 2 != 0; let bits = if high_bytes { ((batch[k] & 0xFFFFFFFF00000000) >> 32) as u32 } else { (batch[k] & 0xFFFFFFFF) as u32 }; row[x] = bits_to_float(bits); } } } pipeline!( self, p, p.set_buffer_for_group(num_channels + i, group, 1, buf, buffer_splitter)? ) } } let lf_global = self.lf_global.as_mut().unwrap(); if self.header.encoding == Encoding::VarDCT { info!("Decoding VarDCT group {group}, pass {pass}"); let hf_global = self.hf_global.as_mut().unwrap(); let hf_meta = self.hf_meta.as_mut().unwrap(); let mut pixels = [ pipeline!(self, p, p.get_buffer(0))?, pipeline!(self, p, p.get_buffer(1))?, pipeline!(self, p, p.get_buffer(2))?, ]; let buffers = self.vardct_buffers.get_or_insert_with(VarDctBuffers::new); decode_vardct_group( group, pass, &self.header, lf_global, hf_global, hf_meta, &self.lf_image, &self.quant_lf, &self .decoder_state .file_header .transform_data .opsin_inverse_matrix .quant_biases, &mut pixels, &mut br, buffers, )?; if self.decoder_state.enable_output && pass + 1 == self.header.passes.num_passes as usize { for (c, img) in pixels.into_iter().enumerate() { pipeline!( self, p, p.set_buffer_for_group(c, group, 1, img, buffer_splitter)? ); } } } lf_global.modular_global.read_stream( ModularStreamId::ModularHF { group, pass }, &self.header, &lf_global.tree, &mut br, )?; lf_global.modular_global.process_output( 2 + pass, group, &self.header, &mut |chan, group, num_passes, image| { pipeline!( self, p, p.set_buffer_for_group(chan, group, num_passes, image, buffer_splitter)? ); Ok(()) }, )?; Ok(()) } }