buesingniklas

Untitled

Jan 5th, 2022
2,854
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 20.47 KB | None | 0 0
  1. use super::stack::Stack;
  2. use std::collections::HashMap;
  3. use std::iter;
  4. use rand::Rng;
  5.  
  6. #[derive(Debug, Copy, Clone)]
  7. pub enum Location<'a> {
  8.    Label(&'a str),
  9.     Address(usize),
  10.     Offset(isize),
  11. }
  12.  
  13. impl<'a> Location<'a> {
  14.     fn as_label(&'a self) -> &'a str {
  15.         match self {
  16.             &Location::Label(label) => label,
  17.             _ => panic!("Location isn't a label"),
  18.         }
  19.     }
  20.  
  21.     fn as_address(&self) -> usize {
  22.         match self {
  23.             &Location::Address(addr) => addr,
  24.             _ => panic!("Location isn't an address"),
  25.         }
  26.     }
  27.  
  28.     fn as_offset(&self) -> isize {
  29.         match self {
  30.             &Location::Offset(offset) => offset,
  31.             _ => panic!("Location isn't an offset"),
  32.         }
  33.     }
  34. }
  35.  
  36. type InstructionPointer<'a> = (Location<'a>, usize);
  37. type LocalSlotIndex = usize;
  38. type FunctionIndex = usize;
  39.  
  40. pub type Integer = i64;
  41.  
  42. #[derive(Debug, Copy, Clone)]
  43. pub enum Instruction<'a> {
  44.    CallFunction(Location<'a>),
  45.  
  46.     PushConstantInteger(Integer),
  47.     StoreInteger(LocalSlotIndex),
  48.     LoadInteger(LocalSlotIndex),
  49.     AddInteger,
  50.     MultiplyInteger,
  51.     ModInteger,
  52.     CompareInteger,
  53.     IntegerGreaterThan,
  54.     IntegerSmallerThan,
  55.     IncrementInteger(LocalSlotIndex),
  56.     DecrementInteger(LocalSlotIndex),
  57.     JumpIfIntegerLessThan(Location<'a>),
  58.    JumpIfIntegerEquals(Location<'a>),
  59.  
  60.     CompareBoolean,
  61.  
  62.     Jump(Location<'a>),
  63.    JumpIfTrue(Location<'a>),
  64.     JumpIfFalse(Location<'a>),
  65.    Return,
  66. }
  67.  
  68. impl<'a> Instruction<'a> {
  69.    fn get_jump_location(&'a self) -> Option<&'a Location> {
  70.        match self {
  71.            Instruction::JumpIfIntegerLessThan(location) => Some(location),
  72.            Instruction::JumpIfIntegerEquals(location) => Some(location),
  73.            Instruction::Jump(location) => Some(location),
  74.            Instruction::JumpIfTrue(location) => Some(location),
  75.            Instruction::JumpIfFalse(location) => Some(location),
  76.            _ => None,
  77.        }
  78.    }
  79.    fn set_jump_location(&mut self, new_location: Location<'a>) {
  80.         let location = match self {
  81.             Instruction::JumpIfIntegerLessThan(location) => Some(location),
  82.             Instruction::JumpIfIntegerEquals(location) => Some(location),
  83.             Instruction::Jump(location) => Some(location),
  84.             Instruction::JumpIfTrue(location) => Some(location),
  85.             Instruction::JumpIfFalse(location) => Some(location),
  86.             _ => None,
  87.         }
  88.         .unwrap();
  89.         std::mem::replace(location, new_location);
  90.     }
  91. }
  92.  
  93. #[derive(Debug)]
  94. pub enum Value {
  95.     Integer(Integer),
  96.     Float(f32),
  97.     Boolean(bool),
  98. }
  99.  
  100. impl Value {
  101.     fn as_integer(&self) -> Integer {
  102.         match self {
  103.             &Value::Integer(value) => value,
  104.             _ => panic!("Value is not an integer"),
  105.         }
  106.     }
  107.  
  108.     fn as_boolean(&self) -> bool {
  109.         match self {
  110.             &Value::Boolean(value) => value,
  111.             _ => panic!("Value is not an boolean"),
  112.         }
  113.     }
  114. }
  115.  
  116. #[derive(Debug)]
  117. pub struct Function<'a> {
  118.    name: &'a str,
  119.     pub local_size: usize,
  120.     instructions: Vec<Instruction<'a>>,
  121.    labels: HashMap<String, usize>,
  122. }
  123.  
  124. impl<'a> Function<'a> {
  125.    pub fn new(name: &'a str) -> Self {
  126.         Self::new_with_locals(name, 0)
  127.     }
  128.  
  129.     pub fn new_with_locals(name: &'a str, local_size: usize) -> Self {
  130.        Function {
  131.            name,
  132.            local_size,
  133.            instructions: Vec::new(),
  134.            labels: HashMap::new(),
  135.        }
  136.    }
  137.  
  138.    pub fn push_instruction(&mut self, instruction: Instruction<'a>) {
  139.         self.instructions.push(instruction);
  140.     }
  141.  
  142.     pub fn label(&mut self, label: String) {
  143.         self.labels.insert(label, self.instructions.len());
  144.     }
  145.  
  146.     pub fn label_str(&mut self, label: &str) {
  147.         self.label(label.to_owned());
  148.     }
  149.  
  150.     fn transform_jumps(&mut self) {
  151.         for (index, instruction) in self.instructions.iter_mut().enumerate() {
  152.             if let Some(location) = instruction.get_jump_location() {
  153.                 let jump_dest = self.labels[location.as_label()];
  154.                 let offset = (jump_dest as isize) - (index as isize + 1);
  155.  
  156.                 instruction.set_jump_location(Location::Offset(offset));
  157.             }
  158.         }
  159.     }
  160. }
  161.  
  162. pub struct StackFrame<'a> {
  163.    pub(crate) locals: Box<[Option<Value>]>,
  164.    return_address: InstructionPointer<'a>,
  165. }
  166.  
  167. impl<'a> StackFrame<'a> {
  168.     fn new(local_size: usize, return_address: InstructionPointer<'a>) -> Self {
  169.        let locals: Box<[Option<Value>]> = iter::repeat_with(|| None).take(local_size).collect();
  170.        StackFrame {
  171.            locals,
  172.            return_address,
  173.        }
  174.    }
  175.  
  176.    fn get_local(&self, slot: LocalSlotIndex) -> &Value {
  177.        self.locals[slot as usize]
  178.            .as_ref()
  179.            .expect("Variable is set to null")
  180.    }
  181.  
  182.    fn set_local(&mut self, slot: LocalSlotIndex, value: Value) {
  183.        self.locals[slot as usize].replace(value);
  184.    }
  185. }
  186.  
  187. pub struct VirtualMachine<'a> {
  188.     pub(crate) functions: Vec<Function<'a>>,
  189.    function_locations: HashMap<&'a str, usize>,
  190.     instruction_pointer: InstructionPointer<'a>,
  191.    pub(crate) call_stack: Stack<StackFrame<'a>>,
  192.     pub(crate) operand_stack: Stack<Value>,
  193. }
  194.  
  195. impl<'a> VirtualMachine<'a> {
  196.     pub(crate) fn new() -> Self {
  197.         Self::new_with_functions(Vec::new())
  198.     }
  199.  
  200.     pub(crate) fn new_with_functions(functions: Vec<Function<'a>>) -> Self {
  201.        VirtualMachine {
  202.            functions,
  203.            function_locations: HashMap::new(),
  204.            instruction_pointer: InstructionPointer::from((Location::Address(0), 0)),
  205.            call_stack: Stack::new(),
  206.            operand_stack: Stack::new(),
  207.        }
  208.    }
  209.  
  210.    pub(crate) fn push_function(&mut self, function: Function<'a>) {
  211.         let name = function.name;
  212.         self.functions.push(function);
  213.         self.function_locations
  214.             .insert(name, self.functions.len() - 1);
  215.     }
  216.  
  217.     pub(crate) fn run(&mut self) {
  218.         self.functions
  219.             .iter_mut()
  220.             .for_each(|function| function.transform_jumps());
  221.  
  222.         for instruction in self
  223.             .functions
  224.             .iter_mut()
  225.             .flat_map(|function| &mut function.instructions)
  226.         {
  227.             if let Instruction::CallFunction(location) = instruction {
  228.                 let address = self.function_locations[location.as_label()];
  229.                 std::mem::replace(location, Location::Address(address));
  230.             }
  231.         }
  232.  
  233.         let (main_index, main_fn) = &self
  234.             .functions
  235.             .iter()
  236.             .enumerate()
  237.             .find(|(_index, function)| function.name == "main")
  238.             .expect("No entry point present (push main fn)");
  239.  
  240.         self.call_stack.push(StackFrame::new(
  241.             main_fn.local_size,
  242.             InstructionPointer::from((Location::Address(*main_index), main_fn.instructions.len())),
  243.         ));
  244.  
  245.         loop {
  246.             let function_index = self.instruction_pointer.0.as_address();
  247.             let function = &self.functions[function_index];
  248.             let instructions = &function.instructions;
  249.             let instruction_count = instructions.len();
  250.             let instruction_index = self.instruction_pointer.1;
  251.  
  252.             let instruction = &instructions[instruction_index];
  253.  
  254.             self.instruction_pointer.1 += 1;
  255.  
  256.             match instruction {
  257.                 Instruction::CallFunction(location) => self.inst_call_function(*location),
  258.                 Instruction::PushConstantInteger(val) => self.inst_load_constant_integer(*val),
  259.                 Instruction::StoreInteger(slot) => self.inst_store_integer(*slot),
  260.                 Instruction::LoadInteger(slot) => self.inst_load_integer(*slot),
  261.                 Instruction::AddInteger => self.inst_add_integer(),
  262.                 Instruction::MultiplyInteger => self.inst_multiply_integer(),
  263.                 Instruction::ModInteger => self.inst_mod_integer(),
  264.                 Instruction::CompareInteger => self.inst_compare_integer(),
  265.                 Instruction::IntegerGreaterThan => self.inst_integer_greater_than(),
  266.                 Instruction::IntegerSmallerThan => self.inst_integer_smaller_than(),
  267.                 Instruction::IncrementInteger(slot) => self.inst_increment_integer(*slot),
  268.                 Instruction::DecrementInteger(slot) => self.inst_decrement_integer(*slot),
  269.                 Instruction::JumpIfIntegerLessThan(location) => {
  270.                     self.inst_jump_if_integer_less_than(*location)
  271.                 }
  272.                 Instruction::JumpIfIntegerEquals(location) => {
  273.                     self.inst_jump_if_integer_equals(*location)
  274.                 }
  275.                 Instruction::CompareBoolean => self.inst_compare_boolean(),
  276.                 Instruction::Jump(location) => self.inst_jump(*location),
  277.                 Instruction::JumpIfTrue(location) => self.inst_jump_if_true(*location),
  278.                 Instruction::JumpIfFalse(location) => self.inst_jump_if_false(*location),
  279.                 Instruction::Return => self.inst_return(),
  280.             }
  281.  
  282.             if self.call_stack.is_empty() {
  283.                 break;
  284.             }
  285.         }
  286.     }
  287.  
  288.     fn jump(&mut self, location: Location) {
  289.         let offset = location.as_offset();
  290.         let current_pointer = self.instruction_pointer.1 as isize;
  291.         self.instruction_pointer.1 = (current_pointer + offset) as usize;
  292.     }
  293.  
  294.     fn inst_call_function(&mut self, location: Location<'a>) {
  295.        let function = &self.functions[location.as_address()];
  296.        self.call_stack.push(StackFrame::new(
  297.            function.local_size,
  298.            self.instruction_pointer.clone(),
  299.        ));
  300.        self.instruction_pointer = InstructionPointer::from((location, 0));
  301.    }
  302.  
  303.    fn inst_load_constant_integer(&mut self, value: Integer) {
  304.        self.operand_stack.push(Value::Integer(value));
  305.    }
  306.  
  307.    fn inst_store_integer(&mut self, slot: LocalSlotIndex) {
  308.        let current_frame = self.call_stack.peek_mut();
  309.        let value = self.operand_stack.pop();
  310.        current_frame.set_local(slot, value);
  311.    }
  312.  
  313.    fn inst_load_integer(&mut self, slot: LocalSlotIndex) {
  314.        let current_frame = self.call_stack.peek_mut();
  315.        let value = current_frame.get_local(slot).as_integer();
  316.        self.operand_stack.push(Value::Integer(value));
  317.    }
  318.  
  319.    fn inst_increment_integer(&mut self, slot: LocalSlotIndex) {
  320.        let current_frame = self.call_stack.peek_mut();
  321.        current_frame.set_local(
  322.            slot,
  323.            Value::Integer(current_frame.get_local(slot).as_integer() + 1),
  324.        );
  325.    }
  326.  
  327.    fn inst_decrement_integer(&mut self, slot: LocalSlotIndex) {
  328.        let current_frame = self.call_stack.peek_mut();
  329.        current_frame.set_local(
  330.            slot,
  331.            Value::Integer(current_frame.get_local(slot).as_integer() - 1),
  332.        );
  333.    }
  334.  
  335.    fn inst_add_integer(&mut self) {
  336.        let rhs = self.operand_stack.pop().as_integer();
  337.        let lhs = self.operand_stack.pop().as_integer();
  338.        self.operand_stack.push(Value::Integer(lhs + rhs));
  339.    }
  340.  
  341.    fn inst_multiply_integer(&mut self) {
  342.        let rhs = self.operand_stack.pop().as_integer();
  343.        let lhs = self.operand_stack.pop().as_integer();
  344.        self.operand_stack.push(Value::Integer(lhs * rhs));
  345.    }
  346.  
  347.    fn inst_mod_integer(&mut self) {
  348.        let rhs = self.operand_stack.pop().as_integer();
  349.        let lhs = self.operand_stack.pop().as_integer();
  350.        self.operand_stack.push(Value::Integer(lhs % rhs));
  351.    }
  352.  
  353.    fn inst_compare_integer(&mut self) {
  354.        let rhs = self.operand_stack.pop().as_integer();
  355.        let lhs = self.operand_stack.pop().as_integer();
  356.        self.operand_stack.push(Value::Boolean(lhs == rhs));
  357.    }
  358.  
  359.    fn inst_integer_greater_than(&mut self) {
  360.        let rhs = self.operand_stack.pop().as_integer();
  361.        let lhs = self.operand_stack.pop().as_integer();
  362.        self.operand_stack.push(Value::Boolean(lhs > rhs));
  363.    }
  364.  
  365.    fn inst_integer_smaller_than(&mut self) {
  366.        let rhs = self.operand_stack.pop().as_integer();
  367.        let lhs = self.operand_stack.pop().as_integer();
  368.        self.operand_stack.push(Value::Boolean(lhs < rhs));
  369.    }
  370.  
  371.    fn inst_jump_if_integer_less_than(&mut self, location: Location) {
  372.        let rhs = self.operand_stack.pop().as_integer();
  373.        let lhs = self.operand_stack.pop().as_integer();
  374.        if lhs < rhs {
  375.            self.jump(location);
  376.        }
  377.    }
  378.  
  379.    fn inst_jump_if_integer_equals(&mut self, location: Location) {
  380.        let rhs = self.operand_stack.pop().as_integer();
  381.        let lhs = self.operand_stack.pop().as_integer();
  382.        if lhs == rhs {
  383.            self.jump(location);
  384.        }
  385.    }
  386.  
  387.    fn inst_compare_boolean(&mut self) {
  388.        let rhs = self.operand_stack.pop().as_boolean();
  389.        let lhs = self.operand_stack.pop().as_boolean();
  390.        self.operand_stack.push(Value::Boolean(lhs == rhs));
  391.    }
  392.  
  393.    fn inst_jump_if_true(&mut self, location: Location) {
  394.        let conditional = self.operand_stack.pop().as_boolean();
  395.        if conditional {
  396.            self.jump(location);
  397.        }
  398.    }
  399.  
  400.    fn inst_jump_if_false(&mut self, location: Location) {
  401.        let conditional = self.operand_stack.pop().as_boolean();
  402.        if !conditional {
  403.            self.jump(location);
  404.        }
  405.    }
  406.  
  407.    fn inst_jump(&mut self, location: Location) {
  408.        self.jump(location);
  409.    }
  410.  
  411.    fn inst_return(&mut self) {
  412.        let frame = self.call_stack.pop();
  413.        self.instruction_pointer = frame.return_address;
  414.    }
  415. }
  416.  
  417. //  vm_bench(42321, 20000000);
  418.  
  419. #[test]
  420. fn test_vm_loops() {
  421.    // locals:
  422.    //   0 (index): integer
  423.    //
  424.    // PushConstantInteger(0)                           // Initialize index with integer 0
  425.    // StoreInteger(0)                                  // ^
  426.    //
  427.    // loopBegin:
  428.    //   LoadInteger(0)                                 // Load integer index from stack
  429.    //   JumpIfIntegerLessThan(100000, "loopIteration") // If integer index less than 100000, jump to loopIteration
  430.    //
  431.    //   Jump("programEnd")                             // We didn't switch branch, so condition must be false; jump to end
  432.     //
  433.     // loopIteration:
  434.     //   IncrementInteger(0)                            // Increment integer index
  435.     //
  436.     //   Jump("loopBegin)                               // Jump back to loop condition
  437.     //
  438.     // programEnd:
  439.     //
  440.     // Return
  441.  
  442.     let mut virtual_machine = VirtualMachine::new();
  443.  
  444.     let mut main_fn = Function::new_with_locals("main", 1);
  445.     main_fn.push_instruction(Instruction::PushConstantInteger(0));
  446.     main_fn.push_instruction(Instruction::StoreInteger(0));
  447.     main_fn.label_str("loopBegin");
  448.     main_fn.push_instruction(Instruction::LoadInteger(0));
  449.     main_fn.push_instruction(Instruction::PushConstantInteger(100000));
  450.     main_fn.push_instruction(Instruction::JumpIfIntegerLessThan(Location::Label(
  451.         "loopIteration",
  452.     )));
  453.     main_fn.push_instruction(Instruction::Jump(Location::Label("programEnd")));
  454.     main_fn.label_str("loopIteration");
  455.     main_fn.push_instruction(Instruction::IncrementInteger(0));
  456.     main_fn.push_instruction(Instruction::Jump(Location::Label("loopBegin")));
  457.     main_fn.label_str("programEnd");
  458.     main_fn.push_instruction(Instruction::Return);
  459.  
  460.     virtual_machine.push_function(main_fn);
  461.  
  462.     virtual_machine.run();
  463.  
  464.     dbg!(virtual_machine.call_stack.pop().locals);
  465. }
  466.  
  467.  
  468. fn vm_bench(a: i64, b: i64) {
  469.     // locals:
  470.     //   0 a: integer
  471.     //   1 b: integer
  472.     //   2 index: integer
  473.     //   3 result: integer
  474.     //
  475.     // StoreInteger(0)                                  // Store first argument in a
  476.     // StoreInteger(1)                                  // Store second argument in b
  477.     // PushConstantInteger(1)                           // Load constant 1 onto stack
  478.     // StoreInteger(2)                                  // Store constant 1 in index
  479.     // LoadInteger(0)                                   // Load a onto stack
  480.     // LoadInteger(1)                                   // Load b onto stack
  481.     // ModInteger                                       // a % b
  482.     // StoreInteger(0)                                  // Store mod result in a
  483.     //
  484.     // loopBegin:
  485.     //   LoadInteger(2)                                 // Load integer index onto stack
  486.     //   LoadInteger(1)                                 // Load integer b onto stack
  487.     //   JumpIfIntegerLessThan("loopIteration")         // If integer index less than integer b, jump to loopIteration
  488.     //
  489.     //   Jump("didntFind")                              // We didn't switch branch, so condition must be false; jump to end
  490.     //
  491.     // loopIteration:
  492.     //   IncrementInteger(2)                            // Increment integer index
  493.     //   LoadInteger(0)                                 // Load a onto stack
  494.     //   LoadInteger(2)                                 // Load index onto stack
  495.     //   MultiplyInteger                                // Push result of a * index onto stack
  496.     //   LoadInteger(1)                                 // push b onto stack
  497.     //   ModInteger                                     // push (a * index) % b onto stack
  498.     //   StoreInteger(3)                                // Store result in result
  499.     //   LoadInteger(3)                                 // Load result back onto stack
  500.     //   PushConstantInteger(1)                         // Push integer 1 onto stack
  501.     //   JumpIfIntegerEquals("didFind")                  // Jump to didFind if result equals 1
  502.     //
  503.     //   Jump("loopBegin")                              // Jump back to loop condition
  504.     //
  505.     // didFind:
  506.     //   LoadInteger(3)                                 // Load result onto stack
  507.     //   Jump("end")                                    // Jump to end
  508.     //
  509.     // didntFind:
  510.     //   PushConstantInteger(-1)                        // Push -1 onto stack to indicate failure
  511.     //
  512.     // end:
  513.     //
  514.     // Return
  515.  
  516.     let mut virtual_machine = VirtualMachine::new();
  517.  
  518.     let mut main_fn = Function::new_with_locals("main", 4);
  519.  
  520.     main_fn.push_instruction(Instruction::StoreInteger(0));
  521.     main_fn.push_instruction(Instruction::StoreInteger(1));
  522.     main_fn.push_instruction(Instruction::PushConstantInteger(1));
  523.     main_fn.push_instruction(Instruction::StoreInteger(2));
  524.     main_fn.push_instruction(Instruction::LoadInteger(0));
  525.     main_fn.push_instruction(Instruction::LoadInteger(1));
  526.     main_fn.push_instruction(Instruction::ModInteger);
  527.     main_fn.push_instruction(Instruction::StoreInteger(0));
  528.  
  529.     main_fn.label_str("loopBegin");
  530.     main_fn.push_instruction(Instruction::LoadInteger(2));
  531.     main_fn.push_instruction(Instruction::LoadInteger(1));
  532.     main_fn.push_instruction(Instruction::JumpIfIntegerLessThan(Location::Label(
  533.         "loopIteration",
  534.     )));
  535.  
  536.     main_fn.push_instruction(Instruction::Jump(Location::Label("didntFind")));
  537.  
  538.     main_fn.label_str("loopIteration");
  539.     main_fn.push_instruction(Instruction::IncrementInteger(2));
  540.     main_fn.push_instruction(Instruction::LoadInteger(0));
  541.     main_fn.push_instruction(Instruction::LoadInteger(2));
  542.     main_fn.push_instruction(Instruction::MultiplyInteger);
  543.     main_fn.push_instruction(Instruction::LoadInteger(1));
  544.     main_fn.push_instruction(Instruction::ModInteger);
  545.     main_fn.push_instruction(Instruction::StoreInteger(3));
  546.     main_fn.push_instruction(Instruction::LoadInteger(3));
  547.     main_fn.push_instruction(Instruction::PushConstantInteger(1));
  548.     main_fn.push_instruction(Instruction::JumpIfIntegerEquals(Location::Label("didFind")));
  549.  
  550.     main_fn.push_instruction(Instruction::Jump(Location::Label("loopBegin")));
  551.  
  552.     main_fn.label_str("didFind");
  553.     main_fn.push_instruction(Instruction::LoadInteger(3));
  554.     main_fn.push_instruction(Instruction::Jump(Location::Label("end")));
  555.  
  556.     main_fn.label_str("didntFind");
  557.     main_fn.push_instruction(Instruction::PushConstantInteger(-1));
  558.  
  559.     main_fn.label_str("end");
  560.  
  561.     main_fn.push_instruction(Instruction::Return);
  562.  
  563.     virtual_machine.push_function(main_fn);
  564.  
  565.     virtual_machine.operand_stack.push(Value::Integer(b));
  566.     virtual_machine.operand_stack.push(Value::Integer(a));
  567.  
  568.     virtual_machine.run();
  569.  
  570.     println!(
  571.         "{}",
  572.         virtual_machine.call_stack.pop().get_local(2).as_integer()
  573.     );
  574. }
  575.  
  576. #[test]
  577. fn test_vm_bench() {
  578.     vm_bench(42, 2017);
  579. }
  580.  
Advertisement