// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. extern crate proc_macro; use proc_macro::TokenStream; use proc_macro_error2::{abort, proc_macro_error}; use proc_macro2::TokenStream as TokenStream2; use quote::quote; use syn::{DeriveInput, Meta, parse_macro_input}; fn get_bits(expr_call: &syn::ExprCall) -> syn::Expr { if let syn::Expr::Path(ep) = &*expr_call.func { if !ep.path.is_ident("Bits") { abort!( expr_call, "Unexpected function name in coder: {}", ep.path.get_ident().unwrap() ); } if expr_call.args.len() != 1 { abort!( expr_call, "Unexpected number of arguments for Bits() in coder: {}", expr_call.args.len() ); } return expr_call.args[0].clone(); } abort!(expr_call, "Unexpected function call in coder"); } fn parse_single_coder(input: &syn::Expr, extra_lit: Option<&syn::ExprLit>) -> TokenStream2 { match &input { syn::Expr::Lit(lit) => match extra_lit { None => quote! {U32::Val(#lit)}, Some(elit) => quote! {U32::Val(#lit + #elit)}, }, syn::Expr::Call(expr_call) => { let bits = get_bits(expr_call); match extra_lit { None => quote! {U32::Bits(#bits)}, Some(elit) => quote! {U32::BitsOffset{n: #bits, off: #elit}}, } } syn::Expr::Binary(syn::ExprBinary { attrs: _, left, op: syn::BinOp::Add(_), right, }) => { let (left, right) = if let syn::Expr::Lit(_) = **left { (right, left) } else { (left, right) }; match (&**left, &**right) { (syn::Expr::Call(expr_call), syn::Expr::Lit(lit)) => { let bits = get_bits(expr_call); match extra_lit { None => quote! {U32::BitsOffset{n: #bits, off: #lit}}, Some(elit) => quote! {U32::BitsOffset{n: #bits, off: #lit + #elit}}, } } _ => abort!( input, "Unexpected expression in coder, must be Bits(a) + b, Bits(a), or b" ), } } _ => abort!( input, "Unexpected expression in coder, must be Bits(a) + b, Bits(a), or b" ), } } fn parse_coder(input: &syn::Expr) -> TokenStream2 { let parse_u2s = |expr_call: &syn::ExprCall, lit: Option<&syn::ExprLit>| { if let syn::Expr::Path(ep) = &*expr_call.func { if !ep.path.is_ident("u2S") { let coder = parse_single_coder(input, None); return quote! {U32Coder::Direct(#coder)}; } if expr_call.args.len() != 4 { abort!( input, "Unexpected number of arguments for U32() in coder: {}", expr_call.args.len() ); } let args = vec![ parse_single_coder(&expr_call.args[0], lit), parse_single_coder(&expr_call.args[1], lit), parse_single_coder(&expr_call.args[2], lit), parse_single_coder(&expr_call.args[3], lit), ]; return quote! {U32Coder::Select(#(#args),*)}; } abort!(input, "Unexpected function call in coder"); }; match &input { syn::Expr::Call(expr_call) => parse_u2s(expr_call, None), syn::Expr::Binary(syn::ExprBinary { attrs: _, left, op: syn::BinOp::Add(_), right, }) => { let (left, right) = if let syn::Expr::Lit(_) = **left { (right, left) } else { (left, right) }; match (&**left, &**right) { (syn::Expr::Call(expr_call), syn::Expr::Lit(lit)) => { parse_u2s(expr_call, Some(lit)) } _ => abort!( input, "Unexpected expression in coder, must be (u2S|Bits)(a) + b, (u2S|Bits)(a), or b" ), } } _ => parse_single_coder(input, None), } } fn parse_size_coder(mut input: syn::Expr) -> TokenStream2 { match input { syn::Expr::Call(syn::ExprCall { ref func, ref mut args, .. }) => { if args.len() != 1 { abort!(input, "Expected 1 argument in sized_coder inner call"); } match &**func { syn::Expr::Path(expr_path) if expr_path.path.is_ident("implicit") => { let arg = args.first().unwrap().clone(); parse_coder(&arg) } syn::Expr::Path(expr_path) if expr_path.path.is_ident("explicit") => { quote! { U32Coder::Direct(U32::Val(#args)) } } _ => abort!( input, "Unexpected expression in size_coder, must be 'implicit()' or 'explicit()'" ), } } _ => abort!( input, "Unexpected expression in size_coder, must be 'implicit()' or 'explicit()'" ), } } fn prettify_condition(cond: &syn::Expr) -> String { (quote! {#cond}) .to_string() .replace(" . ", ".") .replace("! ", "!") .replace(" :: ", "::") } #[derive(Debug)] struct Condition { expr: Option, has_all_default: bool, pretty: String, } impl Condition { fn get_expr(&self, all_default_field: &Option) -> Option { if self.has_all_default { let all_default = all_default_field.as_ref().unwrap(); match &self.expr { Some(expr) => Some(quote! { !#all_default && (#expr) }), None => Some(quote! { !#all_default }), } } else { self.expr.as_ref().map(|expr| quote! {#expr}) } } fn get_pretty(&self, all_default_field: &Option) -> String { if self.has_all_default { let all_default = all_default_field.as_ref().unwrap(); let all_default = "!".to_owned() + "e! {#all_default}.to_string(); match &self.expr { Some(_) => all_default + " && (" + &self.pretty + ")", None => all_default, } } else { self.pretty.clone() } } } #[derive(Debug, Clone)] struct U32 { coder: TokenStream2, } #[derive(Debug)] #[allow(clippy::large_enum_variant)] enum Coder { WithoutConfig, U32(U32), Select(Condition, U32, U32), Vector(U32, Box), } impl Coder { fn ty(&self) -> TokenStream2 { match self { Coder::WithoutConfig => quote! {()}, Coder::U32(..) => quote! {U32Coder}, Coder::Select(..) => quote! {U32Coder}, Coder::Vector(_, value_coder) => { let value_coder_ty = value_coder.ty(); quote! {VectorCoder<#value_coder_ty>} } } } fn config(&self, all_default_field: &Option) -> TokenStream2 { match self { Coder::WithoutConfig => quote! { () }, Coder::U32(U32 { coder }) => quote! { #coder }, Coder::Select(condition, U32 { coder: coder_true }, U32 { coder: coder_false }) => { let cnd = condition.get_expr(all_default_field).unwrap(); quote! { if #cnd { #coder_true } else { #coder_false } } } Coder::Vector(U32 { coder }, value_coder) => { let value_coder = value_coder.config(all_default_field); quote! {VectorCoder{size_coder: #coder, value_coder: #value_coder}} } } } } #[derive(Debug)] enum FieldKind { Unconditional(Coder), Conditional(Condition, Coder), Defaulted(Condition, Coder), } #[derive(Debug)] struct Field { name: proc_macro2::Ident, kind: FieldKind, ty: syn::Type, default: Option, default_element: Option, nonserialized_inits: Vec, } impl Field { fn parse(f: &syn::Field, num: usize, all_default_field: &mut Option) -> Field { let mut condition = None; let mut default = None; let mut coder = None; let mut select_coder = None; let mut coder_true = None; let mut coder_false = None; let mut is_all_default = false; let mut size_coder = None; let mut nonserialized = vec![]; let mut default_element = None; // Parse attributes. for a in &f.attrs { match a.path().get_ident().map(syn::Ident::to_string).as_deref() { Some("coder") => { if coder.is_some() { abort!(f, "Repeated coder"); } let coder_ast = a.parse_args::().unwrap(); coder = Some(Coder::U32(U32 { coder: parse_coder(&coder_ast), })); } Some("default") => { if default.is_some() { abort!(f, "Repeated default"); } let default_expr = a.parse_args::().unwrap(); default = Some(quote! {#default_expr}); } Some("default_element") => { if default_element.is_some() { abort!(f, "Repeated default_element") } let default_element_expr = a.parse_args::().unwrap(); default_element = Some(quote! { #default_element_expr }) } Some("condition") => { if condition.is_some() { abort!(f, "Repeated condition"); } let condition_ast = a.parse_args::().unwrap(); let pretty_cond = prettify_condition(&condition_ast); condition = Some(Condition { expr: Some(condition_ast), has_all_default: all_default_field.is_some(), pretty: pretty_cond, }); } Some("all_default") => { if num != 0 { abort!(f, "all_default is not the first field"); } if default.is_some() { abort!(f, "all_default has an implicit default"); } is_all_default = true; default = Some(quote! { true }); } Some("select_coder") => { if select_coder.is_some() { abort!(f, "Repeated select_coder"); } let condition_ast = a.parse_args::().unwrap(); let pretty_cond = prettify_condition(&condition_ast); select_coder = Some(Condition { expr: Some(condition_ast), has_all_default: false, pretty: pretty_cond, }); } Some("coder_false") => { if coder_false.is_some() { abort!(f, "Repeated coder_false"); } let coder_ast = a.parse_args::().unwrap(); coder_false = Some(U32 { coder: parse_coder(&coder_ast), }); } Some("coder_true") => { if coder_true.is_some() { abort!(f, "Repeated coder_true"); } let coder_ast = a.parse_args::().unwrap(); coder_true = Some(U32 { coder: parse_coder(&coder_ast), }); } Some("size_coder") => { if size_coder.is_some() { abort!(f, "Repeated size_coder"); } let coder_ast = a.parse_args::().unwrap(); size_coder = Some(U32 { coder: parse_size_coder(coder_ast), }); } Some("nonserialized") => { let Meta::List(ns) = &a.meta else { abort!(a, "Invalid attribute"); }; let stream = &ns.tokens; nonserialized.push(quote! {#stream}); } _ => {} } } if default.is_some() && default_element.is_some() { abort!(f, "default is incompatible with default_element"); } if let Some(select_coder) = select_coder { if coder_true.is_none() || coder_false.is_none() { abort!( f, "Invalid field, select_coder is set but coder_true or coder_false are not" ) } if coder.is_some() { abort!(f, "Invalid field, select_coder and coder are both present") } coder = Some(Coder::Select( select_coder, coder_true.unwrap(), coder_false.unwrap(), )) } let condition = if condition.is_some() || all_default_field.is_none() { condition } else { Some(Condition { expr: None, has_all_default: true, pretty: String::new(), }) }; // Assume nested field if no coder. let mut coder = coder.unwrap_or_else(|| Coder::WithoutConfig); if let Some(c) = size_coder { if default.is_none() { default = Some(quote! { Vec::new() }); } coder = Coder::Vector(c, Box::new(coder)) } let ident = f.ident.as_ref().unwrap(); let kind = match (condition, default.is_some()) { (None, _) => FieldKind::Unconditional(coder), (Some(cond), false) => FieldKind::Conditional(cond, coder), (Some(cond), true) => FieldKind::Defaulted(cond, coder), }; if is_all_default { *all_default_field = Some(f.ident.as_ref().unwrap().clone()); } Field { name: ident.clone(), kind, ty: f.ty.clone(), default, default_element, nonserialized_inits: nonserialized, } } // Produces reading code (possibly with tracing). fn read_fun(&self, all_default_field: &Option) -> TokenStream2 { let ident = &self.name; let ty = &self.ty; let nonserialized_inits = &self.nonserialized_inits; match &self.kind { FieldKind::Unconditional(coder) => { let cfg_ty = coder.ty(); let cfg = coder.config(all_default_field); let trc = quote! { crate::util::tracing_wrappers::trace!("Setting {} to {:?}. total_bits_read: {}, peek: {:08b}", stringify!(#ident), #ident, br.total_bits_read(), br.peek(8)); }; quote! { let #ident = { let cfg = #cfg; type NS = <#ty as UnconditionalCoder<#cfg_ty>>::Nonserialized; let nonserialized = NS { #(#nonserialized_inits),* }; <#ty>::read_unconditional(&cfg, br, &nonserialized)? }; #trc } } FieldKind::Conditional(condition, coder) => { let cfg_ty = coder.ty(); let cfg = coder.config(all_default_field); let cnd = condition.get_expr(all_default_field).unwrap(); let pretty_cnd = condition.get_pretty(all_default_field); let trc = quote! { crate::util::tracing_wrappers::trace!("{} is {}, setting {} to {:?}. total_bits_read: {}, peek {:08b}", #pretty_cnd, #cnd, stringify!(#ident), #ident, br.total_bits_read(), br.peek(8)); }; quote! { let #ident = { let cond = #cnd; let cfg = #cfg; type NS = <#ty as ConditionalCoder<#cfg_ty>>::Nonserialized; let nonserialized = NS { #(#nonserialized_inits),* }; <#ty>::read_conditional(&cfg, cond, br, &nonserialized)? }; #trc } } FieldKind::Defaulted(condition, coder) => { let cfg_ty = coder.ty(); let cfg = coder.config(all_default_field); let cnd = condition.get_expr(all_default_field).unwrap(); let pretty_cnd = condition.get_pretty(all_default_field); let default = &self.default; let trc = quote! { crate::util::tracing_wrappers::trace!("{} is {}, setting {} to {:?}. total_bits_read: {}, peek {:08b}", #pretty_cnd, #cnd, stringify!(#ident), #ident, br.total_bits_read(), br.peek(8)); }; let (read_fn, default) = if let Some(def) = &self.default_element { (quote! { read_defaulted_element }, Some(def)) } else { (quote! { read_defaulted }, default.as_ref()) }; quote! { let #ident = { let cond = #cnd; let cfg = #cfg; type NS = <#ty as DefaultedCoder<#cfg_ty>>::Nonserialized; let field_nonserialized = NS { #(#nonserialized_inits),* }; let default = #default; <#ty>::#read_fn(&cfg, cond, default, br, &field_nonserialized)? }; #trc } } } } // Produces default code. fn default_code(&self) -> TokenStream2 { let ident = &self.name; let ty = &self.ty; let nonserialized_inits = &self.nonserialized_inits; let default = &self.default; match &self.kind { FieldKind::Defaulted(_, coder) => { let cfg_ty = coder.ty(); let default = &self.default; quote! { let #ident = { type NS = <#ty as DefaultedCoder<#cfg_ty>>::Nonserialized; let field_nonserialized = NS { #(#nonserialized_inits),* }; #default }; } } _ => quote! { let #ident = #default; }, } } } fn derive_struct(input: &DeriveInput) -> TokenStream2 { let name = &input.ident; let validate = input.attrs.iter().any(|a| a.path().is_ident("validate")); let nonserialized: Vec<_> = input .attrs .iter() .filter_map(|a| { if a.path().is_ident("nonserialized") { Some(a.parse_args::().unwrap()) } else { None } }) .collect(); if nonserialized.len() > 1 { abort!(input, "repeated nonserialized"); } let nonserialized = if nonserialized.is_empty() { quote! {Empty} } else { let v = &nonserialized[0]; quote! {#v} }; let data = if let syn::Data::Struct(struct_data) = &input.data { struct_data } else { abort!(input, "derive_struct didn't get a struct"); }; let fields = if let syn::Fields::Named(syn::FieldsNamed { brace_token: _, named, }) = &data.fields { named } else { abort!(data.fields, "only named fields are supported (for now?)"); }; let mut all_default_field = None; let fields: Vec<_> = fields .iter() .enumerate() .map(|(n, f)| Field::parse(f, n, &mut all_default_field)) .collect(); let fields_read = fields.iter().map(|x| x.read_fun(&all_default_field)); let fields_names = fields.iter().map(|x| &x.name); let impl_default = if fields.iter().all(|x| x.default.is_some()) { let field_init = fields.iter().map(Field::default_code); let struct_init = fields.iter().map(|f| { let ident = &f.name; quote! { #ident } }); quote! { impl #name { pub fn default(nonserialized: &#nonserialized) -> #name { #(#field_init)* #name { #(#struct_init),* } } } } } else { quote! {} }; let impl_validate = if validate { quote! { return_value.check(nonserialized)?; } } else { quote! {} }; let align = match input.attrs.iter().any(|a| a.path().is_ident("aligned")) { true => quote! { br.jump_to_byte_boundary()?; }, false => quote! {}, }; quote! { #impl_default impl crate::headers::encodings::UnconditionalCoder<()> for #name { type Nonserialized = #nonserialized; #[cold] #[inline(never)] fn read_unconditional(_: &(), br: &mut BitReader, nonserialized: &Self::Nonserialized) -> Result<#name, Error> { use crate::headers::encodings::UnconditionalCoder; use crate::headers::encodings::ConditionalCoder; use crate::headers::encodings::DefaultedCoder; use crate::headers::encodings::DefaultedElementCoder; #align #(#fields_read)* let return_value = #name { #(#fields_names),* }; #impl_validate Ok(return_value) } } } } fn derive_enum(input: &DeriveInput) -> TokenStream2 { let name = &input.ident; quote! { impl crate::headers::encodings::UnconditionalCoder for #name { type Nonserialized = Empty; fn read_unconditional(config: &U32Coder, br: &mut BitReader, _: &Empty) -> Result<#name, Error> { use num_traits::FromPrimitive; let u = u32::read_unconditional(config, br, &Empty{})?; if let Some(e) = #name::from_u32(u) { Ok(e) } else { Err(Error::InvalidEnum(u, stringify!(#name).to_string())) } } } impl crate::headers::encodings::UnconditionalCoder<()> for #name { type Nonserialized = Empty; fn read_unconditional(config: &(), br: &mut BitReader, nonserialized: &Empty) -> Result<#name, Error> { #name::read_unconditional( &U32Coder::Select( U32::Val(0), U32::Val(1), U32::BitsOffset{n: 4, off: 2}, U32::BitsOffset{n: 6, off: 18}), br, nonserialized) } } } } #[proc_macro_error] #[proc_macro_derive( UnconditionalCoder, attributes( coder, condition, default, default_element, all_default, select_coder, coder_true, coder_false, validate, size_coder, nonserialized, aligned, ) )] pub fn derive_jxl_headers(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); match &input.data { syn::Data::Struct(_) => derive_struct(&input).into(), syn::Data::Enum(_) => derive_enum(&input).into(), _ => abort!(input, "Only implemented for struct"), } } #[proc_macro_attribute] pub fn noop(_attr: TokenStream, item: TokenStream) -> TokenStream { item } #[cfg(feature = "test")] #[proc_macro] pub fn for_each_test_file(input: TokenStream) -> TokenStream { use std::{fs, path::Path}; use syn::Ident; let fn_name = parse_macro_input!(input as Ident); let root_test_dir = Path::new(env!("CARGO_MANIFEST_DIR")) .join("..") .join("jxl") .join("resources") .join("test"); let conformance_test_dir = root_test_dir.join("conformance_test_images"); let mut tests = vec![]; for test_dir in [root_test_dir, conformance_test_dir] { for entry in fs::read_dir(&test_dir).unwrap() { let entry = entry.unwrap(); let path = entry.path(); if path.extension().is_some_and(|ext| ext == "jxl") { let pathname = path.to_string_lossy(); let relative_path = path .strip_prefix(&test_dir) .unwrap() .to_string_lossy() .replace('/', "_slash_"); let test_name = format!( "{}_{}", fn_name, relative_path.strip_suffix(".jxl").unwrap() ); let test_name = Ident::new(&test_name, fn_name.span()); tests.push(quote! { #[test] fn #test_name() { #fn_name(&Path::new(#pathname)).unwrap() } }); } } } quote! { #(#tests)* } .into() }