Extend VK Groupshared GSM atomic tests to match D3D12 Groupshared

This commit is contained in:
Jake Turner
2025-10-10 11:12:02 +01:00
parent 4f8954ca1a
commit 2fda1f79eb
+137
View File
@@ -49,6 +49,8 @@ layout(binding = 1, std430) buffer outdataBuf
};
shared float gsmData[MAX_THREADS];
shared int gsmIntData[MAX_THREADS];
shared int gInt;
#define IsTest(x) (push.test == x)
@@ -57,6 +59,11 @@ float GetGSMValue(uint i)
return gsmData[i % MAX_THREADS];
}
int GetGSMIntValue(uint i)
{
return gsmIntData[i % MAX_THREADS];
}
layout(local_size_x = MAX_THREADS, local_size_y = 1, local_size_z = 1) in;
#define GroupMemoryBarrierWithGroupSync() memoryBarrierShared();groupMemoryBarrier();barrier();
@@ -68,11 +75,14 @@ void main()
if(gl_LocalInvocationID.x == 0)
{
for(int i=0; i < MAX_THREADS; i++) gsmData[i] = 1.25f;
for(int i=0; i < MAX_THREADS; i++) gsmIntData[i] = 125;
gInt = 25;
}
GroupMemoryBarrierWithGroupSync();
vec4 outval = vec4(0.0);
int u = int(gid.x);
if (IsTest(0))
{
@@ -129,6 +139,133 @@ void main()
outval.z = GetGSMValue(gid.x + 2);
outval.w = GetGSMValue(gid.x + 3);
}
else if (IsTest(3))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
atomicAdd(gsmIntData[u], value);
atomicAdd(gsmIntData[u], -value);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(4))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
atomicAnd(gsmIntData[u], value);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(5))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
atomicOr(gsmIntData[u], value);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(6))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
atomicXor(gsmIntData[u], value);
atomicXor(gsmIntData[u], value);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(7))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
atomicMin(gsmIntData[u], value);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(8))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
atomicMax(gsmIntData[u], value);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(9))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
int original = atomicExchange(gsmIntData[u], value);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(10))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
int original = atomicCompSwap(gsmIntData[u], value, value+1);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(11))
{
int value = int(indata[gid.x] * 100.0);
gsmIntData[gid.x] = u;
GroupMemoryBarrierWithGroupSync();
atomicCompSwap(gsmIntData[u], value, value+1);
GroupMemoryBarrierWithGroupSync();
outval.x = float(GetGSMIntValue(u+0));
outval.y = float(GetGSMIntValue(u+1));
outval.z = float(GetGSMIntValue(u+2));
outval.w = float(GetGSMIntValue(u+3));
}
else if (IsTest(12))
{
GroupMemoryBarrierWithGroupSync();
outval.x = gInt;
GroupMemoryBarrierWithGroupSync();
atomicAdd(gInt,1);
GroupMemoryBarrierWithGroupSync();
outval.y = gInt;
GroupMemoryBarrierWithGroupSync();
atomicAdd(gInt,1);
GroupMemoryBarrierWithGroupSync();
outval.z = gInt;
GroupMemoryBarrierWithGroupSync();
atomicAdd(gInt,1);
GroupMemoryBarrierWithGroupSync();
outval.w = gInt;
}
outdata[gid.x] = outval;
}