Add DXIL Debugger Support for Wave Scan and Prefix ops

DXOp::WavePrefixOp
DxOp::WavePrefixBitCount
This commit is contained in:
Jake Turner
2025-04-07 13:19:18 +01:00
parent cf6e3d68fa
commit 6f949b1fa9
2 changed files with 182 additions and 8 deletions
+180 -6
View File
@@ -124,6 +124,68 @@ static bool DecodePointer(DXILDebug::Id &ptrId, uint64_t &offset, uint64_t &size
return true;
}
static void SetFloatValue(float val, ShaderVariable &var)
{
for(uint8_t c = 0; c < var.columns; c++)
{
#undef _IMPL
#define _IMPL(T) comp<T>(var, c) = val;
IMPL_FOR_FLOAT_TYPES(_IMPL);
}
}
static void SetUIntValue(uint64_t val, ShaderVariable &var)
{
for(uint8_t c = 0; c < var.columns; c++)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(var, c) = U(val);
IMPL_FOR_INT_TYPES(_IMPL);
}
}
static void SetIntValue(int64_t val, ShaderVariable &var)
{
for(uint8_t c = 0; c < var.columns; c++)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<I>(var, c) = I(val);
IMPL_FOR_INT_TYPES(_IMPL);
}
}
static void SetShaderValue(float fVal, uint64_t uVal, int64_t iVal, ShaderVariable &var)
{
switch(var.type)
{
case VarType::Half:
case VarType::Float:
case VarType::Double: SetFloatValue(fVal, var); break;
case VarType::SByte:
case VarType::SShort:
case VarType::SInt:
case VarType::SLong: SetIntValue(iVal, var); break;
case VarType::UByte:
case VarType::UShort:
case VarType::UInt:
case VarType::ULong: SetUIntValue(uVal, var); break;
default: RDCERR("Unknown type %s", ToStr(var.type).c_str());
}
}
static void SetShaderValueZero(ShaderVariable &var)
{
SetShaderValue(0.0f, 0, 0, var);
}
static void SetShaderValueOne(ShaderVariable &var)
{
SetShaderValue(1.0f, 1, 1, var);
}
static bool OperationFlushing(const Operation op, DXOp dxOpCode)
{
if(dxOpCode != DXOp::NumOpCodes)
@@ -3721,12 +3783,125 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
}
break;
}
case DXOp::WavePrefixOp:
{
// WavePrefixOp(value,op,sop)
ShaderVariable arg;
RDCASSERT(GetShaderVariable(inst.args[2], opCode, dxOpCode, arg));
WaveOpCode waveOpCode = (WaveOpCode)arg.value.u32v[0];
RDCASSERT(GetShaderVariable(inst.args[3], opCode, dxOpCode, arg));
bool isUnsigned = (arg.value.u32v[0] != (uint32_t)SignedOpKind::Signed);
// set the identity
ShaderVariable accum(result);
switch(waveOpCode)
{
case WaveOpCode::Sum: SetShaderValueZero(accum); break;
case WaveOpCode::Product: SetShaderValueOne(accum); break;
default:
RDCERR("Unhandled PrefixOp wave opcode");
accum.value = {};
break;
}
// determine active lane indices in our subgroup
rdcarray<uint32_t> activeLanes;
GetSubgroupActiveLanes(activeMask, workgroup, activeLanes);
for(uint32_t lane : activeLanes)
{
// stop before processing our lane
if(lane == m_WorkgroupIndex)
break;
ShaderVariable x;
RDCASSERT(workgroup[lane].GetShaderVariable(inst.args[1], opCode, dxOpCode, x));
switch(waveOpCode)
{
case WaveOpCode::Sum:
{
for(uint8_t c = 0; c < x.columns; c++)
{
if(isUnsigned)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(accum, c) = comp<U>(accum, c) + comp<U>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
else
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<S>(accum, c) + comp<S>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) comp<T>(accum, c) = comp<T>(accum, c) + comp<T>(x, c)
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
}
}
break;
}
case WaveOpCode::Product:
{
for(uint8_t c = 0; c < x.columns; c++)
{
if(isUnsigned)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(accum, c) = comp<U>(accum, c) * comp<U>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
else
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<S>(accum, c) * comp<S>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) comp<T>(accum, c) = comp<T>(accum, c) * comp<T>(x, c)
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
}
}
break;
}
default: RDCERR("Unhandled PrefixOp wave opcode"); break;
}
}
result.value = accum.value;
break;
}
case DXOp::WavePrefixBitCount:
{
// WavePrefixBitCount(cond)
// determine active lane indices in our subgroup
rdcarray<uint32_t> activeLanes;
GetSubgroupActiveLanes(activeMask, workgroup, activeLanes);
uint32_t count = 0;
for(uint32_t lane : activeLanes)
{
// stop before processing our lane
if(lane == m_WorkgroupIndex)
break;
ShaderVariable x;
RDCASSERT(workgroup[lane].GetShaderVariable(inst.args[1], opCode, dxOpCode, x));
count += x.value.u32v[0];
}
result.value.u32v[0] = count;
break;
}
case DXOp::WaveAnyTrue:
case DXOp::WaveAllTrue:
case DXOp::WaveActiveBallot:
case DXOp::WaveActiveOp:
{
ShaderVariable accum;
ShaderVariable accum(result);
bool isUnsigned = true;
WaveOpCode waveOpCode = WaveOpCode::Sum;
@@ -3743,11 +3918,12 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
// set the identity
switch(waveOpCode)
{
case WaveOpCode::Sum: SetShaderValueZero(accum); break;
case WaveOpCode::Product: SetShaderValueOne(accum); break;
default:
RDCERR("Unhandled wave opcode");
RDCERR("Unhandled ActiveOp wave opcode");
accum.value = {};
break;
case WaveOpCode::Sum: accum.value = {}; break;
}
}
else if(dxOpCode == DXOp::WaveAnyTrue)
@@ -3775,7 +3951,6 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
{
switch(waveOpCode)
{
default: RDCERR("Unhandled wave opcode"); break;
case WaveOpCode::Sum:
{
for(uint8_t c = 0; c < x.columns; c++)
@@ -3800,6 +3975,7 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
}
break;
}
default: RDCERR("Unhandled ActiveOp wave opcode"); break;
}
}
else if(dxOpCode == DXOp::WaveAnyTrue)
@@ -4132,9 +4308,7 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
// Wave Operations
case DXOp::WaveActiveAllEqual:
case DXOp::WaveActiveBit:
case DXOp::WavePrefixOp:
case DXOp::WaveAllBitCount:
case DXOp::WavePrefixBitCount:
case DXOp::WaveMatch:
case DXOp::WaveMultiPrefixOp:
case DXOp::WaveMultiPrefixBitCount:
@@ -1879,6 +1879,8 @@ rdcstr Program::GetDebugStatus()
case DXOp::WaveReadLaneAt:
case DXOp::WaveReadLaneFirst:
case DXOp::WaveActiveOp:
case DXOp::WavePrefixOp:
case DXOp::WavePrefixBitCount:
if(!D3D_Hack_EnableGroups())
return StringFormat::Fmt("Unsupported dx.op call `%s` %s", callFunc->name.c_str(),
ToStr(dxOpCode).c_str());
@@ -1905,9 +1907,7 @@ rdcstr Program::GetDebugStatus()
case DXOp::CycleCounterLegacy:
case DXOp::WaveActiveAllEqual:
case DXOp::WaveActiveBit:
case DXOp::WavePrefixOp:
case DXOp::WaveAllBitCount:
case DXOp::WavePrefixBitCount:
case DXOp::AttributeAtVertex:
case DXOp::InstanceID:
case DXOp::InstanceIndex: