Advertisement
Guest User

Untitled

a guest
Aug 17th, 2019
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 42.45 KB | None | 0 0
  1. #include <algorithm>
  2. #include <cassert>
  3. #include <iomanip>
  4. #include <iostream>
  5. #include <locale>
  6. #include <map>
  7. #include <memory>
  8. #include <sstream>
  9. #include <string>
  10. #include <unordered_map>
  11. #include <unordered_set>
  12. #include <vector>
  13.  
  14. typedef enum VSSampleType {
  15. stInteger = 0,
  16. stFloat = 1
  17. } VSSampleType;
  18.  
  19. typedef struct VSFormat {
  20. char name[32];
  21. int id;
  22. int colorFamily; /* see VSColorFamily */
  23. int sampleType; /* see VSSampleType */
  24. int bitsPerSample; /* number of significant bits */
  25. int bytesPerSample; /* actual storage is always in a power of 2 and the smallest possible that can fit the number of bits used per sample */
  26.  
  27. int subSamplingW; /* log2 subsampling factor, applied to second and third plane */
  28. int subSamplingH;
  29.  
  30. int numPlanes; /* implicit from colorFamily */
  31. } VSFormat;
  32.  
  33. typedef struct VSVideoInfo {
  34. const VSFormat *format;
  35. int64_t fpsNum;
  36. int64_t fpsDen;
  37. int width;
  38. int height;
  39. int numFrames; /* api 3.2 - no longer allowed to be 0 */
  40. int flags;
  41. } VSVideoInfo;
  42.  
  43. enum class ExprOpType {
  44. // Terminals.
  45. MEM_LOAD_U8, MEM_LOAD_U16, MEM_LOAD_F16, MEM_LOAD_F32, CONSTANT,
  46. MEM_STORE_U8, MEM_STORE_U16, MEM_STORE_F16, MEM_STORE_F32,
  47.  
  48. // Arithmetic primitives.
  49. ADD, SUB, MUL, DIV, FMA, SQRT, ABS, NEG, MAX, MIN, CMP,
  50.  
  51. // Logical operators.
  52. AND, OR, XOR, NOT,
  53.  
  54. // Transcendental functions.
  55. EXP, LOG, POW,
  56.  
  57. // Ternary operator
  58. TERNARY,
  59.  
  60. // Meta-node holding true/false branches of ternary.
  61. MUX,
  62.  
  63. // Stack helpers.
  64. DUP, SWAP,
  65. };
  66.  
  67. static const char *op_names[] = {
  68. "loadu8", "loadu16", "loadf16", "loadf32", "constant",
  69. "storeu8", "storeu16", "storef16", "storef32",
  70. "add", "sub", "mul", "div", "fma", "sqrt", "abs", "neg", "max", "min", "cmp",
  71. "and", "or", "xor", "not",
  72. "exp", "log", "pow",
  73. "ternary",
  74. "mux",
  75. "dup", "swap",
  76. };
  77. static_assert(sizeof(op_names) / sizeof(op_names[0]) == static_cast<size_t>(ExprOpType::SWAP) + 1, "");
  78.  
  79. enum class FMAType {
  80. FMADD = 0, // (b * c) + a
  81. FMSUB = 1, // (b * c) - a
  82. FNMADD = 2, // -(b * c) + a
  83. FNMSUB = 3, // -(b * c) - a
  84. };
  85.  
  86. enum class ComparisonType {
  87. EQ = 0,
  88. LT = 1,
  89. LE = 2,
  90. NEQ = 4,
  91. NLT = 5,
  92. NLE = 6,
  93. };
  94.  
  95. static const char *cmp_names[8] = {
  96. "EQ", "LT", "LE", "?", "NEQ", "NLT", "NLE", "?"
  97. };
  98.  
  99. union ExprUnion {
  100. int32_t i;
  101. uint32_t u;
  102. float f;
  103.  
  104. constexpr ExprUnion() : u{} {}
  105.  
  106. constexpr ExprUnion(int32_t i) : i(i) {}
  107. constexpr ExprUnion(uint32_t u) : u(u) {}
  108. constexpr ExprUnion(float f) : f(f) {}
  109. };
  110.  
  111. struct ExprOp {
  112. ExprOpType type;
  113. ExprUnion imm;
  114.  
  115. ExprOp(ExprOpType type, ExprUnion param = {}) : type(type), imm(param) {}
  116. };
  117.  
  118. bool operator==(const ExprOp &lhs, const ExprOp &rhs) { return lhs.type == rhs.type && lhs.imm.u == rhs.imm.u; }
  119. bool operator!=(const ExprOp &lhs, const ExprOp &rhs) { return !(lhs == rhs); }
  120.  
  121. struct ExprInstruction {
  122. ExprOp op;
  123. int dst;
  124. int src1;
  125. int src2;
  126. int src3;
  127.  
  128. ExprInstruction(ExprOp op) : op(op), dst(-1), src1(-1), src2(-1), src3(-1) {}
  129. };
  130.  
  131. struct ExpressionTreeNode {
  132. ExprOp op;
  133. ExpressionTreeNode *left;
  134. ExpressionTreeNode *right;
  135. ExpressionTreeNode *parent;
  136. int valueNum;
  137.  
  138. explicit ExpressionTreeNode(ExprOp op) : op(op), left(nullptr), right(nullptr), parent(nullptr), valueNum(-1) {}
  139.  
  140. template <class T>
  141. void preorder(T visitor)
  142. {
  143. if (visitor(*this))
  144. return;
  145.  
  146. if (left)
  147. left->preorder(visitor);
  148. if (right)
  149. right->preorder(visitor);
  150. }
  151.  
  152. template <class T>
  153. void postorder(T visitor)
  154. {
  155. if (left)
  156. left->postorder(visitor);
  157. if (right)
  158. right->postorder(visitor);
  159. visitor(*this);
  160. }
  161. };
  162.  
  163. class ExpressionTree {
  164. std::vector<std::unique_ptr<ExpressionTreeNode>> nodes;
  165. ExpressionTreeNode *root;
  166. public:
  167. ExpressionTree() : root() {}
  168.  
  169. ExpressionTreeNode *getRoot() { return root; }
  170. const ExpressionTreeNode *getRoot() const { return root; }
  171.  
  172. void setRoot(ExpressionTreeNode *node) { root = node; }
  173.  
  174. ExpressionTreeNode *makeNode(ExprOp data)
  175. {
  176. nodes.push_back(std::unique_ptr<ExpressionTreeNode>(new ExpressionTreeNode(data)));
  177. return nodes.back().get();
  178. }
  179.  
  180. ExpressionTreeNode *clone(const ExpressionTreeNode *node)
  181. {
  182. if (!node)
  183. return nullptr;
  184.  
  185. ExpressionTreeNode *newnode = makeNode(node->op);
  186. ExpressionTreeNode *newleft = clone(node->left);
  187. ExpressionTreeNode *newright = clone(node->right);
  188.  
  189. if (newleft) {
  190. newnode->left = newleft;
  191. newnode->left->parent = newnode;
  192. }
  193. if (newright) {
  194. newnode->right = newright;
  195. newnode->right->parent = newnode;
  196. }
  197.  
  198. return newnode;
  199. }
  200. };
  201.  
  202. std::vector<std::string> tokenize(const std::string &expr)
  203. {
  204. std::vector<std::string> tokens;
  205. auto it = expr.begin();
  206. auto prev = expr.begin();
  207.  
  208. while (it != expr.end()) {
  209. char c = *it;
  210.  
  211. if (std::isspace(c)) {
  212. if (it != prev)
  213. tokens.push_back(expr.substr(prev - expr.begin(), it - prev));
  214. prev = it + 1;
  215. }
  216. ++it;
  217. }
  218. if (prev != expr.end())
  219. tokens.push_back(expr.substr(prev - expr.begin(), expr.end() - prev));
  220.  
  221. return tokens;
  222. }
  223.  
  224. ExprOp decodeToken(const std::string &token)
  225. {
  226. static const std::unordered_map<std::string, ExprOp> simple{
  227. { "+", { ExprOpType::ADD } },
  228. { "-", { ExprOpType::SUB } },
  229. { "*", { ExprOpType::MUL } },
  230. { "/", { ExprOpType::DIV } } ,
  231. { "sqrt", { ExprOpType::SQRT } },
  232. { "abs", { ExprOpType::ABS } },
  233. { "max", { ExprOpType::MAX } },
  234. { "min", { ExprOpType::MIN } },
  235. { "<", { ExprOpType::CMP, static_cast<int>(ComparisonType::LT) } },
  236. { ">", { ExprOpType::CMP, static_cast<int>(ComparisonType::NLE) } },
  237. { "=", { ExprOpType::CMP, static_cast<int>(ComparisonType::EQ) } },
  238. { ">=", { ExprOpType::CMP, static_cast<int>(ComparisonType::NLT) } },
  239. { "<=", { ExprOpType::CMP, static_cast<int>(ComparisonType::LE) } },
  240. { "and", { ExprOpType::AND } },
  241. { "or", { ExprOpType::OR } },
  242. { "xor", { ExprOpType::XOR } },
  243. { "not", { ExprOpType::NOT } },
  244. { "?", { ExprOpType::TERNARY } },
  245. { "exp", { ExprOpType::EXP } },
  246. { "log", { ExprOpType::LOG } },
  247. { "pow", { ExprOpType::POW } },
  248. { "dup", { ExprOpType::DUP, 0 } },
  249. { "swap", { ExprOpType::SWAP, 1 } },
  250. };
  251.  
  252. auto it = simple.find(token);
  253. if (it != simple.end()) {
  254. return it->second;
  255. } else if (token.size() == 1 && token[0] >= 'a' && token[0] <= 'z') {
  256. return{ ExprOpType::MEM_LOAD_U8, token[0] >= 'x' ? token[0] - 'x' : token[0] - 'a' + 3 };
  257. } else if (token.substr(0, 3) == "dup" || token.substr(0, 4) == "swap") {
  258. size_t count;
  259. int idx = -1;
  260.  
  261. try {
  262. idx = std::stoi(token.substr(token[0] == 'd' ? 3 : 4), &count);
  263. } catch (...) {
  264. // ...
  265. }
  266.  
  267. if (idx < 0)
  268. throw std::runtime_error("illegal token: " + token);
  269. return{ token[0] == 'd' ? ExprOpType::DUP : ExprOpType::SWAP, idx };
  270. } else {
  271. float f;
  272. std::string s;
  273. std::istringstream numStream(token);
  274. numStream.imbue(std::locale::classic());
  275. if (!(numStream >> f))
  276. throw std::runtime_error("failed to convert '" + token + "' to float");
  277. if (numStream >> s)
  278. throw std::runtime_error("failed to convert '" + token + "' to float, not the whole token could be converted");
  279. return{ ExprOpType::CONSTANT, f };
  280. }
  281. }
  282.  
  283. ExpressionTree parseExpr(const std::string &expr, const VSVideoInfo * const *vi, int numInputs)
  284. {
  285. constexpr unsigned char numOperands[] = {
  286. 0, // MEM_LOAD_U8
  287. 0, // MEM_LOAD_U16
  288. 0, // MEM_LOAD_F16
  289. 0, // MEM_LOAD_F32
  290. 0, // CONSTANT
  291. 0, // MEM_STORE_U8
  292. 0, // MEM_STORE_U16
  293. 0, // MEM_STORE_F16
  294. 0, // MEM_STORE_F32
  295. 2, // ADD
  296. 2, // SUB
  297. 2, // MUL
  298. 2, // DIV
  299. 3, // FMA
  300. 1, // SQRT
  301. 1, // ABS
  302. 1, // NEG
  303. 2, // MAX
  304. 2, // MIN
  305. 2, // CMP
  306. 2, // AND
  307. 2, // OR
  308. 2, // XOR
  309. 2, // NOT
  310. 1, // EXP
  311. 1, // LOG
  312. 2, // POW
  313. 3, // TERNARY
  314. 0, // MUX
  315. 0, // DUP
  316. 0, // SWAP
  317. };
  318. static_assert(sizeof(numOperands) == static_cast<unsigned>(ExprOpType::SWAP) + 1, "invalid table");
  319.  
  320. auto tokens = tokenize(expr);
  321.  
  322. ExpressionTree tree;
  323. std::vector<ExpressionTreeNode *> stack;
  324.  
  325. for (const std::string &tok : tokens) {
  326. ExprOp op = decodeToken(tok);
  327.  
  328. // Check validity.
  329. if (op.type == ExprOpType::MEM_LOAD_U8 && op.imm.i >= numInputs)
  330. throw std::runtime_error("reference to undefined clip: " + tok);
  331. if ((op.type == ExprOpType::DUP || op.type == ExprOpType::SWAP) && op.imm.u >= stack.size())
  332. throw std::runtime_error("insufficient values on stack: " + tok);
  333. if (stack.size() < numOperands[static_cast<size_t>(op.type)])
  334. throw std::runtime_error("insufficient values on stack: " + tok);
  335.  
  336. // Rename load operations with the correct data type.
  337. if (op.type == ExprOpType::MEM_LOAD_U8) {
  338. const VSFormat *format = vi[op.imm.i]->format;
  339.  
  340. if (format->sampleType == stInteger && format->bytesPerSample == 1)
  341. op.type = ExprOpType::MEM_LOAD_U8;
  342. else if (format->sampleType == stInteger && format->bytesPerSample == 2)
  343. op.type = ExprOpType::MEM_LOAD_U16;
  344. else if (format->sampleType == stFloat && format->bytesPerSample == 2)
  345. op.type = ExprOpType::MEM_LOAD_F16;
  346. else if (format->sampleType == stFloat && format->bytesPerSample == 4)
  347. op.type = ExprOpType::MEM_LOAD_F32;
  348. }
  349.  
  350. // Apply DUP and SWAP in the frontend.
  351. if (op.type == ExprOpType::DUP) {
  352. stack.push_back(tree.clone(stack[stack.size() - 1 - op.imm.u]));
  353. } else if (op.type == ExprOpType::SWAP) {
  354. std::swap(stack.back(), stack[stack.size() - 1 - op.imm.u]);
  355. } else {
  356. size_t operands = numOperands[static_cast<size_t>(op.type)];
  357.  
  358. if (operands == 0) {
  359. stack.push_back(tree.makeNode(op));
  360. } else if (operands == 1) {
  361. ExpressionTreeNode *child = stack.back();
  362. stack.pop_back();
  363.  
  364. ExpressionTreeNode *node = tree.makeNode(op);
  365. node->left = child;
  366. node->left->parent = node;
  367. stack.push_back(node);
  368. } else if (operands == 2) {
  369. ExpressionTreeNode *left = stack[stack.size() - 2];
  370. ExpressionTreeNode *right = stack[stack.size() - 1];
  371. stack.resize(stack.size() - 2);
  372.  
  373. ExpressionTreeNode *node = tree.makeNode(op);
  374. node->left = left;
  375. node->left->parent = node;
  376. node->right = right;
  377. node->right->parent = node;
  378. stack.push_back(node);
  379. } else if (operands == 3) {
  380. ExpressionTreeNode *arg1 = stack[stack.size() - 3];
  381. ExpressionTreeNode *arg2 = stack[stack.size() - 2];
  382. ExpressionTreeNode *arg3 = stack[stack.size() - 1];
  383. stack.resize(stack.size() - 3);
  384.  
  385. ExpressionTreeNode *mux = tree.makeNode(ExprOpType::MUX);
  386. mux->left = arg2;
  387. mux->left->parent = mux;
  388. mux->right = arg3;
  389. mux->right->parent = mux;
  390.  
  391. ExpressionTreeNode *node= tree.makeNode(op);
  392. node->left = arg1;
  393. node->left->parent = node;
  394. node->right = mux;
  395. node->right->parent = node;
  396. stack.push_back(node);
  397. }
  398. }
  399. }
  400.  
  401. if (stack.empty())
  402. throw std::runtime_error("empty expression: " + expr);
  403. if (stack.size() > 1)
  404. throw std::runtime_error("unconsumed values on stack: " + expr);
  405.  
  406. tree.setRoot(stack.back());
  407. return tree;
  408. }
  409.  
  410. bool equalSubTree(const ExpressionTreeNode *lhs, const ExpressionTreeNode *rhs)
  411. {
  412. if (lhs->valueNum >= 0 && rhs->valueNum >= 0)
  413. return lhs->valueNum == rhs->valueNum;
  414. if (lhs->op.type != rhs->op.type || lhs->op.imm.u != rhs->op.imm.u)
  415. return false;
  416. if (!!lhs->left != !!rhs->left || !!lhs->right != !!rhs->right)
  417. return false;
  418. if (lhs->left && !equalSubTree(lhs->left, rhs->left))
  419. return false;
  420. if (lhs->right && !equalSubTree(lhs->right, rhs->right))
  421. return false;
  422. return true;
  423. }
  424.  
  425. bool isConstantExpr(const ExpressionTreeNode &node)
  426. {
  427. switch (node.op.type) {
  428. case ExprOpType::MEM_LOAD_U8:
  429. case ExprOpType::MEM_LOAD_U16:
  430. case ExprOpType::MEM_LOAD_F16:
  431. case ExprOpType::MEM_LOAD_F32:
  432. return false;
  433. case ExprOpType::CONSTANT:
  434. return true;
  435. default:
  436. return (!node.left || isConstantExpr(*node.left)) && (!node.right || isConstantExpr(*node.right));
  437. }
  438. }
  439.  
  440. bool isConstant(const ExpressionTreeNode &node)
  441. {
  442. return node.op.type == ExprOpType::CONSTANT;
  443. }
  444.  
  445. bool isConstant(const ExpressionTreeNode &node, float val)
  446. {
  447. return node.op.type == ExprOpType::CONSTANT && node.op.imm.f == val;
  448. }
  449.  
  450. float evalConstantExpr(const ExpressionTreeNode &node)
  451. {
  452. switch (node.op.type) {
  453. case ExprOpType::CONSTANT: return node.op.imm.f;
  454. case ExprOpType::ADD: return evalConstantExpr(*node.left) + evalConstantExpr(*node.right);
  455. case ExprOpType::SUB: return evalConstantExpr(*node.left) - evalConstantExpr(*node.right);
  456. case ExprOpType::MUL: return evalConstantExpr(*node.left) * evalConstantExpr(*node.right);
  457. case ExprOpType::DIV: return evalConstantExpr(*node.left) / evalConstantExpr(*node.right);
  458. case ExprOpType::FMA:
  459. switch (static_cast<FMAType>(node.op.imm.u)) {
  460. case FMAType::FMADD: return evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) + evalConstantExpr(*node.left);
  461. case FMAType::FMSUB: return evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) - evalConstantExpr(*node.left);
  462. case FMAType::FNMADD: return -evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) + evalConstantExpr(*node.left);
  463. case FMAType::FNMSUB: return -evalConstantExpr(*node.right->left) * evalConstantExpr(*node.right->right) - evalConstantExpr(*node.left);
  464. }
  465. return NAN;
  466. case ExprOpType::SQRT: return std::sqrt(evalConstantExpr(*node.left));
  467. case ExprOpType::ABS: return std::fabs(evalConstantExpr(*node.left));
  468. case ExprOpType::NEG: return -evalConstantExpr(*node.left);
  469. case ExprOpType::MAX: return std::max(evalConstantExpr(*node.left), evalConstantExpr(*node.right));
  470. case ExprOpType::MIN: return std::min(evalConstantExpr(*node.left), evalConstantExpr(*node.right));
  471. case ExprOpType::CMP:
  472. switch (static_cast<ComparisonType>(node.op.imm.u)) {
  473. case ComparisonType::EQ: return evalConstantExpr(*node.left) == evalConstantExpr(*node.right) ? 1.0f : 0.0f;
  474. case ComparisonType::LT: return evalConstantExpr(*node.left) < evalConstantExpr(*node.right) ? 1.0f : 0.0f;
  475. case ComparisonType::LE: return evalConstantExpr(*node.left) <= evalConstantExpr(*node.right) ? 1.0f : 0.0f;
  476. case ComparisonType::NEQ: return evalConstantExpr(*node.left) != evalConstantExpr(*node.right) ? 1.0f : 0.0f;
  477. case ComparisonType::NLT: return evalConstantExpr(*node.left) >= evalConstantExpr(*node.right) ? 1.0f : 0.0f;
  478. case ComparisonType::NLE: return evalConstantExpr(*node.left) > evalConstantExpr(*node.right) ? 1.0f : 0.0f;
  479. }
  480. return NAN;
  481. case ExprOpType::AND: return evalConstantExpr(*node.left) > 0.0f && evalConstantExpr(*node.right) > 0.0f ? 1.0f : 0.0f;
  482. case ExprOpType::OR: return evalConstantExpr(*node.left) > 0.0f || evalConstantExpr(*node.right) > 0.0f ? 1.0f : 0.0f;
  483. case ExprOpType::XOR: return evalConstantExpr(*node.left) > 0.0f != evalConstantExpr(*node.right) > 0.0f ? 1.0f : 0.0f;
  484. case ExprOpType::NOT: return evalConstantExpr(*node.left) > 0.0f ? 0.0f : 1.0f;
  485. case ExprOpType::EXP: return std::exp(evalConstantExpr(*node.left));
  486. case ExprOpType::LOG: return std::log(evalConstantExpr(*node.left));
  487. case ExprOpType::POW: return std::pow(evalConstantExpr(*node.left), evalConstantExpr(*node.right));
  488. case ExprOpType::TERNARY: return evalConstantExpr(*node.left) > 0.0f ? evalConstantExpr(*node.right->left) : evalConstantExpr(*node.right->right);
  489. default: return NAN;
  490. }
  491. }
  492.  
  493. bool isOpCode(const ExpressionTreeNode &node, std::initializer_list<ExprOpType> types)
  494. {
  495. for (ExprOpType type : types) {
  496. if (node.op.type == type)
  497. return true;
  498. }
  499. return false;
  500. }
  501.  
  502. bool isInteger(float x)
  503. {
  504. return std::floor(x) == x;
  505. }
  506.  
  507. void replaceNode(ExpressionTreeNode &node, const ExpressionTreeNode &replacement)
  508. {
  509. node.op = replacement.op;
  510. if (node.left)
  511. node.left->parent = nullptr;
  512. if (node.right)
  513. node.right->parent = nullptr;
  514. node.left = replacement.left;
  515. node.right = replacement.right;
  516. if (node.left)
  517. node.left->parent = &node;
  518. if (node.right)
  519. node.right->parent = &node;
  520. }
  521.  
  522. void applyValueNumbering(ExpressionTree &tree)
  523. {
  524. std::vector<ExpressionTreeNode *> numbered;
  525. int valueNum = 0;
  526.  
  527. tree.getRoot()->postorder([&](ExpressionTreeNode &node)
  528. {
  529. node.valueNum = -1;
  530. });
  531.  
  532. tree.getRoot()->postorder([&](ExpressionTreeNode &node)
  533. {
  534. if (node.op.type == ExprOpType::MUX)
  535. return;
  536.  
  537. for (ExpressionTreeNode *testnode : numbered) {
  538. if (equalSubTree(&node, testnode)) {
  539. node.valueNum = testnode->valueNum;
  540. return;
  541. }
  542. }
  543.  
  544. node.valueNum = valueNum++;
  545. numbered.push_back(&node);
  546. });
  547. }
  548.  
  549. ExpressionTreeNode *integerPower(ExpressionTree &tree, const ExpressionTreeNode &node, int exponent)
  550. {
  551. if (exponent == 1)
  552. return tree.clone(&node);
  553.  
  554. ExpressionTreeNode *lhs = integerPower(tree, node, (exponent + 1) / 2);
  555. ExpressionTreeNode *rhs = integerPower(tree, node, exponent - (exponent + 1) / 2);
  556. ExpressionTreeNode *mulNode = tree.makeNode({ ExprOpType::MUL });
  557. mulNode->left = lhs;
  558. mulNode->right = rhs;
  559. mulNode->left->parent = mulNode;
  560. mulNode->right->parent = mulNode;
  561. return mulNode;
  562. }
  563.  
  564. bool applyLocalOptimizations(ExpressionTree &tree)
  565. {
  566. bool changed = false;
  567.  
  568. tree.getRoot()->postorder([&](ExpressionTreeNode &node)
  569. {
  570. if (node.op.type == ExprOpType::MUX)
  571. return;
  572.  
  573. // Constant folding.
  574. if (node.op.type != ExprOpType::CONSTANT && isConstantExpr(node)) {
  575. float val = evalConstantExpr(node);
  576. replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, val } });
  577. changed = true;
  578. }
  579.  
  580. // Move constants to right-hand side to simplify identities.
  581. if (isOpCode(node, { ExprOpType::ADD, ExprOpType::MUL }) && isConstant(*node.left) && !isConstant(*node.right)) {
  582. std::swap(node.left, node.right);
  583. changed = true;
  584. }
  585.  
  586. // x + 0 = x x - 0 = x
  587. if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && isConstant(*node.right, 0.0f)) {
  588. replaceNode(node, *node.left);
  589. changed = true;
  590. }
  591.  
  592. // x * 0 = 0 0 / x = 0
  593. if ((node.op == ExprOpType::MUL && isConstant(*node.right, 0.0f)) || (node.op == ExprOpType::DIV && isConstant(*node.left, 0.0f))) {
  594. replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
  595. changed = true;
  596. }
  597.  
  598. // x * 1 = x x / 1 = x
  599. if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, 1.0f)) {
  600. replaceNode(node, *node.left);
  601. changed = true;
  602. }
  603.  
  604. // sqrt(x) = x ** 0.5
  605. if (node.op == ExprOpType::SQRT) {
  606. node.op = ExprOpType::POW;
  607. node.right = tree.makeNode({ ExprOpType::CONSTANT, 0.5f });
  608. node.right->parent = &node;
  609. changed = true;
  610. }
  611.  
  612. // log(exp(x)) = x exp(log(x)) = x
  613. if ((node.op == ExprOpType::LOG && node.left->op == ExprOpType::EXP) || (node.op == ExprOpType::EXP && node.left->op == ExprOpType::LOG)) {
  614. replaceNode(node, *node.left->left);
  615. changed = true;
  616. }
  617.  
  618. // x ** 0 = 1
  619. if (node.op == ExprOpType::POW && isConstant(*node.right, 0.0f)) {
  620. replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
  621. changed = true;
  622. }
  623.  
  624. // x ** 1 = x
  625. if (node.op == ExprOpType::POW && isConstant(*node.right, 1.0f)) {
  626. replaceNode(node, *node.left);
  627. changed = true;
  628. }
  629.  
  630. // (a ** b) ** c = a ** (b * c)
  631. if (node.op == ExprOpType::POW && node.left->op == ExprOpType::POW) {
  632. ExpressionTreeNode *a = node.left->left;
  633. ExpressionTreeNode *b = node.left->right;
  634. ExpressionTreeNode *c = node.right;
  635. replaceNode(*node.left, *a);
  636. node.right = tree.makeNode(ExprOpType::MUL);
  637. node.right->left = b;
  638. node.right->right = c;
  639. node.right->left->parent = node.right;
  640. node.right->right->parent = node.right;
  641. changed = true;
  642. }
  643.  
  644. // 0 ? x y = y 1 ? x y = x
  645. if (node.op == ExprOpType::TERNARY && isConstant(*node.left)) {
  646. ExpressionTreeNode *replacement = node.left->op.imm.f > 0.0f ? node.right->left : node.right->right;
  647. replaceNode(node, *replacement);
  648. changed = true;
  649. }
  650. });
  651.  
  652. return changed;
  653. }
  654.  
  655. typedef std::map<int, float> exponentMap;
  656. typedef std::vector<std::pair<exponentMap, float>> additiveTermList;
  657.  
  658. bool isEqualTerm(const exponentMap &lhs, const exponentMap &rhs)
  659. {
  660. auto it1 = lhs.begin();
  661. auto it2 = rhs.begin();
  662.  
  663. while (it1 != lhs.end() && it2 != rhs.end()) {
  664. if (it1->first != it2->first || it1->second != it2->second)
  665. return false;
  666.  
  667. ++it1;
  668. ++it2;
  669. }
  670.  
  671. return it1 == lhs.end() && it2 == rhs.end();
  672. }
  673.  
  674. bool sortTerms(additiveTermList &list, const std::unordered_map<int, const ExpressionTreeNode *> &values)
  675. {
  676. auto pred = [&](const std::pair<exponentMap, float> &lhs, const std::pair<exponentMap, float> &rhs)
  677. {
  678. std::vector<std::pair<int, float>> lhsTerms(lhs.first.begin(), lhs.first.end());
  679. std::vector<std::pair<int, float>> rhsTerms(rhs.first.begin(), rhs.first.end());
  680.  
  681. auto pred2 = [&](const std::pair<int, float> &lhs2, const std::pair<int, float> &rhs2)
  682. {
  683. std::initializer_list<ExprOpType> memOpCodes = { ExprOpType::MEM_LOAD_U8, ExprOpType::MEM_LOAD_U16, ExprOpType::MEM_LOAD_F16, ExprOpType::MEM_LOAD_F32 };
  684.  
  685. if (lhs2.first == rhs2.first)
  686. return lhs2.second < rhs2.second;
  687.  
  688. const ExpressionTreeNode *lhsValue = values.at(lhs2.first);
  689. const ExpressionTreeNode *rhsValue = values.at(rhs2.first);
  690. int lhsCategory = isConstant(*lhsValue) ? 0 : isOpCode(*lhsValue, memOpCodes) ? 1 : 2;
  691. int rhsCategory = isConstant(*rhsValue) ? 0 : isOpCode(*rhsValue, memOpCodes) ? 1 : 2;
  692.  
  693. // Simpler terms towards the right.
  694. if (lhsCategory != rhsCategory)
  695. return lhsCategory > rhsCategory;
  696.  
  697. if (lhsCategory == 0)
  698. return lhsValue->op.imm.f < rhsValue->op.imm.f;
  699. else if (lhsCategory == 1)
  700. return lhsValue->op.imm.u < rhsValue->op.imm.u;
  701. else
  702. return lhs2.first < rhs2.first;
  703. };
  704.  
  705. std::sort(lhsTerms.begin(), lhsTerms.end(), pred2);
  706. std::sort(rhsTerms.begin(), rhsTerms.end(), pred2);
  707. return std::lexicographical_compare(lhsTerms.begin(), lhsTerms.end(), rhsTerms.begin(), rhsTerms.end(), pred2);
  708. };
  709.  
  710. if (std::is_sorted(list.begin(), list.end(), pred))
  711. return true;
  712.  
  713. std::sort(list.begin(), list.end(), pred);
  714. return false;
  715. }
  716.  
  717. void expandMultiplies(exponentMap &term, std::unordered_map<int, const ExpressionTreeNode *> &values)
  718. {
  719. bool changed = true;
  720.  
  721. while (changed) {
  722. changed = false;
  723.  
  724. for (auto it = term.begin(); it != term.end();) {
  725. const ExpressionTreeNode *value = values.at(it->first);
  726.  
  727. if (value->op == ExprOpType::POW && isConstant(*value->right)) {
  728. values[value->left->valueNum] = value->left;
  729.  
  730. term[value->left->valueNum] += it->second * value->right->op.imm.f;
  731. it = term.erase(it);
  732. changed = true;
  733. continue;
  734. } else if (value->op == ExprOpType::MUL) {
  735. values[value->left->valueNum] = value->left;
  736. values[value->right->valueNum] = value->right;
  737.  
  738. term[value->left->valueNum] += it->second;
  739. term[value->right->valueNum] += it->second;
  740. it = term.erase(it);
  741. changed = true;
  742. continue;
  743. } else if (value->op == ExprOpType::DIV) {
  744. values[value->left->valueNum] = value->left;
  745. values[value->right->valueNum] = value->right;
  746.  
  747. term[value->left->valueNum] += it->second;
  748. term[value->right->valueNum] -= it->second;
  749. it = term.erase(it);
  750. changed = true;
  751. continue;
  752. }
  753.  
  754. ++it;
  755. }
  756. }
  757. }
  758.  
  759. std::pair<float, size_t> addConstants(additiveTermList &terms, const std::unordered_map<int, const ExpressionTreeNode *> &values)
  760. {
  761. float scalarTerm = 0.0f;
  762. size_t numScalarEliminated = 0;
  763. bool nonTerminalScalar = false;
  764.  
  765. for (auto it1 = terms.begin(); it1 < terms.end();) {
  766. for (auto it2 = it1->first.begin(); it2 != it1->first.end(); ++it2) {
  767. const ExpressionTreeNode *value = values.at(it2->first);
  768.  
  769. if (isConstant(*value)) {
  770. it1->second *= std::pow(value->op.imm.f, it2->second);
  771. it2->second = 0.0f;
  772. }
  773. }
  774.  
  775. for (auto it2 = it1->first.begin(); it2 != it1->first.end();) {
  776. if (it2->second == 0.0f) {
  777. it2 = it1->first.erase(it2);
  778. continue;
  779. }
  780. ++it2;
  781. }
  782.  
  783. if (it1->first.empty()) {
  784. scalarTerm += it1->second;
  785. it1->second = 0.0f;
  786.  
  787. nonTerminalScalar = nonTerminalScalar || (it1 + 1 != terms.end());
  788.  
  789. it1 = terms.erase(it1);
  790. ++numScalarEliminated;
  791. continue;
  792. }
  793.  
  794. ++it1;
  795. }
  796.  
  797. return{ scalarTerm, numScalarEliminated + nonTerminalScalar };
  798. }
  799.  
  800. size_t addIdenticalTerms(additiveTermList &terms)
  801. {
  802. size_t numCanceled = 0;
  803.  
  804. for (auto it1 = terms.begin(); it1 < terms.end();) {
  805. for (auto it2 = it1 + 1; it2 < terms.end(); ++it2) {
  806. if (isEqualTerm(it1->first, it2->first)) {
  807. it1->second += it2->second;
  808. it2->second = 0.0f;
  809. }
  810. }
  811.  
  812. if (it1->second == 0.0f) {
  813. it1 = terms.erase(it1);
  814. ++numCanceled;
  815. continue;
  816. }
  817.  
  818. ++it1;
  819. }
  820.  
  821. return numCanceled;
  822. }
  823.  
  824. ExpressionTreeNode *emitMultiplicativeSequence(ExpressionTree &tree, const exponentMap &terms, float scalarTerm, const std::unordered_map<int, const ExpressionTreeNode *> &values)
  825. {
  826. ExpressionTreeNode *node = nullptr;
  827.  
  828. for (auto &t : terms) {
  829. ExpressionTreeNode *powNode = tree.makeNode(ExprOpType::POW);
  830. powNode->left = tree.clone(values.at(t.first));
  831. powNode->right = tree.makeNode({ ExprOpType::CONSTANT, t.second });
  832. powNode->left->parent = powNode;
  833. powNode->right->parent = powNode;
  834.  
  835. if (node) {
  836. ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
  837. mulNode->left = node;
  838. mulNode->right = powNode;
  839. mulNode->left->parent = mulNode;
  840. mulNode->right->parent = mulNode;
  841. node = mulNode;
  842. } else {
  843. node = powNode;
  844. }
  845. }
  846.  
  847. if (node) {
  848. ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
  849. mulNode->left = node;
  850. mulNode->right = tree.makeNode({ ExprOpType::CONSTANT, scalarTerm });
  851. mulNode->left->parent = mulNode;
  852. mulNode->right->parent = mulNode;
  853. node = mulNode;
  854. } else {
  855. node = tree.makeNode({ ExprOpType::CONSTANT, scalarTerm });
  856. }
  857.  
  858. return node;
  859. }
  860.  
  861. ExpressionTreeNode *emitAdditiveSequence(ExpressionTree &tree, const additiveTermList &terms, float scalarTerm, const std::unordered_map<int, const ExpressionTreeNode *> &values)
  862. {
  863. ExpressionTreeNode *head = nullptr;
  864.  
  865. for (const auto &term : terms) {
  866. assert(!term.first.empty());
  867.  
  868. ExpressionTreeNode *node = emitMultiplicativeSequence(tree, term.first, term.second, values);
  869.  
  870. if (head) {
  871. ExpressionTreeNode *addNode = tree.makeNode(term.second < 0 ? ExprOpType::SUB : ExprOpType::ADD);
  872. addNode->left = head;
  873. addNode->right = node;
  874. addNode->left->parent = addNode;
  875. addNode->right->parent = addNode;
  876. head = addNode;
  877. } else {
  878. head = node;
  879. }
  880. }
  881.  
  882. if (head) {
  883. ExpressionTreeNode *addNode = tree.makeNode(scalarTerm < 0 ? ExprOpType::SUB : ExprOpType::ADD);
  884. addNode->left = head;
  885. addNode->right = tree.makeNode({ ExprOpType::CONSTANT, std::fabs(scalarTerm) });
  886. addNode->left->parent = addNode;
  887. addNode->right->parent = addNode;
  888. head = addNode;
  889. } else {
  890. head = tree.makeNode({ ExprOpType::CONSTANT, 0.0f });
  891. }
  892.  
  893. return head;
  894. }
  895.  
  896. bool analyzeAdditiveExpression(ExpressionTree &tree, ExpressionTreeNode &node)
  897. {
  898. // Stores the exponent of each term in a multiplicative tuple.
  899. // e.g. 3 * v0^1.5 * v1^2 * v2 => { 0: 1.5, 1: 2, 2: 1 }.
  900. additiveTermList terms;
  901. std::unordered_map<int, const ExpressionTreeNode *> values;
  902.  
  903. node.preorder([&](ExpressionTreeNode &node)
  904. {
  905. if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }))
  906. return false;
  907.  
  908. // Deduce net sign of term.
  909. const ExpressionTreeNode *parent = node.parent;
  910. const ExpressionTreeNode *cur = &node;
  911. int polarity = 1;
  912.  
  913. while (parent && isOpCode(*parent, { ExprOpType::ADD, ExprOpType::SUB })) {
  914. if (parent->op == ExprOpType::SUB && cur == parent->right)
  915. polarity = -polarity;
  916.  
  917. cur = parent;
  918. parent = parent->parent;
  919. }
  920.  
  921. exponentMap term{ { node.valueNum, 1.0f } };
  922. terms.emplace_back(std::move(term), static_cast<float>(polarity));
  923.  
  924. values[node.valueNum] = &node;
  925. return true;
  926. });
  927.  
  928. for (auto &term : terms) {
  929. expandMultiplies(term.first, values);
  930. }
  931.  
  932. // Combine constant terms.
  933. float scalarTerm = 0.0f;
  934. size_t numScalarEliminated = 0;
  935. {
  936. auto result = addConstants(terms, values);
  937. scalarTerm += result.first;
  938. numScalarEliminated += result.second;
  939. }
  940.  
  941. // Cancel identical terms.
  942. size_t numCanceled = addIdenticalTerms(terms);
  943.  
  944. // Normalize order of terms to assist multiplicative analysis.
  945. bool wasSorted = sortTerms(terms, values);
  946.  
  947. if (numCanceled > 0 || numScalarEliminated > 1 || !wasSorted) {
  948. ExpressionTreeNode *seq = emitAdditiveSequence(tree, terms, scalarTerm, values);
  949. replaceNode(node, *seq);
  950. return true;
  951. }
  952.  
  953. return false;
  954. }
  955.  
  956. bool analyzeMultiplicativeExpression(ExpressionTree &tree, ExpressionTreeNode &node)
  957. {
  958. std::vector<int> termOrder;
  959. exponentMap term;
  960. std::unordered_map<int, const ExpressionTreeNode *> values;
  961. float scalarTerm = 1.0f;
  962. size_t numDivs = 0;
  963.  
  964. node.preorder([&](ExpressionTreeNode &node)
  965. {
  966. if (node.op == ExprOpType::DIV)
  967. ++numDivs;
  968.  
  969. if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }))
  970. return false;
  971.  
  972. // Deduce net sign of term.
  973. const ExpressionTreeNode *parent = node.parent;
  974. const ExpressionTreeNode *cur = &node;
  975. int polarity = 1;
  976.  
  977. while (parent && isOpCode(*parent, { ExprOpType::MUL, ExprOpType::DIV })) {
  978. if (parent->op == ExprOpType::DIV && cur == parent->right)
  979. polarity = -polarity;
  980.  
  981. cur = parent;
  982. parent = parent->parent;
  983. }
  984.  
  985. term[node.valueNum] += static_cast<float>(polarity);
  986. termOrder.push_back(node.valueNum);
  987. values[node.valueNum] = &node;
  988. return true;
  989. });
  990.  
  991. expandMultiplies(term, values);
  992.  
  993. // Combine constants.
  994. for (auto it = term.begin(); it != term.end();) {
  995. const ExpressionTreeNode *node = values[it->first];
  996.  
  997. if (isConstant(*node)) {
  998. scalarTerm *= std::powf(node->op.imm.f, it->second);
  999. it = term.erase(it);
  1000. continue;
  1001. }
  1002.  
  1003. ++it;
  1004. }
  1005.  
  1006. size_t origScalarTerms = 0;
  1007. bool nonTerminalScalar = false;
  1008. for (auto it = termOrder.begin(); it != termOrder.end();) {
  1009. if (isConstant(*values[*it])) {
  1010. nonTerminalScalar = nonTerminalScalar || it + 1 != termOrder.end();
  1011. it = termOrder.erase(it);
  1012. ++origScalarTerms;
  1013. continue;
  1014. }
  1015.  
  1016. ++it;
  1017. }
  1018.  
  1019. if (term.size() + (scalarTerm != 1.0f) < termOrder.size() + origScalarTerms || !std::is_sorted(termOrder.begin(), termOrder.end()) || nonTerminalScalar || numDivs) {
  1020. ExpressionTreeNode *seq = emitMultiplicativeSequence(tree, term, scalarTerm, values);
  1021. replaceNode(node, *seq);
  1022. return true;
  1023. }
  1024.  
  1025. return false;
  1026. }
  1027.  
  1028. bool applyAlgebraicOptimizations(ExpressionTree &tree)
  1029. {
  1030. bool changed = false;
  1031.  
  1032. applyValueNumbering(tree);
  1033.  
  1034. tree.getRoot()->preorder([&](ExpressionTreeNode &node)
  1035. {
  1036. if (node.op.type == ExprOpType::CMP && node.left->valueNum == node.right->valueNum) {
  1037. ComparisonType type = static_cast<ComparisonType>(node.op.imm.u);
  1038. if (type == ComparisonType::EQ || type == ComparisonType::LE || type == ComparisonType::NLT)
  1039. replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
  1040. else
  1041. replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
  1042.  
  1043. changed = true;
  1044. return changed;
  1045. }
  1046.  
  1047. if (node.op.type == ExprOpType::TERNARY && node.right->left->valueNum == node.right->right->valueNum) {
  1048. replaceNode(node, *node.right->left);
  1049. changed = true;
  1050. return changed;
  1051. }
  1052.  
  1053. if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::ADD, ExprOpType::SUB }))) {
  1054. changed = changed || analyzeAdditiveExpression(tree, node);
  1055. return changed;
  1056. }
  1057.  
  1058. if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::MUL, ExprOpType::DIV }))) {
  1059. changed = changed || analyzeMultiplicativeExpression(tree, node);
  1060. return changed;
  1061. }
  1062.  
  1063. return false;
  1064. });
  1065.  
  1066. return changed;
  1067. }
  1068.  
  1069. bool applyStrengthReduction(ExpressionTree &tree)
  1070. {
  1071. bool changed = false;
  1072.  
  1073. tree.getRoot()->postorder([&](ExpressionTreeNode &node)
  1074. {
  1075. if (node.op == ExprOpType::MUX)
  1076. return;
  1077.  
  1078. // 0 - x = -x
  1079. if (node.op == ExprOpType::SUB && isConstant(*node.left, 0.0f)) {
  1080. ExpressionTreeNode *tmp = node.right;
  1081. replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
  1082. node.left = tmp;
  1083. node.left->parent = &node;
  1084. changed = true;
  1085. }
  1086.  
  1087. // x * -1 = -x x / -1 = -x
  1088. if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, -1.0f)) {
  1089. ExpressionTreeNode *tmp = node.left;
  1090. replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
  1091. node.left = tmp;
  1092. node.left->parent = &node;
  1093. changed = true;
  1094. }
  1095.  
  1096. // a + -b = a - b a - -b = a + b
  1097. if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && node.right->op.type == ExprOpType::NEG) {
  1098. node.op = node.op == ExprOpType::ADD ? ExprOpType::SUB : ExprOpType::ADD;
  1099. replaceNode(*node.right, *node.right->left);
  1100. changed = true;
  1101. }
  1102.  
  1103. // -a + b = b - a
  1104. if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::NEG) {
  1105. node.op = ExprOpType::SUB;
  1106. replaceNode(*node.left, *node.left->left);
  1107. std::swap(node.left, node.right);
  1108. }
  1109.  
  1110. // -(a - b) = b - a
  1111. if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::SUB) {
  1112. replaceNode(node, *node.left);
  1113. std::swap(node.left, node.right);
  1114. changed = true;
  1115. }
  1116.  
  1117. // x * 2 = x + x
  1118. if (node.op == ExprOpType::MUL && isConstant(*node.right, 2.0f) && (!node.parent || node.parent->op != ExprOpType::ADD)) {
  1119. ExpressionTreeNode *replacement = tree.clone(node.left);
  1120. node.op = ExprOpType::ADD;
  1121. replaceNode(*node.right, *replacement);
  1122. changed = true;
  1123. }
  1124.  
  1125. // x / y = x * (1 / y)
  1126. if (node.op == ExprOpType::DIV && isConstant(*node.right)) {
  1127. node.op = ExprOpType::MUL;
  1128. node.right->op.imm.f = 1.0f / node.right->op.imm.f;
  1129. changed = true;
  1130. }
  1131.  
  1132. // (1 / x) * y = y / x
  1133. if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV && isConstant(*node.left->left, 1.0f)) {
  1134. node.op = ExprOpType::DIV;
  1135. replaceNode(*node.left, *node.left->right);
  1136. std::swap(node.left, node.right);
  1137. changed = true;
  1138. }
  1139.  
  1140. // x * (1 / y) = x / y
  1141. if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV && isConstant(*node.right->left, 1.0f)) {
  1142. node.op = ExprOpType::DIV;
  1143. replaceNode(*node.right, *node.right->right);
  1144. changed = true;
  1145. }
  1146.  
  1147. // (a / b) * c = (a * c) / b
  1148. if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV) {
  1149. node.op = ExprOpType::DIV;
  1150. node.left->op = ExprOpType::MUL;
  1151. std::swap(node.left->right, node.right);
  1152. node.left->right->parent = node.left;
  1153. node.right->parent = &node;
  1154. changed = true;
  1155. }
  1156.  
  1157. // a * (b / c) = (a * b) / c
  1158. if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV) {
  1159. node.op = ExprOpType::DIV;
  1160. node.right->op = ExprOpType::MUL;
  1161. std::swap(node.left, node.right); // (b * c) / a
  1162. std::swap(node.left->left, node.left->right); // (c * b) / a
  1163. std::swap(node.left->left, node.right); // (a * b) / c
  1164. node.left->left->parent = node.left;
  1165. node.right->parent = &node;
  1166. changed = true;
  1167. }
  1168.  
  1169. // a / (b / c) = (a * c) / b
  1170. if (node.op == ExprOpType::DIV && node.right->op == ExprOpType::DIV) {
  1171. node.right->op = ExprOpType::MUL; // a / (b * c)
  1172. std::swap(node.left, node.right); // (b * c) / a
  1173. std::swap(node.left->left, node.right); // (a * c) / b
  1174. node.left->left->parent = node.left;
  1175. node.right->parent = &node;
  1176. changed = true;
  1177. }
  1178.  
  1179. // (a / b) / c = a / (b * c)
  1180. if (node.op == ExprOpType::DIV && node.left->op == ExprOpType::DIV) {
  1181. node.left->op = ExprOpType::MUL; // (a * b) / c
  1182. std::swap(node.left, node.right); // c / (a * b)
  1183. std::swap(node.left, node.right->left); // a / (c * b)
  1184. std::swap(node.right->left, node.right->right); // a / (b * c)
  1185. node.left->parent = &node;
  1186. node.right->left->parent = node.right;
  1187. node.right->right->parent = node.right;
  1188. changed = true;
  1189. }
  1190.  
  1191. // x ** (n / 2) = sqrt(x ** n)
  1192. if (node.op == ExprOpType::POW && isConstant(*node.right) && !isInteger(node.right->op.imm.f) && isInteger(node.right->op.imm.f * 2.0f)) {
  1193. ExpressionTreeNode *dup = tree.clone(&node);
  1194. replaceNode(node, ExpressionTreeNode{ ExprOpType::SQRT });
  1195. node.left = dup;
  1196. node.left->parent = &node;
  1197. node.left->right->op.imm.f *= 2.0f;
  1198. changed = true;
  1199. }
  1200.  
  1201. // x ** -N = 1 / (x ** N)
  1202. if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f < 0) {
  1203. ExpressionTreeNode *dup = tree.clone(&node);
  1204. replaceNode(node, ExpressionTreeNode{ ExprOpType::DIV });
  1205. node.left = tree.makeNode({ ExprOpType::CONSTANT, 1.0f });
  1206. node.right = dup;
  1207. node.left->parent = &node;
  1208. node.right->parent = &node;
  1209. node.right->right->op.imm.f = -node.right->right->op.imm.f;
  1210. changed = true;
  1211. }
  1212.  
  1213. // x ** N = x * x * x * ...
  1214. if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f > 0) {
  1215. ExpressionTreeNode *replacement = integerPower(tree, *node.left, static_cast<int>(node.right->op.imm.f));
  1216. replaceNode(node, *replacement);
  1217. changed = true;
  1218. }
  1219. });
  1220.  
  1221. return changed;
  1222. }
  1223.  
  1224. bool applyOpFusion(ExpressionTree &tree)
  1225. {
  1226. std::unordered_map<int, size_t> refCount;
  1227. bool changed = false;
  1228.  
  1229. applyValueNumbering(tree);
  1230.  
  1231. tree.getRoot()->postorder([&](ExpressionTreeNode &node)
  1232. {
  1233. if (node.op == ExprOpType::MUX)
  1234. return;
  1235.  
  1236. refCount[node.valueNum]++;
  1237. });
  1238.  
  1239. tree.getRoot()->postorder([&](ExpressionTreeNode &node)
  1240. {
  1241. if (node.op == ExprOpType::MUX)
  1242. return;
  1243.  
  1244. // FMA.
  1245. if (node.op == ExprOpType::ADD && node.right->op == ExprOpType::MUL && refCount[node.right->valueNum] <= 1) {
  1246. node.right->op = ExprOpType::MUX;
  1247. node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
  1248. changed = true;
  1249. }
  1250. if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::MUL && refCount[node.left->valueNum] <= 1) {
  1251. std::swap(node.left, node.right);
  1252. node.right->op = ExprOpType::MUX;
  1253. node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
  1254. changed = true;
  1255. }
  1256. if (node.op == ExprOpType::SUB && node.right->op == ExprOpType::MUL && refCount[node.right->valueNum] <= 1) {
  1257. node.right->op = ExprOpType::MUX;
  1258. node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FNMADD) };
  1259. changed = true;
  1260. }
  1261. if (node.op == ExprOpType::SUB && node.left->op == ExprOpType::MUL && refCount[node.left->valueNum] <= 1) {
  1262. std::swap(node.left, node.right);
  1263. node.right->op = ExprOpType::MUX;
  1264. node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMSUB) };
  1265. changed = true;
  1266. }
  1267. if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::FMA && refCount[node.left->valueNum] <= 1) {
  1268. ExpressionTreeNode *replacement = node.left;
  1269. node.op = replacement->op;
  1270. node.left = replacement->left;
  1271. node.right = replacement->right;
  1272.  
  1273. switch (static_cast<FMAType>(node.op.imm.u)) {
  1274. case FMAType::FMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FNMSUB); break;
  1275. case FMAType::FMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FNMADD); break;
  1276. case FMAType::FNMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FMSUB); break;
  1277. case FMAType::FNMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FMADD); break;
  1278. }
  1279.  
  1280. changed = true;
  1281. }
  1282. });
  1283.  
  1284. return changed;
  1285. }
  1286.  
  1287. std::vector<ExprInstruction> compile(ExpressionTree &tree, const VSFormat *format)
  1288. {
  1289. std::vector<ExprInstruction> code;
  1290. std::unordered_set<int> found;
  1291.  
  1292. if (!tree.getRoot())
  1293. return code;
  1294.  
  1295. while (applyLocalOptimizations(tree) || applyAlgebraicOptimizations(tree)) {
  1296. // ...
  1297. }
  1298.  
  1299. // Substitution rules can hide algebraic expressions from the optimizer, so they run in a separate pass.
  1300. while (applyStrengthReduction(tree) || applyOpFusion(tree)) {
  1301. // ...
  1302. }
  1303.  
  1304. applyValueNumbering(tree);
  1305.  
  1306. tree.getRoot()->postorder([&](ExpressionTreeNode &node)
  1307. {
  1308. if (node.op.type == ExprOpType::MUX)
  1309. return;
  1310. if (found.find(node.valueNum) != found.end())
  1311. return;
  1312.  
  1313. ExprInstruction opcode(node.op);
  1314. opcode.dst = node.valueNum;
  1315.  
  1316. if (node.left) {
  1317. assert(node.left->valueNum >= 0);
  1318. opcode.src1 = node.left->valueNum;
  1319. }
  1320. if (node.right) {
  1321. if (node.right->op.type == ExprOpType::MUX) {
  1322. assert(node.right->left->valueNum >= 0);
  1323. assert(node.right->right->valueNum >= 0);
  1324. opcode.src2 = node.right->left->valueNum;
  1325. opcode.src3 = node.right->right->valueNum;
  1326. } else {
  1327. assert(node.right->valueNum >= 0);
  1328. opcode.src2 = node.right->valueNum;
  1329. }
  1330. }
  1331.  
  1332. code.push_back(opcode);
  1333. found.insert(node.valueNum);
  1334. });
  1335.  
  1336.  
  1337. ExprInstruction store(ExprOpType::MEM_STORE_U8);
  1338.  
  1339. if (format->sampleType == stInteger && format->bytesPerSample == 1)
  1340. store.op.type = ExprOpType::MEM_STORE_U8;
  1341. else if (format->sampleType == stInteger && format->bytesPerSample == 2)
  1342. store.op.type = ExprOpType::MEM_STORE_U16;
  1343. else if (format->sampleType == stFloat && format->bytesPerSample == 2)
  1344. store.op.type = ExprOpType::MEM_STORE_F16;
  1345. else if (format->sampleType == stFloat && format->bytesPerSample == 4)
  1346. store.op.type = ExprOpType::MEM_STORE_F32;
  1347.  
  1348. store.src1 = code.back().dst;
  1349. code.push_back(store);
  1350.  
  1351. return code;
  1352. }
  1353.  
  1354. int main(int argc, char **argv)
  1355. {
  1356. VSFormat format{};
  1357. VSVideoInfo realvi{};
  1358. const VSVideoInfo *vi[26];
  1359.  
  1360. format.bytesPerSample = 1;
  1361. format.sampleType = stInteger;
  1362. realvi.format = &format;
  1363.  
  1364. for (int i = 0; i < 26; ++i) {
  1365. vi[i] = &realvi;
  1366. }
  1367.  
  1368. for (int i = 1; i < 2; ++i) {
  1369. std::cout << argv[i] << '\n';
  1370.  
  1371. ExpressionTree tree = parseExpr(argv[i], vi, 26);
  1372. std::vector<ExprInstruction> code = compile(tree, &format);
  1373.  
  1374. for (auto &insn : code) {
  1375. std::cout << std::setw(12) << std::left << op_names[static_cast<size_t>(insn.op.type)];
  1376.  
  1377. if (insn.op.type == ExprOpType::MEM_STORE_U8 || insn.op.type == ExprOpType::MEM_STORE_U16 || insn.op.type == ExprOpType::MEM_STORE_F16 || insn.op.type == ExprOpType::MEM_STORE_F32) {
  1378. std::cout << " r" << insn.src1 << '\n';
  1379. continue;
  1380. }
  1381.  
  1382. std::cout << " r" << insn.dst;
  1383.  
  1384. if (insn.src1 >= 0)
  1385. std::cout << ",r" << insn.src1;
  1386. if (insn.src2 >= 0)
  1387. std::cout << ",r" << insn.src2;
  1388. if (insn.src3 >= 0)
  1389. std::cout << ",r" << insn.src3;
  1390.  
  1391. switch (insn.op.type) {
  1392. case ExprOpType::MEM_LOAD_U8:
  1393. case ExprOpType::MEM_LOAD_U16:
  1394. case ExprOpType::MEM_LOAD_F16:
  1395. case ExprOpType::MEM_LOAD_F32:
  1396. std::cout << ',' << static_cast<char>(insn.op.imm.u < 3 ? 'x' + insn.op.imm.u : 'a' + insn.op.imm.u - 3);
  1397. break;
  1398. case ExprOpType::CONSTANT:
  1399. std::cout << ',' << insn.op.imm.f;
  1400. break;
  1401. case ExprOpType::FMA:
  1402. std::cout << "," << insn.op.imm.u;
  1403. break;
  1404. case ExprOpType::CMP:
  1405. std::cout << ',' << cmp_names[insn.op.imm.u];
  1406. break;
  1407. }
  1408.  
  1409. std::cout << '\n';
  1410. }
  1411. }
  1412.  
  1413. return 0;
  1414. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement