Divide up tests between Workgroup_Zoo and Subgroup_Zoo

Subgroup_Zoo : unit tests, non-trivial convergence tests moved to Workgroup_Zoo
Workgroup_Zoo : convergence tests, small number of unit tests (not full coverage)

Added checks for workgroup convergence in Workgroup_Zoo tests
* Vulkan uses barrier()
* D3D12 uses AllMemoryBarrierWithGroupSync()
* dispatches workgroup of 2x1x1
* test debug results for workgroup 1,0,0
This commit is contained in:
Jake Turner
2025-04-18 14:17:04 +01:00
parent 79e80c8337
commit 2fbf85fad1
6 changed files with 69 additions and 436 deletions
+6 -133
View File
@@ -173,46 +173,6 @@ float4 main(IN input) : SV_Target0
const std::string comp = compCommon + R"EOSHADER(
float4 funcD(uint id)
{
return WaveActiveSum(id/2).xxxx;
}
float4 nestedFunc(uint id)
{
float4 ret = funcD(id/3);
ret.w = WaveActiveSum(id);
return ret;
}
float4 funcA(uint id)
{
return nestedFunc(id*2);
}
float4 funcB(uint id)
{
return nestedFunc(id*4);
}
float4 funcTest(uint id)
{
if ((id % 2) == 0)
{
return 0.xxxx;
}
else
{
float value = WaveActiveSum(id);
if (id < 10)
{
return value.xxxx;
}
value += WaveActiveSum(id/2);
return value.xxxx;
}
}
[numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)]
void main(uint3 inTid : SV_DispatchThreadID)
{
@@ -224,100 +184,13 @@ void main(uint3 inTid : SV_DispatchThreadID)
SetOutput(id);
if(IsTest(0))
{
data.x = id;
}
else if(IsTest(1))
{
data.x = WaveActiveSum(id);
}
else if(IsTest(2))
{
// Diverged threads which reconverge
if (id < 10)
{
// active threads 0-9
data.x = WaveActiveSum(id);
if ((id % 2) == 0)
data.y = WaveActiveSum(id);
else
data.y = WaveActiveSum(id);
data.x += WaveActiveSum(id);
}
else
{
// active threads 10...
data.x = WaveActiveSum(id);
}
data.y = WaveActiveSum(id);
}
else if(IsTest(3))
{
// Converged threads calling a function
data = funcTest(id);
data.y = WaveActiveSum(id);
}
else if(IsTest(4))
{
// Converged threads calling a function which has a nested function call in it
data = nestedFunc(id);
data.y = WaveActiveSum(id);
}
else if(IsTest(5))
{
// Diverged threads calling the same function
if (id < 10)
{
data = funcD(id);
}
else
{
data = funcD(id);
}
data.y = WaveActiveSum(id);
}
else if(IsTest(6))
{
// Diverged threads calling the same function which has a nested function call in it
if (id < 10)
{
data = funcA(id);
}
else
{
data = funcB(id);
}
data.y = WaveActiveSum(id);
}
else if(IsTest(7))
{
// Diverged threads which early exit
if (id < 10)
{
data.x = WaveActiveSum(id+10);
SetOutput(data);
return;
}
data.x = WaveActiveSum(id);
}
else if(IsTest(8))
{
// Loops with different number of iterations per thread
for (uint i = 0; i < id; i++)
{
data.x += WaveActiveSum(id);
}
}
else if(IsTest(9))
{
// Query functions : unit tests
data.x = float(WaveGetLaneCount());
data.y = float(WaveGetLaneIndex());
data.z = float(WaveIsFirstLane());
}
else if(IsTest(10))
else if(IsTest(1))
{
// Vote functions : unit tests
data.x = float(WaveActiveAnyTrue(id*2 > id+10));
@@ -335,7 +208,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
data.w = countbits(ballot.x) + countbits(ballot.y) + countbits(ballot.z) + countbits(ballot.w);
}
}
else if(IsTest(11))
else if(IsTest(2))
{
// Broadcast functions : unit tests
if (id >= 2 && id <= 20)
@@ -346,7 +219,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
data.w = WaveReadLaneAt(data.x, 2+id%3);
}
}
else if(IsTest(12))
else if(IsTest(3))
{
// Scan and Prefix functions : unit tests
if (id >= 2 && id <= 20)
@@ -364,7 +237,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
data.w = WavePrefixSum(data.y);
}
}
else if(IsTest(13))
else if(IsTest(4))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
@@ -375,7 +248,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
data.w = float(WaveActiveSum(id));
}
}
else if(IsTest(14))
else if(IsTest(5))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
@@ -386,7 +259,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
data.w = float(WaveActiveBitXor(id));
}
}
else if(IsTest(15))
else if(IsTest(6))
{
// Reduction functions : unit tests
if (id > 13)
+40 -82
View File
@@ -98,6 +98,34 @@ float4 funcTest(uint id)
}
}
float4 ComplexPartialReconvergence(uint id)
{
float4 result = 0.0.xxxx;
// Loops with different number of iterations per thread
for (uint i = id; i < 23; i++)
{
result.x += WaveActiveSum(id);
}
if ((result.x < 5) || (id > 20))
{
result.y += WaveActiveSum(id);
if (id < 25)
result.z += WaveActiveSum(id);
}
else if (result.x < 10)
{
result.y += WaveActiveSum(id);
if (result.x > 5)
result.z += WaveActiveSum(id);
}
result.w += WaveActiveSum(id);
return result;
}
[numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)]
void main(uint3 inTid : SV_DispatchThreadID)
{
@@ -110,10 +138,12 @@ void main(uint3 inTid : SV_DispatchThreadID)
if(IsTest(0))
{
data.x = id;
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(1))
{
data.x = WaveActiveSum(id);
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(2))
{
@@ -136,18 +166,21 @@ void main(uint3 inTid : SV_DispatchThreadID)
data.x = WaveActiveSum(id);
}
data.y = WaveActiveSum(id);
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(3))
{
// Converged threads calling a function
data = funcTest(id);
data.y = WaveActiveSum(id);
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(4))
{
// Converged threads calling a function which has a nested function call in it
data = nestedFunc(id);
data.y = WaveActiveSum(id);
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(5))
{
@@ -161,6 +194,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
data = funcD(id);
}
data.y = WaveActiveSum(id);
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(6))
{
@@ -174,6 +208,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
data = funcB(id);
}
data.y = WaveActiveSum(id);
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(7))
{
@@ -193,6 +228,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
{
data.x += WaveActiveSum(id);
}
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(9))
{
@@ -200,91 +236,13 @@ void main(uint3 inTid : SV_DispatchThreadID)
data.x = float(WaveGetLaneCount());
data.y = float(WaveGetLaneIndex());
data.z = float(WaveIsFirstLane());
AllMemoryBarrierWithGroupSync();
}
else if(IsTest(10))
{
// Vote functions : unit tests
data.x = float(WaveActiveAnyTrue(id*2 > id+10));
data.y = float(WaveActiveAllTrue(id < WaveGetLaneCount()));
if (id > 10)
{
data.z = float(WaveActiveAllTrue(id > 10));
uint4 ballot = WaveActiveBallot(id > 20);
data.w = countbits(ballot.x) + countbits(ballot.y) + countbits(ballot.z) + countbits(ballot.w);
}
else
{
data.z = float(WaveActiveAllTrue(id > 3));
uint4 ballot = WaveActiveBallot(id > 4);
data.w = countbits(ballot.x) + countbits(ballot.y) + countbits(ballot.z) + countbits(ballot.w);
}
}
else if(IsTest(11))
{
// Broadcast functions : unit tests
if (id >= 2 && id <= 20)
{
data.x = WaveReadLaneFirst(id);
data.y = WaveReadLaneAt(id, 5);
data.z = WaveReadLaneAt(id, id);
data.w = WaveReadLaneAt(data.x, 2+id%3);
}
}
else if(IsTest(12))
{
// Scan and Prefix functions : unit tests
if (id >= 2 && id <= 20)
{
data.x = WavePrefixCountBits(id > 4);
data.y = WavePrefixCountBits(id > 10);
data.z = WavePrefixSum(data.x);
data.w = WavePrefixProduct(1 + data.y);
}
else
{
data.x = WavePrefixCountBits(id > 23);
data.y = WavePrefixCountBits(id < 1);
data.z = WavePrefixSum(data.x);
data.w = WavePrefixSum(data.y);
}
}
else if(IsTest(13))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
{
data.x = float(WaveActiveMax(id));
data.y = float(WaveActiveMin(id));
data.z = float(WaveActiveProduct(id));
data.w = float(WaveActiveSum(id));
}
}
else if(IsTest(14))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
{
data.x = float(WaveActiveCountBits(id > 23));
data.y = float(WaveActiveBitAnd(id));
data.z = float(WaveActiveBitOr(id));
data.w = float(WaveActiveBitXor(id));
}
}
else if(IsTest(15))
{
// Reduction functions : unit tests
if (id > 13)
{
bool test1 = (id > 15).x;
bool2 test2 = bool2(test1, (id < 23));
bool3 test3 = bool3(test1, (id < 23), (id >= 25));
bool4 test4 = bool4(test1, (id < 23), (id >= 25), (id >= 28));
data = ComplexPartialReconvergence(id);
data.x = float(WaveActiveAllEqual(test1).x);
data.y = float(WaveActiveAllEqual(test2).y);
data.z = float(WaveActiveAllEqual(test3).z);
data.w = float(WaveActiveAllEqual(test4).w);
}
AllMemoryBarrierWithGroupSync();
}
SetOutput(data);
@@ -396,7 +354,7 @@ void main(uint3 inTid : SV_DispatchThreadID)
for(int i = 0; i < numCompTests; i++)
{
cmd->SetComputeRoot32BitConstant(0, i, 0);
cmd->Dispatch(1, 1, 1);
cmd->Dispatch(2, 1, 1);
}
popMarker(cmd);
+6 -133
View File
@@ -151,46 +151,6 @@ layout(binding = 0, std430) buffer outbuftype {
layout(local_size_x = GROUP_SIZE_X, local_size_y = GROUP_SIZE_Y, local_size_z = 1) in;
vec4 funcD(uint id)
{
return vec4(subgroupAdd(id/2));
}
vec4 nestedFunc(uint id)
{
vec4 ret = funcD(id/3);
ret.w = subgroupAdd(id);
return ret;
}
vec4 funcA(uint id)
{
return nestedFunc(id*2);
}
vec4 funcB(uint id)
{
return nestedFunc(id*4);
}
vec4 funcTest(uint id)
{
if ((id % 2) == 0)
{
return vec4(0);
}
else
{
float value = subgroupAdd(id);
if (id < 10)
{
return vec4(value);
}
value += subgroupAdd(id/2);
return vec4(value);
}
}
void SetOutput(vec4 data)
{
outbuf.data[push.test].vals[gl_LocalInvocationID.y * GROUP_SIZE_X + gl_LocalInvocationID.x] = data;
@@ -202,100 +162,13 @@ void main()
SetOutput(data);
if(IsTest(0))
{
data.x = id;
}
else if(IsTest(1))
{
data.x = subgroupAdd(id);
}
else if(IsTest(2))
{
// Diverged threads which reconverge
if (id < 10)
{
// active threads 0-9
data.x = subgroupAdd(id);
if ((id % 2) == 0)
data.y = subgroupAdd(id);
else
data.y = subgroupAdd(id);
data.x += subgroupAdd(id);
}
else
{
// active threads 10...
data.x = subgroupAdd(id);
}
data.y = subgroupAdd(id);
}
else if(IsTest(3))
{
// Converged threads calling a function
data = funcTest(id);
data.y = subgroupAdd(id);
}
else if(IsTest(4))
{
// Converged threads calling a function which has a nested function call in it
data = nestedFunc(id);
data.y = subgroupAdd(id);
}
else if(IsTest(5))
{
// Diverged threads calling the same function
if (id < 10)
{
data = funcD(id);
}
else
{
data = funcD(id);
}
data.y = subgroupAdd(id);
}
else if(IsTest(6))
{
// Diverged threads calling the same function which has a nested function call in it
if (id < 10)
{
data = funcA(id);
}
else
{
data = funcB(id);
}
data.y = subgroupAdd(id);
}
else if(IsTest(7))
{
// Diverged threads which early exit
if (id < 10)
{
data.x = subgroupAdd(id+10);
SetOutput(data);
return;
}
data.x = subgroupAdd(id);
}
else if(IsTest(8))
{
// Loops with different number of iterations per thread
for (uint i = 0; i < id; i++)
{
data.x += subgroupAdd(id);
}
}
else if(IsTest(9))
{
// Query functions : unit tests
data.x = float(gl_SubgroupSize);
data.y = float(gl_SubgroupInvocationID);
data.z = float(subgroupElect());
}
else if(IsTest(10))
else if(IsTest(1))
{
// Vote functions : unit tests
data.x = float(subgroupAny(id*2 > id+10));
@@ -313,7 +186,7 @@ void main()
data.w = bitCount(ballot.x) + bitCount(ballot.y) + bitCount(ballot.z) + bitCount(ballot.w);
}
}
else if(IsTest(11))
else if(IsTest(2))
{
// Broadcast functions : unit tests
if (id >= 2 && id <= 20)
@@ -324,7 +197,7 @@ void main()
data.w = subgroupShuffle(data.x, 2+id%3);
}
}
else if(IsTest(12))
else if(IsTest(3))
{
// Scan and Prefix functions : unit tests
if (id >= 2 && id <= 20)
@@ -346,7 +219,7 @@ void main()
data.w = subgroupExclusiveAdd(data.y);
}
}
else if(IsTest(13))
else if(IsTest(4))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
@@ -357,7 +230,7 @@ void main()
data.w = float(subgroupAdd(id));
}
}
else if(IsTest(14))
else if(IsTest(5))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
@@ -369,7 +242,7 @@ void main()
data.w = float(subgroupXor(id));
}
}
else if(IsTest(15))
else if(IsTest(6))
{
// Reduction functions : unit tests
if (id > 13)
+14 -87
View File
@@ -69,7 +69,7 @@ layout(push_constant) uniform PushData
const std::string comp = common + R"EOSHADER(
shared uvec4 gsmUint4[COMP_TESTS];
shared uvec4 gsmUint4[1024];
struct Output
{
@@ -126,6 +126,7 @@ void SetOutput(vec4 data)
{
outbuf.data[push.test].vals[gl_LocalInvocationID.y * GROUP_SIZE_X + gl_LocalInvocationID.x] = data;
}
void main()
{
vec4 data = vec4(0);
@@ -136,10 +137,12 @@ void main()
if(IsTest(0))
{
data.x = id;
barrier();
}
else if(IsTest(1))
{
data.x = subgroupAdd(id);
barrier();
}
else if(IsTest(2))
{
@@ -162,18 +165,21 @@ void main()
data.x = subgroupAdd(id);
}
data.y = subgroupAdd(id);
barrier();
}
else if(IsTest(3))
{
// Converged threads calling a function
data = funcTest(id);
data.y = subgroupAdd(id);
barrier();
}
else if(IsTest(4))
{
// Converged threads calling a function which has a nested function call in it
data = nestedFunc(id);
data.y = subgroupAdd(id);
barrier();
}
else if(IsTest(5))
{
@@ -187,6 +193,7 @@ void main()
data = funcD(id);
}
data.y = subgroupAdd(id);
barrier();
}
else if(IsTest(6))
{
@@ -200,6 +207,7 @@ void main()
data = funcB(id);
}
data.y = subgroupAdd(id);
barrier();
}
else if(IsTest(7))
{
@@ -219,6 +227,7 @@ void main()
{
data.x += subgroupAdd(id);
}
barrier();
}
else if(IsTest(9))
{
@@ -226,92 +235,10 @@ void main()
data.x = float(gl_SubgroupSize);
data.y = float(gl_SubgroupInvocationID);
data.z = float(subgroupElect());
barrier();
}
else if(IsTest(10))
{
// Vote functions : unit tests
data.x = float(subgroupAny(id*2 > id+10));
data.y = float(subgroupAll(id < gl_SubgroupSize));
if (id > 10)
{
data.z = float(subgroupAll(id > 10));
uvec4 ballot = subgroupBallot(id > 20);
data.w = bitCount(ballot.x) + bitCount(ballot.y) + bitCount(ballot.z) + bitCount(ballot.w);
}
else
{
data.z = float(subgroupAll(id > 3));
uvec4 ballot = subgroupBallot(id > 4);
data.w = bitCount(ballot.x) + bitCount(ballot.y) + bitCount(ballot.z) + bitCount(ballot.w);
}
}
else if(IsTest(11))
{
// Broadcast functions : unit tests
if (id >= 2 && id <= 20)
{
data.x = subgroupBroadcastFirst(id);
data.y = subgroupBroadcast(id, 5);
data.z = subgroupShuffle(id, id);
data.w = subgroupShuffle(data.x, 2+id%3);
}
}
else if(IsTest(12))
{
// Scan and Prefix functions : unit tests
if (id >= 2 && id <= 20)
{
uvec4 bits = subgroupBallot(id > 4);
data.x = subgroupBallotExclusiveBitCount(bits);
bits = subgroupBallot(id > 10);
data.y = subgroupBallotExclusiveBitCount(bits);
data.z = subgroupExclusiveAdd(data.x);
data.w = subgroupExclusiveMul(1 + data.y);
}
else
{
uvec4 bits = subgroupBallot(id > 23);
data.x = subgroupBallotExclusiveBitCount(bits);
bits = subgroupBallot(id < 1);
data.y = subgroupBallotExclusiveBitCount(bits);
data.z = subgroupExclusiveAdd(data.x);
data.w = subgroupExclusiveAdd(data.y);
}
}
else if(IsTest(13))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
{
data.x = float(subgroupMax(id));
data.y = float(subgroupMin(id));
data.z = float(subgroupMul(id));
data.w = float(subgroupAdd(id));
}
}
else if(IsTest(14))
{
// Reduction functions : unit tests
if (id >= 2 && id <= 20)
{
uvec4 bits = subgroupBallot(id > 23);
data.x = float(subgroupBallotBitCount(bits));
data.y = float(subgroupAnd(id));
data.z = float(subgroupOr(id));
data.w = float(subgroupXor(id));
}
}
else if(IsTest(15))
{
// Reduction functions : unit tests
if (id > 13)
{
data.x = float(subgroupAllEqual(id > 15));
data.y = float(subgroupAllEqual(id < 23));
data.z = float(subgroupAllEqual(id >= 25));
data.w = float(subgroupAllEqual(id >= 28));
}
}
SetOutput(data);
}
@@ -467,7 +394,7 @@ void main()
for(int i = 0; i < numCompTests; i++)
{
vkh::cmdPushConstants(cmd, layout, i);
vkCmdDispatch(cmd, 1, 1, 1);
vkCmdDispatch(cmd, 2, 1, 1);
}
popMarker(cmd);
+2 -1
View File
@@ -6,6 +6,7 @@ import rdtest
class Subgroup_Zoo(rdtest.TestCase):
internal = True
demos_test_name = None
workgroup = (0, 0, 0)
def check_support(self, **kwargs):
# Only allow this if explicitly run
@@ -19,7 +20,7 @@ class Subgroup_Zoo(rdtest.TestCase):
"4f", bufdata, 16*y*dim[0] + 16*x)
trace = self.controller.DebugThread(
(0, 0, 0), (x, y, z))
self.workgroup, (x, y, z))
_, variables = self.process_trace(trace)
+1
View File
@@ -24,6 +24,7 @@ class Workgroup_Zoo(rdtest.Subgroup_Zoo):
150
]
self.workgroup = (1, 0, 0)
if self.check_compute_tests(compute_dims, thread_checks):
raise rdtest.TestFailureException("Some tests were not as expected")