Implement GPU unwrapping of handles in shader records

This commit is contained in:
baldurk
2024-04-30 15:28:02 +01:00
parent a9849c050b
commit b06e7030a5
7 changed files with 98 additions and 4 deletions
+11
View File
@@ -299,9 +299,20 @@ cbuffer RayDispatchPatchCB REG(b0)
uint raydispatch_calloffs;
uint raydispatch_callstride;
uint raydispatch_callcount;
GPUAddress wrapped_sampHeapBase;
GPUAddress wrapped_srvHeapBase;
GPUAddress unwrapped_sampHeapBase;
GPUAddress unwrapped_srvHeapBase;
uint wrapped_sampHeapSize;
uint wrapped_srvHeapSize;
uint unwrapped_heapStrides; // LSB = sampler, MSB = srv
};
#define MAX_LOCALSIG_HANDLES 31
#define WRAPPED_DESCRIPTOR_STRIDE 64
cbuffer DebugSampleOperation REG(b0)
{
+46 -2
View File
@@ -89,6 +89,16 @@ struct WrappedRecord
RWByteAddressBuffer bufferToPatch : register(u0);
struct DescriptorHeapData
{
GPUAddress wrapped_base;
GPUAddress wrapped_end;
GPUAddress unwrapped_base;
uint unwrapped_stride;
};
void PatchTable(uint byteOffset)
{
// load our wrapped record from the start of the table
@@ -130,9 +140,43 @@ void PatchTable(uint byteOffset)
{
RootSig sig = rootsigs[recordData.rootSigIndex];
for(int i = 0; i < sig.numHandles; i++)
DescriptorHeapData heaps[2];
heaps[0].wrapped_base = wrapped_sampHeapBase;
heaps[1].wrapped_base = wrapped_srvHeapBase;
heaps[0].wrapped_end = add(wrapped_sampHeapBase, GPUAddress(wrapped_sampHeapSize, 0));
heaps[1].wrapped_end = add(wrapped_srvHeapBase, GPUAddress(wrapped_srvHeapSize, 0));
heaps[0].unwrapped_stride = unwrapped_heapStrides & 0xffff;
heaps[1].unwrapped_stride = unwrapped_heapStrides >> 16;
heaps[0].unwrapped_base = unwrapped_sampHeapBase;
heaps[1].unwrapped_base = unwrapped_srvHeapBase;
for(uint i = 0; i < sig.numHandles; i++)
{
// TODO: patch descriptor handle at offset sig.handleOffsets[i]
GPUAddress wrappedHandlePtr = bufferToPatch.Load2(sig.handleOffsets[i]);
bool patched = false;
for(int h = 0; h < 2; h++)
{
if(lessEqual(heaps[h].wrapped_base, wrappedHandlePtr) &&
lessThan(wrappedHandlePtr, heaps[h].wrapped_end))
{
// assume the byte offsets will all fit into the LSB 32-bits
uint index = sub(wrappedHandlePtr, wrapped_sampHeapBase).x / WRAPPED_DESCRIPTOR_STRIDE;
GPUAddress handleOffset = GPUAddress(index * heaps[h].unwrapped_stride, 0);
bufferToPatch.Store2(sig.handleOffsets[i], add(heaps[h].unwrapped_base, handleOffset));
patched = true;
}
}
if(!patched)
{
// won't work but is our best effort
bufferToPatch.Store2(sig.handleOffsets[i], GPUAddress(0, 0));
}
}
}
}
@@ -1279,7 +1279,8 @@ void WrappedID3D12GraphicsCommandList::DispatchRays(_In_ const D3D12_DISPATCH_RA
// reference to the lookup buffer used as well as a reference to the scratch buffer containing the
// patched shader records.
PatchedRayDispatch patchedDispatch =
GetResourceManager()->GetRaytracingResourceAndUtilHandler()->PatchRayDispatch(m_pList4, *pDesc);
GetResourceManager()->GetRaytracingResourceAndUtilHandler()->PatchRayDispatch(
m_pList4, m_CaptureComputeState.heaps, *pDesc);
// restore state that would have been mutated by the patching process
m_pList4->SetComputeRootSignature(Unwrap(GetResourceManager()->GetCurrentAs<ID3D12RootSignature>(
@@ -1060,6 +1060,10 @@ void WrappedID3D12GraphicsCommandList::SetDescriptorHeaps(UINT NumDescriptorHeap
m_ListRecord->AddChunk(scope.Get(m_ListRecord->cmdInfo->alloc));
for(UINT i = 0; i < NumDescriptorHeaps; i++)
m_ListRecord->MarkResourceFrameReferenced(GetResID(ppDescriptorHeaps[i]), eFrameRef_Read);
m_CaptureComputeState.heaps.resize(NumDescriptorHeaps);
for(size_t i = 0; i < m_CaptureComputeState.heaps.size(); i++)
m_CaptureComputeState.heaps[i] = GetResID(ppDescriptorHeaps[i]);
}
}
+32 -1
View File
@@ -783,7 +783,8 @@ void D3D12RaytracingResourceAndUtilHandler::ResizeSerialisationBuffer(UINT64 siz
}
PatchedRayDispatch D3D12RaytracingResourceAndUtilHandler::PatchRayDispatch(
ID3D12GraphicsCommandList4 *unwrappedCmd, const D3D12_DISPATCH_RAYS_DESC &desc)
ID3D12GraphicsCommandList4 *unwrappedCmd, rdcarray<ResourceId> heaps,
const D3D12_DISPATCH_RAYS_DESC &desc)
{
PatchedRayDispatch ret = {};
@@ -1000,6 +1001,31 @@ PatchedRayDispatch D3D12RaytracingResourceAndUtilHandler::PatchRayDispatch(
cbufferData.raydispatch_callcount =
uint32_t(desc.CallableShaderTable.SizeInBytes / desc.CallableShaderTable.StrideInBytes);
RDCCOMPILE_ASSERT(WRAPPED_DESCRIPTOR_STRIDE == sizeof(D3D12Descriptor),
"Shader descriptor stride is wrong");
for(ResourceId heapId : heaps)
{
WrappedID3D12DescriptorHeap *heap =
(WrappedID3D12DescriptorHeap *)m_wrappedDevice->GetResourceManager()
->GetCurrentAs<ID3D12DescriptorHeap>(heapId);
if(heap->GetDescriptors()->GetType() == D3D12DescriptorType::Sampler)
{
cbufferData.wrapped_sampHeapBase = heap->GetCPUDescriptorHandleForHeapStart().ptr;
cbufferData.unwrapped_sampHeapBase = heap->GetGPU(0).ptr;
cbufferData.wrapped_sampHeapSize = heap->GetNumDescriptors() * sizeof(D3D12Descriptor);
cbufferData.unwrapped_heapStrides |= uint16_t(heap->GetUnwrappedIncrement());
}
else
{
cbufferData.wrapped_srvHeapBase = heap->GetCPUDescriptorHandleForHeapStart().ptr;
cbufferData.unwrapped_srvHeapBase = heap->GetGPU(0).ptr;
cbufferData.wrapped_srvHeapSize = heap->GetNumDescriptors() * sizeof(D3D12Descriptor);
cbufferData.unwrapped_heapStrides |= uint32_t(heap->GetUnwrappedIncrement()) << 16;
}
}
unwrappedCmd->SetPipelineState(m_RayPatchingData.pipe);
unwrappedCmd->SetComputeRootSignature(m_RayPatchingData.rootSig);
unwrappedCmd->SetComputeRoot32BitConstants((UINT)D3D12PatchRayDispatchParam::RootConstantBuffer,
@@ -1030,6 +1056,11 @@ PatchedRayDispatch D3D12RaytracingResourceAndUtilHandler::PatchRayDispatch(
void D3D12RaytracingResourceAndUtilHandler::InitRayDispatchPatchingResources()
{
// need 4x 2-DWORD root buffers, the rest we can have for constants.
// this could be made another buffer to track but it fits in push constants so we'll use them
RDCCOMPILE_ASSERT((sizeof(RayDispatchPatchCB) / sizeof(uint32_t)) + 4 * 2 < 64,
"Root signature constnats are too large");
// Root Signature
rdcarray<D3D12_ROOT_PARAMETER1> rootParameters;
rootParameters.reserve((uint16_t)D3D12PatchRayDispatchParam::Count);
+1
View File
@@ -1110,6 +1110,7 @@ public:
void UnregisterExportDatabase(D3D12ShaderExportDatabase *db);
PatchedRayDispatch PatchRayDispatch(ID3D12GraphicsCommandList4 *unwrappedCmd,
rdcarray<ResourceId> heaps,
const D3D12_DISPATCH_RAYS_DESC &desc);
void ResizeSerialisationBuffer(UINT64 size);
+2
View File
@@ -455,6 +455,8 @@ public:
handle.ptr += idx * increment;
return handle;
}
uint32_t GetUnwrappedIncrement() const { return increment; }
};
class WrappedID3D12Fence : public WrappedDeviceChild12<ID3D12Fence, ID3D12Fence1>