use core::fmt; use alloc::{ format, string::{String, ToString}, vec::Vec, }; use crate::{ back::{ self, hlsl::{ writer::{EntryPointBinding, EpStructMember, Io, NestedEntryPointArgs}, BackendResult, Error, }, }, proc::NameKey, Handle, Module, ShaderStage, TypeInner, }; impl NestedEntryPointArgs { pub fn write_call_args(&self, out: &mut impl fmt::Write) -> fmt::Result { let all_args = self .user_args .iter() .map(String::as_str) .chain(self.task_payload.as_deref()) .chain(core::iter::once(self.local_invocation_index.as_str())); for (i, arg) in all_args.enumerate() { if i != 0 { write!(out, ", ")?; } write!(out, "{arg}")?; } Ok(()) } } impl super::Writer<'_, W> { #[expect(clippy::too_many_arguments)] fn write_mesh_shader_wrapper( &mut self, module: &Module, func_ctx: &back::FunctionCtx, need_workgroup_variables_initialization: bool, nested_name: &str, entry_point: &crate::EntryPoint, args: NestedEntryPointArgs, mut separator_if_needed: impl FnMut() -> &'static str, ) -> BackendResult { let Some(ref mesh_info) = entry_point.mesh_info else { unreachable!() }; let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else { unreachable!() }; // Mesh shader wrapper let mesh_interface = self.entry_point_io.get(&(ep_index as usize)).unwrap(); let vert_info = mesh_interface.mesh_vertices.as_ref().unwrap(); let prim_info = mesh_interface.mesh_primitives.as_ref().unwrap(); let indices_info = mesh_interface.mesh_indices.as_ref().unwrap(); // Write something of the form `out indices uint3 indices_var[num_primitives]` write!( self.out, "{}out indices {} {}[{}]", separator_if_needed(), indices_info.ty_name, indices_info.arg_name, mesh_info.max_primitives )?; // Write something of the form `out vertices VertexType vertices_var[num_vertices]` write!( self.out, ", out vertices {} {}[{}]", vert_info.ty_name, vert_info.arg_name, mesh_info.max_vertices )?; // Write something of the form `out primitives PrimitiveType} primitives_var[num_primitives]` write!( self.out, ", out primitives {} {}[{}]", prim_info.ty_name, prim_info.arg_name, mesh_info.max_primitives )?; if let Some(task_payload) = entry_point.task_payload { // Write the outer-function `in payload` arg. The name is already in // args.task_payload, having been collected when the inner function // signature was written in write_function (writer.rs). write!(self.out, ", in payload ")?; let var = &module.global_variables[task_payload]; self.write_type(module, var.ty)?; let name = &self.names[&NameKey::GlobalVariable(task_payload)]; write!(self.out, " {name}")?; if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner { self.write_array_size(module, base, size)?; } } writeln!(self.out, ") {{")?; if need_workgroup_variables_initialization { writeln!( self.out, "{}if ({} == 0) {{", back::INDENT, args.local_invocation_index, )?; self.write_workgroup_variables_initialization( func_ctx, module, module.entry_points[ep_index as usize].stage, )?; writeln!(self.out, "{}}}", back::INDENT)?; self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?; } write!(self.out, "{}{nested_name}(", back::INDENT)?; args.write_call_args(&mut self.out)?; writeln!(self.out, ");")?; writeln!( self.out, "{}GroupMemoryBarrierWithGroupSync();", back::INDENT )?; let ep = &module.entry_points[ep_index as usize]; let mesh_info = ep.mesh_info.as_ref().unwrap(); let io = self.entry_point_io.get(&(ep_index as usize)).unwrap(); let var_name = &self.names[&NameKey::GlobalVariable(mesh_info.output_variable)]; let var_type = module.global_variables[mesh_info.output_variable].ty; let wg_size: u32 = ep.workgroup_size.iter().product(); let get_var_member_name = |bi, var_type| { // The mesh shader output type must be a struct with exactly 4 members. let TypeInner::Struct { ref members, .. } = module.types[var_type].inner else { unreachable!() }; let idx = members .iter() .position(|f| f.binding == Some(crate::Binding::BuiltIn(bi))) .unwrap(); self.names[&NameKey::StructMember(var_type, idx as u32)].clone() }; let vert_count = format!( "{var_name}.{}", get_var_member_name(crate::BuiltIn::VertexCount, var_type), ); let prim_count = format!( "{var_name}.{}", get_var_member_name(crate::BuiltIn::PrimitiveCount, var_type), ); let level = back::Level(1); writeln!( self.out, "{level}SetMeshOutputCounts({vert_count}, {prim_count});" )?; // We need separate loops for vertices and primitives writing struct OutputArray<'a> { array_bi: crate::BuiltIn, count: String, io_interface: &'a EntryPointBinding, is_primitive: bool, index_name: &'static str, ty: Handle, } let output_arrays = [ OutputArray { array_bi: crate::BuiltIn::Vertices, count: vert_count, io_interface: io.mesh_vertices.as_ref().unwrap(), is_primitive: false, index_name: "vertIndex", ty: mesh_info.vertex_output_type, }, OutputArray { array_bi: crate::BuiltIn::Primitives, count: prim_count, io_interface: io.mesh_primitives.as_ref().unwrap(), is_primitive: true, index_name: "primIndex", ty: mesh_info.primitive_output_type, }, ]; for output in output_arrays { let OutputArray { array_bi, count, io_interface, is_primitive, index_name, ty, } = output; let out_var_name = &io_interface.arg_name; let index_name = self.namer.call(index_name); let array_name = get_var_member_name(array_bi, var_type); let item_name = format!("{var_name}.{array_name}[{index_name}]"); writeln!( self.out, "{level}for (int {index_name} = {}; {index_name} < {count}; {index_name} += {}) {{", args.local_invocation_index, wg_size )?; // Loop body, uses more indentation { let level = level.next(); for member in &io_interface.members { let out_member_name = &member.name; let in_member_name = &self.names[&NameKey::StructMember(ty, member.index)]; writeln!(self.out, "{level}{out_var_name}[{index_name}].{out_member_name} = {item_name}.{in_member_name};",)?; } if is_primitive { let indices_member_name = get_var_member_name( mesh_info.topology.to_builtin(), mesh_info.primitive_output_type, ); let indices_var_name = &io.mesh_indices.as_ref().unwrap().arg_name; writeln!( self.out, "{level}{indices_var_name}[{index_name}] = {item_name}.{indices_member_name};", )?; } } writeln!(self.out, "{level}}}")?; } Ok(()) } fn write_task_shader_wrapper( &mut self, module: &Module, func_ctx: &back::FunctionCtx, need_workgroup_variables_initialization: bool, nested_name: &str, entry_point: &crate::EntryPoint, args: NestedEntryPointArgs, ) -> BackendResult { let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else { unreachable!() }; // Task shader wrapper writeln!(self.out, ") {{")?; if need_workgroup_variables_initialization { writeln!( self.out, "{}if ({} == 0) {{", back::INDENT, args.local_invocation_index, )?; self.write_workgroup_variables_initialization( func_ctx, module, module.entry_points[ep_index as usize].stage, )?; writeln!(self.out, "{}}}", back::INDENT)?; self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?; } let grid_size = self.namer.call("gridSize"); write!( self.out, "{}uint3 {grid_size} = {nested_name}(", back::INDENT )?; args.write_call_args(&mut self.out)?; writeln!(self.out, ");")?; writeln!( self.out, "{}GroupMemoryBarrierWithGroupSync();", back::INDENT )?; if let Some(limits) = self.options.task_dispatch_limits { let level = back::Level(2); writeln!(self.out, "{}if (", back::INDENT)?; let max_per_dim = limits.max_mesh_workgroups_per_dim.min(2 << 21); let max_total = limits.max_mesh_workgroups_total; for i in 0..3 { writeln!( self.out, "{level}{grid_size}.{} > {max_per_dim} ||", back::COMPONENTS[i], )?; } writeln!( self.out, "{level}((uint64_t){grid_size}.x) * ((uint64_t){grid_size}.y) > 0xffffffffull ||" )?; writeln!( self.out, "{level}((uint64_t){grid_size}.x) * ((uint64_t){grid_size}.y) * ((uint64_t){grid_size}.z) > {max_total}", )?; writeln!(self.out, "{}) {{", back::INDENT)?; writeln!(self.out, "{level}{grid_size} = uint3(0, 0, 0);")?; writeln!(self.out, "{}}}", back::INDENT)?; } writeln!( self.out, "{}DispatchMesh({grid_size}.x, {grid_size}.y, {grid_size}.z, {});", back::INDENT, self.names[&NameKey::GlobalVariable(entry_point.task_payload.unwrap())] )?; Ok(()) } /// Mesh and task entry points must all return at the same `return` statement, /// so we have a nested function that can return wherever. This writes the caller, /// or the actual entry point. #[expect(clippy::too_many_arguments)] pub(super) fn write_nested_function_outer( &mut self, module: &Module, func_ctx: &back::FunctionCtx, header: &str, name: &str, need_workgroup_variables_initialization: bool, nested_name: &str, entry_point: &crate::EntryPoint, // Built in write_function alongside the inner function signature, so the // call-site argument order is guaranteed to match the declaration order. args: NestedEntryPointArgs, ) -> BackendResult { let mut any_args_written = false; let mut separator_if_needed = || { if any_args_written { ", " } else { any_args_written = true; "" } }; let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else { unreachable!(); }; let stage = module.entry_points[ep_index as usize].stage; write!(self.out, "{header}")?; write!(self.out, "void {name}(")?; // Write the outer function's argument list with full type annotations and // semantics. Arg names come from self.names and are the same names that // were collected into `args` when writing the inner function signature. if let Some(ref ep_input) = self.entry_point_io.get(&(ep_index as usize)).unwrap().input { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?; } else { for (index, arg) in entry_point.function.arguments.iter().enumerate() { write!(self.out, "{}", separator_if_needed())?; self.write_type(module, arg.ty)?; let argument_name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; write!(self.out, " {argument_name}")?; if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner { self.write_array_size(module, base, size)?; } self.write_semantic(&arg.binding, Some((stage, Io::Input)))?; } } if need_workgroup_variables_initialization || stage == ShaderStage::Mesh { write!( self.out, "{}uint {} : SV_GroupIndex", separator_if_needed(), args.local_invocation_index, )?; } if entry_point.stage == ShaderStage::Mesh { self.write_mesh_shader_wrapper( module, func_ctx, need_workgroup_variables_initialization, nested_name, entry_point, args, separator_if_needed, )?; } else { self.write_task_shader_wrapper( module, func_ctx, need_workgroup_variables_initialization, nested_name, entry_point, args, )?; } writeln!(self.out, "}}")?; Ok(()) } pub(super) fn write_ep_mesh_output_struct( &mut self, module: &Module, entry_point_name: &str, is_primitive: bool, mesh_info: &crate::MeshStageInfo, ) -> Result { let (in_type, io, var_prefix, arg_name) = if is_primitive { ( mesh_info.primitive_output_type, Io::MeshPrimitives, "Primitive", "primitives", ) } else { ( mesh_info.vertex_output_type, Io::MeshVertices, "Vertex", "vertices", ) }; let struct_name = format!("Mesh{var_prefix}Output_{entry_point_name}",); // Mesh shader output types must be structs; this is validated by naga let members = match module.types[in_type].inner { TypeInner::Struct { ref members, .. } => members, _ => unreachable!(), }; let mut out_members = Vec::new(); for (index, member) in members.iter().enumerate() { if matches!( member.binding, Some(crate::Binding::BuiltIn( crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices )) ) { continue; } let member_name = self.namer.call_or(&member.name, "member"); out_members.push(EpStructMember { name: member_name, ty: member.ty, binding: member.binding.clone(), index: index as u32, }) } self.write_interface_struct( module, (ShaderStage::Mesh, io), struct_name, Some(arg_name), out_members, ) } pub(super) fn write_ep_mesh_output_indices( &mut self, topology: crate::MeshOutputTopology, ) -> Result { let (indices_name, indices_type) = match topology { // Points require a capability that isn't supported in the HLSL writer crate::MeshOutputTopology::Points => unreachable!(), crate::MeshOutputTopology::Lines => (self.namer.call("lineIndices"), "uint2"), crate::MeshOutputTopology::Triangles => (self.namer.call("triangleIndices"), "uint3"), }; Ok(EntryPointBinding { ty_name: indices_type.to_string(), arg_name: indices_name, members: Vec::new(), local_invocation_index_name: None, }) } }