use alloc::{sync::Arc, vec::Vec}; use core::{ cmp::max, num::NonZeroU64, ops::{Deref, Range}, }; use wgt::{math::align_to, BufferUsages, BufferUses, Features}; use crate::{ command::encoder::EncodingState, ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError}, resource::InvalidResourceError, }; use crate::{command::EncoderStateError, device::resource::CommandIndices}; use crate::{ command::{ArcCommand, ArcReferences, CommandBufferMutable}, device::queue::TempResource, global::Global, id::CommandEncoderId, init_tracker::MemoryInitKind, ray_tracing::{ ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance, ArcTlasPackage, BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError, OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage, }, resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas}, scratch::ScratchBuffer, snatch::SnatchGuard, }; use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess}; use crate::id::{BlasId, TlasId}; struct BlasStore<'a> { blas: Arc, entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>, scratch_buffer_offset: u64, } struct UnsafeTlasStore<'a> { tlas: Arc, entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>, scratch_buffer_offset: u64, } struct TlasStore<'a> { internal: UnsafeTlasStore<'a>, range: Range, } impl Global { fn resolve_blas_id(&self, blas_id: BlasId) -> Result, InvalidResourceError> { self.hub.blas_s.get(blas_id).get() } fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result, InvalidResourceError> { self.hub.tlas_s.get(tlas_id).get() } pub fn command_encoder_mark_acceleration_structures_built( &self, command_encoder_id: CommandEncoderId, blas_ids: &[BlasId], tlas_ids: &[TlasId], ) -> Result<(), EncoderStateError> { profiling::scope!("CommandEncoder::mark_acceleration_structures_built"); let hub = &self.hub; let cmd_enc = hub.command_encoders.get(command_encoder_id); let mut cmd_buf_data = cmd_enc.data.lock(); cmd_buf_data.with_buffer( crate::command::EncodingApi::Raw, |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> { let device = &cmd_enc.device; device.check_is_valid()?; device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?; let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len()); for blas in blas_ids { let blas = hub.blas_s.get(*blas).get()?; build_command.blas_s_built.push(blas); } for tlas in tlas_ids { let tlas = hub.tlas_s.get(*tlas).get()?; build_command.tlas_s_built.push(TlasBuild { tlas, dependencies: Vec::new(), }); } cmd_buf_data.as_actions.push(AsAction::Build(build_command)); Ok(()) }, ) } pub fn command_encoder_build_acceleration_structures<'a>( &self, command_encoder_id: CommandEncoderId, blas_iter: impl Iterator>, tlas_iter: impl Iterator>, ) -> Result<(), EncoderStateError> { profiling::scope!("CommandEncoder::build_acceleration_structures"); let hub = &self.hub; let cmd_enc = hub.command_encoders.get(command_encoder_id); let mut cmd_buf_data = cmd_enc.data.lock(); cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> { let blas = blas_iter .map(|blas_entry| { let geometries = match blas_entry.geometries { BlasGeometries::TriangleGeometries(triangle_geometries) => { let tri_geo = triangle_geometries .map(|tg| { Ok(ArcBlasTriangleGeometry { size: tg.size.clone(), vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?, index_buffer: tg .index_buffer .map(|id| self.resolve_buffer_id(id)) .transpose()?, transform_buffer: tg .transform_buffer .map(|id| self.resolve_buffer_id(id)) .transpose()?, first_vertex: tg.first_vertex, vertex_stride: tg.vertex_stride, first_index: tg.first_index, transform_buffer_offset: tg.transform_buffer_offset, }) }) .collect::>()?; ArcBlasGeometries::TriangleGeometries(tri_geo) } }; Ok(ArcBlasBuildEntry { blas: self.resolve_blas_id(blas_entry.blas_id)?, geometries, }) }) .collect::>()?; let tlas = tlas_iter .map(|tlas_package| { let instances = tlas_package .instances .map(|instance| { instance .as_ref() .map(|instance| { Ok(ArcTlasInstance { blas: self.resolve_blas_id(instance.blas_id)?, transform: *instance.transform, custom_data: instance.custom_data, mask: instance.mask, }) }) .transpose() }) .collect::>()?; Ok(ArcTlasPackage { tlas: self.resolve_tlas_id(tlas_package.tlas_id)?, instances, lowest_unmodified: tlas_package.lowest_unmodified, }) }) .collect::>()?; Ok(ArcCommand::BuildAccelerationStructures { blas, tlas }) }) } } pub(crate) fn build_acceleration_structures( state: &mut EncodingState, blas: Vec>, tlas: Vec>, ) -> Result<(), BuildAccelerationStructureError> { state .device .require_features(Features::EXPERIMENTAL_RAY_QUERY)?; let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len()); let mut input_barriers = Vec::>::new(); let mut scratch_buffer_blas_size = 0; let mut blas_storage = Vec::with_capacity(blas.len()); iter_blas( blas.iter(), &mut build_command, &mut input_barriers, &mut scratch_buffer_blas_size, &mut blas_storage, state, )?; let mut scratch_buffer_tlas_size = 0; let mut tlas_storage = Vec::::with_capacity(tlas.len()); let mut instance_buffer_staging_source = Vec::::new(); for package in tlas.iter() { let tlas = &package.tlas; state.tracker.tlas_s.insert_single(tlas.clone()); let scratch_buffer_offset = scratch_buffer_tlas_size; scratch_buffer_tlas_size += align_to( tlas.size_info.build_scratch_size as u32, state.device.alignments.ray_tracing_scratch_buffer_alignment, ) as u64; let first_byte_index = instance_buffer_staging_source.len(); let mut dependencies = Vec::new(); let mut instance_count = 0; for instance in package.instances.iter().flatten() { if instance.custom_data >= (1u32 << 24u32) { return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex( tlas.error_ident(), )); } let blas = &instance.blas; state.tracker.blas_s.insert_single(blas.clone()); instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes( hal::TlasInstance { transform: instance.transform, custom_data: instance.custom_data, mask: instance.mask, blas_address: blas.handle, }, )); if tlas .flags .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN) && !blas .flags .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN) { return Err( BuildAccelerationStructureError::TlasDependentMissingVertexReturn( tlas.error_ident(), blas.error_ident(), ), ); } instance_count += 1; dependencies.push(blas.clone()); } build_command.tlas_s_built.push(TlasBuild { tlas: tlas.clone(), dependencies, }); if instance_count > tlas.max_instance_count { return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded( tlas.error_ident(), instance_count, tlas.max_instance_count, )); } tlas_storage.push(TlasStore { internal: UnsafeTlasStore { tlas: tlas.clone(), entries: hal::AccelerationStructureEntries::Instances( hal::AccelerationStructureInstances { buffer: Some(tlas.instance_buffer.as_ref()), offset: 0, count: instance_count, }, ), scratch_buffer_offset, }, range: first_byte_index..instance_buffer_staging_source.len(), }); } let Some(scratch_size) = wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) else { // if the size is zero there is nothing to build return Ok(()); }; let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?; let scratch_buffer_barrier = hal::BufferBarrier:: { buffer: scratch_buffer.raw(), usage: hal::StateTransition { from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH, to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH, }, }; let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len()); for &TlasStore { internal: UnsafeTlasStore { ref tlas, ref entries, ref scratch_buffer_offset, }, .. } in &tlas_storage { if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented"); } tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor { entries, mode: hal::AccelerationStructureBuildMode::Build, flags: tlas.flags, source_acceleration_structure: None, destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?, scratch_buffer: scratch_buffer.raw(), scratch_buffer_offset: *scratch_buffer_offset, }) } let blas_present = !blas_storage.is_empty(); let tlas_present = !tlas_storage.is_empty(); let raw_encoder = &mut state.raw_encoder; let mut blas_s_compactable = Vec::new(); let mut descriptors = Vec::with_capacity(blas.len()); for storage in &blas_storage { descriptors.push(map_blas( storage, scratch_buffer.raw(), state.snatch_guard, &mut blas_s_compactable, )?); } build_blas( *raw_encoder, blas_present, tlas_present, input_barriers, &descriptors, scratch_buffer_barrier, blas_s_compactable, ); if tlas_present { let staging_buffer = if !instance_buffer_staging_source.is_empty() { let mut staging_buffer = StagingBuffer::new( state.device, wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(), )?; staging_buffer.write(&instance_buffer_staging_source); let flushed = staging_buffer.flush(); Some(flushed) } else { None }; unsafe { if let Some(ref staging_buffer) = staging_buffer { raw_encoder.transition_buffers(&[hal::BufferBarrier:: { buffer: staging_buffer.raw(), usage: hal::StateTransition { from: BufferUses::MAP_WRITE, to: BufferUses::COPY_SRC, }, }]); } } let mut instance_buffer_barriers = Vec::new(); for &TlasStore { internal: UnsafeTlasStore { ref tlas, .. }, ref range, } in &tlas_storage { let size = match wgt::BufferSize::new((range.end - range.start) as u64) { None => continue, Some(size) => size, }; instance_buffer_barriers.push(hal::BufferBarrier:: { buffer: tlas.instance_buffer.as_ref(), usage: hal::StateTransition { from: BufferUses::COPY_DST, to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, }, }); unsafe { raw_encoder.transition_buffers(&[hal::BufferBarrier:: { buffer: tlas.instance_buffer.as_ref(), usage: hal::StateTransition { from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, to: BufferUses::COPY_DST, }, }]); let temp = hal::BufferCopy { src_offset: range.start as u64, dst_offset: 0, size, }; raw_encoder.copy_buffer_to_buffer( staging_buffer.as_ref().unwrap().raw(), tlas.instance_buffer.as_ref(), &[temp], ); } } unsafe { raw_encoder.transition_buffers(&instance_buffer_barriers); raw_encoder.build_acceleration_structures(&tlas_descriptors); raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier { usage: hal::StateTransition { from: hal::AccelerationStructureUses::BUILD_OUTPUT, to: hal::AccelerationStructureUses::SHADER_INPUT, }, }); } if let Some(staging_buffer) = staging_buffer { state .temp_resources .push(TempResource::StagingBuffer(staging_buffer)); } } state .temp_resources .push(TempResource::ScratchBuffer(scratch_buffer)); state.as_actions.push(AsAction::Build(build_command)); Ok(()) } impl CommandBufferMutable { pub(crate) fn validate_acceleration_structure_actions( &self, snatch_guard: &SnatchGuard, command_index_guard: &mut RwLockWriteGuard, ) -> Result<(), ValidateAsActionsError> { profiling::scope!("CommandEncoder::[submission]::validate_as_actions"); for action in &self.as_actions { match action { AsAction::Build(build) => { let build_command_index = NonZeroU64::new( command_index_guard.next_acceleration_structure_build_command_index, ) .unwrap(); command_index_guard.next_acceleration_structure_build_command_index += 1; for blas in build.blas_s_built.iter() { let mut state_lock = blas.compacted_state.lock(); *state_lock = match *state_lock { BlasCompactState::Compacted => { unreachable!("Should be validated out in build.") } // Reset the compacted state to idle. This means any prepares, before mapping their // internal buffer, will terminate. _ => BlasCompactState::Idle, }; *blas.built_index.write() = Some(build_command_index); } for tlas_build in build.tlas_s_built.iter() { for blas in &tlas_build.dependencies { if blas.built_index.read().is_none() { return Err(ValidateAsActionsError::UsedUnbuiltBlas( blas.error_ident(), tlas_build.tlas.error_ident(), )); } } *tlas_build.tlas.built_index.write() = Some(build_command_index); tlas_build .tlas .dependencies .write() .clone_from(&tlas_build.dependencies) } } AsAction::UseTlas(tlas) => { let tlas_build_index = tlas.built_index.read(); let dependencies = tlas.dependencies.read(); if (*tlas_build_index).is_none() { return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident())); } for blas in dependencies.deref() { let blas_build_index = *blas.built_index.read(); if blas_build_index.is_none() { return Err(ValidateAsActionsError::UsedUnbuiltBlas( tlas.error_ident(), blas.error_ident(), )); } if blas_build_index.unwrap() > tlas_build_index.unwrap() { return Err(ValidateAsActionsError::BlasNewerThenTlas( blas.error_ident(), tlas.error_ident(), )); } blas.try_raw(snatch_guard)?; } } } } Ok(()) } pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) { profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies"); let tlas_dependencies_locks: Vec<_> = self .as_actions .iter() .filter_map(|action| { if let AsAction::UseTlas(tlas) = action { Some(tlas.dependencies.read()) } else { None } }) .collect(); let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter(); let mut dependencies = Vec::new(); for action in &self.as_actions { match action { AsAction::Build(build) => { for tlas_build in build.tlas_s_built.iter() { for dependency in &tlas_build.dependencies { if let Some(dependency) = dependency.raw(snatch_guard) { dependencies.push(dependency); } } } } AsAction::UseTlas(_tlas) => { let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); // _tlas.dependencies.read(); for dependency in tlas_dependencies.iter() { if let Some(dependency) = dependency.raw(snatch_guard) { dependencies.push(dependency); } } } } } if !dependencies.is_empty() { unsafe { self.encoder .raw .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies); } } } } ///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation). fn iter_blas<'snatch_guard: 'buffers, 'buffers>( blas_iter: impl Iterator>, build_command: &mut AsBuild, input_barriers: &mut Vec>, scratch_buffer_blas_size: &mut u64, blas_storage: &mut Vec>, state: &mut EncodingState<'snatch_guard, '_>, ) -> Result<(), BuildAccelerationStructureError> { for entry in blas_iter { let blas = &entry.blas; state.tracker.blas_s.insert_single(blas.clone()); build_command.blas_s_built.push(blas.clone()); match &entry.geometries { ArcBlasGeometries::TriangleGeometries(triangle_geometries) => { let mut triangle_entries = Vec::>::new(); for (i, mesh) in triangle_geometries.iter().enumerate() { let size_desc = match &blas.sizes { wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors, }; if i >= size_desc.len() { return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes( blas.error_ident(), )); } let size_desc = &size_desc[i]; if size_desc.flags != mesh.size.flags { return Err(BuildAccelerationStructureError::IncompatibleBlasFlags( blas.error_ident(), size_desc.flags, mesh.size.flags, )); } if size_desc.vertex_count < mesh.size.vertex_count { return Err( BuildAccelerationStructureError::IncompatibleBlasVertexCount( blas.error_ident(), size_desc.vertex_count, mesh.size.vertex_count, ), ); } if size_desc.vertex_format != mesh.size.vertex_format { return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats( blas.error_ident(), size_desc.vertex_format, mesh.size.vertex_format, )); } if size_desc .vertex_format .min_acceleration_structure_vertex_stride() > mesh.vertex_stride { return Err(BuildAccelerationStructureError::VertexStrideTooSmall( blas.error_ident(), size_desc .vertex_format .min_acceleration_structure_vertex_stride(), mesh.vertex_stride, )); } if mesh.vertex_stride % size_desc .vertex_format .acceleration_structure_stride_alignment() != 0 { return Err(BuildAccelerationStructureError::VertexStrideUnaligned( blas.error_ident(), size_desc .vertex_format .acceleration_structure_stride_alignment(), mesh.vertex_stride, )); } match (size_desc.index_count, mesh.size.index_count) { (Some(_), None) | (None, Some(_)) => { return Err( BuildAccelerationStructureError::BlasIndexCountProvidedMismatch( blas.error_ident(), ), ) } (Some(create), Some(build)) if create < build => { return Err( BuildAccelerationStructureError::IncompatibleBlasIndexCount( blas.error_ident(), create, build, ), ) } _ => {} } if size_desc.index_format != mesh.size.index_format { return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats( blas.error_ident(), size_desc.index_format, mesh.size.index_format, )); } if size_desc.index_count.is_some() && mesh.index_buffer.is_none() { return Err(BuildAccelerationStructureError::MissingIndexBuffer( blas.error_ident(), )); } let vertex_buffer = mesh.vertex_buffer.clone(); let vertex_pending = state.tracker.buffers.set_single( &vertex_buffer, BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, ); let vertex_buffer = { let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?; let vertex_buffer = &mesh.vertex_buffer; vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?; if let Some(barrier) = vertex_pending.map(|pending| { pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard) }) { input_barriers.push(barrier); } if vertex_buffer.size < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride { return Err(BuildAccelerationStructureError::InsufficientBufferSize( vertex_buffer.error_ident(), vertex_buffer.size, (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride, )); } let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride; state.buffer_memory_init_actions.extend( vertex_buffer.initialization_status.read().create_action( vertex_buffer, vertex_buffer_offset ..(vertex_buffer_offset + mesh.size.vertex_count as u64 * mesh.vertex_stride), MemoryInitKind::NeedsInitializedMemory, ), ); vertex_raw }; let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer { if mesh.first_index.is_none() || mesh.size.index_count.is_none() || mesh.size.index_count.is_none() { return Err(BuildAccelerationStructureError::MissingAssociatedData( index_buffer.error_ident(), )); } let index_pending = state.tracker.buffers.set_single( index_buffer, BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, ); let index_raw = index_buffer.try_raw(state.snatch_guard)?; index_buffer.check_usage(BufferUsages::BLAS_INPUT)?; if let Some(barrier) = index_pending.map(|pending| { pending.into_hal(index_buffer.as_ref(), state.snatch_guard) }) { input_barriers.push(barrier); } let index_stride = mesh.size.index_format.unwrap().byte_size() as u64; let offset = mesh.first_index.unwrap() as u64 * index_stride; let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride; if mesh.size.index_count.unwrap() % 3 != 0 { return Err(BuildAccelerationStructureError::InvalidIndexCount( index_buffer.error_ident(), mesh.size.index_count.unwrap(), )); } if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset { return Err(BuildAccelerationStructureError::InsufficientBufferSize( index_buffer.error_ident(), index_buffer.size, mesh.size.index_count.unwrap() as u64 * index_stride + offset, )); } state.buffer_memory_init_actions.extend( index_buffer.initialization_status.read().create_action( index_buffer, offset..(offset + index_buffer_size), MemoryInitKind::NeedsInitializedMemory, ), ); Some(index_raw) } else { None }; let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer { if !blas .flags .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM) { return Err(BuildAccelerationStructureError::UseTransformMissing( blas.error_ident(), )); } if mesh.transform_buffer_offset.is_none() { return Err(BuildAccelerationStructureError::MissingAssociatedData( transform_buffer.error_ident(), )); } let transform_pending = state.tracker.buffers.set_single( transform_buffer, BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, ); if mesh.transform_buffer_offset.is_none() { return Err(BuildAccelerationStructureError::MissingAssociatedData( transform_buffer.error_ident(), )); } let transform_raw = transform_buffer.try_raw(state.snatch_guard)?; transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?; if let Some(barrier) = transform_pending.map(|pending| { pending.into_hal(transform_buffer.as_ref(), state.snatch_guard) }) { input_barriers.push(barrier); } let offset = mesh.transform_buffer_offset.unwrap(); if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 { return Err( BuildAccelerationStructureError::UnalignedTransformBufferOffset( transform_buffer.error_ident(), ), ); } if transform_buffer.size < 48 + offset { return Err(BuildAccelerationStructureError::InsufficientBufferSize( transform_buffer.error_ident(), transform_buffer.size, 48 + offset, )); } state.buffer_memory_init_actions.extend( transform_buffer.initialization_status.read().create_action( transform_buffer, offset..(offset + 48), MemoryInitKind::NeedsInitializedMemory, ), ); Some(transform_raw) } else { if blas .flags .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM) { return Err(BuildAccelerationStructureError::TransformMissing( blas.error_ident(), )); } None }; let triangles = hal::AccelerationStructureTriangles { vertex_buffer: Some(vertex_buffer), vertex_format: mesh.size.vertex_format, first_vertex: mesh.first_vertex, vertex_count: mesh.size.vertex_count, vertex_stride: mesh.vertex_stride, indices: index_buffer.map(|index_buffer| { let index_stride = mesh.size.index_format.unwrap().byte_size() as u32; hal::AccelerationStructureTriangleIndices:: { format: mesh.size.index_format.unwrap(), buffer: Some(index_buffer), offset: mesh.first_index.unwrap() * index_stride, count: mesh.size.index_count.unwrap(), } }), transform: transform_buffer.map(|transform_buffer| { hal::AccelerationStructureTriangleTransform { buffer: transform_buffer, offset: mesh.transform_buffer_offset.unwrap() as u32, } }), flags: mesh.size.flags, }; triangle_entries.push(triangles); } { let scratch_buffer_offset = *scratch_buffer_blas_size; *scratch_buffer_blas_size += align_to( blas.size_info.build_scratch_size as u32, state.device.alignments.ray_tracing_scratch_buffer_alignment, ) as u64; blas_storage.push(BlasStore { blas: blas.clone(), entries: hal::AccelerationStructureEntries::Triangles(triangle_entries), scratch_buffer_offset, }); } } } } Ok(()) } fn map_blas<'a>( storage: &'a BlasStore<'_>, scratch_buffer: &'a dyn hal::DynBuffer, snatch_guard: &'a SnatchGuard, blases_compactable: &mut Vec<( &'a dyn hal::DynBuffer, &'a dyn hal::DynAccelerationStructure, )>, ) -> Result< hal::BuildAccelerationStructureDescriptor< 'a, dyn hal::DynBuffer, dyn hal::DynAccelerationStructure, >, BuildAccelerationStructureError, > { let BlasStore { blas, entries, scratch_buffer_offset, } = storage; if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { log::debug!("only rebuild implemented") } let raw = blas.try_raw(snatch_guard)?; let state_lock = blas.compacted_state.lock(); if let BlasCompactState::Compacted = *state_lock { return Err(BuildAccelerationStructureError::CompactedBlas( blas.error_ident(), )); } if blas .flags .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION) { blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw)); } Ok(hal::BuildAccelerationStructureDescriptor { entries, mode: hal::AccelerationStructureBuildMode::Build, flags: blas.flags, source_acceleration_structure: None, destination_acceleration_structure: raw, scratch_buffer, scratch_buffer_offset: *scratch_buffer_offset, }) } fn build_blas<'a>( cmd_buf_raw: &mut dyn hal::DynCommandEncoder, blas_present: bool, tlas_present: bool, input_barriers: Vec>, blas_descriptors: &[hal::BuildAccelerationStructureDescriptor< 'a, dyn hal::DynBuffer, dyn hal::DynAccelerationStructure, >], scratch_buffer_barrier: hal::BufferBarrier, blas_s_for_compaction: Vec<( &'a dyn hal::DynBuffer, &'a dyn hal::DynAccelerationStructure, )>, ) { unsafe { cmd_buf_raw.transition_buffers(&input_barriers); } if blas_present { unsafe { cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier { usage: hal::StateTransition { from: hal::AccelerationStructureUses::BUILD_INPUT, to: hal::AccelerationStructureUses::BUILD_OUTPUT, }, }); cmd_buf_raw.build_acceleration_structures(blas_descriptors); } } if blas_present && tlas_present { unsafe { cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]); } } let mut source_usage = hal::AccelerationStructureUses::empty(); let mut destination_usage = hal::AccelerationStructureUses::empty(); for &(buf, blas) in blas_s_for_compaction.iter() { unsafe { cmd_buf_raw.transition_buffers(&[hal::BufferBarrier { buffer: buf, usage: hal::StateTransition { from: BufferUses::ACCELERATION_STRUCTURE_QUERY, to: BufferUses::ACCELERATION_STRUCTURE_QUERY, }, }]) } unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) } destination_usage |= hal::AccelerationStructureUses::COPY_SRC; } if blas_present { source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT } if tlas_present { source_usage |= hal::AccelerationStructureUses::SHADER_INPUT; destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; } unsafe { cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier { usage: hal::StateTransition { from: source_usage, to: destination_usage, }, }); } }