SPIR-V OpSwitch support for 64-bit selectors #3140

This commit is contained in:
Jake Turner
2023-12-07 11:14:08 +00:00
parent 18a28eb8dc
commit 5bfb844643
8 changed files with 211 additions and 132 deletions
@@ -450,12 +450,7 @@ rdcstr DoStringise(const rdcspv::{name} &el)
operand_kind['kind'] == 'LiteralSpecConstantOpInteger'):
operand_kind['size'] = None
elif (operand_kind['kind'] == 'PairLiteralIntegerIdRef'):
operand_kind['size'] = 2
operand_kind['def_name'] = name[0].lower() + name[1:]
operand_kind['def_value'] = '{0, Id()}'
operand_kind['type'] = name
operand_kind['push_words'] = lambda name: 'words.push_back((uint32_t){0}.first); words.push_back({0}.second.value());'.format(name)
ops_header.write('struct {} {{ uint32_t first; Id second; }};\n\n'.format(name))
operand_kind['size'] = None
elif (operand_kind['kind'] == 'PairIdRefLiteralInteger'):
operand_kind['size'] = 2
operand_kind['def_name'] = name[0].lower() + name[1:]
@@ -488,16 +483,6 @@ inline PairIdRefIdRef DecodeParam(const ConstIter &it, uint32_t &word)
return ret;
}
template<>
inline PairLiteralIntegerIdRef DecodeParam(const ConstIter &it, uint32_t &word)
{
if(word >= it.size()) return {};
PairLiteralIntegerIdRef ret = { it.word(word), Id::fromWord(it.word(word+1)) };
word += 2;
return ret;
}
template<>
inline PairIdRefLiteralInteger DecodeParam(const ConstIter &it, uint32_t &word)
{
@@ -835,11 +820,6 @@ inline uint16_t OptionalWordCount(const PairIdRefLiteralInteger &val)
return val.first != Id() ? 2 : 0;
}
inline uint16_t OptionalWordCount(const PairLiteralIntegerIdRef &val)
{
return val.second != Id() ? 2 : 0;
}
inline uint16_t OptionalWordCount(const PairIdRefIdRef &val)
{
return val.first != Id() ? 2 : 0;
@@ -1196,8 +1176,6 @@ rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const Id &el)
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcstr &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairLiteralIntegerIdRef &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefLiteralInteger &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefIdRef &el);
@@ -1218,6 +1196,8 @@ inline rdcstr ParamsToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const
return ret;
}}
extern bool ManualForEachID(const ConstIter &it, const std::function<void(Id,bool)> &callback);
struct OpDecoder
{{
OpDecoder(const ConstIter &it);
@@ -1257,12 +1237,6 @@ rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcstr
return "\\"" + el + "\\"";
}}
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairLiteralIntegerIdRef &el)
{{
return StringFormat::Fmt("[%u, %s]", el.first, idName(el.second).c_str());
}}
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefLiteralInteger &el)
{{
@@ -1279,6 +1253,8 @@ rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdR
void OpDecoder::ForEachID(const ConstIter &it, const std::function<void(Id,bool)> &callback)
{{
if (rdcspv::ManualForEachID(it, callback))
return;
size_t size = it.size();
uint32_t word = 0;
(void)word;
@@ -192,3 +192,22 @@ ShaderBuiltin MakeShaderBuiltin(ShaderStage stage, const rdcspv::BuiltIn el)
return ShaderBuiltin::Undefined;
}
namespace rdcspv
{
bool ManualForEachID(const ConstIter &it, const std::function<void(Id, bool)> &callback)
{
switch(it.opcode())
{
case rdcspv::Op::Switch:
// Include just the selector
callback(Id::fromWord(it.word(1)), false);
return true;
default:
// unhandled
return false;
}
}
}; // namespace rdcspv
@@ -328,6 +328,115 @@ struct OpShaderDbg : public OpExtInstGeneric<rdcspv::ShaderDbg>
OpShaderDbg(const ConstIter &it) : OpExtInstGeneric(it) {}
};
template <typename T>
struct SwitchPairLiteralId
{
T literal;
Id target;
};
typedef SwitchPairLiteralId<uint32_t> SwitchPairU32LiteralId;
typedef SwitchPairLiteralId<uint64_t> SwitchPairU64LiteralId;
// helpers for OpSwitch 32-bit and 64-bit versions in the style of the auto-generated helpers
struct OpSwitch32
{
OpSwitch32(const ConstIter &it)
{
this->op = OpCode;
this->wordCount = (uint16_t)it.size();
this->selector = Id::fromWord(it.word(1));
this->def = Id::fromWord(it.word(2));
uint32_t word = 3;
while(word < it.size())
{
uint32_t literal(it.word(word));
word += 1;
rdcspv::Id target(Id::fromWord(it.word(word)));
word += 1;
this->targets.push_back({literal, target});
}
}
OpSwitch32(Id selector, Id def, const rdcarray<SwitchPairLiteralId<uint32_t>> &targets)
: op(Op::Switch)
{
this->wordCount = MinWordSize + 2 * (uint16_t)targets.count();
this->selector = selector;
this->def = def;
this->targets = targets;
}
operator Operation() const
{
rdcarray<uint32_t> words;
words.push_back(selector.value());
words.push_back(def.value());
for(size_t i = 0; i < targets.size(); i++)
{
words.push_back(targets[i].literal);
words.push_back(targets[i].target.value());
}
return rdcspv::Operation(Op::Switch, words);
}
static constexpr Op OpCode = Op::Switch;
static constexpr uint16_t MinWordSize = 3U;
Op op;
uint16_t wordCount;
Id selector;
Id def;
rdcarray<SwitchPairU32LiteralId> targets;
};
struct OpSwitch64
{
OpSwitch64(const ConstIter &it)
{
this->op = OpCode;
this->wordCount = (uint16_t)it.size();
this->selector = Id::fromWord(it.word(1));
this->def = Id::fromWord(it.word(2));
uint32_t word = 3;
while(word < it.size())
{
uint64_t literal(*(uint64_t *)(it.words() + word));
word += 2;
rdcspv::Id target(Id::fromWord(it.word(word)));
word += 1;
this->targets.push_back({literal, target});
}
}
OpSwitch64(Id selector, Id def, const rdcarray<SwitchPairU64LiteralId> &targets) : op(Op::Switch)
{
this->wordCount = MinWordSize + 3 * (uint16_t)targets.count();
this->selector = selector;
this->def = def;
this->targets = targets;
}
operator Operation() const
{
rdcarray<uint32_t> words;
words.push_back(selector.value());
words.push_back(def.value());
for(size_t i = 0; i < targets.size(); i++)
{
uint32_t *literal = (uint32_t *)&(targets[i].literal);
words.push_back(literal[0]);
words.push_back(literal[1]);
words.push_back(targets[i].target.value());
}
return rdcspv::Operation(Op::Switch, words);
}
static constexpr Op OpCode = Op::Switch;
static constexpr uint16_t MinWordSize = 3U;
Op op;
uint16_t wordCount;
Id selector;
Id def;
rdcarray<SwitchPairU64LiteralId> targets;
};
}; // namespace rdcspv
static const uint32_t SpecializationConstantBindSet = 1234567;
+29 -9
View File
@@ -2848,18 +2848,38 @@ void ThreadState::StepNext(ShaderDebugState *state, const rdcarray<ThreadState>
}
case Op::Switch:
{
OpSwitch switch_(it);
OpSwitch32 switch32(it);
// selector and default are common beteen 32-bit and 64-bit versions of OpSwitch
Id selectorId = switch32.selector;
Id targetLabel = switch32.def;
ShaderVariable selector = GetSrc(switch_.selector);
Id targetLabel = switch_.def;
for(const PairLiteralIntegerIdRef &case_ : switch_.target)
ShaderVariable selector = GetSrc(selectorId);
bool longLiterals = ((selector.type == VarType::SLong) || (selector.type == VarType::ULong));
if(!longLiterals)
{
if(uintComp(selector, 0) == case_.first)
const uint32_t selectorVal = uintComp(selector, 0);
for(size_t i = 0; i < switch32.targets.size(); ++i)
{
targetLabel = case_.second;
break;
SwitchPairU32LiteralId target = switch32.targets[i];
if(selectorVal == target.literal)
{
targetLabel = target.target;
break;
}
}
}
else
{
OpSwitch64 switch64(it);
const uint64_t selectorVal = selector.value.u64v[0];
for(size_t i = 0; i < switch64.targets.size(); ++i)
{
SwitchPairU64LiteralId target = switch64.targets[i];
if(selectorVal == target.literal)
{
targetLabel = target.target;
break;
}
}
}
@@ -52,7 +52,7 @@ struct StructuredCFG
rdcspv::Id continueTarget;
// only valid for switches
rdcarray<rdcspv::PairLiteralIntegerIdRef> caseTargets;
rdcarray<rdcspv::SwitchPairU64LiteralId> caseTargets;
rdcspv::Id defaultTarget;
};
@@ -713,13 +713,36 @@ rdcstr Reflector::Disassemble(const rdcstr &entryPoint,
{
cfg.type = StructuredCFG::Switch;
OpSwitch decodedswitch(it);
OpSwitch32 switch32(it);
// selector and default are common beteen 32-bit and 64-bit versions of OpSwitch
Id selector = switch32.selector;
cfg.defaultTarget = switch32.def;
cfg.caseTargets = decodedswitch.target;
cfg.defaultTarget = decodedswitch.def;
const DataType &type = dataTypes[idTypes[selector]];
RDCASSERT(type.type == DataType::ScalarType);
const uint32_t selectorWidth = type.scalar().width;
const bool longLiterals = (selectorWidth == 64);
if(!longLiterals)
{
for(size_t i = 0; i < switch32.targets.size(); ++i)
{
SwitchPairU32LiteralId target = switch32.targets[i];
cfg.caseTargets.push_back({target.literal, target.target});
}
}
else
{
OpSwitch64 switch64(it);
for(size_t i = 0; i < switch64.targets.size(); ++i)
{
SwitchPairU64LiteralId target = switch64.targets[i];
cfg.caseTargets.push_back({target.literal, target.target});
}
}
ret += indent;
ret += StringFormat::Fmt("switch(%s) {\n", idName(decodedswitch.selector).c_str());
ret += StringFormat::Fmt("switch(%s) {\n", idName(selector).c_str());
lineNum++;
// add another level - each case label will be un-intended.
@@ -926,13 +949,13 @@ rdcstr Reflector::Disassemble(const rdcstr &entryPoint,
if(!cfgStack.empty() && cfgStack.back().type == StructuredCFG::Switch)
{
for(const PairLiteralIntegerIdRef &caseTarget : cfgStack.back().caseTargets)
for(const SwitchPairU64LiteralId &caseTarget : cfgStack.back().caseTargets)
{
if(caseTarget.second == decoded.result)
if(caseTarget.target == decoded.result)
{
// if this is the current switch's default: then print it
ret += indent.substr(0, indent.size() - 2);
ret += StringFormat::Fmt("case %u:\n", caseTarget.first);
ret += StringFormat::Fmt("case %llu:\n", caseTarget.literal);
lineNum++;
break;
}
@@ -1002,13 +1025,13 @@ rdcstr Reflector::Disassemble(const rdcstr &entryPoint,
}
// if we're falling through to the next case, print a comment
for(const PairLiteralIntegerIdRef &caseTarget : cfgStack.back().caseTargets)
for(const SwitchPairU64LiteralId &caseTarget : cfgStack.back().caseTargets)
{
if(caseTarget.second == decoded.targetLabel)
if(caseTarget.target == decoded.targetLabel)
{
ret += indent;
ret +=
StringFormat::Fmt("// deliberate fallthrough to case %u\n", caseTarget.first);
ret += StringFormat::Fmt("// deliberate fallthrough to case %llu\n",
caseTarget.literal);
lineNum++;
break;
}
@@ -1067,11 +1090,11 @@ rdcstr Reflector::Disassemble(const rdcstr &entryPoint,
{
bool printed = false;
for(const PairLiteralIntegerIdRef &caseTarget : lastLoopSwitch->caseTargets)
for(const SwitchPairU64LiteralId &caseTarget : lastLoopSwitch->caseTargets)
{
if(caseTarget.second == decoded.targetLabel)
if(caseTarget.target == decoded.targetLabel)
{
ret += StringFormat::Fmt("goto case %u;\n", caseTarget.first);
ret += StringFormat::Fmt("goto case %llu;\n", caseTarget.literal);
lineNum++;
printed = true;
break;
+4 -16
View File
@@ -1808,12 +1808,6 @@ rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcstr
return "\"" + el + "\"";
}
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairLiteralIntegerIdRef &el)
{
return StringFormat::Fmt("[%u, %s]", el.first, idName(el.second).c_str());
}
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefLiteralInteger &el)
{
@@ -2151,6 +2145,8 @@ rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcspv:
void OpDecoder::ForEachID(const ConstIter &it, const std::function<void(Id,bool)> &callback)
{
if (rdcspv::ManualForEachID(it, callback))
return;
size_t size = it.size();
uint32_t word = 0;
(void)word;
@@ -3338,8 +3334,6 @@ void OpDecoder::ForEachID(const ConstIter &it, const std::function<void(Id,bool)
callback(Id::fromWord(it.word(3)), false);
break;
case rdcspv::Op::Switch:
callback(Id::fromWord(it.word(1)), false);
callback(Id::fromWord(it.word(2)), false);
break;
case rdcspv::Op::Kill:
break;
@@ -7518,14 +7512,8 @@ rdcstr OpDecoder::Disassemble(const ConstIter &it, const std::function<rdcstr(Id
}
case rdcspv::Op::Switch:
{
OpSwitch decoded(it);
ret += rdcstr("Switch("_lit)
+ ParamToStr(idName, decoded.selector)
+ ", "
+ ParamToStr(idName, decoded.def)
+ ", "
+ ParamsToStr(idName, decoded.target)
+ ")";
OpDecoder decoded(it);
ret += "Switch(...)";
break;
}
case rdcspv::Op::Kill:
@@ -140,8 +140,6 @@ inline void EncodeParam(rdcarray<uint32_t> &words, const rdcstr &str)
}
}
struct PairLiteralIntegerIdRef { uint32_t first; Id second; };
struct PairIdRefLiteralInteger { Id first; uint32_t second; };
struct PairIdRefIdRef { Id first, second; };
@@ -157,16 +155,6 @@ inline PairIdRefIdRef DecodeParam(const ConstIter &it, uint32_t &word)
return ret;
}
template<>
inline PairLiteralIntegerIdRef DecodeParam(const ConstIter &it, uint32_t &word)
{
if(word >= it.size()) return {};
PairLiteralIntegerIdRef ret = { it.word(word), Id::fromWord(it.word(word+1)) };
word += 2;
return ret;
}
template<>
inline PairIdRefLiteralInteger DecodeParam(const ConstIter &it, uint32_t &word)
{
@@ -2480,11 +2468,6 @@ inline uint16_t OptionalWordCount(const PairIdRefLiteralInteger &val)
return val.first != Id() ? 2 : 0;
}
inline uint16_t OptionalWordCount(const PairLiteralIntegerIdRef &val)
{
return val.second != Id() ? 2 : 0;
}
inline uint16_t OptionalWordCount(const PairIdRefIdRef &val)
{
return val.first != Id() ? 2 : 0;
@@ -8994,46 +8977,7 @@ struct OpBranchConditional
rdcarray<uint32_t> branchweights;
};
struct OpSwitch
{
OpSwitch(const ConstIter &it)
{
uint32_t word = 0;(void)word;
this->op = OpCode;
this->wordCount = (uint16_t)it.size();
this->selector = Id::fromWord(it.word(1));
this->def = Id::fromWord(it.word(2));
word = 3;
this->target = MultiParam<PairLiteralIntegerIdRef>(it, word);
}
OpSwitch(Id selector, Id def, const rdcarray<PairLiteralIntegerIdRef> &target = {})
: op(Op::Switch)
, wordCount(MinWordSize + MultiWordCount(target))
{
this->selector = selector;
this->def = def;
this->target = target;
}
operator Operation() const
{
rdcarray<uint32_t> words;
words.push_back(selector.value());
words.push_back(def.value());
for(size_t i=0; i < target.size(); i++)
{
words.push_back((uint32_t)target[i].first); words.push_back(target[i].second.value());
}
return Operation(OpCode, words);
}
static constexpr Op OpCode = Op::Switch;
static constexpr uint16_t MinWordSize = 3U;
Op op;
uint16_t wordCount;
Id selector;
Id def;
rdcarray<PairLiteralIntegerIdRef> target;
};
struct OpSwitch; // has operands with variable sizes
struct OpKill
{
@@ -17718,8 +17662,6 @@ rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const Id &el)
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const rdcstr &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairLiteralIntegerIdRef &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefLiteralInteger &el);
template<>
rdcstr ParamToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const PairIdRefIdRef &el);
@@ -17745,6 +17687,8 @@ inline rdcstr ParamsToStr(const std::function<rdcstr(rdcspv::Id)> &idName, const
return ret;
}
extern bool ManualForEachID(const ConstIter &it, const std::function<void(Id,bool)> &callback);
struct OpDecoder
{
OpDecoder(const ConstIter &it);
+4 -4
View File
@@ -2103,7 +2103,7 @@ private:
rdcspv::Id breakLabel = editor.MakeId();
rdcspv::Id defaultLabel = editor.MakeId();
rdcarray<rdcspv::PairLiteralIntegerIdRef> targets;
rdcarray<rdcspv::SwitchPairU32LiteralId> targets;
rdcspv::OperationList cases;
@@ -2209,7 +2209,7 @@ private:
}
func.add(rdcspv::OpSelectionMerge(breakLabel, rdcspv::SelectionControl::None));
func.add(rdcspv::OpSwitch(opParam, defaultLabel, targets));
func.add(rdcspv::OpSwitch32(opParam, defaultLabel, targets));
func.append(cases);
@@ -2544,7 +2544,7 @@ private:
switchVal = func.add(rdcspv::OpIAdd(u32, editor.MakeId(), switchVal, dim));
// switch on the combined operation and image type value
rdcarray<rdcspv::PairLiteralIntegerIdRef> targets;
rdcarray<rdcspv::SwitchPairU32LiteralId> targets;
rdcspv::OperationList cases;
@@ -2925,7 +2925,7 @@ private:
}
func.add(rdcspv::OpSelectionMerge(breakLabel, rdcspv::SelectionControl::None));
func.add(rdcspv::OpSwitch(switchVal, defaultLabel, targets));
func.add(rdcspv::OpSwitch32(switchVal, defaultLabel, targets));
func.append(cases);