From 882cd8c9ecf3031233e3401df93f161e463aa99d Mon Sep 17 00:00:00 2001 From: baldurk Date: Thu, 16 Nov 2023 17:20:01 +0000 Subject: [PATCH] Add function to patch amp. shader DXIL to feed from buffer --- renderdoc/driver/d3d12/d3d12_postvs.cpp | 415 ++++++++++++++++++++++++ 1 file changed, 415 insertions(+) diff --git a/renderdoc/driver/d3d12/d3d12_postvs.cpp b/renderdoc/driver/d3d12/d3d12_postvs.cpp index d33aa4793..8504556a0 100644 --- a/renderdoc/driver/d3d12/d3d12_postvs.cpp +++ b/renderdoc/driver/d3d12/d3d12_postvs.cpp @@ -660,6 +660,420 @@ static void AddDXILAmpShaderPayloadStores(const DXBC::DXBCContainer *dxbc, uint3 } } +static void ConvertToFixedDXILAmpFeeder(const DXBC::DXBCContainer *dxbc, uint32_t space, + rdcfixedarray dispatchDim, bytebuf &editedBlob) +{ + using namespace DXIL; + + ProgramEditor editor(dxbc, editedBlob); + bool isShaderModel6_6OrAbove = + dxbc->m_Version.Major > 6 || (dxbc->m_Version.Major == 6 && dxbc->m_Version.Minor >= 6); + + const Type *i32 = editor.GetInt32Type(); + const Type *i8 = editor.GetInt8Type(); + const Type *i1 = editor.GetBoolType(); + const Type *voidType = editor.GetVoidType(); + + const Type *handleType = editor.CreateNamedStructType( + "dx.types.Handle", {editor.CreatePointerType(i8, Type::PointerAddrSpace::Default)}); + + // this function is named differently based on the payload struct name, so search by prefix, we + // expect the actual type to be the same as we're just modifying the payload in place + const Function *dispatchMesh = editor.GetFunctionByPrefix("dx.op.dispatchMesh"); + + const Function *createHandle = NULL; + const Function *createHandleFromBinding = NULL; + const Function *annotateHandle = NULL; + + // reading from a binding uses a different function in SM6.6+ + if(isShaderModel6_6OrAbove) + { + const Type *resBindType = editor.CreateNamedStructType("dx.types.ResBind", {i32, i32, i32, i8}); + createHandleFromBinding = editor.DeclareFunction("dx.op.createHandleFromBinding", handleType, + {i32, resBindType, i32, i1}, + Attribute::NoUnwind | Attribute::ReadNone); + + const Type *resourcePropertiesType = + editor.CreateNamedStructType("dx.types.ResourceProperties", {i32, i32}); + annotateHandle = editor.DeclareFunction("dx.op.annotateHandle", handleType, + {i32, handleType, resourcePropertiesType}, + Attribute::NoUnwind | Attribute::ReadNone); + } + else if(!createHandle && !isShaderModel6_6OrAbove) + { + createHandle = editor.DeclareFunction("dx.op.createHandle", handleType, {i32, i8, i32, i32, i1}, + Attribute::NoUnwind | Attribute::ReadNone); + } + + const Function *groupId = editor.DeclareFunction("dx.op.groupId.i32", i32, {i32, i32}, + Attribute::NoUnwind | Attribute::ReadNone); + const Type *resRet_i32 = + editor.CreateNamedStructType("dx.types.ResRet.i32", {i32, i32, i32, i32, i32}); + const Function *rawBufferLoad = editor.DeclareFunction("dx.op.rawBufferLoad.i32", resRet_i32, + {i32, handleType, i32, i32, i8, i32}, + Attribute::NoUnwind | Attribute::ReadOnly); + + // declare the resource, this happens purely in metadata but we need to store the slot + uint32_t regSlot = 0; + Metadata *reslist = NULL; + { + const Type *rw = editor.CreateNamedStructType("struct.RWByteAddressBuffer", {i32}); + const Type *rwptr = editor.CreatePointerType(rw, Type::PointerAddrSpace::Default); + + Metadata *resources = editor.CreateNamedMetadata("dx.resources"); + if(resources->children.empty()) + resources->children.push_back(editor.CreateMetadata()); + + reslist = resources->children[0]; + + if(reslist->children.empty()) + reslist->children.resize(4); + + Metadata *uavs = reslist->children[1]; + // if there isn't a UAV list, create an empty one so we can add our own + if(!uavs) + uavs = reslist->children[1] = editor.CreateMetadata(); + + for(size_t i = 0; i < uavs->children.size(); i++) + { + // each UAV child should have a fixed format, [0] is the reg ID and I think this should always + // be == the index + const Metadata *uav = uavs->children[i]; + const Constant *slot = cast(uav->children[(size_t)ResField::ID]->value); + + if(!slot) + { + RDCWARN("Unexpected non-constant slot ID in UAV"); + continue; + } + + RDCASSERT(slot->getU32() == i); + + uint32_t id = slot->getU32(); + regSlot = RDCMAX(id + 1, regSlot); + } + + Constant rwundef; + rwundef.type = rwptr; + rwundef.setUndef(true); + + // create the new UAV record + Metadata *uav = editor.CreateMetadata(); + uav->children = { + editor.CreateConstantMetadata(regSlot), + editor.CreateConstantMetadata(editor.CreateConstant(rwundef)), + editor.CreateConstantMetadata(""), + editor.CreateConstantMetadata(space), + editor.CreateConstantMetadata(1U), // reg base + editor.CreateConstantMetadata(1U), // reg count + editor.CreateConstantMetadata(uint32_t(ResourceKind::RawBuffer)), // shape + editor.CreateConstantMetadata(false), // globally coherent + editor.CreateConstantMetadata(false), // hidden counter + editor.CreateConstantMetadata(false), // raster order + NULL, // UAV tags + }; + + uavs->children.push_back(uav); + } + + uint32_t payloadSize = 0; + + rdcstr entryName; + // add the entry point tags + { + Metadata *entryPoints = editor.GetMetadataByName("dx.entryPoints"); + + if(!entryPoints) + { + RDCERR("Couldn't find entry point list"); + return; + } + + // TODO select the entry point for multiple entry points? RT only for now + Metadata *entry = entryPoints->children[0]; + + entryName = entry->children[1]->str; + + Metadata *taglist = entry->children[4]; + if(!taglist) + taglist = entry->children[4] = editor.CreateMetadata(); + + // find existing shader flags tag, if there is one + Metadata *shaderFlagsTag = NULL; + Metadata *shaderFlagsData = NULL; + Metadata *ampData = NULL; + size_t flagsIndex = 0; + for(size_t t = 0; taglist && t < taglist->children.size(); t += 2) + { + RDCASSERT(taglist->children[t]->isConstant); + if(cast(taglist->children[t]->value)->getU32() == + (uint32_t)ShaderEntryTag::ShaderFlags) + { + shaderFlagsTag = taglist->children[t]; + shaderFlagsData = taglist->children[t + 1]; + flagsIndex = t + 1; + } + else if(cast(taglist->children[t]->value)->getU32() == + (uint32_t)ShaderEntryTag::Amplification) + { + ampData = taglist->children[t + 1]; + } + } + + uint32_t shaderFlagsValue = + shaderFlagsData ? cast(shaderFlagsData->value)->getU32() : 0U; + + // raw and structured buffers + shaderFlagsValue |= 0x10; + + // UAVs on non-PS/CS stages + shaderFlagsValue |= 0x10000; + + // REMOVE wave ops flag as we don't use it but the original shader might have. DXIL requires + // flags to be strictly minimum :( + shaderFlagsValue &= ~0x80000; + + // (re-)create shader flags tag + Type *i64 = editor.CreateScalarType(Type::Int, 64); + shaderFlagsData = + editor.CreateConstantMetadata(editor.CreateConstant(Constant(i64, shaderFlagsValue))); + // shaderFlagsData = editor.CreateConstantMetadata(shaderFlagsValue); + + // if we didn't have a shader tags entry at all, create the metadata node for the shader flags + // tag + if(!shaderFlagsTag) + shaderFlagsTag = editor.CreateConstantMetadata((uint32_t)ShaderEntryTag::ShaderFlags); + + // if we had a tag already, we can just re-use that tag node and replace the data node. + // Otherwise we need to add both, and we insert them first + if(flagsIndex) + { + taglist->children[flagsIndex] = shaderFlagsData; + } + else + { + taglist->children.insert(0, shaderFlagsTag); + taglist->children.insert(1, shaderFlagsData); + } + + // set reslist and taglist in case they were null before + entry->children[3] = reslist; + entry->children[4] = taglist; + + // we must have found an amplification tag. Patch the number of threads and payload size here + ampData->children[0] = editor.CreateMetadata(); + ampData->children[0]->children.push_back(editor.CreateConstantMetadata((uint32_t)1)); + ampData->children[0]->children.push_back(editor.CreateConstantMetadata((uint32_t)1)); + ampData->children[0]->children.push_back(editor.CreateConstantMetadata((uint32_t)1)); + + payloadSize = cast(ampData->children[1]->value)->getU32(); + // add room for our dimensions + offset + ampData->children[1] = editor.CreateConstantMetadata(payloadSize + 16); + } + + // get the editor to patch PSV0 with our extra UAV + editor.RegisterUAV(DXILResourceType::ByteAddressUAV, space, 1, 1, ResourceKind::RawBuffer); + uint32_t dim[] = {1, 1, 1}; + editor.SetNumThreads(dim); + editor.SetASPayloadSize(payloadSize + 16); + + // remove some flags that will no longer be valid + editor.PatchGlobalShaderFlags( + [](DXBC::GlobalShaderFlags &flags) { flags &= ~DXBC::GlobalShaderFlags::WaveOps; }); + + Function *f = editor.GetFunctionByName(entryName); + + if(!f) + { + RDCERR("Couldn't find entry point function '%s'", entryName.c_str()); + return; + } + + // find the dispatchMesh call, and from there the global groupshared variable that's the payload + GlobalVar *payloadVariable = NULL; + Type *payloadType = NULL; + for(size_t i = 0; i < f->instructions.size(); i++) + { + const Instruction &inst = *f->instructions[i]; + + if(inst.op == Operation::Call && inst.getFuncCall()->name == dispatchMesh->name) + { + if(inst.args.size() != 5) + { + RDCERR("Unexpected number of arguments to dispatchMesh"); + continue; + } + payloadVariable = cast(inst.args[4]); + if(!payloadVariable) + { + RDCERR("Unexpected non-variable payload argument to dispatchMesh"); + continue; + } + + payloadType = (Type *)payloadVariable->type; + + RDCASSERT(payloadType->type == Type::Pointer); + payloadType = (Type *)payloadType->inner; + + break; + } + } + + // add the dimensions and offset to the payload type, at the end so we don't have to patch any + // GEPs in future. We'll swizzle these to the start when copying to/from buffers still + RDCASSERT(payloadType && payloadType->type == Type::Struct); + payloadType->members.append({i32, i32, i32, i32}); + + // recreate the function with our own instructions + f->instructions.clear(); + f->blocks.resize(1); + + // create our handle first thing + Constant *annotateConstant = NULL; + Instruction *handle = NULL; + if(createHandle) + { + RDCASSERT(!isShaderModel6_6OrAbove); + handle = editor.AddInstruction( + f, editor.CreateInstruction(createHandle, DXOp::createHandle, + { + // kind = UAV + editor.CreateConstant((uint8_t)HandleKind::UAV), + // ID/slot + editor.CreateConstant(regSlot), + // register + editor.CreateConstant(1U), + // non-uniform + editor.CreateConstant(false), + })); + } + else if(createHandleFromBinding) + { + RDCASSERT(isShaderModel6_6OrAbove); + const Type *resBindType = editor.CreateNamedStructType("dx.types.ResBind", {}); + Constant *resBindConstant = + editor.CreateConstant(resBindType, { + // Lower id bound + editor.CreateConstant(1U), + // Upper id bound + editor.CreateConstant(1U), + // Space ID + editor.CreateConstant(space), + // kind = UAV + editor.CreateConstant((uint8_t)HandleKind::UAV), + }); + + Instruction *unannotatedHandle = editor.AddInstruction( + f, editor.CreateInstruction(createHandleFromBinding, DXOp::createHandleFromBinding, + { + // resBind + resBindConstant, + // ID/slot + editor.CreateConstant(1U), + // non-uniform + editor.CreateConstant(false), + })); + + annotateConstant = editor.CreateConstant( + editor.CreateNamedStructType("dx.types.ResourceProperties", {}), + { + // IsUav : (1 << 12) + editor.CreateConstant(uint32_t((1 << 12) | (uint32_t)ResourceKind::RawBuffer)), + // + editor.CreateConstant(0U), + }); + + handle = editor.AddInstruction(f, editor.CreateInstruction(annotateHandle, DXOp::annotateHandle, + { + // Resource handle + unannotatedHandle, + // Resource properties + annotateConstant, + })); + } + + RDCASSERT(handle); + + Constant *i32_0 = editor.CreateConstant(0U); + Constant *i32_1 = editor.CreateConstant(1U); + Constant *i32_2 = editor.CreateConstant(2U); + Constant *i32_4 = editor.CreateConstant(4U); + + // get our output location from group ID + Instruction *groupX = + editor.AddInstruction(f, editor.CreateInstruction(groupId, DXOp::groupId, {i32_0})); + Instruction *groupY = + editor.AddInstruction(f, editor.CreateInstruction(groupId, DXOp::groupId, {i32_1})); + Instruction *groupZ = + editor.AddInstruction(f, editor.CreateInstruction(groupId, DXOp::groupId, {i32_2})); + + // linearise it based on the number of dispatches + Instruction *groupYMul = editor.AddInstruction( + f, editor.CreateInstruction(Operation::Mul, i32, + {groupY, editor.CreateConstant(dispatchDim[0])})); + Instruction *groupZMul = editor.AddInstruction( + f, editor.CreateInstruction(Operation::Mul, i32, + {groupZ, editor.CreateConstant(dispatchDim[0] * dispatchDim[1])})); + Instruction *groupYZAdd = editor.AddInstruction( + f, editor.CreateInstruction(Operation::Add, i32, {groupYMul, groupZMul})); + Instruction *flatIndex = + editor.AddInstruction(f, editor.CreateInstruction(Operation::Add, i32, {groupX, groupYZAdd})); + + Instruction *baseOffset = editor.AddInstruction( + f, editor.CreateInstruction(Operation::Mul, i32, + {flatIndex, editor.CreateConstant(payloadSize + 16)})); + + Instruction *dimAndOffset = editor.AddInstruction( + f, editor.CreateInstruction(rawBufferLoad, DXOp::rawBufferLoad, + {handle, baseOffset, editor.CreateUndef(i32), + editor.CreateConstant((uint8_t)0xf), i32_4})); + + Instruction *dimX = + editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32, + {dimAndOffset, editor.CreateLiteral(0)})); + Instruction *dimY = + editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32, + {dimAndOffset, editor.CreateLiteral(1)})); + Instruction *dimZ = + editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32, + {dimAndOffset, editor.CreateLiteral(2)})); + Instruction *offset = + editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32, + {dimAndOffset, editor.CreateLiteral(3)})); + + size_t curInst = f->instructions.size(); + // start at 16 bytes, to account for our own data + uint32_t uavByteOffset = 16; + for(uint32_t i = 0; i < payloadType->members.size() - 4; i++) + { + PayloadBufferCopy(BufferToPayload, editor, f, curInst, baseOffset, handle, + payloadType->members[i], uavByteOffset, + {payloadVariable, i32_0, editor.CreateConstant(i)}); + } + + for(size_t i = 0; i < 4; i++) + { + Value *srcs[] = {dimX, dimY, dimZ, offset}; + + Constant *dst = editor.CreateConstantGEP( + editor.GetPointerType(i32, payloadVariable->type->addrSpace), + {payloadVariable, i32_0, + editor.CreateConstant(uint32_t(payloadType->members.size() - 4 + i))}); + + DXIL::Instruction *store = editor.CreateInstruction(Operation::Store); + + store->type = voidType; + store->op = Operation::Store; + store->align = 4; + store->args = {dst, srcs[i]}; + + editor.AddInstruction(f, store); + } + + editor.AddInstruction(f, editor.CreateInstruction(dispatchMesh, DXOp::dispatchMesh, + {dimX, dimY, dimZ, payloadVariable})); + editor.AddInstruction(f, editor.CreateInstruction(Operation::Ret, voidType, {})); +} bool D3D12Replay::CreateSOBuffers() { HRESULT hr = S_OK; @@ -780,6 +1194,7 @@ void D3D12Replay::ClearPostVSCache() { // temporary to avoid a warning (void)&AddDXILAmpShaderPayloadStores; + (void)&ConvertToFixedDXILAmpFeeder; for(auto it = m_PostVSData.begin(); it != m_PostVSData.end(); ++it) {