Add DXIL Debugger Support for Wave Reduction ops

DXOp::WaveActiveAllEqual
DXOp::WaveActiveBit
DXOp::WaveAllBitCount
DXOp::WaveActiveOp (WaveActiveOp::Product, WaveActiveOp::Min, WaveActiveOp::Max)
This commit is contained in:
Jake Turner
2025-04-07 14:26:27 +01:00
parent cc32e24d81
commit f0936cdf1b
2 changed files with 264 additions and 64 deletions
+261 -61
View File
@@ -3794,7 +3794,7 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
RDCASSERT(GetShaderVariable(inst.args[3], opCode, dxOpCode, arg));
bool isUnsigned = (arg.value.u32v[0] != (uint32_t)SignedOpKind::Signed);
// set the identity
// set the initial value
ShaderVariable accum(result);
switch(waveOpCode)
{
@@ -3875,17 +3875,22 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
break;
}
case DXOp::WavePrefixBitCount:
case DXOp::WaveAllBitCount:
{
// WavePrefixBitCount(cond)
// WaveAllBitCount(cond)
// determine active lane indices in our subgroup
rdcarray<uint32_t> activeLanes;
GetSubgroupActiveLanes(activeMask, workgroup, activeLanes);
uint32_t maxLane = (dxOpCode == DXOp::WavePrefixBitCount) ? m_WorkgroupIndex : UINT32_MAX;
uint32_t count = 0;
for(uint32_t lane : activeLanes)
{
// stop before processing our lane
if(lane == m_WorkgroupIndex)
if(lane == maxLane)
break;
ShaderVariable x;
@@ -3899,34 +3904,14 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
case DXOp::WaveAnyTrue:
case DXOp::WaveAllTrue:
case DXOp::WaveActiveBallot:
case DXOp::WaveActiveOp:
case DXOp::WaveActiveAllEqual:
{
ShaderVariable accum(result);
bool isUnsigned = true;
WaveOpCode waveOpCode = WaveOpCode::Sum;
if(dxOpCode == DXOp::WaveActiveOp)
{
// WaveActiveOp(value,op,sop)
ShaderVariable arg;
RDCASSERT(GetShaderVariable(inst.args[2], opCode, dxOpCode, arg));
waveOpCode = (WaveOpCode)arg.value.u32v[0];
ShaderVariable refValue;
RDCASSERT(GetShaderVariable(inst.args[1], opCode, dxOpCode, refValue));
RDCASSERT(GetShaderVariable(inst.args[3], opCode, dxOpCode, arg));
isUnsigned = (arg.value.u32v[0] != (uint32_t)SignedOpKind::Signed);
// set the identity
switch(waveOpCode)
{
case WaveOpCode::Sum: SetShaderValueZero(accum); break;
case WaveOpCode::Product: SetShaderValueOne(accum); break;
default:
RDCERR("Unhandled ActiveOp wave opcode");
accum.value = {};
break;
}
}
else if(dxOpCode == DXOp::WaveAnyTrue)
if(dxOpCode == DXOp::WaveAnyTrue)
{
// WaveAnyTrue(cond)
accum.value.u32v[0] = 0;
@@ -3936,6 +3921,15 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
// WaveAllTrue(cond)
accum.value.u32v[0] = 1;
}
else if(dxOpCode == DXOp::WaveActiveAllEqual)
{
// WaveActiveAllEqual(value)
accum.value.u32v[0] = 1;
}
else
{
RDCERR("Unhandled dxOpCode %s", ToStr(dxOpCode).c_str());
}
// determine active lane indices in our subgroup
rdcarray<uint32_t> activeLanes;
@@ -3947,38 +3941,7 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
ShaderVariable x;
RDCASSERT(workgroup[lane].GetShaderVariable(inst.args[1], opCode, dxOpCode, x));
if(dxOpCode == DXOp::WaveActiveOp)
{
switch(waveOpCode)
{
case WaveOpCode::Sum:
{
for(uint8_t c = 0; c < x.columns; c++)
{
if(isUnsigned)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(accum, c) = comp<U>(accum, c) + comp<U>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
else
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<S>(accum, c) + comp<S>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) comp<T>(accum, c) = comp<T>(accum, c) + comp<T>(x, c)
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
}
}
break;
}
default: RDCERR("Unhandled ActiveOp wave opcode"); break;
}
}
else if(dxOpCode == DXOp::WaveAnyTrue)
if(dxOpCode == DXOp::WaveAnyTrue)
{
accum.value.u32v[0] |= x.value.u32v[0];
}
@@ -3994,6 +3957,246 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
if(x.value.u32v[0])
accum.value.u32v[c] |= bit;
}
else if(dxOpCode == DXOp::WaveActiveAllEqual)
{
for(uint8_t c = 0; c < x.columns; c++)
{
bool matches = false;
#undef _IMPL
#define _IMPL(I, S, U) matches = (comp<I>(x, c) == comp<I>(refValue, c));
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) matches = (comp<T>(x, c) == comp<T>(refValue, c));
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
accum.value.u32v[c] &= (matches ? 1 : 0);
}
}
}
result.value = accum.value;
break;
}
case DXOp::WaveActiveOp:
{
// WaveActiveOp(value,op,sop)
ShaderVariable accum(result);
ShaderVariable refValue;
RDCASSERT(GetShaderVariable(inst.args[1], opCode, dxOpCode, refValue));
ShaderVariable arg;
RDCASSERT(GetShaderVariable(inst.args[2], opCode, dxOpCode, arg));
WaveOpCode waveOpCode = (WaveOpCode)arg.value.u32v[0];
RDCASSERT(GetShaderVariable(inst.args[3], opCode, dxOpCode, arg));
bool isUnsigned = (arg.value.u32v[0] != (uint32_t)SignedOpKind::Signed);
// set the initial value
switch(waveOpCode)
{
case WaveOpCode::Sum: SetShaderValueZero(accum); break;
case WaveOpCode::Product: SetShaderValueOne(accum); break;
case WaveOpCode::Min:
case WaveOpCode::Max:
{
accum.value = refValue.value;
break;
}
default:
RDCERR("Unhandled ActiveOp wave opcode");
accum.value = {};
break;
}
// determine active lane indices in our subgroup
rdcarray<uint32_t> activeLanes;
GetSubgroupActiveLanes(activeMask, workgroup, activeLanes);
for(uint32_t lane : activeLanes)
{
ShaderVariable x;
RDCASSERT(workgroup[lane].GetShaderVariable(inst.args[1], opCode, dxOpCode, x));
switch(waveOpCode)
{
case WaveOpCode::Sum:
{
for(uint8_t c = 0; c < x.columns; c++)
{
if(isUnsigned)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(accum, c) = comp<U>(accum, c) + comp<U>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
else
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<S>(accum, c) + comp<S>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) comp<T>(accum, c) = comp<T>(accum, c) + comp<T>(x, c)
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
}
}
break;
}
case WaveOpCode::Product:
{
for(uint8_t c = 0; c < x.columns; c++)
{
if(isUnsigned)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(accum, c) = comp<U>(accum, c) * comp<U>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
else
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<S>(accum, c) * comp<S>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) comp<T>(accum, c) = comp<T>(accum, c) * comp<T>(x, c)
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
}
}
break;
}
case WaveOpCode::Min:
{
for(uint8_t c = 0; c < x.columns; c++)
{
if(isUnsigned)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(accum, c) = RDCMIN(comp<U>(accum, c), comp<U>(x, c))
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
else
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = RDCMIN(comp<S>(accum, c), comp<S>(x, c))
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) comp<T>(accum, c) = RDCMIN(comp<T>(accum, c), comp<T>(x, c))
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
}
}
break;
}
case WaveOpCode::Max:
{
for(uint8_t c = 0; c < x.columns; c++)
{
if(isUnsigned)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<U>(accum, c) = RDCMAX(comp<U>(accum, c), comp<U>(x, c))
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
else
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = RDCMAX(comp<S>(accum, c), comp<S>(x, c))
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
#undef _IMPL
#define _IMPL(T) comp<T>(accum, c) = RDCMAX(comp<T>(accum, c), comp<T>(x, c))
IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, x.type);
}
}
break;
}
default: RDCERR("Unhandled ActiveOp wave opcode"); break;
}
}
result.value = accum.value;
break;
}
case DXOp::WaveActiveBit:
{
// WaveActiveBit(value,op)
ShaderVariable accum(result);
ShaderVariable refValue;
RDCASSERT(GetShaderVariable(inst.args[1], opCode, dxOpCode, refValue));
ShaderVariable arg;
RDCASSERT(GetShaderVariable(inst.args[2], opCode, dxOpCode, arg));
WaveBitOpCode waveBitOpCode = (WaveBitOpCode)arg.value.u32v[0];
// set the initial value
switch(waveBitOpCode)
{
case WaveBitOpCode::Or:
case WaveBitOpCode::Xor: SetShaderValueZero(accum); break;
case WaveBitOpCode::And:
{
accum.value = refValue.value;
break;
}
default:
RDCERR("Unhandled ActiveBitOp wave opcode");
accum.value = {};
break;
}
// determine active lane indices in our subgroup
rdcarray<uint32_t> activeLanes;
GetSubgroupActiveLanes(activeMask, workgroup, activeLanes);
for(uint32_t lane : activeLanes)
{
ShaderVariable x;
RDCASSERT(workgroup[lane].GetShaderVariable(inst.args[1], opCode, dxOpCode, x));
switch(waveBitOpCode)
{
case WaveBitOpCode::And:
{
for(uint8_t c = 0; c < x.columns; c++)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<I>(accum, c) & comp<I>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
break;
}
case WaveBitOpCode::Or:
{
for(uint8_t c = 0; c < x.columns; c++)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<I>(accum, c) | comp<I>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
break;
}
case WaveBitOpCode::Xor:
{
for(uint8_t c = 0; c < x.columns; c++)
{
#undef _IMPL
#define _IMPL(I, S, U) comp<S>(accum, c) = comp<I>(accum, c) ^ comp<I>(x, c)
IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, x.type);
}
break;
}
default: RDCERR("Unhandled ActiveBitOp wave opcode"); break;
}
}
result.value = accum.value;
@@ -4306,9 +4509,6 @@ bool ThreadState::ExecuteInstruction(DebugAPIWrapper *apiWrapper,
case DXOp::EmitThenCutStream:
// Wave Operations
case DXOp::WaveActiveAllEqual:
case DXOp::WaveActiveBit:
case DXOp::WaveAllBitCount:
case DXOp::WaveMatch:
case DXOp::WaveMultiPrefixOp:
case DXOp::WaveMultiPrefixBitCount:
@@ -1875,12 +1875,15 @@ rdcstr Program::GetDebugStatus()
case DXOp::WaveGetLaneCount:
case DXOp::WaveAnyTrue:
case DXOp::WaveAllTrue:
case DXOp::WaveActiveAllEqual:
case DXOp::WaveActiveBallot:
case DXOp::WaveReadLaneAt:
case DXOp::WaveReadLaneFirst:
case DXOp::WaveActiveOp:
case DXOp::WaveActiveBit:
case DXOp::WavePrefixOp:
case DXOp::WavePrefixBitCount:
case DXOp::WaveAllBitCount:
if(!D3D_Hack_EnableGroups())
return StringFormat::Fmt("Unsupported dx.op call `%s` %s", callFunc->name.c_str(),
ToStr(dxOpCode).c_str());
@@ -1905,9 +1908,6 @@ rdcstr Program::GetDebugStatus()
case DXOp::StorePatchConstant:
case DXOp::OutputControlPointID:
case DXOp::CycleCounterLegacy:
case DXOp::WaveActiveAllEqual:
case DXOp::WaveActiveBit:
case DXOp::WaveAllBitCount:
case DXOp::AttributeAtVertex:
case DXOp::InstanceID:
case DXOp::InstanceIndex: