// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. use jxl_simd::{ F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask, U32SimdVec, shl, shr, simd_function, }; use crate::{ error::{Error, Result}, frame::modular::{ChannelInfo, ModularChannel}, headers::modular::SqueezeParams, image::{Image, ImageRect}, util::AtomicRef, }; use crate::util::tracing_wrappers::*; #[instrument(level = "trace", err)] pub fn check_squeeze_params( channels: &[(usize, ChannelInfo)], params: &SqueezeParams, ) -> Result<()> { let end_channel = (params.begin_channel + params.num_channels) as usize; if end_channel > channels.len() { return Err(Error::InvalidChannelRange( params.begin_channel as usize, params.num_channels as usize, channels.len(), )); } if channels[params.begin_channel as usize].1.is_meta() != channels[end_channel - 1].1.is_meta() { return Err(Error::MixingDifferentChannels); } if channels[params.begin_channel as usize].1.is_meta() && !params.in_place { return Err(Error::MetaSqueezeRequiresInPlace); } Ok(()) } pub fn default_squeeze(data_channel_info: &[(usize, ChannelInfo)]) -> Vec { let num_meta_channels = data_channel_info .iter() .take_while(|x| x.1.is_meta()) .count(); let mut w = data_channel_info[num_meta_channels].1.size.0; let mut h = data_channel_info[num_meta_channels].1.size.1; let nc = data_channel_info.len() - num_meta_channels; let mut params = vec![]; if nc > 2 && data_channel_info[num_meta_channels + 1].1.size == (w, h) { // 420 previews let sp = SqueezeParams { horizontal: true, in_place: false, begin_channel: num_meta_channels as u32 + 1, num_channels: 2, }; if w > 1 { params.push(sp); } if h > 1 { params.push(SqueezeParams { horizontal: false, ..sp }); } } const MAX_FIRST_PREVIEW_SIZE: usize = 8; let sp = SqueezeParams { begin_channel: num_meta_channels as u32, num_channels: nc as u32, in_place: true, horizontal: false, }; // vertical first on tall images if w <= h && h > MAX_FIRST_PREVIEW_SIZE { params.push(SqueezeParams { horizontal: false, ..sp }); h = h.div_ceil(2); } while w > MAX_FIRST_PREVIEW_SIZE || h > MAX_FIRST_PREVIEW_SIZE { if w > MAX_FIRST_PREVIEW_SIZE { params.push(SqueezeParams { horizontal: true, ..sp }); w = w.div_ceil(2); } if h > MAX_FIRST_PREVIEW_SIZE { params.push(SqueezeParams { horizontal: false, ..sp }); h = h.div_ceil(2); } } params } #[inline(always)] fn smooth_tendency_impl( d: D, a: D::I32Vec, b: D::I32Vec, c: D::I32Vec, ) -> D::I32Vec { let a_b = a - b; let b_c = b - c; let a_c = a - c; let abs_a_b = a_b.abs(); let abs_b_c = b_c.abs(); let abs_a_c = a_c.abs(); let non_monotonic = (a_b ^ b_c).lt_zero(); let skip = a_b.eq_zero().andnot(non_monotonic); let skip = b_c.eq_zero().andnot(skip); let abs_a_b_3 = abs_a_b.mul_wide_take_high(D::I32Vec::splat(d, 0x55555556)); let x = shr!(D::I32Vec::splat(d, 2) + abs_a_c + abs_a_b_3, 2); let abs_a_b_2_add_x = shl!(abs_a_b, 1) + (x & D::I32Vec::splat(d, 1)); let x = x .gt(abs_a_b_2_add_x) .if_then_else_i32(shl!(abs_a_b, 1) + D::I32Vec::splat(d, 1), x); let abs_b_c_2 = shl!(abs_b_c, 1); let x = (x + (x & D::I32Vec::splat(d, 1))) .gt(abs_b_c_2) .if_then_else_i32(abs_b_c_2, x); let need_neg = a_c.lt_zero(); let x = skip.maskz_i32(x); need_neg.if_then_else_i32(-x, x) } #[inline(always)] fn smooth_tendency_scalar(b: i64, a: i64, n: i64) -> i64 { let mut diff = 0; if b >= a && a >= n { diff = (4 * b - 3 * n - a + 6) / 12; // 2c = a<<1 + diff - diff&1 <= 2b so diff - diff&1 <= 2b - 2a // 2d = a<<1 - diff - diff&1 >= 2n so diff + diff&1 <= 2a - 2n if diff - (diff & 1) > 2 * (b - a) { diff = 2 * (b - a) + 1; } if diff + (diff & 1) > 2 * (a - n) { diff = 2 * (a - n); } } else if b <= a && a <= n { diff = (4 * b - 3 * n - a - 6) / 12; // 2c = a<<1 + diff + diff&1 >= 2b so diff + diff&1 >= 2b - 2a // 2d = a<<1 - diff + diff&1 <= 2n so diff - diff&1 >= 2a - 2n if diff + (diff & 1) < 2 * (b - a) { diff = 2 * (b - a) - 1; } if diff - (diff & 1) < 2 * (a - n) { diff = 2 * (a - n); } } diff } #[inline(always)] fn unsqueeze_impl( d: D, avg: D::I32Vec, res: D::I32Vec, next_avg: D::I32Vec, prev: D::I32Vec, ) -> (D::I32Vec, D::I32Vec) { let tendency = smooth_tendency_impl(d, prev, avg, next_avg); let diff = res + tendency; let sign = shr!(diff.bitcast_to_u32(), 31).bitcast_to_i32(); let diff_2 = shr!(diff + sign, 1); let a = avg + diff_2; let b = a - diff; (a, b) } #[inline(always)] fn unsqueeze_scalar(avg: i32, res: i32, next_avg: i32, prev: i32) -> (i32, i32) { let tendency = smooth_tendency_scalar(prev as i64, avg as i64, next_avg as i64); let diff = (res as i64) + tendency; let a = (avg as i64) + (diff / 2); let b = a - diff; (a as i32, b as i32) } #[inline(always)] fn hsqueeze_impl( d: D, y_start: usize, in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, out: &mut Image, ) { const { assert!(D::I32Vec::LEN <= 16); assert!(D::I32Vec::LEN.is_power_of_two()); } let lanes = D::I32Vec::LEN; assert_eq!(y_start % lanes, 0); let (w, h) = in_res.size(); if lanes == 1 { return hsqueeze_scalar(y_start, in_avg, in_res, in_next_avg, out_prev, out); } let has_tail = out.size().0 & 1 == 1; if has_tail { debug_assert!(in_avg.size().0 == w + 1); debug_assert!(out.size().0 == 2 * w + 1); } let mask = !(lanes - 1); let y_limit = if w >= lanes { h & mask } else { y_start }; let mut buf = [0f32; 512]; for y in (y_start..y_limit).step_by(lanes) { for dy in 0..lanes { buf[dy] = f32::from_bits(in_avg.row(y + dy)[0] as u32); buf[lanes + dy] = f32::from_bits(in_res.row(y + dy)[0] as u32); } let mut avg_first = D::F32Vec::load(d, &buf).bitcast_to_i32(); let mut res_first = D::F32Vec::load(d, &buf[lanes..]).bitcast_to_i32(); let mut prev_b = match out_prev { None => avg_first, Some(mc) => { let mc_w = mc.data.size().0; let mc = &mc.data; for (dy, out) in buf[..lanes].iter_mut().enumerate() { *out = f32::from_bits(mc.row(y + dy)[mc_w - 1] as u32); } D::F32Vec::load(d, &buf).bitcast_to_i32() } }; let remainder_start = ((w - 1) & mask) + 1; let remainder_count = w - remainder_start; for x in (1..remainder_start).step_by(lanes) { let buf_arr = D::F32Vec::make_array_slice_mut(&mut buf); for dy in 0..lanes { let avg_row = &in_avg.row(y + dy)[x..][..lanes]; let res_row = &in_res.row(y + dy)[x..][..lanes]; let avg = D::I32Vec::load(d, avg_row); let res = D::I32Vec::load(d, res_row); avg.bitcast_to_f32().store_array(&mut buf_arr[2 * dy]); res.bitcast_to_f32().store_array(&mut buf_arr[2 * dy + 1]); } D::F32Vec::transpose_square(d, buf_arr, 2); D::F32Vec::transpose_square(d, &mut buf_arr[1..], 2); for idx in 0..lanes { let avg_next = D::F32Vec::load_array(d, &buf_arr[2 * idx]).bitcast_to_i32(); let res_next = D::F32Vec::load_array(d, &buf_arr[2 * idx + 1]).bitcast_to_i32(); let (a, b) = unsqueeze_impl(d, avg_first, res_first, avg_next, prev_b); a.bitcast_to_f32().store_array(&mut buf_arr[2 * idx]); b.bitcast_to_f32().store_array(&mut buf_arr[2 * idx + 1]); avg_first = avg_next; res_first = res_next; prev_b = b; } D::F32Vec::transpose_square(d, buf_arr, 1); D::F32Vec::transpose_square(d, &mut buf_arr[lanes..], 1); for dy in 0..lanes { let out_row = &mut out.row_mut(y + dy)[2 * x - 2..][..2 * lanes]; for group in 0..2 { let v = D::F32Vec::load_array(d, &buf_arr[dy + group * lanes]).bitcast_to_i32(); v.store(&mut out_row[group * lanes..]); } } } let x = remainder_start; if remainder_count == 0 { let avg_last = if has_tail { for (idx, out) in buf[..lanes].iter_mut().enumerate() { *out = f32::from_bits(in_avg.row(y + idx)[w] as u32); } D::F32Vec::load(d, &buf).bitcast_to_i32() } else if let Some(mc) = in_next_avg { for (idx, out) in buf[..lanes].iter_mut().enumerate() { *out = f32::from_bits(mc.row(y + idx)[0] as u32); } D::F32Vec::load(d, &buf).bitcast_to_i32() } else { avg_first }; let buf_arr = D::F32Vec::make_array_slice_mut(&mut buf); let (a, b) = unsqueeze_impl(d, avg_first, res_first, avg_last, prev_b); a.bitcast_to_f32().store_array(&mut buf_arr[0]); b.bitcast_to_f32().store_array(&mut buf_arr[1]); } else { for dy in 0..lanes { let avg_row = in_avg.row(y + dy); let res_row = in_res.row(y + dy); for dx in 0..remainder_count { buf[dx + lanes * 2 * dy] = f32::from_bits(avg_row[x + dx] as u32); buf[dx + lanes * (2 * dy + 1)] = f32::from_bits(res_row[x + dx] as u32); } buf[remainder_count + lanes * 2 * dy] = if has_tail { f32::from_bits(avg_row[w] as u32) } else if let Some(mc) = in_next_avg { f32::from_bits(mc.row(y + dy)[0] as u32) } else { buf[remainder_count - 1 + lanes * 2 * dy] }; } let buf_arr = D::F32Vec::make_array_slice_mut(&mut buf); D::F32Vec::transpose_square(d, buf_arr, 2); D::F32Vec::transpose_square(d, &mut buf_arr[1..], 2); for idx in 0..=remainder_count { let avg_next = D::F32Vec::load_array(d, &buf_arr[2 * idx]).bitcast_to_i32(); let res_next = D::F32Vec::load_array(d, &buf_arr[2 * idx + 1]).bitcast_to_i32(); let (a, b) = unsqueeze_impl(d, avg_first, res_first, avg_next, prev_b); a.bitcast_to_f32().store_array(&mut buf_arr[2 * idx]); b.bitcast_to_f32().store_array(&mut buf_arr[2 * idx + 1]); avg_first = avg_next; res_first = res_next; prev_b = b; } } let buf_arr = D::F32Vec::make_array_slice_mut(&mut buf); D::F32Vec::transpose_square(d, buf_arr, 1); D::F32Vec::transpose_square(d, &mut buf_arr[lanes..], 1); let x_limit = 2 * (remainder_count + 1); for dy in 0..lanes { let out_row = &mut out.row_mut(y + dy)[2 * x - 2..]; for (dx, out) in out_row[..x_limit].iter_mut().enumerate() { let group = dx / lanes; let group_x = dx % lanes; *out = buf[(dy + group * lanes) * lanes + group_x].to_bits() as i32; } } if has_tail { for dy in 0..lanes { out.row_mut(y + dy)[2 * w] = in_avg.row(y + dy)[w]; } } } let remainder_rows = h - y_limit; // We need `lanes > N` to convince the compiler that this function does not recurse if lanes > 8 && remainder_rows >= 8 && w >= 8 { return hsqueeze_impl( d.maybe_downgrade_256bit(), y_limit, in_avg, in_res, in_next_avg, out_prev, out, ); } if lanes > 4 && remainder_rows >= 4 && w >= 4 { return hsqueeze_impl( d.maybe_downgrade_128bit(), y_limit, in_avg, in_res, in_next_avg, out_prev, out, ); } hsqueeze_scalar(y_limit, in_avg, in_res, in_next_avg, out_prev, out) } #[inline(always)] fn hsqueeze_scalar( y_start: usize, in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, out: &mut Image, ) { let (w, h) = in_res.size(); debug_assert!(w >= 1); let has_tail = out.size().0 & 1 == 1; if has_tail { debug_assert!(in_avg.size().0 == w + 1); debug_assert!(out.size().0 == 2 * w + 1); } for y in y_start..h { let avg_row = in_avg.row(y); let res_row = in_res.row(y); let mut prev_b = match out_prev { None => avg_row[0], Some(mc) => mc.data.row(y)[mc.data.size().0 - 1], }; // Guarantee that `avg_row[x + 1]` is available. let x_end = if has_tail { w } else { w - 1 }; for x in 0..x_end { let (a, b) = unsqueeze_scalar(avg_row[x], res_row[x], avg_row[x + 1], prev_b); out.row_mut(y)[2 * x] = a; out.row_mut(y)[2 * x + 1] = b; prev_b = b; } if !has_tail { let last_avg = match in_next_avg { None => avg_row[w - 1], Some(mc) => mc.row(y)[0], }; let (a, b) = unsqueeze_scalar(avg_row[w - 1], res_row[w - 1], last_avg, prev_b); out.row_mut(y)[2 * w - 2] = a; out.row_mut(y)[2 * w - 1] = b; } else { // 1 last pixel out.row_mut(y)[2 * w] = in_avg.row(y)[w]; } } } simd_function!( hsqueeze, d: D, pub fn hsqueeze_fwd( in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, out: &mut Image, ) { hsqueeze_impl(d, 0, in_avg, in_res, in_next_avg, out_prev, out) } ); pub fn do_hsqueeze_step( in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, buffers: &mut [&mut ModularChannel], ) { trace!("hsqueeze step in_avg: {in_avg:?} in_res: {in_res:?} in_next_avg: {in_next_avg:?}"); let out = buffers.first_mut().unwrap(); // Shortcut: guarantees that row is at least 1px in the main loop if out.data.size().0 == 0 || out.data.size().1 == 0 { return; } let (w, h) = in_res.size(); // Another shortcut: when output row has just 1px if w == 0 { for y in 0..h { out.data.row_mut(y)[0] = in_avg.row(y)[0]; } return; } // Otherwise: 2 or more in in row hsqueeze(in_avg, in_res, in_next_avg, out_prev, &mut out.data); } #[inline(always)] fn vsqueeze_impl( d: D, x_start: usize, in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, out: &mut Image, ) { const { assert!(D::I32Vec::LEN.is_power_of_two()) }; let lanes = D::I32Vec::LEN; assert_eq!(x_start % lanes, 0); let (w, h) = in_res.size(); if lanes == 1 { return vsqueeze_scalar(x_start, in_avg, in_res, in_next_avg, out_prev, out); } let has_tail = out.size().1 & 1 == 1; if has_tail { debug_assert!(in_avg.size().1 == h + 1); debug_assert!(out.size().1 == 2 * h + 1); } let mask = !(lanes - 1); let x_limit = if w >= lanes { w & mask } else { x_start }; let prev_b_row = match out_prev { None => in_avg.row(0), Some(mc) => mc.data.row(mc.data.size().1 - 1), }; for x in (x_start..x_limit).step_by(lanes) { let mut prev_b = D::I32Vec::load(d, &prev_b_row[x..]); let mut avg_first = D::I32Vec::load(d, &in_avg.row(0)[x..]); let mut res_first = D::I32Vec::load(d, &in_res.row(0)[x..]); for y in 0..h - 1 { let avg_next = D::I32Vec::load(d, &in_avg.row(y + 1)[x..]); let (a, b) = unsqueeze_impl(d, avg_first, res_first, avg_next, prev_b); a.store(&mut out.row_mut(2 * y)[x..]); b.store(&mut out.row_mut(2 * y + 1)[x..]); prev_b = b; avg_first = avg_next; res_first = D::I32Vec::load(d, &in_res.row(y + 1)[x..]); } let avg_last = if has_tail { D::I32Vec::load(d, &in_avg.row(h)[x..]) } else if let Some(mc) = in_next_avg { D::I32Vec::load(d, &mc.row(0)[x..]) } else { avg_first }; let (a, b) = unsqueeze_impl(d, avg_first, res_first, avg_last, prev_b); a.store(&mut out.row_mut(2 * h - 2)[x..]); b.store(&mut out.row_mut(2 * h - 1)[x..]); if has_tail { avg_last.store(&mut out.row_mut(2 * h)[x..]); } } let remainder_cols = w - x_limit; // We need `lanes > N` to convince the compiler that this function does not recurse if lanes > 8 && remainder_cols >= 8 { return vsqueeze_impl( d.maybe_downgrade_256bit(), x_limit, in_avg, in_res, in_next_avg, out_prev, out, ); } if lanes > 4 && remainder_cols >= 4 { return vsqueeze_impl( d.maybe_downgrade_128bit(), x_limit, in_avg, in_res, in_next_avg, out_prev, out, ); } vsqueeze_scalar(x_limit, in_avg, in_res, in_next_avg, out_prev, out) } #[inline(always)] fn vsqueeze_scalar( x_start: usize, in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, out: &mut Image, ) { let (w, h) = in_res.size(); let has_tail = out.size().1 & 1 == 1; if has_tail { debug_assert!(in_avg.size().1 == h + 1); debug_assert!(out.size().1 == 2 * h + 1); } { let prev_b_row = match out_prev { None => in_avg.row(0), Some(mc) => mc.data.row(mc.data.size().1 - 1), }; let avg_row = in_avg.row(0); let res_row = in_res.row(0); let avg_row_next = if !has_tail && (h == 1) { debug_assert!(in_next_avg.is_none()); in_avg.row(0) } else { in_avg.row(1) }; for x in x_start..w { let (a, b) = unsqueeze_scalar(avg_row[x], res_row[x], avg_row_next[x], prev_b_row[x]); out.row_mut(0)[x] = a; out.row_mut(1)[x] = b; } } for y in 1..h { let avg_row = in_avg.row(y); let res_row = in_res.row(y); let avg_row_next = if has_tail || y < h - 1 { in_avg.row(y + 1) } else { match in_next_avg { None => avg_row, Some(mc) => mc.row(0), } }; for x in x_start..w { let (a, b) = unsqueeze_scalar( avg_row[x], res_row[x], avg_row_next[x], out.row(2 * y - 1)[x], ); out.row_mut(2 * y)[x] = a; out.row_mut(2 * y + 1)[x] = b; } } if has_tail { out.row_mut(2 * h)[x_start..].copy_from_slice(&in_avg.row(h)[x_start..]); } } simd_function!( vsqueeze, d: D, pub fn vsqueeze_fwd( in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, out: &mut Image, ) { vsqueeze_impl(d, 0, in_avg, in_res, in_next_avg, out_prev, out) } ); pub fn do_vsqueeze_step( in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, in_next_avg: &Option>, out_prev: &Option>, buffers: &mut [&mut ModularChannel], ) { trace!("vsqueeze step in_avg: {in_avg:?} in_res: {in_res:?} in_next_avg: {in_next_avg:?}"); let out = &mut buffers.first_mut().unwrap().data; // Shortcut: guarantees that there at least 1 output row if out.size().1 == 0 || out.size().0 == 0 { return; } // Another shortcut: when there is one output row if in_res.size().1 == 0 { out.row_mut(0).copy_from_slice(in_avg.row(0)); return; } // Otherwise: 2 or more rows vsqueeze(in_avg, in_res, in_next_avg, out_prev, out); }