diff --git a/util/test/demos/d3d12/d3d12_subgroup_zoo.cpp b/util/test/demos/d3d12/d3d12_subgroup_zoo.cpp index 0e15f12e1..08f7e4711 100644 --- a/util/test/demos/d3d12/d3d12_subgroup_zoo.cpp +++ b/util/test/demos/d3d12/d3d12_subgroup_zoo.cpp @@ -173,46 +173,6 @@ float4 main(IN input) : SV_Target0 const std::string comp = compCommon + R"EOSHADER( -float4 funcD(uint id) -{ - return WaveActiveSum(id/2).xxxx; -} - -float4 nestedFunc(uint id) -{ - float4 ret = funcD(id/3); - ret.w = WaveActiveSum(id); - return ret; -} - -float4 funcA(uint id) -{ - return nestedFunc(id*2); -} - -float4 funcB(uint id) -{ - return nestedFunc(id*4); -} - -float4 funcTest(uint id) -{ - if ((id % 2) == 0) - { - return 0.xxxx; - } - else - { - float value = WaveActiveSum(id); - if (id < 10) - { - return value.xxxx; - } - value += WaveActiveSum(id/2); - return value.xxxx; - } -} - [numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)] void main(uint3 inTid : SV_DispatchThreadID) { @@ -224,100 +184,13 @@ void main(uint3 inTid : SV_DispatchThreadID) SetOutput(id); if(IsTest(0)) - { - data.x = id; - } - else if(IsTest(1)) - { - data.x = WaveActiveSum(id); - } - else if(IsTest(2)) - { - // Diverged threads which reconverge - if (id < 10) - { - // active threads 0-9 - data.x = WaveActiveSum(id); - - if ((id % 2) == 0) - data.y = WaveActiveSum(id); - else - data.y = WaveActiveSum(id); - - data.x += WaveActiveSum(id); - } - else - { - // active threads 10... - data.x = WaveActiveSum(id); - } - data.y = WaveActiveSum(id); - } - else if(IsTest(3)) - { - // Converged threads calling a function - data = funcTest(id); - data.y = WaveActiveSum(id); - } - else if(IsTest(4)) - { - // Converged threads calling a function which has a nested function call in it - data = nestedFunc(id); - data.y = WaveActiveSum(id); - } - else if(IsTest(5)) - { - // Diverged threads calling the same function - if (id < 10) - { - data = funcD(id); - } - else - { - data = funcD(id); - } - data.y = WaveActiveSum(id); - } - else if(IsTest(6)) - { - // Diverged threads calling the same function which has a nested function call in it - if (id < 10) - { - data = funcA(id); - } - else - { - data = funcB(id); - } - data.y = WaveActiveSum(id); - } - else if(IsTest(7)) - { - // Diverged threads which early exit - if (id < 10) - { - data.x = WaveActiveSum(id+10); - SetOutput(data); - return; - } - data.x = WaveActiveSum(id); - } - else if(IsTest(8)) - { - // Loops with different number of iterations per thread - for (uint i = 0; i < id; i++) - { - data.x += WaveActiveSum(id); - } - } - else if(IsTest(9)) { // Query functions : unit tests data.x = float(WaveGetLaneCount()); data.y = float(WaveGetLaneIndex()); data.z = float(WaveIsFirstLane()); } - else if(IsTest(10)) + else if(IsTest(1)) { // Vote functions : unit tests data.x = float(WaveActiveAnyTrue(id*2 > id+10)); @@ -335,7 +208,7 @@ void main(uint3 inTid : SV_DispatchThreadID) data.w = countbits(ballot.x) + countbits(ballot.y) + countbits(ballot.z) + countbits(ballot.w); } } - else if(IsTest(11)) + else if(IsTest(2)) { // Broadcast functions : unit tests if (id >= 2 && id <= 20) @@ -346,7 +219,7 @@ void main(uint3 inTid : SV_DispatchThreadID) data.w = WaveReadLaneAt(data.x, 2+id%3); } } - else if(IsTest(12)) + else if(IsTest(3)) { // Scan and Prefix functions : unit tests if (id >= 2 && id <= 20) @@ -364,7 +237,7 @@ void main(uint3 inTid : SV_DispatchThreadID) data.w = WavePrefixSum(data.y); } } - else if(IsTest(13)) + else if(IsTest(4)) { // Reduction functions : unit tests if (id >= 2 && id <= 20) @@ -375,7 +248,7 @@ void main(uint3 inTid : SV_DispatchThreadID) data.w = float(WaveActiveSum(id)); } } - else if(IsTest(14)) + else if(IsTest(5)) { // Reduction functions : unit tests if (id >= 2 && id <= 20) @@ -386,7 +259,7 @@ void main(uint3 inTid : SV_DispatchThreadID) data.w = float(WaveActiveBitXor(id)); } } - else if(IsTest(15)) + else if(IsTest(6)) { // Reduction functions : unit tests if (id > 13) diff --git a/util/test/demos/d3d12/d3d12_workgroup_zoo.cpp b/util/test/demos/d3d12/d3d12_workgroup_zoo.cpp index 8c7555fea..a18e75839 100644 --- a/util/test/demos/d3d12/d3d12_workgroup_zoo.cpp +++ b/util/test/demos/d3d12/d3d12_workgroup_zoo.cpp @@ -98,6 +98,34 @@ float4 funcTest(uint id) } } +float4 ComplexPartialReconvergence(uint id) +{ + float4 result = 0.0.xxxx; + // Loops with different number of iterations per thread + for (uint i = id; i < 23; i++) + { + result.x += WaveActiveSum(id); + } + + if ((result.x < 5) || (id > 20)) + { + result.y += WaveActiveSum(id); + if (id < 25) + result.z += WaveActiveSum(id); + } + else if (result.x < 10) + { + result.y += WaveActiveSum(id); + + if (result.x > 5) + result.z += WaveActiveSum(id); + } + + result.w += WaveActiveSum(id); + + return result; +} + [numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)] void main(uint3 inTid : SV_DispatchThreadID) { @@ -110,10 +138,12 @@ void main(uint3 inTid : SV_DispatchThreadID) if(IsTest(0)) { data.x = id; + AllMemoryBarrierWithGroupSync(); } else if(IsTest(1)) { data.x = WaveActiveSum(id); + AllMemoryBarrierWithGroupSync(); } else if(IsTest(2)) { @@ -136,18 +166,21 @@ void main(uint3 inTid : SV_DispatchThreadID) data.x = WaveActiveSum(id); } data.y = WaveActiveSum(id); + AllMemoryBarrierWithGroupSync(); } else if(IsTest(3)) { // Converged threads calling a function data = funcTest(id); data.y = WaveActiveSum(id); + AllMemoryBarrierWithGroupSync(); } else if(IsTest(4)) { // Converged threads calling a function which has a nested function call in it data = nestedFunc(id); data.y = WaveActiveSum(id); + AllMemoryBarrierWithGroupSync(); } else if(IsTest(5)) { @@ -161,6 +194,7 @@ void main(uint3 inTid : SV_DispatchThreadID) data = funcD(id); } data.y = WaveActiveSum(id); + AllMemoryBarrierWithGroupSync(); } else if(IsTest(6)) { @@ -174,6 +208,7 @@ void main(uint3 inTid : SV_DispatchThreadID) data = funcB(id); } data.y = WaveActiveSum(id); + AllMemoryBarrierWithGroupSync(); } else if(IsTest(7)) { @@ -193,6 +228,7 @@ void main(uint3 inTid : SV_DispatchThreadID) { data.x += WaveActiveSum(id); } + AllMemoryBarrierWithGroupSync(); } else if(IsTest(9)) { @@ -200,91 +236,13 @@ void main(uint3 inTid : SV_DispatchThreadID) data.x = float(WaveGetLaneCount()); data.y = float(WaveGetLaneIndex()); data.z = float(WaveIsFirstLane()); + AllMemoryBarrierWithGroupSync(); } else if(IsTest(10)) { - // Vote functions : unit tests - data.x = float(WaveActiveAnyTrue(id*2 > id+10)); - data.y = float(WaveActiveAllTrue(id < WaveGetLaneCount())); - if (id > 10) - { - data.z = float(WaveActiveAllTrue(id > 10)); - uint4 ballot = WaveActiveBallot(id > 20); - data.w = countbits(ballot.x) + countbits(ballot.y) + countbits(ballot.z) + countbits(ballot.w); - } - else - { - data.z = float(WaveActiveAllTrue(id > 3)); - uint4 ballot = WaveActiveBallot(id > 4); - data.w = countbits(ballot.x) + countbits(ballot.y) + countbits(ballot.z) + countbits(ballot.w); - } - } - else if(IsTest(11)) - { - // Broadcast functions : unit tests - if (id >= 2 && id <= 20) - { - data.x = WaveReadLaneFirst(id); - data.y = WaveReadLaneAt(id, 5); - data.z = WaveReadLaneAt(id, id); - data.w = WaveReadLaneAt(data.x, 2+id%3); - } - } - else if(IsTest(12)) - { - // Scan and Prefix functions : unit tests - if (id >= 2 && id <= 20) - { - data.x = WavePrefixCountBits(id > 4); - data.y = WavePrefixCountBits(id > 10); - data.z = WavePrefixSum(data.x); - data.w = WavePrefixProduct(1 + data.y); - } - else - { - data.x = WavePrefixCountBits(id > 23); - data.y = WavePrefixCountBits(id < 1); - data.z = WavePrefixSum(data.x); - data.w = WavePrefixSum(data.y); - } - } - else if(IsTest(13)) - { - // Reduction functions : unit tests - if (id >= 2 && id <= 20) - { - data.x = float(WaveActiveMax(id)); - data.y = float(WaveActiveMin(id)); - data.z = float(WaveActiveProduct(id)); - data.w = float(WaveActiveSum(id)); - } - } - else if(IsTest(14)) - { - // Reduction functions : unit tests - if (id >= 2 && id <= 20) - { - data.x = float(WaveActiveCountBits(id > 23)); - data.y = float(WaveActiveBitAnd(id)); - data.z = float(WaveActiveBitOr(id)); - data.w = float(WaveActiveBitXor(id)); - } - } - else if(IsTest(15)) - { - // Reduction functions : unit tests - if (id > 13) - { - bool test1 = (id > 15).x; - bool2 test2 = bool2(test1, (id < 23)); - bool3 test3 = bool3(test1, (id < 23), (id >= 25)); - bool4 test4 = bool4(test1, (id < 23), (id >= 25), (id >= 28)); + data = ComplexPartialReconvergence(id); - data.x = float(WaveActiveAllEqual(test1).x); - data.y = float(WaveActiveAllEqual(test2).y); - data.z = float(WaveActiveAllEqual(test3).z); - data.w = float(WaveActiveAllEqual(test4).w); - } + AllMemoryBarrierWithGroupSync(); } SetOutput(data); @@ -396,7 +354,7 @@ void main(uint3 inTid : SV_DispatchThreadID) for(int i = 0; i < numCompTests; i++) { cmd->SetComputeRoot32BitConstant(0, i, 0); - cmd->Dispatch(1, 1, 1); + cmd->Dispatch(2, 1, 1); } popMarker(cmd); diff --git a/util/test/demos/vk/vk_subgroup_zoo.cpp b/util/test/demos/vk/vk_subgroup_zoo.cpp index abe1c2914..826154c9f 100644 --- a/util/test/demos/vk/vk_subgroup_zoo.cpp +++ b/util/test/demos/vk/vk_subgroup_zoo.cpp @@ -151,46 +151,6 @@ layout(binding = 0, std430) buffer outbuftype { layout(local_size_x = GROUP_SIZE_X, local_size_y = GROUP_SIZE_Y, local_size_z = 1) in; -vec4 funcD(uint id) -{ - return vec4(subgroupAdd(id/2)); -} - -vec4 nestedFunc(uint id) -{ - vec4 ret = funcD(id/3); - ret.w = subgroupAdd(id); - return ret; -} - -vec4 funcA(uint id) -{ - return nestedFunc(id*2); -} - -vec4 funcB(uint id) -{ - return nestedFunc(id*4); -} - -vec4 funcTest(uint id) -{ - if ((id % 2) == 0) - { - return vec4(0); - } - else - { - float value = subgroupAdd(id); - if (id < 10) - { - return vec4(value); - } - value += subgroupAdd(id/2); - return vec4(value); - } -} - void SetOutput(vec4 data) { outbuf.data[push.test].vals[gl_LocalInvocationID.y * GROUP_SIZE_X + gl_LocalInvocationID.x] = data; @@ -202,100 +162,13 @@ void main() SetOutput(data); if(IsTest(0)) - { - data.x = id; - } - else if(IsTest(1)) - { - data.x = subgroupAdd(id); - } - else if(IsTest(2)) - { - // Diverged threads which reconverge - if (id < 10) - { - // active threads 0-9 - data.x = subgroupAdd(id); - - if ((id % 2) == 0) - data.y = subgroupAdd(id); - else - data.y = subgroupAdd(id); - - data.x += subgroupAdd(id); - } - else - { - // active threads 10... - data.x = subgroupAdd(id); - } - data.y = subgroupAdd(id); - } - else if(IsTest(3)) - { - // Converged threads calling a function - data = funcTest(id); - data.y = subgroupAdd(id); - } - else if(IsTest(4)) - { - // Converged threads calling a function which has a nested function call in it - data = nestedFunc(id); - data.y = subgroupAdd(id); - } - else if(IsTest(5)) - { - // Diverged threads calling the same function - if (id < 10) - { - data = funcD(id); - } - else - { - data = funcD(id); - } - data.y = subgroupAdd(id); - } - else if(IsTest(6)) - { - // Diverged threads calling the same function which has a nested function call in it - if (id < 10) - { - data = funcA(id); - } - else - { - data = funcB(id); - } - data.y = subgroupAdd(id); - } - else if(IsTest(7)) - { - // Diverged threads which early exit - if (id < 10) - { - data.x = subgroupAdd(id+10); - SetOutput(data); - return; - } - data.x = subgroupAdd(id); - } - else if(IsTest(8)) - { - // Loops with different number of iterations per thread - for (uint i = 0; i < id; i++) - { - data.x += subgroupAdd(id); - } - } - else if(IsTest(9)) { // Query functions : unit tests data.x = float(gl_SubgroupSize); data.y = float(gl_SubgroupInvocationID); data.z = float(subgroupElect()); } - else if(IsTest(10)) + else if(IsTest(1)) { // Vote functions : unit tests data.x = float(subgroupAny(id*2 > id+10)); @@ -313,7 +186,7 @@ void main() data.w = bitCount(ballot.x) + bitCount(ballot.y) + bitCount(ballot.z) + bitCount(ballot.w); } } - else if(IsTest(11)) + else if(IsTest(2)) { // Broadcast functions : unit tests if (id >= 2 && id <= 20) @@ -324,7 +197,7 @@ void main() data.w = subgroupShuffle(data.x, 2+id%3); } } - else if(IsTest(12)) + else if(IsTest(3)) { // Scan and Prefix functions : unit tests if (id >= 2 && id <= 20) @@ -346,7 +219,7 @@ void main() data.w = subgroupExclusiveAdd(data.y); } } - else if(IsTest(13)) + else if(IsTest(4)) { // Reduction functions : unit tests if (id >= 2 && id <= 20) @@ -357,7 +230,7 @@ void main() data.w = float(subgroupAdd(id)); } } - else if(IsTest(14)) + else if(IsTest(5)) { // Reduction functions : unit tests if (id >= 2 && id <= 20) @@ -369,7 +242,7 @@ void main() data.w = float(subgroupXor(id)); } } - else if(IsTest(15)) + else if(IsTest(6)) { // Reduction functions : unit tests if (id > 13) diff --git a/util/test/demos/vk/vk_workgroup_zoo.cpp b/util/test/demos/vk/vk_workgroup_zoo.cpp index d12d61bac..069d03c2d 100644 --- a/util/test/demos/vk/vk_workgroup_zoo.cpp +++ b/util/test/demos/vk/vk_workgroup_zoo.cpp @@ -69,7 +69,7 @@ layout(push_constant) uniform PushData const std::string comp = common + R"EOSHADER( -shared uvec4 gsmUint4[COMP_TESTS]; +shared uvec4 gsmUint4[1024]; struct Output { @@ -126,6 +126,7 @@ void SetOutput(vec4 data) { outbuf.data[push.test].vals[gl_LocalInvocationID.y * GROUP_SIZE_X + gl_LocalInvocationID.x] = data; } + void main() { vec4 data = vec4(0); @@ -136,10 +137,12 @@ void main() if(IsTest(0)) { data.x = id; + barrier(); } else if(IsTest(1)) { data.x = subgroupAdd(id); + barrier(); } else if(IsTest(2)) { @@ -162,18 +165,21 @@ void main() data.x = subgroupAdd(id); } data.y = subgroupAdd(id); + barrier(); } else if(IsTest(3)) { // Converged threads calling a function data = funcTest(id); data.y = subgroupAdd(id); + barrier(); } else if(IsTest(4)) { // Converged threads calling a function which has a nested function call in it data = nestedFunc(id); data.y = subgroupAdd(id); + barrier(); } else if(IsTest(5)) { @@ -187,6 +193,7 @@ void main() data = funcD(id); } data.y = subgroupAdd(id); + barrier(); } else if(IsTest(6)) { @@ -200,6 +207,7 @@ void main() data = funcB(id); } data.y = subgroupAdd(id); + barrier(); } else if(IsTest(7)) { @@ -219,6 +227,7 @@ void main() { data.x += subgroupAdd(id); } + barrier(); } else if(IsTest(9)) { @@ -226,92 +235,10 @@ void main() data.x = float(gl_SubgroupSize); data.y = float(gl_SubgroupInvocationID); data.z = float(subgroupElect()); + + barrier(); } - else if(IsTest(10)) - { - // Vote functions : unit tests - data.x = float(subgroupAny(id*2 > id+10)); - data.y = float(subgroupAll(id < gl_SubgroupSize)); - if (id > 10) - { - data.z = float(subgroupAll(id > 10)); - uvec4 ballot = subgroupBallot(id > 20); - data.w = bitCount(ballot.x) + bitCount(ballot.y) + bitCount(ballot.z) + bitCount(ballot.w); - } - else - { - data.z = float(subgroupAll(id > 3)); - uvec4 ballot = subgroupBallot(id > 4); - data.w = bitCount(ballot.x) + bitCount(ballot.y) + bitCount(ballot.z) + bitCount(ballot.w); - } - } - else if(IsTest(11)) - { - // Broadcast functions : unit tests - if (id >= 2 && id <= 20) - { - data.x = subgroupBroadcastFirst(id); - data.y = subgroupBroadcast(id, 5); - data.z = subgroupShuffle(id, id); - data.w = subgroupShuffle(data.x, 2+id%3); - } - } - else if(IsTest(12)) - { - // Scan and Prefix functions : unit tests - if (id >= 2 && id <= 20) - { - uvec4 bits = subgroupBallot(id > 4); - data.x = subgroupBallotExclusiveBitCount(bits); - bits = subgroupBallot(id > 10); - data.y = subgroupBallotExclusiveBitCount(bits); - data.z = subgroupExclusiveAdd(data.x); - data.w = subgroupExclusiveMul(1 + data.y); - } - else - { - uvec4 bits = subgroupBallot(id > 23); - data.x = subgroupBallotExclusiveBitCount(bits); - bits = subgroupBallot(id < 1); - data.y = subgroupBallotExclusiveBitCount(bits); - data.z = subgroupExclusiveAdd(data.x); - data.w = subgroupExclusiveAdd(data.y); - } - } - else if(IsTest(13)) - { - // Reduction functions : unit tests - if (id >= 2 && id <= 20) - { - data.x = float(subgroupMax(id)); - data.y = float(subgroupMin(id)); - data.z = float(subgroupMul(id)); - data.w = float(subgroupAdd(id)); - } - } - else if(IsTest(14)) - { - // Reduction functions : unit tests - if (id >= 2 && id <= 20) - { - uvec4 bits = subgroupBallot(id > 23); - data.x = float(subgroupBallotBitCount(bits)); - data.y = float(subgroupAnd(id)); - data.z = float(subgroupOr(id)); - data.w = float(subgroupXor(id)); - } - } - else if(IsTest(15)) - { - // Reduction functions : unit tests - if (id > 13) - { - data.x = float(subgroupAllEqual(id > 15)); - data.y = float(subgroupAllEqual(id < 23)); - data.z = float(subgroupAllEqual(id >= 25)); - data.w = float(subgroupAllEqual(id >= 28)); - } - } + SetOutput(data); } @@ -467,7 +394,7 @@ void main() for(int i = 0; i < numCompTests; i++) { vkh::cmdPushConstants(cmd, layout, i); - vkCmdDispatch(cmd, 1, 1, 1); + vkCmdDispatch(cmd, 2, 1, 1); } popMarker(cmd); diff --git a/util/test/rdtest/shared/Subgroup_Zoo.py b/util/test/rdtest/shared/Subgroup_Zoo.py index 72c53801d..99c9dcb83 100644 --- a/util/test/rdtest/shared/Subgroup_Zoo.py +++ b/util/test/rdtest/shared/Subgroup_Zoo.py @@ -6,6 +6,7 @@ import rdtest class Subgroup_Zoo(rdtest.TestCase): internal = True demos_test_name = None + workgroup = (0, 0, 0) def check_support(self, **kwargs): # Only allow this if explicitly run @@ -19,7 +20,7 @@ class Subgroup_Zoo(rdtest.TestCase): "4f", bufdata, 16*y*dim[0] + 16*x) trace = self.controller.DebugThread( - (0, 0, 0), (x, y, z)) + self.workgroup, (x, y, z)) _, variables = self.process_trace(trace) diff --git a/util/test/rdtest/shared/Workgroup_Zoo.py b/util/test/rdtest/shared/Workgroup_Zoo.py index ca9a7aa9a..24e3fbfd7 100644 --- a/util/test/rdtest/shared/Workgroup_Zoo.py +++ b/util/test/rdtest/shared/Workgroup_Zoo.py @@ -24,6 +24,7 @@ class Workgroup_Zoo(rdtest.Subgroup_Zoo): 150 ] + self.workgroup = (1, 0, 0) if self.check_compute_tests(compute_dims, thread_checks): raise rdtest.TestFailureException("Some tests were not as expected")