Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <algorithm>
- #include <cassert>
- #include <iomanip>
- #include <iostream>
- #include <locale>
- #include <map>
- #include <memory>
- #include <sstream>
- #include <string>
- #include <unordered_map>
- #include <unordered_set>
- #include <vector>
- typedef enum VSSampleType {
- stInteger = 0,
- stFloat = 1
- } VSSampleType;
- typedef struct VSFormat {
- char name[32];
- int id;
- int colorFamily; /* see VSColorFamily */
- int sampleType; /* see VSSampleType */
- int bitsPerSample; /* number of significant bits */
- int bytesPerSample; /* actual storage is always in a power of 2 and the smallest possible that can fit the number of bits used per sample */
- int subSamplingW; /* log2 subsampling factor, applied to second and third plane */
- int subSamplingH;
- int numPlanes; /* implicit from colorFamily */
- } VSFormat;
- typedef struct VSVideoInfo {
- const VSFormat *format;
- int64_t fpsNum;
- int64_t fpsDen;
- int width;
- int height;
- int numFrames; /* api 3.2 - no longer allowed to be 0 */
- int flags;
- } VSVideoInfo;
- enum class ExprOpType {
- // Terminals.
- MEM_LOAD_U8, MEM_LOAD_U16, MEM_LOAD_F16, MEM_LOAD_F32, CONSTANT,
- MEM_STORE_U8, MEM_STORE_U16, MEM_STORE_F16, MEM_STORE_F32,
- // Arithmetic primitives.
- ADD, SUB, MUL, DIV, FMA, SQRT, ABS, NEG, MAX, MIN, CMP,
- // Logical operators.
- AND, OR, XOR, NOT,
- // Transcendental functions.
- EXP, LOG, POW,
- // Ternary operator
- TERNARY,
- // Meta-node holding true/false branches of ternary.
- MUX,
- // Stack helpers.
- DUP, SWAP,
- };
- static const char *op_names[] = {
- "loadu8", "loadu16", "loadf16", "loadf32", "constant",
- "storeu8", "storeu16", "storef16", "storef32",
- "add", "sub", "mul", "div", "fma", "sqrt", "abs", "neg", "max", "min", "cmp",
- "and", "or", "xor", "not",
- "exp", "log", "pow",
- "ternary",
- "mux",
- "dup", "swap",
- };
- static_assert(sizeof(op_names) / sizeof(op_names[0]) == static_cast<size_t>(ExprOpType::SWAP) + 1, "");
- enum class FMAType {
- FMADD = 0, // (b * c) + a
- FMSUB = 1, // (b * c) - a
- FNMADD = 2, // -(b * c) + a
- FNMSUB = 3, // -(b * c) - a
- };
- enum class ComparisonType {
- EQ = 0,
- LT = 1,
- LE = 2,
- NEQ = 4,
- NLT = 5,
- NLE = 6,
- };
- static const char *cmp_names[8] = {
- "EQ", "LT", "LE", "?", "NEQ", "NLT", "NLE", "?"
- };
- union ExprUnion {
- int32_t i;
- uint32_t u;
- float f;
- constexpr ExprUnion() : u{} {}
- constexpr ExprUnion(int32_t i) : i(i) {}
- constexpr ExprUnion(uint32_t u) : u(u) {}
- constexpr ExprUnion(float f) : f(f) {}
- };
- struct ExprOp {
- ExprOpType type;
- ExprUnion imm;
- ExprOp(ExprOpType type, ExprUnion param = {}) : type(type), imm(param) {}
- };
- bool operator==(const ExprOp &lhs, const ExprOp &rhs) { return lhs.type == rhs.type && lhs.imm.u == rhs.imm.u; }
- bool operator!=(const ExprOp &lhs, const ExprOp &rhs) { return !(lhs == rhs); }
- struct ExprInstruction {
- ExprOp op;
- int dst;
- int src1;
- int src2;
- int src3;
- ExprInstruction(ExprOp op) : op(op), dst(-1), src1(-1), src2(-1), src3(-1) {}
- };
- struct ExpressionTreeNode {
- ExprOp op;
- ExpressionTreeNode *left;
- ExpressionTreeNode *right;
- ExpressionTreeNode *parent;
- int valueNum;
- explicit ExpressionTreeNode(ExprOp op) : op(op), left(nullptr), right(nullptr), parent(nullptr), valueNum(-1) {}
- template <class T>
- void preorder(T visitor)
- {
- if (visitor(*this))
- return;
- if (left)
- left->preorder(visitor);
- if (right)
- right->preorder(visitor);
- }
- template <class T>
- void postorder(T visitor)
- {
- if (left)
- left->postorder(visitor);
- if (right)
- right->postorder(visitor);
- visitor(*this);
- }
- };
- class ExpressionTree {
- std::vector<std::unique_ptr<ExpressionTreeNode>> nodes;
- ExpressionTreeNode *root;
- public:
- ExpressionTree() : root() {}
- ExpressionTreeNode *getRoot() { return root; }
- const ExpressionTreeNode *getRoot() const { return root; }
- void setRoot(ExpressionTreeNode *node) { root = node; }
- ExpressionTreeNode *makeNode(ExprOp data)
- {
- nodes.push_back(std::unique_ptr<ExpressionTreeNode>(new ExpressionTreeNode(data)));
- return nodes.back().get();
- }
- ExpressionTreeNode *clone(const ExpressionTreeNode *node)
- {
- if (!node)
- return nullptr;
- ExpressionTreeNode *newnode = makeNode(node->op);
- ExpressionTreeNode *newleft = clone(node->left);
- ExpressionTreeNode *newright = clone(node->right);
- if (newleft) {
- newnode->left = newleft;
- newnode->left->parent = newnode;
- }
- if (newright) {
- newnode->right = newright;
- newnode->right->parent = newnode;
- }
- return newnode;
- }
- };
- std::vector<std::string> tokenize(const std::string &expr)
- {
- std::vector<std::string> tokens;
- auto it = expr.begin();
- auto prev = expr.begin();
- while (it != expr.end()) {
- char c = *it;
- if (std::isspace(c)) {
- if (it != prev)
- tokens.push_back(expr.substr(prev - expr.begin(), it - prev));
- prev = it + 1;
- }
- ++it;
- }
- if (prev != expr.end())
- tokens.push_back(expr.substr(prev - expr.begin(), expr.end() - prev));
- return tokens;
- }
- ExprOp decodeToken(const std::string &token)
- {
- static const std::unordered_map<std::string, ExprOp> simple{
- { "+", { ExprOpType::ADD } },
- { "-", { ExprOpType::SUB } },
- { "*", { ExprOpType::MUL } },
- { "/", { ExprOpType::DIV } } ,
- { "sqrt", { ExprOpType::SQRT } },
- { "abs", { ExprOpType::ABS } },
- { "max", { ExprOpType::MAX } },
- { "min", { ExprOpType::MIN } },
- { "<", { ExprOpType::CMP, static_cast<int>(ComparisonType::LT) } },
- { ">", { ExprOpType::CMP, static_cast<int>(ComparisonType::NLE) } },
- { "=", { ExprOpType::CMP, static_cast<int>(ComparisonType::EQ) } },
- { ">=", { ExprOpType::CMP, static_cast<int>(ComparisonType::NLT) } },
- { "<=", { ExprOpType::CMP, static_cast<int>(ComparisonType::LE) } },
- { "and", { ExprOpType::AND } },
- { "or", { ExprOpType::OR } },
- { "xor", { ExprOpType::XOR } },
- { "not", { ExprOpType::NOT } },
- { "?", { ExprOpType::TERNARY } },
- { "exp", { ExprOpType::EXP } },
- { "log", { ExprOpType::LOG } },
- { "pow", { ExprOpType::POW } },
- { "dup", { ExprOpType::DUP, 0 } },
- { "swap", { ExprOpType::SWAP, 1 } },
- };
- auto it = simple.find(token);
- if (it != simple.end()) {
- return it->second;
- } else if (token.size() == 1 && token[0] >= 'a' && token[0] <= 'z') {
- return{ ExprOpType::MEM_LOAD_U8, token[0] >= 'x' ? token[0] - 'x' : token[0] - 'a' + 3 };
- } else if (token.substr(0, 3) == "dup" || token.substr(0, 4) == "swap") {
- size_t count;
- int idx = -1;
- try {
- idx = std::stoi(token.substr(token[0] == 'd' ? 3 : 4), &count);
- } catch (...) {
- // ...
- }
- if (idx < 0)
- throw std::runtime_error("illegal token: " + token);
- return{ token[0] == 'd' ? ExprOpType::DUP : ExprOpType::SWAP, idx };
- } else {
- float f;
- std::string s;
- std::istringstream numStream(token);
- numStream.imbue(std::locale::classic());
- if (!(numStream >> f))
- throw std::runtime_error("failed to convert '" + token + "' to float");
- if (numStream >> s)
- throw std::runtime_error("failed to convert '" + token + "' to float, not the whole token could be converted");
- return{ ExprOpType::CONSTANT, f };
- }
- }
- ExpressionTree parseExpr(const std::string &expr, const VSVideoInfo * const *vi, int numInputs)
- {
- constexpr unsigned char numOperands[] = {
- 0, // MEM_LOAD_U8
- 0, // MEM_LOAD_U16
- 0, // MEM_LOAD_F16
- 0, // MEM_LOAD_F32
- 0, // CONSTANT
- 0, // MEM_STORE_U8
- 0, // MEM_STORE_U16
- 0, // MEM_STORE_F16
- 0, // MEM_STORE_F32
- 2, // ADD
- 2, // SUB
- 2, // MUL
- 2, // DIV
- 3, // FMA
- 1, // SQRT
- 1, // ABS
- 1, // NEG
- 2, // MAX
- 2, // MIN
- 2, // CMP
- 2, // AND
- 2, // OR
- 2, // XOR
- 2, // NOT
- 1, // EXP
- 1, // LOG
- 2, // POW
- 3, // TERNARY
- 0, // MUX
- 0, // DUP
- 0, // SWAP
- };
- static_assert(sizeof(numOperands) == static_cast<unsigned>(ExprOpType::SWAP) + 1, "invalid table");
- auto tokens = tokenize(expr);
- ExpressionTree tree;
- std::vector<ExpressionTreeNode *> stack;
- for (const std::string &tok : tokens) {
- ExprOp op = decodeToken(tok);
- // Check validity.
- if (op.type == ExprOpType::MEM_LOAD_U8 && op.imm.i >= numInputs)
- throw std::runtime_error("reference to undefined clip: " + tok);
- if ((op.type == ExprOpType::DUP || op.type == ExprOpType::SWAP) && op.imm.u >= stack.size())
- throw std::runtime_error("insufficient values on stack: " + tok);
- if (stack.size() < numOperands[static_cast<size_t>(op.type)])
- throw std::runtime_error("insufficient values on stack: " + tok);
- // Rename load operations with the correct data type.
- if (op.type == ExprOpType::MEM_LOAD_U8) {
- const VSFormat *format = vi[op.imm.i]->format;
- if (format->sampleType == stInteger && format->bytesPerSample == 1)
- op.type = ExprOpType::MEM_LOAD_U8;
- else if (format->sampleType == stInteger && format->bytesPerSample == 2)
- op.type = ExprOpType::MEM_LOAD_U16;
- else if (format->sampleType == stFloat && format->bytesPerSample == 2)
- op.type = ExprOpType::MEM_LOAD_F16;
- else if (format->sampleType == stFloat && format->bytesPerSample == 4)
- op.type = ExprOpType::MEM_LOAD_F32;
- }
- // Apply DUP and SWAP in the frontend.
- if (op.type == ExprOpType::DUP) {
- stack.push_back(tree.clone(stack[stack.size() - 1 - op.imm.u]));
- } else if (op.type == ExprOpType::SWAP) {
- std::swap(stack.back(), stack[stack.size() - 1 - op.imm.u]);
- } else {
- size_t operands = numOperands[static_cast<size_t>(op.type)];
- if (operands == 0) {
- stack.push_back(tree.makeNode(op));
- } else if (operands == 1) {
- ExpressionTreeNode *child = stack.back();
- stack.pop_back();
- ExpressionTreeNode *node = tree.makeNode(op);
- node->left = child;
- node->left->parent = node;
- stack.push_back(node);
- } else if (operands == 2) {
- ExpressionTreeNode *left = stack[stack.size() - 2];
- ExpressionTreeNode *right = stack[stack.size() - 1];
- stack.resize(stack.size() - 2);
- ExpressionTreeNode *node = tree.makeNode(op);
- node->left = left;
- node->left->parent = node;
- node->right = right;
- node->right->parent = node;
- stack.push_back(node);
- } else if (operands == 3) {
- ExpressionTreeNode *arg1 = stack[stack.size() - 3];
- ExpressionTreeNode *arg2 = stack[stack.size() - 2];
- ExpressionTreeNode *arg3 = stack[stack.size() - 1];
- stack.resize(stack.size() - 3);
- ExpressionTreeNode *mux = tree.makeNode(ExprOpType::MUX);
- mux->left = arg2;
- mux->left->parent = mux;
- mux->right = arg3;
- mux->right->parent = mux;
- ExpressionTreeNode *node= tree.makeNode(op);
- node->left = arg1;
- node->left->parent = node;
- node->right = mux;
- node->right->parent = node;
- stack.push_back(node);
- }
- }
- }
- if (stack.empty())
- throw std::runtime_error("empty expression: " + expr);
- if (stack.size() > 1)
- throw std::runtime_error("unconsumed values on stack: " + expr);
- tree.setRoot(stack.back());
- return tree;
- }
- bool equalSubTree(const ExpressionTreeNode *lhs, const ExpressionTreeNode *rhs)
- {
- if (lhs->valueNum >= 0 && rhs->valueNum >= 0)
- return lhs->valueNum == rhs->valueNum;
- if (lhs->op.type != rhs->op.type || lhs->op.imm.u != rhs->op.imm.u)
- return false;
- if (!!lhs->left != !!rhs->left || !!lhs->right != !!rhs->right)
- return false;
- if (lhs->left && !equalSubTree(lhs->left, rhs->left))
- return false;
- if (lhs->right && !equalSubTree(lhs->right, rhs->right))
- return false;
- return true;
- }
- bool isConstantExpr(const ExpressionTreeNode &node)
- {
- switch (node.op.type) {
- case ExprOpType::MEM_LOAD_U8:
- case ExprOpType::MEM_LOAD_U16:
- case ExprOpType::MEM_LOAD_F16:
- case ExprOpType::MEM_LOAD_F32:
- return false;
- case ExprOpType::CONSTANT:
- return true;
- default:
- return (!node.left || isConstantExpr(*node.left)) && (!node.right || isConstantExpr(*node.right));
- }
- }
- bool isConstant(const ExpressionTreeNode &node)
- {
- return node.op.type == ExprOpType::CONSTANT;
- }
- bool isConstant(const ExpressionTreeNode &node, float val)
- {
- return node.op.type == ExprOpType::CONSTANT && node.op.imm.f == val;
- }
- float evalConstantExpr(const ExpressionTreeNode &node)
- {
- switch (node.op.type) {
- case ExprOpType::CONSTANT: return node.op.imm.f;
- case ExprOpType::ADD: return evalConstantExpr(*node.left) + evalConstantExpr(*node.right);
- case ExprOpType::SUB: return evalConstantExpr(*node.left) - evalConstantExpr(*node.right);
- case ExprOpType::MUL: return evalConstantExpr(*node.left) * evalConstantExpr(*node.right);
- case ExprOpType::DIV: return evalConstantExpr(*node.left) / evalConstantExpr(*node.right);
- case ExprOpType::FMA:
- switch (static_cast<FMAType>(node.op.imm.u)) {
- case FMAType::FMADD: return evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) + evalConstantExpr(*node.left);
- case FMAType::FMSUB: return evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) - evalConstantExpr(*node.left);
- case FMAType::FNMADD: return -evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) + evalConstantExpr(*node.left);
- case FMAType::FNMSUB: return -evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) - evalConstantExpr(*node.left);
- }
- return NAN;
- case ExprOpType::SQRT: return std::sqrt(evalConstantExpr(*node.left));
- case ExprOpType::ABS: return std::fabs(evalConstantExpr(*node.left));
- case ExprOpType::NEG: return -evalConstantExpr(*node.left);
- case ExprOpType::MAX: return std::max(evalConstantExpr(*node.left), evalConstantExpr(*node.right));
- case ExprOpType::MIN: return std::min(evalConstantExpr(*node.left), evalConstantExpr(*node.right));
- case ExprOpType::CMP:
- switch (static_cast<ComparisonType>(node.op.imm.u)) {
- case ComparisonType::EQ: return evalConstantExpr(*node.left) == evalConstantExpr(*node.right) ? 1.0f : 0.0f;
- case ComparisonType::LT: return evalConstantExpr(*node.left) < evalConstantExpr(*node.right) ? 1.0f : 0.0f;
- case ComparisonType::LE: return evalConstantExpr(*node.left) <= evalConstantExpr(*node.right) ? 1.0f : 0.0f;
- case ComparisonType::NEQ: return evalConstantExpr(*node.left) != evalConstantExpr(*node.right) ? 1.0f : 0.0f;
- case ComparisonType::NLT: return evalConstantExpr(*node.left) >= evalConstantExpr(*node.right) ? 1.0f : 0.0f;
- case ComparisonType::NLE: return evalConstantExpr(*node.left) > evalConstantExpr(*node.right) ? 1.0f : 0.0f;
- }
- return NAN;
- case ExprOpType::AND: return evalConstantExpr(*node.left) > 0.0f && evalConstantExpr(*node.right) > 0.0f ? 1.0f : 0.0f;
- case ExprOpType::OR: return evalConstantExpr(*node.left) > 0.0f || evalConstantExpr(*node.right) > 0.0f ? 1.0f : 0.0f;
- case ExprOpType::XOR: return evalConstantExpr(*node.left) > 0.0f != evalConstantExpr(*node.right) > 0.0f ? 1.0f : 0.0f;
- case ExprOpType::NOT: return evalConstantExpr(*node.left) > 0.0f ? 0.0f : 1.0f;
- case ExprOpType::EXP: return std::exp(evalConstantExpr(*node.left));
- case ExprOpType::LOG: return std::log(evalConstantExpr(*node.left));
- case ExprOpType::POW: return std::pow(evalConstantExpr(*node.left), evalConstantExpr(*node.right));
- case ExprOpType::TERNARY: return evalConstantExpr(*node.left) > 0.0f ? evalConstantExpr(*node.right->left) : evalConstantExpr(*node.right->right);
- default: return NAN;
- }
- }
- bool isOpCode(const ExpressionTreeNode &node, std::initializer_list<ExprOpType> types)
- {
- for (ExprOpType type : types) {
- if (node.op.type == type)
- return true;
- }
- return false;
- }
- bool isInteger(float x)
- {
- return std::floor(x) == x;
- }
- void replaceNode(ExpressionTreeNode &node, const ExpressionTreeNode &replacement)
- {
- node.op = replacement.op;
- if (node.left)
- node.left->parent = nullptr;
- if (node.right)
- node.right->parent = nullptr;
- node.left = replacement.left;
- node.right = replacement.right;
- if (node.left)
- node.left->parent = &node;
- if (node.right)
- node.right->parent = &node;
- }
- void applyValueNumbering(ExpressionTree &tree)
- {
- std::vector<ExpressionTreeNode *> numbered;
- int valueNum = 0;
- tree.getRoot()->postorder([&](ExpressionTreeNode &node)
- {
- node.valueNum = -1;
- });
- tree.getRoot()->postorder([&](ExpressionTreeNode &node)
- {
- if (node.op.type == ExprOpType::MUX)
- return;
- for (ExpressionTreeNode *testnode : numbered) {
- if (equalSubTree(&node, testnode)) {
- node.valueNum = testnode->valueNum;
- return;
- }
- }
- node.valueNum = valueNum++;
- numbered.push_back(&node);
- });
- }
- ExpressionTreeNode *integerPower(ExpressionTree &tree, const ExpressionTreeNode &node, int exponent)
- {
- if (exponent == 1)
- return tree.clone(&node);
- ExpressionTreeNode *lhs = integerPower(tree, node, (exponent + 1) / 2);
- ExpressionTreeNode *rhs = integerPower(tree, node, exponent - (exponent + 1) / 2);
- ExpressionTreeNode *mulNode = tree.makeNode({ ExprOpType::MUL });
- mulNode->left = lhs;
- mulNode->right = rhs;
- mulNode->left->parent = mulNode;
- mulNode->right->parent = mulNode;
- return mulNode;
- }
- bool applyLocalOptimizations(ExpressionTree &tree)
- {
- bool changed = false;
- tree.getRoot()->postorder([&](ExpressionTreeNode &node)
- {
- if (node.op.type == ExprOpType::MUX)
- return;
- // Constant folding.
- if (node.op.type != ExprOpType::CONSTANT && isConstantExpr(node)) {
- float val = evalConstantExpr(node);
- replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, val } });
- changed = true;
- }
- // Move constants to right-hand side to simplify identities.
- if (isOpCode(node, { ExprOpType::ADD, ExprOpType::MUL }) && isConstant(*node.left) && !isConstant(*node.right)) {
- std::swap(node.left, node.right);
- changed = true;
- }
- // x + 0 = x x - 0 = x
- if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && isConstant(*node.right, 0.0f)) {
- replaceNode(node, *node.left);
- changed = true;
- }
- // x * 0 = 0 0 / x = 0
- if ((node.op == ExprOpType::MUL && isConstant(*node.right, 0.0f)) || (node.op == ExprOpType::DIV && isConstant(*node.left, 0.0f))) {
- replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
- changed = true;
- }
- // x * 1 = x x / 1 = x
- if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, 1.0f)) {
- replaceNode(node, *node.left);
- changed = true;
- }
- // sqrt(x) = x ** 0.5
- if (node.op == ExprOpType::SQRT) {
- node.op = ExprOpType::POW;
- node.right = tree.makeNode({ ExprOpType::CONSTANT, 0.5f });
- node.right->parent = &node;
- changed = true;
- }
- // log(exp(x)) = x exp(log(x)) = x
- if ((node.op == ExprOpType::LOG && node.left->op == ExprOpType::EXP) || (node.op == ExprOpType::EXP && node.left->op == ExprOpType::LOG)) {
- replaceNode(node, *node.left->left);
- changed = true;
- }
- // x ** 0 = 1
- if (node.op == ExprOpType::POW && isConstant(*node.right, 0.0f)) {
- replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
- changed = true;
- }
- // x ** 1 = x
- if (node.op == ExprOpType::POW && isConstant(*node.right, 1.0f)) {
- replaceNode(node, *node.left);
- changed = true;
- }
- // (a ** b) ** c = a ** (b * c)
- if (node.op == ExprOpType::POW && node.left->op == ExprOpType::POW) {
- ExpressionTreeNode *a = node.left->left;
- ExpressionTreeNode *b = node.left->right;
- ExpressionTreeNode *c = node.right;
- replaceNode(*node.left, *a);
- node.right = tree.makeNode(ExprOpType::MUL);
- node.right->left = b;
- node.right->right = c;
- node.right->left->parent = node.right;
- node.right->right->parent = node.right;
- changed = true;
- }
- // 0 ? x y = y 1 ? x y = x
- if (node.op == ExprOpType::TERNARY && isConstant(*node.left)) {
- ExpressionTreeNode *replacement = node.left->op.imm.f > 0.0f ? node.right->left : node.right->right;
- replaceNode(node, *replacement);
- changed = true;
- }
- });
- return changed;
- }
- typedef std::map<int, float> exponentMap;
- typedef std::vector<std::pair<exponentMap, float>> additiveTermList;
- bool isEqualTerm(const exponentMap &lhs, const exponentMap &rhs)
- {
- auto it1 = lhs.begin();
- auto it2 = rhs.begin();
- while (it1 != lhs.end() && it2 != rhs.end()) {
- if (it1->first != it2->first || it1->second != it2->second)
- return false;
- ++it1;
- ++it2;
- }
- return it1 == lhs.end() && it2 == rhs.end();
- }
- bool sortTerms(additiveTermList &list, const std::unordered_map<int, const ExpressionTreeNode *> &values)
- {
- auto pred = [&](const std::pair<exponentMap, float> &lhs, const std::pair<exponentMap, float> &rhs)
- {
- std::vector<std::pair<int, float>> lhsTerms(lhs.first.begin(), lhs.first.end());
- std::vector<std::pair<int, float>> rhsTerms(rhs.first.begin(), rhs.first.end());
- auto pred2 = [&](const std::pair<int, float> &lhs2, const std::pair<int, float> &rhs2)
- {
- std::initializer_list<ExprOpType> memOpCodes = { ExprOpType::MEM_LOAD_U8, ExprOpType::MEM_LOAD_U16, ExprOpType::MEM_LOAD_F16, ExprOpType::MEM_LOAD_F32 };
- if (lhs2.first == rhs2.first)
- return lhs2.second < rhs2.second;
- const ExpressionTreeNode *lhsValue = values.at(lhs2.first);
- const ExpressionTreeNode *rhsValue = values.at(rhs2.first);
- int lhsCategory = isConstant(*lhsValue) ? 0 : isOpCode(*lhsValue, memOpCodes) ? 1 : 2;
- int rhsCategory = isConstant(*rhsValue) ? 0 : isOpCode(*rhsValue, memOpCodes) ? 1 : 2;
- // Simpler terms towards the right.
- if (lhsCategory != rhsCategory)
- return lhsCategory > rhsCategory;
- if (lhsCategory == 0)
- return lhsValue->op.imm.f < rhsValue->op.imm.f;
- else if (lhsCategory == 1)
- return lhsValue->op.imm.u < rhsValue->op.imm.u;
- else
- return lhs2.first < rhs2.first;
- };
- std::sort(lhsTerms.begin(), lhsTerms.end(), pred2);
- std::sort(rhsTerms.begin(), rhsTerms.end(), pred2);
- return std::lexicographical_compare(lhsTerms.begin(), lhsTerms.end(), rhsTerms.begin(), rhsTerms.end(), pred2);
- };
- if (std::is_sorted(list.begin(), list.end(), pred))
- return true;
- std::sort(list.begin(), list.end(), pred);
- return false;
- }
- void expandMultiplies(exponentMap &term, std::unordered_map<int, const ExpressionTreeNode *> &values)
- {
- bool changed = true;
- while (changed) {
- changed = false;
- for (auto it = term.begin(); it != term.end();) {
- const ExpressionTreeNode *value = values.at(it->first);
- if (value->op == ExprOpType::POW && isConstant(*value->right)) {
- values[value->left->valueNum] = value->left;
- term[value->left->valueNum] += it->second * value->right->op.imm.f;
- it = term.erase(it);
- changed = true;
- continue;
- } else if (value->op == ExprOpType::MUL) {
- values[value->left->valueNum] = value->left;
- values[value->right->valueNum] = value->right;
- term[value->left->valueNum] += it->second;
- term[value->right->valueNum] += it->second;
- it = term.erase(it);
- changed = true;
- continue;
- } else if (value->op == ExprOpType::DIV) {
- values[value->left->valueNum] = value->left;
- values[value->right->valueNum] = value->right;
- term[value->left->valueNum] += it->second;
- term[value->right->valueNum] -= it->second;
- it = term.erase(it);
- changed = true;
- continue;
- }
- ++it;
- }
- }
- }
- std::pair<float, size_t> addConstants(additiveTermList &terms, const std::unordered_map<int, const ExpressionTreeNode *> &values)
- {
- float scalarTerm = 0.0f;
- size_t numScalarEliminated = 0;
- bool nonTerminalScalar = false;
- for (auto it1 = terms.begin(); it1 < terms.end();) {
- for (auto it2 = it1->first.begin(); it2 != it1->first.end(); ++it2) {
- const ExpressionTreeNode *value = values.at(it2->first);
- if (isConstant(*value)) {
- it1->second *= std::pow(value->op.imm.f, it2->second);
- it2->second = 0.0f;
- }
- }
- for (auto it2 = it1->first.begin(); it2 != it1->first.end();) {
- if (it2->second == 0.0f) {
- it2 = it1->first.erase(it2);
- continue;
- }
- ++it2;
- }
- if (it1->first.empty()) {
- scalarTerm += it1->second;
- it1->second = 0.0f;
- nonTerminalScalar = nonTerminalScalar || (it1 + 1 != terms.end());
- it1 = terms.erase(it1);
- ++numScalarEliminated;
- continue;
- }
- ++it1;
- }
- return{ scalarTerm, numScalarEliminated + nonTerminalScalar };
- }
- size_t addIdenticalTerms(additiveTermList &terms)
- {
- size_t numCanceled = 0;
- for (auto it1 = terms.begin(); it1 < terms.end();) {
- for (auto it2 = it1 + 1; it2 < terms.end(); ++it2) {
- if (isEqualTerm(it1->first, it2->first)) {
- it1->second += it2->second;
- it2->second = 0.0f;
- }
- }
- if (it1->second == 0.0f) {
- it1 = terms.erase(it1);
- ++numCanceled;
- continue;
- }
- ++it1;
- }
- return numCanceled;
- }
- ExpressionTreeNode *emitMultiplicativeSequence(ExpressionTree &tree, const exponentMap &terms, float scalarTerm, const std::unordered_map<int, const ExpressionTreeNode *> &values)
- {
- ExpressionTreeNode *node = nullptr;
- for (auto &t : terms) {
- ExpressionTreeNode *powNode = tree.makeNode(ExprOpType::POW);
- powNode->left = tree.clone(values.at(t.first));
- powNode->right = tree.makeNode({ ExprOpType::CONSTANT, t.second });
- powNode->left->parent = powNode;
- powNode->right->parent = powNode;
- if (node) {
- ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
- mulNode->left = node;
- mulNode->right = powNode;
- mulNode->left->parent = mulNode;
- mulNode->right->parent = mulNode;
- node = mulNode;
- } else {
- node = powNode;
- }
- }
- if (node) {
- ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
- mulNode->left = node;
- mulNode->right = tree.makeNode({ ExprOpType::CONSTANT, scalarTerm });
- mulNode->left->parent = mulNode;
- mulNode->right->parent = mulNode;
- node = mulNode;
- } else {
- node = tree.makeNode({ ExprOpType::CONSTANT, scalarTerm });
- }
- return node;
- }
- ExpressionTreeNode *emitAdditiveSequence(ExpressionTree &tree, const additiveTermList &terms, float scalarTerm, const std::unordered_map<int, const ExpressionTreeNode *> &values)
- {
- ExpressionTreeNode *head = nullptr;
- for (const auto &term : terms) {
- assert(!term.first.empty());
- ExpressionTreeNode *node = emitMultiplicativeSequence(tree, term.first, term.second, values);
- if (head) {
- ExpressionTreeNode *addNode = tree.makeNode(term.second < 0 ? ExprOpType::SUB : ExprOpType::ADD);
- addNode->left = head;
- addNode->right = node;
- addNode->left->parent = addNode;
- addNode->right->parent = addNode;
- head = addNode;
- } else {
- head = node;
- }
- }
- if (head) {
- ExpressionTreeNode *addNode = tree.makeNode(scalarTerm < 0 ? ExprOpType::SUB : ExprOpType::ADD);
- addNode->left = head;
- addNode->right = tree.makeNode({ ExprOpType::CONSTANT, std::fabs(scalarTerm) });
- addNode->left->parent = addNode;
- addNode->right->parent = addNode;
- head = addNode;
- } else {
- head = tree.makeNode({ ExprOpType::CONSTANT, 0.0f });
- }
- return head;
- }
- bool analyzeAdditiveExpression(ExpressionTree &tree, ExpressionTreeNode &node)
- {
- // Stores the exponent of each term in a multiplicative tuple.
- // e.g. 3 * v0^1.5 * v1^2 * v2 => { 0: 1.5, 1: 2, 2: 1 }.
- additiveTermList terms;
- std::unordered_map<int, const ExpressionTreeNode *> values;
- node.preorder([&](ExpressionTreeNode &node)
- {
- if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }))
- return false;
- // Deduce net sign of term.
- const ExpressionTreeNode *parent = node.parent;
- const ExpressionTreeNode *cur = &node;
- int polarity = 1;
- while (parent && isOpCode(*parent, { ExprOpType::ADD, ExprOpType::SUB })) {
- if (parent->op == ExprOpType::SUB && cur == parent->right)
- polarity = -polarity;
- cur = parent;
- parent = parent->parent;
- }
- exponentMap term{ { node.valueNum, 1.0f } };
- terms.emplace_back(std::move(term), static_cast<float>(polarity));
- values[node.valueNum] = &node;
- return true;
- });
- for (auto &term : terms) {
- expandMultiplies(term.first, values);
- }
- // Combine constant terms.
- float scalarTerm = 0.0f;
- size_t numScalarEliminated = 0;
- {
- auto result = addConstants(terms, values);
- scalarTerm += result.first;
- numScalarEliminated += result.second;
- }
- // Cancel identical terms.
- size_t numCanceled = addIdenticalTerms(terms);
- // Normalize order of terms to assist multiplicative analysis.
- bool wasSorted = sortTerms(terms, values);
- if (numCanceled > 0 || numScalarEliminated > 1 || !wasSorted) {
- ExpressionTreeNode *seq = emitAdditiveSequence(tree, terms, scalarTerm, values);
- replaceNode(node, *seq);
- return true;
- }
- return false;
- }
- bool analyzeMultiplicativeExpression(ExpressionTree &tree, ExpressionTreeNode &node)
- {
- std::vector<int> termOrder;
- exponentMap term;
- std::unordered_map<int, const ExpressionTreeNode *> values;
- float scalarTerm = 1.0f;
- size_t numDivs = 0;
- node.preorder([&](ExpressionTreeNode &node)
- {
- if (node.op == ExprOpType::DIV)
- ++numDivs;
- if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }))
- return false;
- // Deduce net sign of term.
- const ExpressionTreeNode *parent = node.parent;
- const ExpressionTreeNode *cur = &node;
- int polarity = 1;
- while (parent && isOpCode(*parent, { ExprOpType::MUL, ExprOpType::DIV })) {
- if (parent->op == ExprOpType::DIV && cur == parent->right)
- polarity = -polarity;
- cur = parent;
- parent = parent->parent;
- }
- term[node.valueNum] += static_cast<float>(polarity);
- termOrder.push_back(node.valueNum);
- values[node.valueNum] = &node;
- return true;
- });
- expandMultiplies(term, values);
- // Combine constants.
- for (auto it = term.begin(); it != term.end();) {
- const ExpressionTreeNode *node = values[it->first];
- if (isConstant(*node)) {
- scalarTerm *= std::powf(node->op.imm.f, it->second);
- it = term.erase(it);
- continue;
- }
- ++it;
- }
- size_t origScalarTerms = 0;
- bool nonTerminalScalar = false;
- for (auto it = termOrder.begin(); it != termOrder.end();) {
- if (isConstant(*values[*it])) {
- nonTerminalScalar = nonTerminalScalar || it + 1 != termOrder.end();
- it = termOrder.erase(it);
- ++origScalarTerms;
- continue;
- }
- ++it;
- }
- if (term.size() + (scalarTerm != 1.0f) < termOrder.size() + origScalarTerms || !std::is_sorted(termOrder.begin(), termOrder.end()) || nonTerminalScalar || numDivs) {
- ExpressionTreeNode *seq = emitMultiplicativeSequence(tree, term, scalarTerm, values);
- replaceNode(node, *seq);
- return true;
- }
- return false;
- }
- bool applyAlgebraicOptimizations(ExpressionTree &tree)
- {
- bool changed = false;
- applyValueNumbering(tree);
- tree.getRoot()->preorder([&](ExpressionTreeNode &node)
- {
- if (node.op.type == ExprOpType::CMP && node.left->valueNum == node.right->valueNum) {
- ComparisonType type = static_cast<ComparisonType>(node.op.imm.u);
- if (type == ComparisonType::EQ || type == ComparisonType::LE || type == ComparisonType::NLT)
- replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
- else
- replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
- changed = true;
- return changed;
- }
- if (node.op.type == ExprOpType::TERNARY && node.right->left->valueNum == node.right->right->valueNum) {
- replaceNode(node, *node.right->left);
- changed = true;
- return changed;
- }
- if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::ADD, ExprOpType::SUB }))) {
- changed = changed || analyzeAdditiveExpression(tree, node);
- return changed;
- }
- if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::MUL, ExprOpType::DIV }))) {
- changed = changed || analyzeMultiplicativeExpression(tree, node);
- return changed;
- }
- return false;
- });
- return changed;
- }
- bool applyStrengthReduction(ExpressionTree &tree)
- {
- bool changed = false;
- tree.getRoot()->postorder([&](ExpressionTreeNode &node)
- {
- if (node.op == ExprOpType::MUX)
- return;
- // 0 - x = -x
- if (node.op == ExprOpType::SUB && isConstant(*node.left, 0.0f)) {
- ExpressionTreeNode *tmp = node.right;
- replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
- node.left = tmp;
- node.left->parent = &node;
- changed = true;
- }
- // x * -1 = -x x / -1 = -x
- if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, -1.0f)) {
- ExpressionTreeNode *tmp = node.left;
- replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
- node.left = tmp;
- node.left->parent = &node;
- changed = true;
- }
- // a + -b = a - b a - -b = a + b
- if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && node.right->op.type == ExprOpType::NEG) {
- node.op = node.op == ExprOpType::ADD ? ExprOpType::SUB : ExprOpType::ADD;
- replaceNode(*node.right, *node.right->left);
- changed = true;
- }
- // -a + b = b - a
- if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::NEG) {
- node.op = ExprOpType::SUB;
- replaceNode(*node.left, *node.left->left);
- std::swap(node.left, node.right);
- }
- // -(a - b) = b - a
- if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::SUB) {
- replaceNode(node, *node.left);
- std::swap(node.left, node.right);
- changed = true;
- }
- // x * 2 = x + x
- if (node.op == ExprOpType::MUL && isConstant(*node.right, 2.0f) && (!node.parent || node.parent->op != ExprOpType::ADD)) {
- ExpressionTreeNode *replacement = tree.clone(node.left);
- node.op = ExprOpType::ADD;
- replaceNode(*node.right, *replacement);
- changed = true;
- }
- // x / y = x * (1 / y)
- if (node.op == ExprOpType::DIV && isConstant(*node.right)) {
- node.op = ExprOpType::MUL;
- node.right->op.imm.f = 1.0f / node.right->op.imm.f;
- changed = true;
- }
- // (1 / x) * y = y / x
- if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV && isConstant(*node.left->left, 1.0f)) {
- node.op = ExprOpType::DIV;
- replaceNode(*node.left, *node.left->right);
- std::swap(node.left, node.right);
- changed = true;
- }
- // x * (1 / y) = x / y
- if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV && isConstant(*node.right->left, 1.0f)) {
- node.op = ExprOpType::DIV;
- replaceNode(*node.right, *node.right->right);
- changed = true;
- }
- // (a / b) * c = (a * c) / b
- if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV) {
- node.op = ExprOpType::DIV;
- node.left->op = ExprOpType::MUL;
- std::swap(node.left->right, node.right);
- node.left->right->parent = node.left;
- node.right->parent = &node;
- changed = true;
- }
- // a * (b / c) = (a * b) / c
- if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV) {
- node.op = ExprOpType::DIV;
- node.right->op = ExprOpType::MUL;
- std::swap(node.left, node.right); // (b * c) / a
- std::swap(node.left->left, node.left->right); // (c * b) / a
- std::swap(node.left->left, node.right); // (a * b) / c
- node.left->left->parent = node.left;
- node.right->parent = &node;
- changed = true;
- }
- // a / (b / c) = (a * c) / b
- if (node.op == ExprOpType::DIV && node.right->op == ExprOpType::DIV) {
- node.right->op = ExprOpType::MUL; // a / (b * c)
- std::swap(node.left, node.right); // (b * c) / a
- std::swap(node.left->left, node.right); // (a * c) / b
- node.left->left->parent = node.left;
- node.right->parent = &node;
- changed = true;
- }
- // (a / b) / c = a / (b * c)
- if (node.op == ExprOpType::DIV && node.left->op == ExprOpType::DIV) {
- node.left->op = ExprOpType::MUL; // (a * b) / c
- std::swap(node.left, node.right); // c / (a * b)
- std::swap(node.left, node.right->left); // a / (c * b)
- std::swap(node.right->left, node.right->right); // a / (b * c)
- node.left->parent = &node;
- node.right->left->parent = node.right;
- node.right->right->parent = node.right;
- changed = true;
- }
- // x ** (n / 2) = sqrt(x ** n)
- if (node.op == ExprOpType::POW && isConstant(*node.right) && !isInteger(node.right->op.imm.f) && isInteger(node.right->op.imm.f * 2.0f)) {
- ExpressionTreeNode *dup = tree.clone(&node);
- replaceNode(node, ExpressionTreeNode{ ExprOpType::SQRT });
- node.left = dup;
- node.left->parent = &node;
- node.left->right->op.imm.f *= 2.0f;
- changed = true;
- }
- // x ** -N = 1 / (x ** N)
- if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f < 0) {
- ExpressionTreeNode *dup = tree.clone(&node);
- replaceNode(node, ExpressionTreeNode{ ExprOpType::DIV });
- node.left = tree.makeNode({ ExprOpType::CONSTANT, 1.0f });
- node.right = dup;
- node.left->parent = &node;
- node.right->parent = &node;
- node.right->right->op.imm.f = -node.right->right->op.imm.f;
- changed = true;
- }
- // x ** N = x * x * x * ...
- if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f > 0) {
- ExpressionTreeNode *replacement = integerPower(tree, *node.left, static_cast<int>(node.right->op.imm.f));
- replaceNode(node, *replacement);
- changed = true;
- }
- });
- return changed;
- }
- bool applyOpFusion(ExpressionTree &tree)
- {
- std::unordered_map<int, size_t> refCount;
- bool changed = false;
- applyValueNumbering(tree);
- tree.getRoot()->postorder([&](ExpressionTreeNode &node)
- {
- if (node.op == ExprOpType::MUX)
- return;
- refCount[node.valueNum]++;
- });
- tree.getRoot()->postorder([&](ExpressionTreeNode &node)
- {
- if (node.op == ExprOpType::MUX)
- return;
- // FMA.
- if (node.op == ExprOpType::ADD && node.right->op == ExprOpType::MUL && refCount[node.right->valueNum] <= 1) {
- node.right->op = ExprOpType::MUX;
- node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
- changed = true;
- }
- if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::MUL && refCount[node.left->valueNum] <= 1) {
- std::swap(node.left, node.right);
- node.right->op = ExprOpType::MUX;
- node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
- changed = true;
- }
- if (node.op == ExprOpType::SUB && node.right->op == ExprOpType::MUL && refCount[node.right->valueNum] <= 1) {
- node.right->op = ExprOpType::MUX;
- node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FNMADD) };
- changed = true;
- }
- if (node.op == ExprOpType::SUB && node.left->op == ExprOpType::MUL && refCount[node.left->valueNum] <= 1) {
- std::swap(node.left, node.right);
- node.right->op = ExprOpType::MUX;
- node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMSUB) };
- changed = true;
- }
- if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::FMA && refCount[node.left->valueNum] <= 1) {
- ExpressionTreeNode *replacement = node.left;
- node.op = replacement->op;
- node.left = replacement->left;
- node.right = replacement->right;
- switch (static_cast<FMAType>(node.op.imm.u)) {
- case FMAType::FMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FNMSUB); break;
- case FMAType::FMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FNMADD); break;
- case FMAType::FNMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FMSUB); break;
- case FMAType::FNMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FMADD); break;
- }
- changed = true;
- }
- });
- return changed;
- }
- std::vector<ExprInstruction> compile(ExpressionTree &tree, const VSFormat *format)
- {
- std::vector<ExprInstruction> code;
- std::unordered_set<int> found;
- if (!tree.getRoot())
- return code;
- while (applyLocalOptimizations(tree) || applyAlgebraicOptimizations(tree)) {
- // ...
- }
- // Substitution rules can hide algebraic expressions from the optimizer, so they run in a separate pass.
- while (applyStrengthReduction(tree) || applyOpFusion(tree)) {
- // ...
- }
- applyValueNumbering(tree);
- tree.getRoot()->postorder([&](ExpressionTreeNode &node)
- {
- if (node.op.type == ExprOpType::MUX)
- return;
- if (found.find(node.valueNum) != found.end())
- return;
- ExprInstruction opcode(node.op);
- opcode.dst = node.valueNum;
- if (node.left) {
- assert(node.left->valueNum >= 0);
- opcode.src1 = node.left->valueNum;
- }
- if (node.right) {
- if (node.right->op.type == ExprOpType::MUX) {
- assert(node.right->left->valueNum >= 0);
- assert(node.right->right->valueNum >= 0);
- opcode.src2 = node.right->left->valueNum;
- opcode.src3 = node.right->right->valueNum;
- } else {
- assert(node.right->valueNum >= 0);
- opcode.src2 = node.right->valueNum;
- }
- }
- code.push_back(opcode);
- found.insert(node.valueNum);
- });
- ExprInstruction store(ExprOpType::MEM_STORE_U8);
- if (format->sampleType == stInteger && format->bytesPerSample == 1)
- store.op.type = ExprOpType::MEM_STORE_U8;
- else if (format->sampleType == stInteger && format->bytesPerSample == 2)
- store.op.type = ExprOpType::MEM_STORE_U16;
- else if (format->sampleType == stFloat && format->bytesPerSample == 2)
- store.op.type = ExprOpType::MEM_STORE_F16;
- else if (format->sampleType == stFloat && format->bytesPerSample == 4)
- store.op.type = ExprOpType::MEM_STORE_F32;
- store.src1 = code.back().dst;
- code.push_back(store);
- return code;
- }
- int main(int argc, char **argv)
- {
- VSFormat format{};
- VSVideoInfo realvi{};
- const VSVideoInfo *vi[26];
- format.bytesPerSample = 1;
- format.sampleType = stInteger;
- realvi.format = &format;
- for (int i = 0; i < 26; ++i) {
- vi[i] = &realvi;
- }
- for (int i = 1; i < 2; ++i) {
- std::cout << argv[i] << '\n';
- ExpressionTree tree = parseExpr(argv[i], vi, 26);
- std::vector<ExprInstruction> code = compile(tree, &format);
- for (auto &insn : code) {
- std::cout << std::setw(12) << std::left << op_names[static_cast<size_t>(insn.op.type)];
- if (insn.op.type == ExprOpType::MEM_STORE_U8 || insn.op.type == ExprOpType::MEM_STORE_U16 || insn.op.type == ExprOpType::MEM_STORE_F16 || insn.op.type == ExprOpType::MEM_STORE_F32) {
- std::cout << " r" << insn.src1 << '\n';
- continue;
- }
- std::cout << " r" << insn.dst;
- if (insn.src1 >= 0)
- std::cout << ",r" << insn.src1;
- if (insn.src2 >= 0)
- std::cout << ",r" << insn.src2;
- if (insn.src3 >= 0)
- std::cout << ",r" << insn.src3;
- switch (insn.op.type) {
- case ExprOpType::MEM_LOAD_U8:
- case ExprOpType::MEM_LOAD_U16:
- case ExprOpType::MEM_LOAD_F16:
- case ExprOpType::MEM_LOAD_F32:
- std::cout << ',' << static_cast<char>(insn.op.imm.u < 3 ? 'x' + insn.op.imm.u : 'a' + insn.op.imm.u - 3);
- break;
- case ExprOpType::CONSTANT:
- std::cout << ',' << insn.op.imm.f;
- break;
- case ExprOpType::FMA:
- std::cout << "," << insn.op.imm.u;
- break;
- case ExprOpType::CMP:
- std::cout << ',' << cmp_names[insn.op.imm.u];
- break;
- }
- std::cout << '\n';
- }
- }
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement