Advertisement
Martmists

bytecode_optimizer.py

Nov 18th, 2019
171
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.64 KB | None | 0 0
  1. from dis import opmap, dis, stack_effect, hasjrel, hasjabs, hasname, hasconst, haslocal
  2. from struct import pack
  3. from types import CodeType
  4. from typing import List, Tuple, Any, T, Sequence, Generator
  5.  
  6.  
  7. class Flags:
  8.     DEBUG = False
  9.     REMOVE_UNUSED_VARS = True
  10.     TAIL_CALL_OPTIMIZATION = True
  11.     OPTIMIZE_ACCESSORS = True
  12.     OPTIMIZE_NAMES = True
  13.  
  14.  
  15. def debug(*args):
  16.     if Flags.DEBUG:
  17.         print(*args)
  18.  
  19.  
  20. def dump(ops):
  21.     if Flags.DEBUG:
  22.         print("-"*50)
  23.         dis(b"".join(pack("BB", op[0], op[1]) for op in ops))
  24.  
  25.  
  26. def iter_size(it: Sequence[T], size: int) -> Generator[Sequence[T], None, None]:
  27.     index = 0
  28.     while index < len(it) - size:
  29.         yield it[index:index+size]
  30.         index += 1
  31.     return
  32.  
  33.  
  34. def is_name_used_upper(name, code) -> bool:
  35.     return name not in code.co_varnames
  36.  
  37.  
  38. def remove_unused_variables(ops: List[Tuple[int, int, int]], code: CodeType) -> List[Tuple[int, int, int]]:
  39.     found = True
  40.     continuefrom = 0
  41.     while found:
  42.         copy_ops = ops[continuefrom:]
  43.         found = False
  44.         stored = None
  45.         stored_index = 0
  46.         for i, op in enumerate(copy_ops):
  47.             i = i + continuefrom
  48.             if stored is None and op[0] in (opmap["STORE_FAST"], opmap["STORE_NAME"]):
  49.                 if op[0] == opmap["STORE_FAST"]:
  50.                     if not is_name_used_upper(code.co_varnames[op[1]], code):
  51.                         debug(f"Checking: {op}")
  52.                         stored = op
  53.                         stored_index = i
  54.                 elif op[0] == opmap["STORE_NAME"]:
  55.                     if not is_name_used_upper(code.co_names[op[1]], code):
  56.                         debug(f"Checking: {op}")
  57.                         stored = op
  58.                         stored_index = i
  59.  
  60.             elif stored and op[0] in (*haslocal, *hasname):
  61.                 if op[1] == stored[1]:
  62.                     if op[0] == stored[0]:  # Same store
  63.                         debug(f"Found overriding store {op} for {stored}")
  64.                         # Useless variable
  65.                         ops[stored_index] = (opmap["POP_TOP"], 0, stored[2])
  66.                         continuefrom = stored_index + 1
  67.                         found = True
  68.                         break
  69.                     elif ((op[0] in haslocal and stored[0] == opmap["STORE_FAST"]) or
  70.                           (op[0] in hasname and stored[0] == opmap["STORE_NAME"])):
  71.                         debug(f"Found load {op} for {stored}")
  72.                         # variable is used
  73.                         continuefrom = stored_index+1
  74.                         found = True
  75.                         break
  76.         else:
  77.             if stored:
  78.                 debug(f"Removing op {stored}")
  79.                 ops[stored_index] = (opmap["POP_TOP"], 0, stored[2])
  80.                 continuefrom = stored_index + 1
  81.                 found = True
  82.  
  83.     changed = True
  84.     while changed:
  85.         changed = False
  86.         for i, (first, second) in enumerate(zip(ops[:-1], ops[1:])):
  87.             try:
  88.                 effect = stack_effect(first[0], first[1])
  89.             except ValueError:
  90.                 try:
  91.                     effect = stack_effect(first[0])
  92.                 except ValueError:
  93.                     effect = 1
  94.             if effect == 1 and second[0] == opmap["POP_TOP"]:
  95.                 del ops[i+1]
  96.                 del ops[i]
  97.                 changed = True
  98.                 break
  99.  
  100.     return ops
  101.  
  102.  
  103. def optimize_accessors(ops: List[Tuple[int, int, int]]) -> List[Tuple[int, int, int]]:
  104.     changed = True
  105.     while changed:
  106.         changed = False
  107.         for i, (first, second) in enumerate(zip(ops[:-1], ops[1:])):
  108.             if first[1] == second[1]:
  109.                 if (first[0], second[0]) in ((opmap["STORE_FAST"], opmap["LOAD_FAST"]),
  110.                                              (opmap["STORE_NAME"], opmap["LOAD_NAME"])):
  111.                     if not any(arg[:2] == second[:2] for arg in ops[i+2:]):
  112.                         # Make sure the fast isn't accessed a second time
  113.                         del ops[i + 1]
  114.                         del ops[i]
  115.                         changed = True
  116.                         break
  117.     return ops
  118.  
  119.  
  120. def get_stack_size(ops: List[Tuple[int, int, int]]) -> int:
  121.     stack = 0
  122.     max_stack = 0
  123.     for op in ops:
  124.         try:
  125.             stack += stack_effect(op[0], op[1])
  126.         except ValueError:
  127.             try:
  128.                 stack += stack_effect(op[0])
  129.             except ValueError:
  130.                 stack += 1
  131.         max_stack = max(max_stack, stack)
  132.     return max_stack
  133.  
  134.  
  135. def fix_jumps(ops: List[Tuple[int, int, int]]) -> List[Tuple[int, int, int]]:
  136.     for i, op in enumerate(ops):
  137.         if op[0] in hasjabs:
  138.             target = [x for x in ops if x[2] >= op[1]][0]  # if not exists, jump to next instruction
  139.             ops[i] = (op[0], ops.index(target)*2, op[2])
  140.         elif op[0] in hasjrel:
  141.             target = [x for x in ops if x[2] >= op[1]+op[2]][0]  # if not exists, jump to next instruction
  142.             ops[i] = (op[0], (ops.index(target)-ops.index(op))*2, op[2])
  143.     return ops
  144.  
  145.  
  146. def optimize_names(opcodes: List[Tuple[int, int, int]], code: CodeType) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[Any, ...]]:
  147.     accessed_names = []
  148.     accessed_varnames = []
  149.     for i in range(code.co_argcount + code.co_kwonlyargcount):
  150.         accessed_varnames.append(code.co_varnames[i])
  151.     accessed_consts = []
  152.     for op in opcodes:
  153.         if op[0] in hasname and code.co_names[op[1]] not in accessed_names:
  154.             accessed_names.append(code.co_names[op[1]])
  155.         elif op[0] in haslocal and code.co_varnames[op[1]] not in accessed_varnames:
  156.             accessed_varnames.append(code.co_varnames[op[1]])
  157.         elif op[0] in hasconst and code.co_consts[op[1]] not in accessed_consts:
  158.             accessed_consts.append(code.co_consts[op[1]])
  159.  
  160.     for i, op in enumerate(opcodes):
  161.         if op[0] in hasname:
  162.             opcodes[i] = (op[0], accessed_names.index(code.co_names[op[1]]), op[2])
  163.         elif op[0] in haslocal:
  164.             opcodes[i] = (op[0], accessed_varnames.index(code.co_varnames[op[1]]), op[2])
  165.         elif op[0] in hasconst:
  166.             opcodes[i] = (op[0], accessed_consts.index(code.co_consts[op[1]]), op[2])
  167.  
  168.     return tuple(accessed_names), tuple(accessed_varnames), tuple(accessed_consts)
  169.  
  170.  
  171. def optimize_tco(ops: List[Tuple[int, int, int]], code: CodeType) -> List[Tuple[int, int, int]]:
  172.     ops_copy = ops[:]
  173.     changed = True
  174.     name = code.co_name if "<optimized>" not in code.co_name else code.co_name[12:]
  175.  
  176.     while changed:
  177.         changed = False
  178.         for i, new_ops in enumerate(iter_size(ops_copy, 3)):
  179.             if new_ops[0][0] in (opmap["LOAD_DEREF"], opmap["LOAD_GLOBAL"]):
  180.                 names = code.co_names if new_ops[0][0] == opmap["LOAD_GLOBAL"] else list(code.co_freevars) + list(code.co_cellvars)
  181.                 if names[new_ops[0][1]] == name:
  182.                     print("Searching for CALL_FUNCTION RETURN_VALUE")
  183.                     for j, (op, op2) in enumerate(iter_size(ops_copy[i+3:], 2)):
  184.                         if op[0] == opmap["CALL_FUNCTION"] and op2[0] == opmap["RETURN_VALUE"]:
  185.                             print("Found CALL_FUNCTION RETURN_VALUE")
  186.                             nargs = op[1]
  187.                             added_ops = []
  188.                             for k in reversed(range(nargs)):
  189.                                 added_ops.append((opmap["STORE_FAST"], k, op[2]))
  190.                             added_ops.append((opmap["JUMP_ABSOLUTE"], 0, op2[2]))
  191.                             print(ops_copy)
  192.                             ops_copy[i+j+3:i+j+5] = added_ops
  193.                             print(ops_copy)
  194.                             changed = True
  195.                             break
  196.  
  197.     return ops_copy
  198.  
  199.  
  200. def nested_tco(ops: List[Tuple[int, int, int]], code: CodeType) -> Tuple[Any, ...]:
  201.     consts = list(code.co_consts)
  202.  
  203.     for i, new_ops in enumerate(iter_size(ops, 4)):
  204.         if new_ops[0][0] == opmap["MAKE_FUNCTION"]:
  205.             for op in ops[i+4:]:
  206.                 if op[0] == opmap["STORE_NAME"] and op[1] == new_ops[3][1]:
  207.                     # dont optimize
  208.                     break
  209.             else:
  210.                 # optimize
  211.                 new_code = consts[ops[i-2][1]]
  212.                 opcodes = [*zip(new_code.co_code[::2], new_code.co_code[1::2], range(0, int(len(new_code.co_code)), 2))]
  213.                 opcodes = optimize_tco(opcodes, new_code)
  214.                 co_code = b"".join(pack("BB", op[0], op[1]) for op in opcodes)
  215.                 code = CodeType(new_code.co_argcount, new_code.co_kwonlyargcount, new_code.co_nlocals, new_code.co_stacksize,
  216.                                 new_code.co_flags, co_code, new_code.co_consts, new_code.co_names, new_code.co_varnames,
  217.                                 new_code.co_filename, new_code.co_name, new_code.co_firstlineno, new_code.co_lnotab, new_code.co_freevars,
  218.                                 new_code.co_cellvars)
  219.                 consts[ops[i-2][1]] = code
  220.     return tuple(consts)
  221.  
  222.  
  223. def optimize_code(code: CodeType) -> CodeType:
  224.     co_argcount = None
  225.     co_kwonlyargcount = None
  226.     co_nlocals = None
  227.     co_flags = None
  228.     co_filename = None
  229.     co_names = None
  230.     co_varnames = None
  231.     co_name = "<optimized> "+code.co_name if not code.co_name.startswith("<") else "<optimized " + code.co_name[1:]
  232.     co_firstlineno = None
  233.     co_lnotab = None
  234.     co_freevars = None
  235.     co_cellvars = None
  236.  
  237.     opcodes = [*zip(code.co_code[::2], code.co_code[1::2], range(0, int(len(code.co_code)), 2))]
  238.     co_consts = tuple((const if not isinstance(const, CodeType)
  239.                        else optimize_code(const))
  240.                       for const in code.co_consts)
  241.     # We do this to optimize out all nested code consts first
  242.     code = CodeType(code.co_argcount, code.co_kwonlyargcount, code.co_nlocals, code.co_stacksize, code.co_flags,
  243.                     code.co_code, co_consts, code.co_names, code.co_varnames, code.co_filename, code.co_name,
  244.                     code.co_firstlineno, code.co_lnotab, code.co_freevars, code.co_cellvars)
  245.  
  246.     dump(opcodes)
  247.     if Flags.REMOVE_UNUSED_VARS:
  248.         opcodes = remove_unused_variables(opcodes, code)
  249.         dump(opcodes)
  250.  
  251.     if Flags.TAIL_CALL_OPTIMIZATION:
  252.         co_consts = nested_tco(opcodes, code)
  253.         code = CodeType(code.co_argcount, code.co_kwonlyargcount, code.co_nlocals, code.co_stacksize, code.co_flags,
  254.                         code.co_code, co_consts, code.co_names, code.co_varnames, code.co_filename, code.co_name,
  255.                         code.co_firstlineno, code.co_lnotab, code.co_freevars, code.co_cellvars)
  256.         dump(opcodes)
  257.  
  258.     if Flags.OPTIMIZE_ACCESSORS:
  259.         opcodes = optimize_accessors(opcodes)
  260.         dump(opcodes)
  261.     opcodes = fix_jumps(opcodes)
  262.  
  263.     if Flags.OPTIMIZE_NAMES:
  264.         co_names, co_varnames, co_consts = optimize_names(opcodes, code)
  265.  
  266.     co_stacksize = get_stack_size(opcodes)
  267.  
  268.     co_code = b"".join(pack("BB", op[0], op[1]) for op in opcodes)
  269.     return CodeType(
  270.         co_argcount or code.co_argcount,
  271.         co_kwonlyargcount or code.co_kwonlyargcount,
  272.         co_nlocals or code.co_nlocals,
  273.         co_stacksize or code.co_stacksize,
  274.         co_flags or code.co_flags,
  275.         co_code or code.co_code,
  276.         co_consts or code.co_consts,
  277.         co_names or code.co_names,
  278.         co_varnames or code.co_varnames,
  279.         co_filename or code.co_filename,
  280.         co_name or code.co_name,
  281.         co_firstlineno or code.co_firstlineno,
  282.         co_lnotab or code.co_lnotab,
  283.         co_freevars or code.co_freevars,
  284.         co_cellvars or code.co_cellvars
  285.     )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement