Add function to patch amp. shader DXIL to feed from buffer

This commit is contained in:
baldurk
2023-11-16 17:20:01 +00:00
parent 980677f50c
commit 882cd8c9ec
+415
View File
@@ -660,6 +660,420 @@ static void AddDXILAmpShaderPayloadStores(const DXBC::DXBCContainer *dxbc, uint3
}
}
static void ConvertToFixedDXILAmpFeeder(const DXBC::DXBCContainer *dxbc, uint32_t space,
rdcfixedarray<uint32_t, 3> dispatchDim, bytebuf &editedBlob)
{
using namespace DXIL;
ProgramEditor editor(dxbc, editedBlob);
bool isShaderModel6_6OrAbove =
dxbc->m_Version.Major > 6 || (dxbc->m_Version.Major == 6 && dxbc->m_Version.Minor >= 6);
const Type *i32 = editor.GetInt32Type();
const Type *i8 = editor.GetInt8Type();
const Type *i1 = editor.GetBoolType();
const Type *voidType = editor.GetVoidType();
const Type *handleType = editor.CreateNamedStructType(
"dx.types.Handle", {editor.CreatePointerType(i8, Type::PointerAddrSpace::Default)});
// this function is named differently based on the payload struct name, so search by prefix, we
// expect the actual type to be the same as we're just modifying the payload in place
const Function *dispatchMesh = editor.GetFunctionByPrefix("dx.op.dispatchMesh");
const Function *createHandle = NULL;
const Function *createHandleFromBinding = NULL;
const Function *annotateHandle = NULL;
// reading from a binding uses a different function in SM6.6+
if(isShaderModel6_6OrAbove)
{
const Type *resBindType = editor.CreateNamedStructType("dx.types.ResBind", {i32, i32, i32, i8});
createHandleFromBinding = editor.DeclareFunction("dx.op.createHandleFromBinding", handleType,
{i32, resBindType, i32, i1},
Attribute::NoUnwind | Attribute::ReadNone);
const Type *resourcePropertiesType =
editor.CreateNamedStructType("dx.types.ResourceProperties", {i32, i32});
annotateHandle = editor.DeclareFunction("dx.op.annotateHandle", handleType,
{i32, handleType, resourcePropertiesType},
Attribute::NoUnwind | Attribute::ReadNone);
}
else if(!createHandle && !isShaderModel6_6OrAbove)
{
createHandle = editor.DeclareFunction("dx.op.createHandle", handleType, {i32, i8, i32, i32, i1},
Attribute::NoUnwind | Attribute::ReadNone);
}
const Function *groupId = editor.DeclareFunction("dx.op.groupId.i32", i32, {i32, i32},
Attribute::NoUnwind | Attribute::ReadNone);
const Type *resRet_i32 =
editor.CreateNamedStructType("dx.types.ResRet.i32", {i32, i32, i32, i32, i32});
const Function *rawBufferLoad = editor.DeclareFunction("dx.op.rawBufferLoad.i32", resRet_i32,
{i32, handleType, i32, i32, i8, i32},
Attribute::NoUnwind | Attribute::ReadOnly);
// declare the resource, this happens purely in metadata but we need to store the slot
uint32_t regSlot = 0;
Metadata *reslist = NULL;
{
const Type *rw = editor.CreateNamedStructType("struct.RWByteAddressBuffer", {i32});
const Type *rwptr = editor.CreatePointerType(rw, Type::PointerAddrSpace::Default);
Metadata *resources = editor.CreateNamedMetadata("dx.resources");
if(resources->children.empty())
resources->children.push_back(editor.CreateMetadata());
reslist = resources->children[0];
if(reslist->children.empty())
reslist->children.resize(4);
Metadata *uavs = reslist->children[1];
// if there isn't a UAV list, create an empty one so we can add our own
if(!uavs)
uavs = reslist->children[1] = editor.CreateMetadata();
for(size_t i = 0; i < uavs->children.size(); i++)
{
// each UAV child should have a fixed format, [0] is the reg ID and I think this should always
// be == the index
const Metadata *uav = uavs->children[i];
const Constant *slot = cast<Constant>(uav->children[(size_t)ResField::ID]->value);
if(!slot)
{
RDCWARN("Unexpected non-constant slot ID in UAV");
continue;
}
RDCASSERT(slot->getU32() == i);
uint32_t id = slot->getU32();
regSlot = RDCMAX(id + 1, regSlot);
}
Constant rwundef;
rwundef.type = rwptr;
rwundef.setUndef(true);
// create the new UAV record
Metadata *uav = editor.CreateMetadata();
uav->children = {
editor.CreateConstantMetadata(regSlot),
editor.CreateConstantMetadata(editor.CreateConstant(rwundef)),
editor.CreateConstantMetadata(""),
editor.CreateConstantMetadata(space),
editor.CreateConstantMetadata(1U), // reg base
editor.CreateConstantMetadata(1U), // reg count
editor.CreateConstantMetadata(uint32_t(ResourceKind::RawBuffer)), // shape
editor.CreateConstantMetadata(false), // globally coherent
editor.CreateConstantMetadata(false), // hidden counter
editor.CreateConstantMetadata(false), // raster order
NULL, // UAV tags
};
uavs->children.push_back(uav);
}
uint32_t payloadSize = 0;
rdcstr entryName;
// add the entry point tags
{
Metadata *entryPoints = editor.GetMetadataByName("dx.entryPoints");
if(!entryPoints)
{
RDCERR("Couldn't find entry point list");
return;
}
// TODO select the entry point for multiple entry points? RT only for now
Metadata *entry = entryPoints->children[0];
entryName = entry->children[1]->str;
Metadata *taglist = entry->children[4];
if(!taglist)
taglist = entry->children[4] = editor.CreateMetadata();
// find existing shader flags tag, if there is one
Metadata *shaderFlagsTag = NULL;
Metadata *shaderFlagsData = NULL;
Metadata *ampData = NULL;
size_t flagsIndex = 0;
for(size_t t = 0; taglist && t < taglist->children.size(); t += 2)
{
RDCASSERT(taglist->children[t]->isConstant);
if(cast<Constant>(taglist->children[t]->value)->getU32() ==
(uint32_t)ShaderEntryTag::ShaderFlags)
{
shaderFlagsTag = taglist->children[t];
shaderFlagsData = taglist->children[t + 1];
flagsIndex = t + 1;
}
else if(cast<Constant>(taglist->children[t]->value)->getU32() ==
(uint32_t)ShaderEntryTag::Amplification)
{
ampData = taglist->children[t + 1];
}
}
uint32_t shaderFlagsValue =
shaderFlagsData ? cast<Constant>(shaderFlagsData->value)->getU32() : 0U;
// raw and structured buffers
shaderFlagsValue |= 0x10;
// UAVs on non-PS/CS stages
shaderFlagsValue |= 0x10000;
// REMOVE wave ops flag as we don't use it but the original shader might have. DXIL requires
// flags to be strictly minimum :(
shaderFlagsValue &= ~0x80000;
// (re-)create shader flags tag
Type *i64 = editor.CreateScalarType(Type::Int, 64);
shaderFlagsData =
editor.CreateConstantMetadata(editor.CreateConstant(Constant(i64, shaderFlagsValue)));
// shaderFlagsData = editor.CreateConstantMetadata(shaderFlagsValue);
// if we didn't have a shader tags entry at all, create the metadata node for the shader flags
// tag
if(!shaderFlagsTag)
shaderFlagsTag = editor.CreateConstantMetadata((uint32_t)ShaderEntryTag::ShaderFlags);
// if we had a tag already, we can just re-use that tag node and replace the data node.
// Otherwise we need to add both, and we insert them first
if(flagsIndex)
{
taglist->children[flagsIndex] = shaderFlagsData;
}
else
{
taglist->children.insert(0, shaderFlagsTag);
taglist->children.insert(1, shaderFlagsData);
}
// set reslist and taglist in case they were null before
entry->children[3] = reslist;
entry->children[4] = taglist;
// we must have found an amplification tag. Patch the number of threads and payload size here
ampData->children[0] = editor.CreateMetadata();
ampData->children[0]->children.push_back(editor.CreateConstantMetadata((uint32_t)1));
ampData->children[0]->children.push_back(editor.CreateConstantMetadata((uint32_t)1));
ampData->children[0]->children.push_back(editor.CreateConstantMetadata((uint32_t)1));
payloadSize = cast<Constant>(ampData->children[1]->value)->getU32();
// add room for our dimensions + offset
ampData->children[1] = editor.CreateConstantMetadata(payloadSize + 16);
}
// get the editor to patch PSV0 with our extra UAV
editor.RegisterUAV(DXILResourceType::ByteAddressUAV, space, 1, 1, ResourceKind::RawBuffer);
uint32_t dim[] = {1, 1, 1};
editor.SetNumThreads(dim);
editor.SetASPayloadSize(payloadSize + 16);
// remove some flags that will no longer be valid
editor.PatchGlobalShaderFlags(
[](DXBC::GlobalShaderFlags &flags) { flags &= ~DXBC::GlobalShaderFlags::WaveOps; });
Function *f = editor.GetFunctionByName(entryName);
if(!f)
{
RDCERR("Couldn't find entry point function '%s'", entryName.c_str());
return;
}
// find the dispatchMesh call, and from there the global groupshared variable that's the payload
GlobalVar *payloadVariable = NULL;
Type *payloadType = NULL;
for(size_t i = 0; i < f->instructions.size(); i++)
{
const Instruction &inst = *f->instructions[i];
if(inst.op == Operation::Call && inst.getFuncCall()->name == dispatchMesh->name)
{
if(inst.args.size() != 5)
{
RDCERR("Unexpected number of arguments to dispatchMesh");
continue;
}
payloadVariable = cast<GlobalVar>(inst.args[4]);
if(!payloadVariable)
{
RDCERR("Unexpected non-variable payload argument to dispatchMesh");
continue;
}
payloadType = (Type *)payloadVariable->type;
RDCASSERT(payloadType->type == Type::Pointer);
payloadType = (Type *)payloadType->inner;
break;
}
}
// add the dimensions and offset to the payload type, at the end so we don't have to patch any
// GEPs in future. We'll swizzle these to the start when copying to/from buffers still
RDCASSERT(payloadType && payloadType->type == Type::Struct);
payloadType->members.append({i32, i32, i32, i32});
// recreate the function with our own instructions
f->instructions.clear();
f->blocks.resize(1);
// create our handle first thing
Constant *annotateConstant = NULL;
Instruction *handle = NULL;
if(createHandle)
{
RDCASSERT(!isShaderModel6_6OrAbove);
handle = editor.AddInstruction(
f, editor.CreateInstruction(createHandle, DXOp::createHandle,
{
// kind = UAV
editor.CreateConstant((uint8_t)HandleKind::UAV),
// ID/slot
editor.CreateConstant(regSlot),
// register
editor.CreateConstant(1U),
// non-uniform
editor.CreateConstant(false),
}));
}
else if(createHandleFromBinding)
{
RDCASSERT(isShaderModel6_6OrAbove);
const Type *resBindType = editor.CreateNamedStructType("dx.types.ResBind", {});
Constant *resBindConstant =
editor.CreateConstant(resBindType, {
// Lower id bound
editor.CreateConstant(1U),
// Upper id bound
editor.CreateConstant(1U),
// Space ID
editor.CreateConstant(space),
// kind = UAV
editor.CreateConstant((uint8_t)HandleKind::UAV),
});
Instruction *unannotatedHandle = editor.AddInstruction(
f, editor.CreateInstruction(createHandleFromBinding, DXOp::createHandleFromBinding,
{
// resBind
resBindConstant,
// ID/slot
editor.CreateConstant(1U),
// non-uniform
editor.CreateConstant(false),
}));
annotateConstant = editor.CreateConstant(
editor.CreateNamedStructType("dx.types.ResourceProperties", {}),
{
// IsUav : (1 << 12)
editor.CreateConstant(uint32_t((1 << 12) | (uint32_t)ResourceKind::RawBuffer)),
//
editor.CreateConstant(0U),
});
handle = editor.AddInstruction(f, editor.CreateInstruction(annotateHandle, DXOp::annotateHandle,
{
// Resource handle
unannotatedHandle,
// Resource properties
annotateConstant,
}));
}
RDCASSERT(handle);
Constant *i32_0 = editor.CreateConstant(0U);
Constant *i32_1 = editor.CreateConstant(1U);
Constant *i32_2 = editor.CreateConstant(2U);
Constant *i32_4 = editor.CreateConstant(4U);
// get our output location from group ID
Instruction *groupX =
editor.AddInstruction(f, editor.CreateInstruction(groupId, DXOp::groupId, {i32_0}));
Instruction *groupY =
editor.AddInstruction(f, editor.CreateInstruction(groupId, DXOp::groupId, {i32_1}));
Instruction *groupZ =
editor.AddInstruction(f, editor.CreateInstruction(groupId, DXOp::groupId, {i32_2}));
// linearise it based on the number of dispatches
Instruction *groupYMul = editor.AddInstruction(
f, editor.CreateInstruction(Operation::Mul, i32,
{groupY, editor.CreateConstant(dispatchDim[0])}));
Instruction *groupZMul = editor.AddInstruction(
f, editor.CreateInstruction(Operation::Mul, i32,
{groupZ, editor.CreateConstant(dispatchDim[0] * dispatchDim[1])}));
Instruction *groupYZAdd = editor.AddInstruction(
f, editor.CreateInstruction(Operation::Add, i32, {groupYMul, groupZMul}));
Instruction *flatIndex =
editor.AddInstruction(f, editor.CreateInstruction(Operation::Add, i32, {groupX, groupYZAdd}));
Instruction *baseOffset = editor.AddInstruction(
f, editor.CreateInstruction(Operation::Mul, i32,
{flatIndex, editor.CreateConstant(payloadSize + 16)}));
Instruction *dimAndOffset = editor.AddInstruction(
f, editor.CreateInstruction(rawBufferLoad, DXOp::rawBufferLoad,
{handle, baseOffset, editor.CreateUndef(i32),
editor.CreateConstant((uint8_t)0xf), i32_4}));
Instruction *dimX =
editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32,
{dimAndOffset, editor.CreateLiteral(0)}));
Instruction *dimY =
editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32,
{dimAndOffset, editor.CreateLiteral(1)}));
Instruction *dimZ =
editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32,
{dimAndOffset, editor.CreateLiteral(2)}));
Instruction *offset =
editor.AddInstruction(f, editor.CreateInstruction(Operation::ExtractVal, i32,
{dimAndOffset, editor.CreateLiteral(3)}));
size_t curInst = f->instructions.size();
// start at 16 bytes, to account for our own data
uint32_t uavByteOffset = 16;
for(uint32_t i = 0; i < payloadType->members.size() - 4; i++)
{
PayloadBufferCopy(BufferToPayload, editor, f, curInst, baseOffset, handle,
payloadType->members[i], uavByteOffset,
{payloadVariable, i32_0, editor.CreateConstant(i)});
}
for(size_t i = 0; i < 4; i++)
{
Value *srcs[] = {dimX, dimY, dimZ, offset};
Constant *dst = editor.CreateConstantGEP(
editor.GetPointerType(i32, payloadVariable->type->addrSpace),
{payloadVariable, i32_0,
editor.CreateConstant(uint32_t(payloadType->members.size() - 4 + i))});
DXIL::Instruction *store = editor.CreateInstruction(Operation::Store);
store->type = voidType;
store->op = Operation::Store;
store->align = 4;
store->args = {dst, srcs[i]};
editor.AddInstruction(f, store);
}
editor.AddInstruction(f, editor.CreateInstruction(dispatchMesh, DXOp::dispatchMesh,
{dimX, dimY, dimZ, payloadVariable}));
editor.AddInstruction(f, editor.CreateInstruction(Operation::Ret, voidType, {}));
}
bool D3D12Replay::CreateSOBuffers()
{
HRESULT hr = S_OK;
@@ -780,6 +1194,7 @@ void D3D12Replay::ClearPostVSCache()
{
// temporary to avoid a warning
(void)&AddDXILAmpShaderPayloadStores;
(void)&ConvertToFixedDXILAmpFeeder;
for(auto it = m_PostVSData.begin(); it != m_PostVSData.end(); ++it)
{