Implement Atomic ops

This commit is contained in:
baldurk
2016-04-22 18:32:47 +02:00
parent c7512e15e2
commit 7663ff248a
@@ -630,6 +630,10 @@ struct SPVOperation
// OpLoad/OpStore/OpCopyMemory
spv::MemoryAccessMask access;
// OpAtomic*
spv::Scope scope;
spv::MemorySemanticsMask semantics, semanticsUnequal;
// OpExtInst
vector<uint32_t> literals;
@@ -1436,6 +1440,20 @@ struct SPVInstruction
case spv::OpImageSparseFetch:
case spv::OpImageSparseGather:
case spv::OpImageSparseDrefGather:
case spv::OpAtomicStore:
case spv::OpAtomicExchange:
case spv::OpAtomicCompareExchange:
case spv::OpAtomicIIncrement:
case spv::OpAtomicIDecrement:
case spv::OpAtomicIAdd:
case spv::OpAtomicISub:
case spv::OpAtomicSMin:
case spv::OpAtomicUMin:
case spv::OpAtomicSMax:
case spv::OpAtomicUMax:
case spv::OpAtomicAnd:
case spv::OpAtomicOr:
case spv::OpAtomicXor:
case spv::OpConvertFToS:
case spv::OpConvertFToU:
case spv::OpConvertUToF:
@@ -1477,7 +1495,7 @@ struct SPVInstruction
string ret = "";
if(!inlineOp && op->type && op->type->type != SPVTypeData::eVoid)
if(!inlineOp && op->type && op->type->type != SPVTypeData::eVoid && opcode != spv::OpAtomicStore)
ret = StringFormat::Fmt("%s %s = ", op->type->GetName().c_str(), GetIDName().c_str());
size_t numArgs = op->arguments.size();
@@ -1539,6 +1557,32 @@ struct SPVInstruction
ret += ")";
// for atomic operations, print the execution scope and memory semantics
switch(opcode)
{
case spv::OpAtomicStore:
case spv::OpAtomicExchange:
case spv::OpAtomicIIncrement:
case spv::OpAtomicIDecrement:
case spv::OpAtomicIAdd:
case spv::OpAtomicISub:
case spv::OpAtomicSMin:
case spv::OpAtomicUMin:
case spv::OpAtomicSMax:
case spv::OpAtomicUMax:
case spv::OpAtomicAnd:
case spv::OpAtomicOr:
case spv::OpAtomicXor:
ret += StringFormat::Fmt(" Scope=%s Semantics=%s", ToStr::Get(op->scope).c_str(), ToStr::Get(op->semantics).c_str());
break;
case spv::OpAtomicCompareExchange:
ret += StringFormat::Fmt(" Scope=%s Semantics=(equal: %s unequal: %s)",
ToStr::Get(op->scope).c_str(), ToStr::Get(op->semantics).c_str(), ToStr::Get(op->semanticsUnequal).c_str());
break;
default:
break;
}
return ret;
}
case spv::OpEmitVertex:
@@ -5306,6 +5350,105 @@ void ParseSPIRV(uint32_t *spirv, size_t spirvLength, SPVModule &module)
curBlock->instructions.push_back(&op);
break;
}
case spv::OpAtomicStore:
case spv::OpAtomicExchange:
case spv::OpAtomicCompareExchange:
case spv::OpAtomicIIncrement:
case spv::OpAtomicIDecrement:
case spv::OpAtomicIAdd:
case spv::OpAtomicISub:
case spv::OpAtomicSMin:
case spv::OpAtomicUMin:
case spv::OpAtomicSMax:
case spv::OpAtomicUMax:
case spv::OpAtomicAnd:
case spv::OpAtomicOr:
case spv::OpAtomicXor:
{
int word = 1;
op.op = new SPVOperation();
// all atomic operations but store return a new ID of a given type
if(op.opcode != spv::OpAtomicStore)
{
SPVInstruction *typeInst = module.GetByID(spirv[it+word]);
RDCASSERT(typeInst && typeInst->type);
op.op->type = typeInst->type;
word++;
op.id = spirv[it+word];
module.ids[spirv[it+word]] = &op;
word++;
}
SPVInstruction *ptrInst = module.GetByID(spirv[it+word]);
RDCASSERT(ptrInst);
op.op->arguments.push_back(ptrInst);
word++;
SPVInstruction *scopeInst = module.GetByID(spirv[it+word]);
RDCASSERT(scopeInst && scopeInst->constant); // shader capability requires this to be a constant
if(scopeInst && scopeInst->constant)
op.op->scope = (spv::Scope)scopeInst->constant->u32;
word++;
SPVInstruction *semanticsInst = module.GetByID(spirv[it+word]);
RDCASSERT(semanticsInst && semanticsInst->constant); // shader capability requires this to be a constant
if(semanticsInst && semanticsInst->constant)
op.op->semantics = (spv::MemorySemanticsMask)semanticsInst->constant->u32;
word++;
// compare-exchange operations define an additional semantics for the unequal case
if(op.opcode == spv::OpAtomicCompareExchange)
{
semanticsInst = module.GetByID(spirv[it+word]);
RDCASSERT(semanticsInst && semanticsInst->constant); // shader capability requires this to be a constant
if(semanticsInst && semanticsInst->constant)
op.op->semanticsUnequal = (spv::MemorySemanticsMask)semanticsInst->constant->u32;
word++;
}
// all but increment/decrement and load then take a value
if(op.opcode != spv::OpAtomicIIncrement &&
op.opcode != spv::OpAtomicIDecrement &&
op.opcode != spv::OpAtomicLoad)
{
SPVInstruction *valueInst = module.GetByID(spirv[it+word]);
RDCASSERT(valueInst);
op.op->arguments.push_back(valueInst);
word++;
}
// compare exchange then takes a comparison value
if(op.opcode == spv::OpAtomicCompareExchange)
{
SPVInstruction *compareInst = module.GetByID(spirv[it+word]);
RDCASSERT(compareInst);
op.op->arguments.push_back(compareInst);
word++;
}
// never combine atomic operations
op.op->complexity = NEVER_INLINE_COMPLEXITY;
curBlock->instructions.push_back(&op);
break;
}
case spv::OpName:
case spv::OpMemberName:
case spv::OpLine:
@@ -6110,6 +6253,21 @@ string ToStrHelper<false, spv::BuiltIn>::Get(const spv::BuiltIn &el)
}
return StringFormat::Fmt("Unrecognised{%u}", (uint32_t)el);
template<>
string ToStrHelper<false, spv::Scope>::Get(const spv::Scope &el)
{
switch(el)
{
case spv::ScopeCrossDevice: return "CrossDevice";
case spv::ScopeDevice: return "Device";
case spv::ScopeWorkgroup: return "Workgroup";
case spv::ScopeSubgroup: return "Subgroup";
case spv::ScopeInvocation: return "Invocation";
default: break;
}
return StringFormat::Fmt("UnrecognisedScope{%u}", (uint32_t)el);
}
template<>
@@ -6169,3 +6327,28 @@ string ToStrHelper<false, spv::MemoryAccessMask>::Get(const spv::MemoryAccessMas
return ret;
}
template<>
string ToStrHelper<false, spv::MemorySemanticsMask>::Get(const spv::MemorySemanticsMask &el)
{
string ret;
if(el == spv::MemorySemanticsMaskNone)
return "None";
if(el & spv::MemorySemanticsAcquireMask) ret += ", Acquire";
if(el & spv::MemorySemanticsReleaseMask) ret += ", Release";
if(el & spv::MemorySemanticsAcquireReleaseMask) ret += ", Acquire/Release";
if(el & spv::MemorySemanticsSequentiallyConsistentMask) ret += ", Sequentially Consistent";
if(el & spv::MemorySemanticsUniformMemoryMask) ret += ", Uniform Memory";
if(el & spv::MemorySemanticsSubgroupMemoryMask) ret += ", Subgroup Memory";
if(el & spv::MemorySemanticsWorkgroupMemoryMask) ret += ", Workgroup Memory";
if(el & spv::MemorySemanticsCrossWorkgroupMemoryMask) ret += ", Cross Workgroup Memory";
if(el & spv::MemorySemanticsAtomicCounterMemoryMask) ret += ", Atomic Counter Memory";
if(el & spv::MemorySemanticsImageMemoryMask) ret += ", Image Memory";
if(!ret.empty())
ret = ret.substr(2);
return ret;
}