Fix incorrect command buffer used to restore state after AS build

This commit is contained in:
baldurk
2024-09-02 17:25:25 +01:00
parent 7940e86ec3
commit 08ec78aace
2 changed files with 25 additions and 24 deletions
+1 -1
View File
@@ -589,7 +589,7 @@ public:
_In_ SIZE_T ExecutionParametersDataSizeInBytes);
bool PatchAccStructBlasAddress(const D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC *accStructInput,
ID3D12GraphicsCommandList4 *list,
ID3D12GraphicsCommandList4 *unwrappedList,
BakedCmdListInfo::PatchRaytracing *patchRaytracing);
bool ProcessASBuildAfterSubmission(ResourceId asbWrappedResourceId,
@@ -795,7 +795,7 @@ bool WrappedID3D12GraphicsCommandList::ProcessASBuildAfterSubmission(ResourceId
bool WrappedID3D12GraphicsCommandList::PatchAccStructBlasAddress(
const D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC *accStructInput,
ID3D12GraphicsCommandList4 *list, BakedCmdListInfo::PatchRaytracing *patchRaytracing)
ID3D12GraphicsCommandList4 *unwrappedList, BakedCmdListInfo::PatchRaytracing *patchRaytracing)
{
if(accStructInput->Inputs.Type == D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL &&
accStructInput->Inputs.NumDescs > 0)
@@ -857,12 +857,12 @@ bool WrappedID3D12GraphicsCommandList::PatchAccStructBlasAddress(
resBarriers.push_back(resBarrier);
}
list->ResourceBarrier((UINT)resBarriers.size(), resBarriers.data());
unwrappedList->ResourceBarrier((UINT)resBarriers.size(), resBarriers.data());
}
list->CopyBufferRegion(patchRaytracing->m_patchedInstanceBuffer->Resource(),
patchRaytracing->m_patchedInstanceBuffer->Offset(), instanceResource,
instanceResOffset, totalInstancesSize);
unwrappedList->CopyBufferRegion(patchRaytracing->m_patchedInstanceBuffer->Resource(),
patchRaytracing->m_patchedInstanceBuffer->Offset(),
instanceResource, instanceResOffset, totalInstancesSize);
D3D12AccStructPatchInfo patchInfo = rtHandler->GetAccStructPatchInfo();
@@ -891,7 +891,7 @@ bool WrappedID3D12GraphicsCommandList::PatchAccStructBlasAddress(
resBarriers.push_back(resBarrier);
}
list->ResourceBarrier((UINT)resBarriers.size(), resBarriers.data());
unwrappedList->ResourceBarrier((UINT)resBarriers.size(), resBarriers.data());
}
RDCCOMPILE_ASSERT(sizeof(D3D12_RAYTRACING_INSTANCE_DESC) == sizeof(InstanceDesc),
@@ -908,7 +908,7 @@ bool WrappedID3D12GraphicsCommandList::PatchAccStructBlasAddress(
resBarrier.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV;
resBarrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
resBarrier.UAV.pResource = patchRaytracing->m_patchedInstanceBuffer->Resource();
list->ResourceBarrier(1, &resBarrier);
unwrappedList->ResourceBarrier(1, &resBarrier);
}
ID3D12Resource *addressPairRes = m_pDevice->GetBLASAddressBufferResource();
@@ -916,22 +916,23 @@ bool WrappedID3D12GraphicsCommandList::PatchAccStructBlasAddress(
uint64_t addressCount = m_pDevice->GetBLASAddressCount();
list->SetPipelineState(patchInfo.m_pipeline);
list->SetComputeRootSignature(patchInfo.m_rootSignature);
list->SetComputeRoot32BitConstant((UINT)D3D12PatchTLASBuildParam::RootConstantBuffer,
(UINT)addressCount, 0);
list->SetComputeRootShaderResourceView((UINT)D3D12PatchTLASBuildParam::RootAddressPairSrv,
addressPairResAddress);
list->SetComputeRootUnorderedAccessView((UINT)D3D12PatchTLASBuildParam::RootPatchedAddressUav,
patchRaytracing->m_patchedInstanceBuffer->Address());
list->Dispatch(accStructInput->Inputs.NumDescs, 1, 1);
unwrappedList->SetPipelineState(patchInfo.m_pipeline);
unwrappedList->SetComputeRootSignature(patchInfo.m_rootSignature);
unwrappedList->SetComputeRoot32BitConstant((UINT)D3D12PatchTLASBuildParam::RootConstantBuffer,
(UINT)addressCount, 0);
unwrappedList->SetComputeRootShaderResourceView(
(UINT)D3D12PatchTLASBuildParam::RootAddressPairSrv, addressPairResAddress);
unwrappedList->SetComputeRootUnorderedAccessView(
(UINT)D3D12PatchTLASBuildParam::RootPatchedAddressUav,
patchRaytracing->m_patchedInstanceBuffer->Address());
unwrappedList->Dispatch(accStructInput->Inputs.NumDescs, 1, 1);
{
D3D12_RESOURCE_BARRIER resBarrier;
resBarrier.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV;
resBarrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
resBarrier.UAV.pResource = patchRaytracing->m_patchedInstanceBuffer->Resource();
list->ResourceBarrier(1, &resBarrier);
unwrappedList->ResourceBarrier(1, &resBarrier);
}
{
@@ -942,7 +943,7 @@ bool WrappedID3D12GraphicsCommandList::PatchAccStructBlasAddress(
resBarrier.Transition.pResource = patchRaytracing->m_patchedInstanceBuffer->Resource();
resBarrier.Transition.StateBefore = D3D12_RESOURCE_STATE_UNORDERED_ACCESS;
resBarrier.Transition.StateAfter = D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE;
list->ResourceBarrier(1, &resBarrier);
unwrappedList->ResourceBarrier(1, &resBarrier);
}
patchRaytracing->m_patched = true;
@@ -980,13 +981,13 @@ bool WrappedID3D12GraphicsCommandList::Serialise_BuildRaytracingAccelerationStru
{
if(m_Cmd->InRerecordRange(m_Cmd->m_LastCmdListID))
{
ID3D12GraphicsCommandList4 *list = Unwrap4(m_Cmd->RerecordCmdList(m_Cmd->m_LastCmdListID));
ID3D12GraphicsCommandListX *list = m_Cmd->RerecordCmdList(m_Cmd->m_LastCmdListID);
if(AccStructDesc.Inputs.Type == D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL &&
AccStructDesc.Inputs.NumDescs > 0)
{
patchInfo.m_patched = false;
PatchAccStructBlasAddress(&AccStructDesc, list, &patchInfo);
PatchAccStructBlasAddress(&AccStructDesc, Unwrap4(list), &patchInfo);
if(patchInfo.m_patched)
{
AccStructDesc.Inputs.InstanceDescs = patchInfo.m_patchedInstanceBuffer->Address();
@@ -998,11 +999,11 @@ bool WrappedID3D12GraphicsCommandList::Serialise_BuildRaytracingAccelerationStru
}
// Switch back to previous state
bakedCmdInfo.state.ApplyState(m_pDevice, (ID3D12GraphicsCommandListX *)pCommandList);
bakedCmdInfo.state.ApplyState(m_pDevice, list);
}
list->BuildRaytracingAccelerationStructure(&AccStructDesc, NumPostbuildInfoDescs,
pPostbuildInfoDescs);
Unwrap4(list)->BuildRaytracingAccelerationStructure(&AccStructDesc, NumPostbuildInfoDescs,
pPostbuildInfoDescs);
}
}
else