D3D12_Subgroup_Zoo tests for Quad Ops in Compute Shader

QuadReadLaneAt
QuadReadAcrossDiagonal
QuadReadAcrossX
QuadReadAcrossY
QuadAny
QuadAll
This commit is contained in:
Jake Turner
2025-12-11 14:57:20 +13:00
parent 823ab380c1
commit e2ca32a348
+107 -17
View File
@@ -280,7 +280,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
)EOSHADER";
const std::string comp65 = compCommon + R"EOSHADER(
const std::string comp66 = compCommon + R"EOSHADER(
[numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)]
void main(uint3 inTid : SV_DispatchThreadID)
@@ -312,6 +312,55 @@ void main(uint3 inTid : SV_DispatchThreadID)
testResult.z = WaveMultiPrefixBitOr(id, mask);
testResult.w = WaveMultiPrefixBitXor(id, mask);
}
if(IsTest(2))
{
// QuadReadLaneAt: unit tests
testResult.x = float(QuadReadLaneAt(id, 0));
testResult.y = float(QuadReadLaneAt(id, 1));
testResult.z = float(QuadReadLaneAt(id, 2));
testResult.w = float(QuadReadLaneAt(id, 3));
}
if(IsTest(3))
{
// QuadReadAcrossDiagonal, QuadReadAcrossX, QuadReadAcrossY: unit tests
testResult.x = float(QuadReadAcrossDiagonal(id));
testResult.y = float(QuadReadAcrossX(id));
testResult.z = float(QuadReadAcrossY(id));
testResult.w = QuadReadLaneAt(testResult.x, 2);
}
if(IsTest(4))
{
// QuadAny, QuadAll: unit tests
testResult.x = float(QuadAny(id > 2));
testResult.y = float(QuadAll(id < 10));
testResult.z = float(QuadAny(testResult.x == 0.0f));
testResult.w = float(QuadAll(testResult.x == 0.0f));
}
SetOutput(testResult);
}
)EOSHADER";
const std::string comp67 = compCommon + R"EOSHADER(
[numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)]
void main(uint3 inTid : SV_DispatchThreadID)
{
float4 testResult = 0.0f.xxxx;
tid = inTid;
uint id = WaveGetLaneIndex();
SetOutput(id);
if(IsTest(0))
{
// SM6.7 functions : QuadAny, QuadAll: unit tests
testResult.x = float(QuadAny(id > 2));
testResult.y = float(QuadAll(id < 10));
testResult.z = float(QuadAny(testResult.x == 0.0f));
testResult.w = float(QuadAll(testResult.x == 0.0f));
}
SetOutput(testResult);
}
@@ -351,7 +400,8 @@ void main(uint3 inTid : SV_DispatchThreadID)
int numPixelTests60 = 0;
int numPixelTests67 = 0;
int numCompTests60 = 0;
int numCompTests65 = 0;
int numCompTests66 = 0;
int numCompTests67 = 0;
{
size_t pos = 0;
@@ -370,7 +420,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
if(pos == std::string::npos)
break;
pos += sizeof("IsTest(") - 1;
numPixelTests67 = std::max(numCompTests65, atoi(pixel67.c_str() + pos) + 1);
numPixelTests67 = std::max(numPixelTests67, atoi(pixel67.c_str() + pos) + 1);
}
pos = 0;
@@ -395,17 +445,26 @@ void main(uint3 inTid : SV_DispatchThreadID)
pos = 0;
while(pos != std::string::npos)
{
pos = comp65.find("IsTest(", pos);
pos = comp66.find("IsTest(", pos);
if(pos == std::string::npos)
break;
pos += sizeof("IsTest(") - 1;
numCompTests65 = std::max(numCompTests65, atoi(comp65.c_str() + pos) + 1);
numCompTests66 = std::max(numCompTests66, atoi(comp66.c_str() + pos) + 1);
}
pos = 0;
while(pos != std::string::npos)
{
pos = comp67.find("IsTest(", pos);
if(pos == std::string::npos)
break;
pos += sizeof("IsTest(") - 1;
numCompTests67 = std::max(numCompTests67, atoi(comp67.c_str() + pos) + 1);
}
}
const uint32_t numGraphicsTests60 = std::max(vertTests, numPixelTests60);
const uint32_t numGraphicsTests67 = numPixelTests67;
const uint32_t numCompTests = std::max(numCompTests60, numCompTests65);
const uint32_t numCompTests = std::max(std::max(numCompTests60, numCompTests66), numCompTests67);
struct
{
@@ -414,16 +473,18 @@ void main(uint3 inTid : SV_DispatchThreadID)
{256, 1},
{128, 2},
{8, 128},
{150, 1},
{152, 1},
};
std::string comppipe_name[ARRAY_COUNT(compsize)];
ID3D12PipelineStatePtr comppipe[ARRAY_COUNT(compsize)];
ID3D12PipelineStatePtr comppipe65[ARRAY_COUNT(compsize)];
ID3D12PipelineStatePtr comppipe66[ARRAY_COUNT(compsize)];
ID3D12PipelineStatePtr comppipe67[ARRAY_COUNT(compsize)];
std::string defines60;
std::string defines65;
std::string defines66;
std::string defines67;
bool supportSM65 = (m_HighestShaderModel >= D3D_SHADER_MODEL_6_5) && m_DXILSupport;
bool supportSM66 = (m_HighestShaderModel >= D3D_SHADER_MODEL_6_6) && m_DXILSupport;
bool supportSM67 = (m_HighestShaderModel >= D3D_SHADER_MODEL_6_7) && m_DXILSupport;
ID3D12PipelineStatePtr graphics60 = MakePSO()
@@ -450,11 +511,17 @@ void main(uint3 inTid : SV_DispatchThreadID)
MakePSO().RootSig(sig).CS(Compile(defines60 + sizedefine + comp, "main", "cs_6_0"));
comppipe[i]->SetName(UTF82Wide(comppipe_name[i]).c_str());
if(supportSM65)
if(supportSM66)
{
comppipe65[i] =
MakePSO().RootSig(sig).CS(Compile(defines65 + sizedefine + comp65, "main", "cs_6_5"));
comppipe65[i]->SetName(UTF82Wide(comppipe_name[i]).c_str());
comppipe66[i] =
MakePSO().RootSig(sig).CS(Compile(defines66 + sizedefine + comp66, "main", "cs_6_6"));
comppipe66[i]->SetName(UTF82Wide(comppipe_name[i]).c_str());
}
if(supportSM67)
{
comppipe67[i] =
MakePSO().RootSig(sig).CS(Compile(defines67 + sizedefine + comp67, "main", "cs_6_7"));
comppipe67[i]->SetName(UTF82Wide(comppipe_name[i]).c_str());
}
}
@@ -543,7 +610,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
popMarker(cmd);
}
for(size_t p = 0; p < ARRAY_COUNT(comppipe65); p++)
for(size_t p = 0; p < ARRAY_COUNT(comppipe66); p++)
{
ResourceBarrier(cmd);
@@ -553,11 +620,34 @@ void main(uint3 inTid : SV_DispatchThreadID)
ResourceBarrier(cmd);
pushMarker(cmd, comppipe_name[p]);
cmd->SetPipelineState(comppipe65[p]);
cmd->SetPipelineState(comppipe66[p]);
cmd->SetComputeRootSignature(sig);
cmd->SetComputeRootUnorderedAccessView(1, bufOut->GetGPUVirtualAddress());
for(int i = 0; i < numCompTests65; i++)
for(int i = 0; i < numCompTests66; i++)
{
cmd->SetComputeRoot32BitConstant(0, i, 0);
cmd->Dispatch(1, 1, 1);
}
popMarker(cmd);
}
for(size_t p = 0; p < ARRAY_COUNT(comppipe67); p++)
{
ResourceBarrier(cmd);
UINT zero[4] = {};
cmd->ClearUnorderedAccessViewUint(uavgpu, uavcpu, bufOut, zero, 0, NULL);
ResourceBarrier(cmd);
pushMarker(cmd, comppipe_name[p]);
cmd->SetPipelineState(comppipe67[p]);
cmd->SetComputeRootSignature(sig);
cmd->SetComputeRootUnorderedAccessView(1, bufOut->GetGPUVirtualAddress());
for(int i = 0; i < numCompTests67; i++)
{
cmd->SetComputeRoot32BitConstant(0, i, 0);
cmd->Dispatch(1, 1, 1);