use alloc::{ format, string::{String, ToString}, vec::Vec, }; use crate::{ back::{ self, msl::{ writer::{TypeContext, TypedGlobalVariable}, BackendResult, EntryPointArgument, Error, NAMESPACE, WRAPPED_ARRAY_FIELD, }, }, proc::NameKey, }; pub(super) struct MeshOutputInfo { out_vertex_ty_name: String, out_primitive_ty_name: String, out_vertex_member_names: Vec>, out_primitive_member_names: Vec>, } pub(super) struct NestedFunctionInfo<'a> { pub(super) options: &'a super::Options, pub(super) ep: &'a crate::EntryPoint, pub(super) module: &'a crate::Module, pub(super) mod_info: &'a crate::valid::ModuleInfo, pub(super) fun_info: &'a crate::valid::FunctionInfo, pub(super) args: Vec, pub(super) local_invocation_index: Option<&'a NameKey>, pub(super) nested_name: &'a str, pub(super) outer_name: &'a str, pub(super) out_mesh_info: Option, } impl super::Writer { /// This writes the output vertex and primitive structs given the reflection information about them. pub(super) fn write_mesh_output_types( &mut self, mesh_info: &crate::MeshStageInfo, fun_name: &str, module: &crate::Module, // See `PipelineOptions::allow_and_force_point_size` allow_and_force_point_size: bool, options: &super::Options, ) -> Result { let mut vertex_member_names = Vec::new(); let mut primitive_member_names = Vec::new(); let vertex_out_name = self.namer.call(&format!("{fun_name}VertexOutput")); let primitive_out_name = self.namer.call(&format!("{fun_name}PrimitiveOutput")); let mut existing_names = Vec::new(); for (out_name, struct_ty, is_primitive, member_names) in [ ( &vertex_out_name, mesh_info.vertex_output_type, false, &mut vertex_member_names, ), ( &primitive_out_name, mesh_info.primitive_output_type, true, &mut primitive_member_names, ), ] { writeln!(self.out, "struct {out_name} {{")?; // Mesh output types are guaranteed to be user defined structs. This is validated by naga. let crate::TypeInner::Struct { ref members, .. } = module.types[struct_ty].inner else { unreachable!() }; let mut has_point_size = false; for (index, member) in members.iter().enumerate() { member_names.push(None); let ty_name = TypeContext { handle: member.ty, gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), first_time: true, }; let binding = member .binding .clone() .ok_or_else(|| Error::GenericValidation("Expected binding, got None".into()))?; if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = binding { has_point_size = true; if !allow_and_force_point_size { continue; } } if let crate::Binding::BuiltIn( crate::BuiltIn::PointIndex | crate::BuiltIn::LineIndices | crate::BuiltIn::TriangleIndices, ) = binding { continue; } // Names of struct members must be unique across vertex and primitive output. // Therefore, when writing the primitive output struct, we might need to rename some fields. let mut name = self.names[&NameKey::StructMember(struct_ty, index as u32)].clone(); if existing_names.contains(&name) { name = self.namer.call(&name); } else { // Let the namer know this is illegal to use again let _ = self.namer.call(&name); } let array_len = match module.types[member.ty].inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size), .. } => Some(size), _ => None, }; let resolved = options.resolve_local_binding(&binding, back::msl::LocationMode::MeshOutput)?; write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; if let Some(array_len) = array_len { write!(self.out, " [{array_len}]")?; } resolved.try_fmt(&mut self.out)?; writeln!(self.out, ";")?; *member_names.last_mut().unwrap() = Some(name.clone()); existing_names.push(name); } if allow_and_force_point_size && !has_point_size && !is_primitive { // inject the point size output last writeln!( self.out, "{}float _point_size [[point_size]];", back::INDENT )?; } writeln!(self.out, "}};")?; } Ok(MeshOutputInfo { out_vertex_ty_name: vertex_out_name, out_primitive_ty_name: primitive_out_name, out_vertex_member_names: vertex_member_names, out_primitive_member_names: primitive_member_names, }) } pub(super) fn write_wrapper_function(&mut self, info: NestedFunctionInfo<'_>) -> BackendResult { let NestedFunctionInfo { options, ep, module, mod_info, fun_info, args, local_invocation_index: local_invocation_index_key, nested_name, outer_name, out_mesh_info, } = info; let indent = back::INDENT; let em_str = match ep.stage { crate::ShaderStage::Mesh => "[[mesh]]", crate::ShaderStage::Task => "[[object]]", _ => unreachable!(), }; writeln!(self.out, "{em_str} void {outer_name}(")?; // Arguments let mut mesh_out_name: Option = None; let mut mesh_variable_name = None; let mut task_grid_name = None; if let Some(ref info) = ep.mesh_info { let mesh_out = out_mesh_info.as_ref().unwrap(); let mesh_name = self.namer.call("meshOutput"); let topology_name = match info.topology { crate::MeshOutputTopology::Points => "point", crate::MeshOutputTopology::Lines => "line", crate::MeshOutputTopology::Triangles => "triangle", }; let num_verts = info.max_vertices; let num_prims = info.max_primitives; writeln!(self.out, " {NAMESPACE}::mesh<{}, {}, {num_verts}, {num_prims}, metal::topology::{topology_name}> {mesh_name}", mesh_out.out_vertex_ty_name, mesh_out.out_primitive_ty_name, )?; mesh_out_name = Some(mesh_name); mesh_variable_name = Some( self.names [&NameKey::GlobalVariable(ep.mesh_info.as_ref().unwrap().output_variable)] .clone(), ); } else if ep.stage == crate::ShaderStage::Task { let grid_name = self.namer.call("nagaMeshGrid"); writeln!(self.out, " {NAMESPACE}::mesh_grid_properties {grid_name}")?; task_grid_name = Some(grid_name); } let local_invocation_index = if let Some(key) = local_invocation_index_key { self.names[key].clone() } else { "__local_invocation_index".to_string() }; for arg in &args { write!(self.out, ", {} {}{}", arg.ty_name, arg.name, arg.binding)?; if let Some(init) = arg.init { write!(self.out, " = ")?; self.put_const_expression(init, module, mod_info, &module.global_expressions)?; } writeln!(self.out)?; } writeln!(self.out, ") {{")?; // Function body if ep.stage == crate::ShaderStage::Mesh { for (handle, var) in module.global_variables.iter() { if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() { continue; } let tyvar = TypedGlobalVariable { module, names: &self.names, handle, usage: crate::valid::GlobalUse::WRITE | crate::valid::GlobalUse::READ, reference: false, }; write!(self.out, "{}", back::INDENT)?; tyvar.try_fmt(&mut self.out)?; writeln!(self.out, ";")?; } } write!(self.out, "{indent}")?; let result_name = if ep.stage == crate::ShaderStage::Task { let name = self.namer.call("nagaGridSize"); write!(self.out, "uint3 {} = ", name)?; Some(name) } else { None }; write!(self.out, "{nested_name}(")?; { let mut is_first = true; for arg in &args { if !is_first { write!(self.out, ", ")?; } is_first = false; write!(self.out, "{}", arg.name)?; } if ep.stage == crate::ShaderStage::Mesh { for (handle, var) in module.global_variables.iter() { if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() { continue; } if !is_first { write!(self.out, ", ")?; } let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, "{name}")?; } } } writeln!(self.out, ");")?; self.write_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?; if let Some(grid_name) = task_grid_name { let result_name = result_name.unwrap(); writeln!(self.out, "{indent}if ({local_invocation_index} == 0u) {{")?; { let level2 = back::Level(2); if let Some(limits) = options.task_dispatch_limits { let level3 = back::Level(3); let max_per_dim = limits.max_mesh_workgroups_per_dim; let max_total = limits.max_mesh_workgroups_total; writeln!(self.out, "{level2}if (")?; writeln!(self.out, "{level3}{result_name}.x > {max_per_dim}u ||")?; writeln!(self.out, "{level3}{result_name}.y > {max_per_dim}u ||")?; writeln!(self.out, "{level3}{result_name}.z > {max_per_dim}u ||")?; writeln!( self.out, "{level3}{NAMESPACE}::mulhi({result_name}.x, {result_name}.y) != 0u ||" )?; writeln!( self.out, "{level3}{NAMESPACE}::mulhi({result_name}.x * {result_name}.y, {result_name}.z) != 0u ||" )?; writeln!(self.out, "{level3}({result_name}.x * {result_name}.y * {result_name}.z) > {max_total}u")?; writeln!(self.out, "{level2}) {{")?; writeln!(self.out, "{level3}{result_name} = {NAMESPACE}::uint3(0u);")?; writeln!(self.out, "{level2}}}")?; } writeln!( self.out, "{level2}{grid_name}.set_threadgroups_per_grid({result_name});" )?; } writeln!(self.out, "{indent}}}")?; writeln!(self.out, "{indent}return;")?; } else if let Some(ref info) = ep.mesh_info { let mesh_out = out_mesh_info.as_ref().unwrap(); let out_ty = module.global_variables[info.output_variable].ty; let mesh_out_name = mesh_out_name.unwrap(); let mesh_variable_name = mesh_variable_name.unwrap(); // The output type is guaranteed to be a struct with exactly 4 members let crate::TypeInner::Struct { ref members, .. } = module.types[out_ty].inner else { unreachable!(); }; let get_out_value = |bi| { let member_idx = members .iter() .position(|a| a.binding == Some(crate::Binding::BuiltIn(bi))) .unwrap() as u32; format!( "{}.{}", mesh_variable_name, self.names[&NameKey::StructMember(out_ty, member_idx)] ) }; let vert_count = format!( "{NAMESPACE}::min({}, {}u)", get_out_value(crate::BuiltIn::VertexCount), info.max_vertices ); let prim_count = format!( "{NAMESPACE}::min({}, {}u)", get_out_value(crate::BuiltIn::PrimitiveCount), info.max_primitives ); let workgroup_size: u32 = ep.workgroup_size.iter().product(); { let vert_index = self.namer.call("vertexIndex"); let in_array = get_out_value(crate::BuiltIn::Vertices); writeln!( self.out, "{indent}for(uint {vert_index} = {local_invocation_index}; {vert_index} < {vert_count}; {vert_index} += {workgroup_size}) {{" )?; let out_vert = self.namer.call("vertex"); writeln!( self.out, "{indent}{indent}{} {out_vert};", mesh_out.out_vertex_ty_name, )?; for (member_idx, new_name) in mesh_out.out_vertex_member_names.iter().enumerate() { let in_value = format!( "{in_array}.{WRAPPED_ARRAY_FIELD}[{vert_index}].{}", self.names [&NameKey::StructMember(info.vertex_output_type, member_idx as u32)] ); let out_value = format!("{out_vert}.{}", new_name.as_ref().unwrap()); writeln!(self.out, "{indent}{indent}{out_value} = {in_value};")?; } writeln!( self.out, "{indent}{indent}{}.set_vertex({vert_index}, {out_vert});", mesh_out_name )?; writeln!(self.out, "{indent}}}")?; } { let prim_index = self.namer.call("primitiveIndex"); let in_array = get_out_value(crate::BuiltIn::Primitives); writeln!( self.out, "{indent}for(uint {prim_index} = {local_invocation_index}; {prim_index} < {prim_count}; {prim_index} += {workgroup_size}) {{" )?; let out_prim = self.namer.call("primitive"); writeln!( self.out, "{indent}{indent}{} {out_prim};", mesh_out.out_primitive_ty_name )?; for (member_idx, new_name) in mesh_out.out_primitive_member_names.iter().enumerate() { let in_value = format!( "{in_array}.{WRAPPED_ARRAY_FIELD}[{prim_index}].{}", self.names [&NameKey::StructMember(info.primitive_output_type, member_idx as u32)] ); if let Some(new_name) = new_name.as_ref() { let out_value = format!("{out_prim}.{new_name}"); writeln!( self.out, "{indent}{}{out_value} = {in_value};", back::INDENT )?; } else { let num_indices = match info.topology { crate::MeshOutputTopology::Points => 1, crate::MeshOutputTopology::Lines => 2, crate::MeshOutputTopology::Triangles => 3, }; for i in 0..num_indices { let component = if num_indices == 1 { "".to_string() } else { format!(".{}", back::COMPONENTS[i]) }; writeln!( self.out, "{indent}{}{}.set_index({prim_index} * {num_indices} + {i}, {in_value}{component});", back::INDENT, mesh_out_name, )?; } } } writeln!( self.out, "{indent}{}{}.set_primitive({prim_index}, {out_prim});", back::INDENT, mesh_out_name )?; writeln!(self.out, "{indent}}}")?; } writeln!(self.out, "{indent}if ({local_invocation_index} == 0u) {{")?; writeln!( self.out, "{indent}{indent}{}.set_primitive_count({prim_count});", mesh_out_name, )?; writeln!(self.out, "{indent}}}")?; } else { // Must either have task output grid (task shader) or mesh output info (mesh shader) unreachable!() } writeln!(self.out, "}}")?; Ok(()) } }