Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import dis
- from types import CodeType
- def _patch_code(code: CodeType, **kwargs):
- """Create a new CodeType object with modified attributes."""
- code_attrs = {}
- # Collect the original CodeType attributes
- for attr in dir(code):
- if "__" not in attr:
- code_attrs[attr] = getattr(code, attr)
- # Patch the new attributes over the original ones
- code_attrs.update(kwargs)
- new_object = CodeType(
- code_attrs["co_argcount"],
- code_attrs["co_kwonlyargcount"],
- code_attrs["co_nlocals"],
- code_attrs["co_stacksize"],
- code_attrs["co_flags"],
- code_attrs["co_code"],
- code_attrs["co_consts"],
- code_attrs["co_names"],
- code_attrs["co_varnames"],
- code_attrs["co_filename"],
- code_attrs["co_name"],
- code_attrs["co_firstlineno"],
- code_attrs["co_lnotab"]
- )
- return new_object
- def _assemble(*instructions):
- """
- Assemble CPython bytecode into a byte-string given
- any amount of two-tuples containing: (op_name, arg_value)
- """
- code = ""
- for op_name, arg in instructions:
- # Some instructions don't take arguments, so just null it.
- if arg is None:
- arg = 0
- # Find the opcode so we can create the two-byte instruction
- # from the opcode itself and the argument number.
- op_code = dis.opmap[op_name]
- code += chr(op_code) + chr(arg)
- # We can't use `str.encode()` here because some opcodes are
- # greater than 128 (such as CALL_FUNCTION -> 131) so they wouldn't
- # be encoded to ASCII, and UTF-8 would obviously yield inconsistent
- # results due to the possibility of multi-byte characters.
- return bytes(ord(char) for char in code)
- def _safe_search(tup, *items):
- """Search a tuple, or add the item if it's not present."""
- indices = []
- for item in items:
- # Create a new tuple with the item if it's not already present
- if item not in tup:
- tup = *tup, item
- indices.append(tup.index(item))
- return (*tup, *indices)
- def _exec_code(code: CodeType, *args, **kwargs):
- """Execute a CodeType object with args and kwargs."""
- # We re-assign the bytecode of this empty
- # function with the CodeType object so that
- # it can be executed in a normal manner.
- util = lambda *args, **kwargs: None
- util.__code__ = code
- return util(*args, **kwargs)
- def cpp_stdio(func):
- """
- Modifies the bytecode of a function so that C++ style `cin` and
- `cout` calls are possible in place of `input` and `print`, then
- executes it.
- Example:
- >>> @cpp_stdio
- ... def hello_name():
- ... cout << "Please enter your name: ";
- ... cin >> name;
- ...
- ... cout << "Hello, " << name << "!" << endl;
- ...
- Note: I will hurt you if you unironically use this. :D
- """
- def _patch_cin(code: CodeType):
- """Patch instances of C++ style `cin` in the function."""
- # The attributes themselves are read-only so we just
- # have to patch these copies onto the original object.
- nlocals = code.co_nlocals
- varnames = code.co_varnames
- consts = code.co_consts
- new_code = code.co_code
- # We need cin_num and input_num for instruction args.
- *names, cin_num, input_num = _safe_search(
- code.co_names,
- "cin", "input"
- )
- # This will be used to find where `cin` is called.
- cin_start = _assemble(
- ("LOAD_GLOBAL", cin_num)
- )
- # This list will contain all of the implicitly-declared
- # variables. This means we can declare them locally from
- # the `cin` call alone, which is the 'pythonic' twist. :P
- imp_decl = []
- start_pos = 0
- while True:
- # Attempt to find another `cin` in the function,
- # and stop looking if one couldn't be found.
- start_pos = new_code.find(cin_start, start_pos)
- if start_pos < 0:
- break
- # `cin` calls are 4 instructions x 2 bytes = 8 bytes
- end_pos = start_pos + 8
- cin_call = new_code[start_pos:end_pos]
- # The third byte is the local arg number
- # of the variable that is being changed.
- store_num = cin_call[3]
- # Define the variable in the function's local
- # scope if it hasn't yet been declared locally.
- if cin_call[2] == dis.opmap["LOAD_GLOBAL"]:
- *consts, none_num = _safe_search(
- consts,
- None
- )
- # We'll need to keep track of this to replace it
- # if the variable is used later in the function.
- prev_store_num = store_num
- # Add the variable to the local scope declarations
- *varnames, store_num = _safe_search(
- varnames,
- names[store_num]
- )
- nlocals += 1
- imp_decl.append(
- (prev_store_num, store_num)
- )
- # This is the `var = input()` bytecode. It directly
- # replaces the `cin` calls, so that `cin` doesn't even
- # need to be defined at all for this to work. Snazzy!
- changed_code = _assemble(
- ("LOAD_GLOBAL", input_num),
- ("CALL_FUNCTION", 0),
- ("STORE_FAST", store_num)
- )
- new_code = new_code[:start_pos] + changed_code + new_code[end_pos:]
- # Stop the intepreter from treating implicity-declared
- # local variables as potentially undefined global variables.
- for prev_store, new_store in imp_decl:
- # The global-loading bytecode
- wrong_decl = _assemble(
- ("LOAD_GLOBAL", prev_store),
- )
- # The (correct) local-loading bytecode
- new_decl = _assemble(
- ("LOAD_FAST", new_store),
- )
- new_code = new_code.replace(wrong_decl, new_decl)
- return _patch_code(code,
- co_code=new_code,
- co_names=tuple(names),
- co_consts=tuple(consts),
- co_varnames=tuple(varnames),
- co_nlocals=nlocals
- )
- def _patch_cout(code: CodeType):
- """
- Patch instances of C++ style `cout` in the function.
- Note: I'm very aware that one can simply make a custom class
- which overrides the __lshift__ magic method and does a bunch
- of fancy stuff, and that's what I had originally.
- This is more fun though :D
- """
- # We need to patch these over the read-only attributes of `code`
- new_code = code.co_code
- consts = code.co_consts
- names = code.co_names
- # Find the const arg numbers for the two bools,
- # newline and empty strings, and the `print` kwargs.
- *consts, false_num, true_num, newln_num, empty_str, print_kws = _safe_search(
- consts,
- False, True, "\n", "", ("sep", "end", "flush")
- )
- # Find the global arg numbers of `count`, `endl` and `print`
- *names, cout_num, endl_num, print_num = _safe_search(
- names,
- "cout", "endl", "print"
- )
- # `cout` calls will always begin with this instruction
- cout_start = _assemble(
- ("LOAD_GLOBAL", cout_num)
- )
- # Each value is separated by a `<<`, so we'll use this.
- separator = _assemble(
- ("BINARY_LSHIFT", None)
- )
- # `cout` calls always end with this instruction
- cout_end = _assemble(
- ("POP_TOP", None)
- )
- # And this is what `endl` will appear as.
- endl_value = _assemble(
- ("LOAD_GLOBAL", endl_num)
- )
- start_pos = 0
- while True:
- # Find the boundaries and the bytecode of the `cout` call
- start_pos = new_code.find(cout_start, start_pos)
- if start_pos < 0:
- break
- end_pos = new_code.find(cout_end, start_pos)
- cout_call = new_code[start_pos:end_pos]
- # Cut off the `cout` part, and remove the `<<` separators.
- out_values = cout_call[2:].replace(separator, b"")
- # Push the print function onto the stack
- changed_code = _assemble(
- ("LOAD_GLOBAL", print_num)
- )
- # Add each value to be printed from the cout call
- changed_code += out_values
- if out_values.endswith(endl_value):
- changed_code = changed_code[:-2]
- # `cout` typically doesn't have a separator.
- changed_code += _assemble(
- ("LOAD_CONST", empty_str) # sep=''
- )
- # Each argument occupies 2 bytes, therefore we can
- # just divide the size by 2 to get the amount of them.
- args_length = len(out_values) // 2
- if out_values.endswith(endl_value):
- # `endl` follows `print` defaults, but all print
- # calls will have the 3 kwargs for the sake
- # of consistency (and I'm lazy!)
- changed_code +=_assemble(
- ("LOAD_CONST", newln_num), # end='\n'
- ("LOAD_CONST", true_num), # flush=True
- )
- # Account for the stripped off `endl` value
- args_length -= 1
- else:
- changed_code += _assemble(
- ("LOAD_CONST", empty_str), # end=''
- ("LOAD_CONST", false_num), # flush=False
- )
- # Loads the kwarg names and call the function
- changed_code += _assemble(
- ("LOAD_CONST", print_kws),
- ("CALL_FUNCTION_KW", 3 + args_length)
- )
- new_code = new_code[:start_pos] + changed_code + new_code[end_pos:]
- return _patch_code(code,
- co_code=new_code,
- co_consts=tuple(consts),
- co_names=tuple(names)
- )
- def wrapper(*args, **kwargs):
- code = func.__code__
- # We only need to patch `cin` and `count` if they're used.
- if "cin" in code.co_names:
- code = _patch_cin(code)
- if "cout" in code.co_names:
- code = _patch_cout(code)
- # Finally, execute the patched code as if nothing has happened :P
- _exec_code(code, *args, **kwargs)
- return wrapper
- if __name__ == '__main__':
- @cpp_stdio
- def addition():
- cout << "Enter a number: ";
- cin >> x;
- cout << "And another: ";
- cin >> y;
- result = int(x) + int(y)
- cout << x << " + " << y << " = " << result << endl;
- addition()
Add Comment
Please, Sign In to add comment