Add handling for binop DXIL constants

This commit is contained in:
baldurk
2023-03-28 13:26:06 +01:00
parent f36a9f8845
commit e9708d909e
5 changed files with 142 additions and 25 deletions
+13 -24
View File
@@ -225,6 +225,17 @@ void Program::ParseConstant(ValueList &values, const LLVMBC::BlockOrRecord &cons
c->setInner(values.getOrCreatePlaceholder((size_t)constant.ops[2]));
values.addValue();
}
else if(IS_KNOWN(constant.id, ConstantsRecord::EVAL_BINOP))
{
Constant *c = values.nextValue<Constant>();
c->type = m_CurParseType;
c->op = DecodeBinOp(c->type, constant.ops[0]);
rdcarray<Value *> members;
members.push_back(values.getOrCreatePlaceholder((size_t)constant.ops[1]));
members.push_back(values.getOrCreatePlaceholder((size_t)constant.ops[2]));
c->setCompound(alloc, std::move(members));
values.addValue();
}
else if(IS_KNOWN(constant.id, ConstantsRecord::EVAL_GEP))
{
Constant *c = values.nextValue<Constant>();
@@ -1332,29 +1343,7 @@ Program::Program(const byte *bytes, size_t length) : alloc(32 * 1024)
inst->type = inst->args.back()->type;
inst->args.push_back(op.getSymbol(false));
bool isFloatOp = (inst->type->scalarType == Type::Float);
uint64_t opcode = op.get<uint64_t>();
switch(opcode)
{
case 0: inst->op = isFloatOp ? Operation::FAdd : Operation::Add; break;
case 1: inst->op = isFloatOp ? Operation::FSub : Operation::Sub; break;
case 2: inst->op = isFloatOp ? Operation::FMul : Operation::Mul; break;
case 3: inst->op = Operation::UDiv; break;
case 4: inst->op = isFloatOp ? Operation::FDiv : Operation::SDiv; break;
case 5: inst->op = Operation::URem; break;
case 6: inst->op = isFloatOp ? Operation::FRem : Operation::SRem; break;
case 7: inst->op = Operation::ShiftLeft; break;
case 8: inst->op = Operation::LogicalShiftRight; break;
case 9: inst->op = Operation::ArithShiftRight; break;
case 10: inst->op = Operation::And; break;
case 11: inst->op = Operation::Or; break;
case 12: inst->op = Operation::Xor; break;
default:
inst->op = Operation::And;
RDCERR("Unhandled binop type %llu", opcode);
break;
}
inst->op = DecodeBinOp(inst->type, op.get<uint64_t>());
if(op.remaining() > 0)
{
@@ -1374,7 +1363,7 @@ Program::Program(const byte *bytes, size_t length) : alloc(32 * 1024)
if(flags & 0x1)
inst->opFlags() |= InstructionFlags::Exact;
}
else if(isFloatOp)
else if(inst->type->scalarType == Type::Float)
{
// fast math flags overlap
inst->opFlags() = InstructionFlags(flags);
@@ -307,6 +307,55 @@ enum class Operation : uint8_t
AtomicUMin,
};
inline Operation DecodeBinOp(const Type *type, uint64_t opcode)
{
bool isFloatOp = (type->scalarType == Type::Float);
switch(opcode)
{
case 0: return isFloatOp ? Operation::FAdd : Operation::Add; break;
case 1: return isFloatOp ? Operation::FSub : Operation::Sub; break;
case 2: return isFloatOp ? Operation::FMul : Operation::Mul; break;
case 3: return Operation::UDiv; break;
case 4: return isFloatOp ? Operation::FDiv : Operation::SDiv; break;
case 5: return Operation::URem; break;
case 6: return isFloatOp ? Operation::FRem : Operation::SRem; break;
case 7: return Operation::ShiftLeft; break;
case 8: return Operation::LogicalShiftRight; break;
case 9: return Operation::ArithShiftRight; break;
case 10: return Operation::And; break;
case 11: return Operation::Or; break;
case 12: return Operation::Xor; break;
default: RDCERR("Unhandled binop type %llu", opcode); return Operation::And;
}
}
inline uint64_t EncodeBinOp(Operation op)
{
switch(op)
{
case Operation::FAdd:
case Operation::Add: return 0; break;
case Operation::FSub:
case Operation::Sub: return 1; break;
case Operation::FMul:
case Operation::Mul: return 2; break;
case Operation::UDiv: return 3; break;
case Operation::FDiv:
case Operation::SDiv: return 4; break;
case Operation::URem: return 5; break;
case Operation::FRem:
case Operation::SRem: return 6; break;
case Operation::ShiftLeft: return 7; break;
case Operation::LogicalShiftRight: return 8; break;
case Operation::ArithShiftRight: return 9; break;
case Operation::And: return 10; break;
case Operation::Or: return 11; break;
case Operation::Xor: return 12; break;
default: return ~0U;
}
}
inline Operation DecodeCast(uint64_t opcode)
{
switch(opcode)
@@ -349,6 +398,11 @@ inline uint64_t EncodeCast(Operation op)
}
}
inline bool IsCast(Operation op)
{
return EncodeCast(op) != ~0U;
}
enum class ValueKind : uint32_t
{
ForwardReferencePlaceholder,
@@ -1667,7 +1667,7 @@ void ProgramEditor::EncodeConstants(LLVMBC::BitcodeWriter &writer,
writer.Record(LLVMBC::ConstantsRecord::EVAL_GEP, vals);
}
else if(c->op != Operation::NoOp)
else if(IsCast(c->op))
{
uint64_t cast = EncodeCast(c->op);
RDCASSERT(cast != ~0U);
@@ -1675,6 +1675,14 @@ void ProgramEditor::EncodeConstants(LLVMBC::BitcodeWriter &writer,
writer.Record(LLVMBC::ConstantsRecord::EVAL_CAST,
{cast, getTypeID(c->getInner()->type), getValueID(c->getInner())});
}
else if(c->op != Operation::NoOp)
{
uint64_t binop = EncodeBinOp(c->op);
RDCASSERT(binop != ~0U);
writer.Record(LLVMBC::ConstantsRecord::EVAL_BINOP,
{binop, getValueID(c->getMembers()[0]), getValueID(c->getMembers()[1])});
}
else if(c->isData())
{
rdcarray<uint64_t> vals;
@@ -1913,6 +1913,71 @@ rdcstr Constant::toString(bool withType) const
ret += ")";
break;
}
case Operation::FAdd:
case Operation::FSub:
case Operation::FMul:
case Operation::FDiv:
case Operation::FRem:
case Operation::Add:
case Operation::Sub:
case Operation::Mul:
case Operation::UDiv:
case Operation::SDiv:
case Operation::URem:
case Operation::SRem:
case Operation::ShiftLeft:
case Operation::LogicalShiftRight:
case Operation::ArithShiftRight:
case Operation::And:
case Operation::Or:
case Operation::Xor:
{
switch(op)
{
case Operation::FAdd: ret += "fadd "; break;
case Operation::FSub: ret += "fsub "; break;
case Operation::FMul: ret += "fmul "; break;
case Operation::FDiv: ret += "fdiv "; break;
case Operation::FRem: ret += "frem "; break;
case Operation::Add: ret += "add "; break;
case Operation::Sub: ret += "sub "; break;
case Operation::Mul: ret += "mul "; break;
case Operation::UDiv: ret += "udiv "; break;
case Operation::SDiv: ret += "sdiv "; break;
case Operation::URem: ret += "urem "; break;
case Operation::SRem: ret += "srem "; break;
case Operation::ShiftLeft: ret += "shl "; break;
case Operation::LogicalShiftRight: ret += "lshr "; break;
case Operation::ArithShiftRight: ret += "ashr "; break;
case Operation::And: ret += "and "; break;
case Operation::Or: ret += "or "; break;
case Operation::Xor: ret += "xor "; break;
default: break;
}
ret += "(";
for(size_t i = 0; i < members->size(); i++)
{
if(i > 0)
ret += ", ";
if(Literal *l = cast<Literal>(members->at(i)))
{
ShaderValue v;
v.u64v[0] = l->literal;
shaderValAppendToString(members->at(i)->type, v, 0, ret);
}
else
{
ret += members->at(i)->toString(withType);
}
}
ret += ")";
break;
}
}
}
else if(type->type == Type::Scalar)
@@ -109,6 +109,7 @@ enum class ConstantsRecord : uint32_t
AGGREGATE = 7,
STRING = 8,
CSTRING = 9,
EVAL_BINOP = 10,
EVAL_CAST = 11,
EVAL_GEP = 20,
DATA = 22,