Add function to enumerate the entry points in a given shader

This commit is contained in:
baldurk
2024-04-22 17:43:02 +01:00
parent 89bea3ea8b
commit 476fed06d6
4 changed files with 54 additions and 0 deletions
@@ -1792,6 +1792,20 @@ DXBCContainer::DXBCContainer(const bytebuf &ByteCode, const rdcstr &debugInfoPat
m_Reflection = new Reflection;
}
if(dxilReflectProgram)
{
m_EntryPoints = dxilReflectProgram->GetEntryPoints();
}
else if(m_EntryPoints.empty())
{
rdcstr entry;
if(GetDebugInfo())
entry = GetDebugInfo()->GetEntryFunction();
if(entry.empty())
entry = "main";
m_EntryPoints = {ShaderEntryPoint(entry, GetShaderStage(m_Type))};
}
SAFE_DELETE(dxilSTATProgram);
for(uint32_t chunkIdx = 0; chunkIdx < header->numChunks; chunkIdx++)
@@ -184,6 +184,8 @@ public:
const Reflection *GetReflection() const { return m_Reflection; }
D3D_PRIMITIVE_TOPOLOGY GetOutputTopology();
rdcarray<ShaderEntryPoint> GetEntryPoints() const { return m_EntryPoints; }
const rdcstr &GetDisassembly(bool dxcStyle);
void FillTraceLineInfo(ShaderDebugTrace &trace) const;
@@ -249,6 +251,7 @@ private:
DXIL::Program *m_DXILByteCode = NULL;
IDebugInfo *m_DebugInfo = NULL;
Reflection *m_Reflection = NULL;
rdcarray<ShaderEntryPoint> m_EntryPoints;
};
}; // namespace DXBC
@@ -1229,6 +1229,7 @@ public:
const bytebuf &GetBytes() const { return m_Bytes; }
void FetchComputeProperties(DXBC::Reflection *reflection);
DXBC::Reflection *GetReflection();
rdcarray<ShaderEntryPoint> GetEntryPoints();
DXBC::ShaderType GetShaderType() const { return m_Type; }
uint32_t GetMajorVersion() const { return m_Major; }
@@ -1242,6 +1242,42 @@ static void AddResourceBind(DXBC::Reflection *refl, const TypeInfo &typeInfo, co
refl->UAVs.push_back(bind);
}
rdcarray<ShaderEntryPoint> Program::GetEntryPoints()
{
rdcarray<ShaderEntryPoint> ret;
DXMeta dx(m_NamedMeta);
if(dx.entryPoints)
{
for(Metadata *entry : dx.entryPoints->children)
{
if(entry->children.size() > 2 && entry->children[0] != NULL)
{
ShaderEntryPoint entryPoint;
entryPoint.name = entry->children[1]->str;
Metadata *tags = entry->children[4];
for(size_t i = 0; i < tags->children.size(); i += 2)
{
// 8 is the type tag
if(getival<uint32_t>(tags->children[i]) == 8U)
{
entryPoint.stage =
GetShaderStage((DXBC::ShaderType)getival<uint32_t>(tags->children[i + 1]));
break;
}
}
ret.push_back(entryPoint);
}
}
}
return ret;
}
DXBC::Reflection *Program::GetReflection()
{
using namespace DXBC;