From 2fda1f79eb56c1a03413aa9ee0fe3b8c8f587fee Mon Sep 17 00:00:00 2001 From: Jake Turner Date: Fri, 10 Oct 2025 11:12:02 +0100 Subject: [PATCH] Extend VK Groupshared GSM atomic tests to match D3D12 Groupshared --- util/test/demos/vk/vk_groupshared.cpp | 137 ++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/util/test/demos/vk/vk_groupshared.cpp b/util/test/demos/vk/vk_groupshared.cpp index 2844611d8..764c398e1 100644 --- a/util/test/demos/vk/vk_groupshared.cpp +++ b/util/test/demos/vk/vk_groupshared.cpp @@ -49,6 +49,8 @@ layout(binding = 1, std430) buffer outdataBuf }; shared float gsmData[MAX_THREADS]; +shared int gsmIntData[MAX_THREADS]; +shared int gInt; #define IsTest(x) (push.test == x) @@ -57,6 +59,11 @@ float GetGSMValue(uint i) return gsmData[i % MAX_THREADS]; } +int GetGSMIntValue(uint i) +{ + return gsmIntData[i % MAX_THREADS]; +} + layout(local_size_x = MAX_THREADS, local_size_y = 1, local_size_z = 1) in; #define GroupMemoryBarrierWithGroupSync() memoryBarrierShared();groupMemoryBarrier();barrier(); @@ -68,11 +75,14 @@ void main() if(gl_LocalInvocationID.x == 0) { for(int i=0; i < MAX_THREADS; i++) gsmData[i] = 1.25f; + for(int i=0; i < MAX_THREADS; i++) gsmIntData[i] = 125; + gInt = 25; } GroupMemoryBarrierWithGroupSync(); vec4 outval = vec4(0.0); + int u = int(gid.x); if (IsTest(0)) { @@ -129,6 +139,133 @@ void main() outval.z = GetGSMValue(gid.x + 2); outval.w = GetGSMValue(gid.x + 3); } + else if (IsTest(3)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + atomicAdd(gsmIntData[u], value); + atomicAdd(gsmIntData[u], -value); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(4)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + atomicAnd(gsmIntData[u], value); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(5)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + atomicOr(gsmIntData[u], value); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(6)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + atomicXor(gsmIntData[u], value); + atomicXor(gsmIntData[u], value); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(7)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + atomicMin(gsmIntData[u], value); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(8)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + atomicMax(gsmIntData[u], value); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(9)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + int original = atomicExchange(gsmIntData[u], value); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(10)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + int original = atomicCompSwap(gsmIntData[u], value, value+1); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(11)) + { + int value = int(indata[gid.x] * 100.0); + gsmIntData[gid.x] = u; + GroupMemoryBarrierWithGroupSync(); + atomicCompSwap(gsmIntData[u], value, value+1); + GroupMemoryBarrierWithGroupSync(); + outval.x = float(GetGSMIntValue(u+0)); + outval.y = float(GetGSMIntValue(u+1)); + outval.z = float(GetGSMIntValue(u+2)); + outval.w = float(GetGSMIntValue(u+3)); + } + else if (IsTest(12)) + { + GroupMemoryBarrierWithGroupSync(); + outval.x = gInt; + GroupMemoryBarrierWithGroupSync(); + atomicAdd(gInt,1); + GroupMemoryBarrierWithGroupSync(); + outval.y = gInt; + GroupMemoryBarrierWithGroupSync(); + atomicAdd(gInt,1); + GroupMemoryBarrierWithGroupSync(); + outval.z = gInt; + GroupMemoryBarrierWithGroupSync(); + atomicAdd(gInt,1); + GroupMemoryBarrierWithGroupSync(); + outval.w = gInt; + } outdata[gid.x] = outval; }