Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <string>
- #include <cassert>
- #include <algorithm>
- #include "xbyak/xbyak.h"
- class JIT: public Xbyak::CodeGenerator {
- public:
- template <typename Iter>
- JIT(const Iter beg, const Iter end): Xbyak::CodeGenerator() {
- // int (int a, int b, int c, int d)
- #ifndef XBYAK64
- push(edi);
- push(esi);
- mov(ecx, ptr[esp + 0x0C]);
- mov(edx, ptr[esp + 0x10]);
- mov(edi, ptr[esp + 0x14]);
- mov(esi, ptr[esp + 0x18]);
- #endif
- int params[4] = {};
- char Op = '\0';
- int constNum = 0;
- for (Iter it = beg; it != end; ++it) {
- if (*it >= 'a' && *it <= 'd') {
- switch (Op) {
- case '\0':
- case '+':
- ++params[*it - 'a'];
- break;
- case '-':
- --params[*it - 'a'];
- break;
- default:
- assert(false);
- }
- } else if (*it == '+' || *it == '-') {
- Op = *it;
- } else if (*it >= '0' && *it <= '9') {
- int t = parseInt(it, end);
- --it;
- constNum += Op == '-' ? -t : t;
- } else if (*it == ' ') {
- continue;
- } else {
- throw std::invalid_argument("stynax error");
- }
- }
- #ifndef XBYAK64
- const static Xbyak::Reg32 regs[] = { ecx, edx, edi, esi };
- const static Xbyak::Reg32 res_reg = eax;
- #else
- const static Xbyak::Reg64 regs[] = { rcx, rdx, r8, r9 };
- const static Xbyak::Reg64 res_reg = rax;
- #endif
- bool resInitialized = false;
- for (int i = 0; i < 4; ++i) {
- switch (params[i]) {
- case 0:
- break;
- case 1:
- if (resInitialized) {
- add(res_reg, regs[i]); // r += X
- } else {
- resInitialized = true;
- mov(res_reg, regs[i]); // r = X
- }
- break;
- case -1:
- if (resInitialized) {
- sub(res_reg, regs[i]); // r -= X
- } else {
- resInitialized = true;
- neg(regs[i]); // X = -X
- mov(res_reg, regs[i]); // r = X
- }
- break;
- case 2:
- if (resInitialized)
- lea(res_reg, ptr[res_reg + regs[i] * 2]); // r = r + X * 2
- else {
- resInitialized = true;
- lea(res_reg, ptr[regs[i] + regs[i]]); // r = X + X
- }
- break;
- case 3:
- case 5:
- case 9:
- if (resInitialized) {
- lea(regs[i], ptr[regs[i] + regs[i] * (params[i] - 1)]); // X = X + X * (N - 1)
- add(res_reg, regs[i]); // r += X
- } else {
- resInitialized = true;
- lea(res_reg, ptr[regs[i] + regs[i] * (params[i] - 1)]); // r = X + X * (N - 1)
- }
- break;
- case -2:
- case -3:
- case -5:
- case -9:
- if (resInitialized) {
- lea(regs[i], ptr[regs[i] + regs[i] * (-params[i] - 1)]); // X = X + X * (abs(N) - 1)
- sub(res_reg, regs[i]); // r = X
- } else {
- resInitialized = true;
- lea(res_reg, ptr[regs[i] + regs[i] * (-params[i] - 1)]); // r = X + X * (abs(N) - 1)
- neg(res_reg); // r = -r
- }
- break;
- case 4:
- case -4:
- case 8:
- case -8:
- if (resInitialized && params[i] > 0) {
- lea(res_reg, ptr[res_reg + regs[i] * params[i]]); // r = r + X * N
- } else {
- shl(regs[i], abs(params[i]) == 4 ? 2 : 3); // X <<= log2(abs(N))
- if (params[i] > 0) {
- resInitialized = true;
- mov(res_reg, regs[i]); // r = X
- } else { // params[i] < 0
- if (resInitialized) {
- sub(res_reg, regs[i]); // r -= X
- } else {
- resInitialized = true;
- neg(regs[i]); // X = -X
- mov(res_reg, regs[i]); // r = X
- }
- }
- }
- break;
- default:
- if (resInitialized) {
- imul(regs[i], regs[i], params[i]); // X = X * N
- add(res_reg, regs[i]); // r += X
- } else {
- resInitialized = true;
- imul(res_reg, regs[i], params[i]); // r = X * N
- }
- break;
- }
- }
- if (constNum != 0) {
- if (resInitialized) {
- add(res_reg, constNum); // r += Num
- } else {
- resInitialized = true;
- mov(res_reg, constNum); // r = Num
- }
- }
- if (!resInitialized)
- xor(res_reg, res_reg); // r = 0
- #ifndef XBYAK64
- pop(edi);
- pop(esi);
- #endif
- ret();
- }
- template <typename Iter>
- static int parseInt(Iter& it, const Iter end) {
- int res = 0, t;
- while (it != end && (t = *it - '0', t>=0 && t<=9)) {
- res *= 10;
- res += t;
- ++it;
- }
- return res;
- }
- int exec(int a=0, int b=0, int c=0, int d=0) {
- auto f = reinterpret_cast<int (*)(int, int, int, int)>(getCode());
- return f(a, b, c, d);
- }
- };
- int main() {
- std::string str = "a+a+b-b+c+c+c+c-d-d-d";
- // std::getline(std::cin, str);
- JIT jit(str.begin(), str.end());
- int a=1, b=2, c=3, d=4;
- int res = jit.exec(a, b, c, d);
- assert(a+a+b-b+c+c+c+c-d-d-d == res);
- std::cout << res;
- }
Advertisement
Add Comment
Please, Sign In to add comment