Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- use super::stack::Stack;
- use std::collections::HashMap;
- use std::iter;
- use rand::Rng;
- #[derive(Debug, Copy, Clone)]
- pub enum Location<'a> {
- Label(&'a str),
- Address(usize),
- Offset(isize),
- }
- impl<'a> Location<'a> {
- fn as_label(&'a self) -> &'a str {
- match self {
- &Location::Label(label) => label,
- _ => panic!("Location isn't a label"),
- }
- }
- fn as_address(&self) -> usize {
- match self {
- &Location::Address(addr) => addr,
- _ => panic!("Location isn't an address"),
- }
- }
- fn as_offset(&self) -> isize {
- match self {
- &Location::Offset(offset) => offset,
- _ => panic!("Location isn't an offset"),
- }
- }
- }
- type InstructionPointer<'a> = (Location<'a>, usize);
- type LocalSlotIndex = usize;
- type FunctionIndex = usize;
- pub type Integer = i64;
- #[derive(Debug, Copy, Clone)]
- pub enum Instruction<'a> {
- CallFunction(Location<'a>),
- PushConstantInteger(Integer),
- StoreInteger(LocalSlotIndex),
- LoadInteger(LocalSlotIndex),
- AddInteger,
- MultiplyInteger,
- ModInteger,
- CompareInteger,
- IntegerGreaterThan,
- IntegerSmallerThan,
- IncrementInteger(LocalSlotIndex),
- DecrementInteger(LocalSlotIndex),
- JumpIfIntegerLessThan(Location<'a>),
- JumpIfIntegerEquals(Location<'a>),
- CompareBoolean,
- Jump(Location<'a>),
- JumpIfTrue(Location<'a>),
- JumpIfFalse(Location<'a>),
- Return,
- }
- impl<'a> Instruction<'a> {
- fn get_jump_location(&'a self) -> Option<&'a Location> {
- match self {
- Instruction::JumpIfIntegerLessThan(location) => Some(location),
- Instruction::JumpIfIntegerEquals(location) => Some(location),
- Instruction::Jump(location) => Some(location),
- Instruction::JumpIfTrue(location) => Some(location),
- Instruction::JumpIfFalse(location) => Some(location),
- _ => None,
- }
- }
- fn set_jump_location(&mut self, new_location: Location<'a>) {
- let location = match self {
- Instruction::JumpIfIntegerLessThan(location) => Some(location),
- Instruction::JumpIfIntegerEquals(location) => Some(location),
- Instruction::Jump(location) => Some(location),
- Instruction::JumpIfTrue(location) => Some(location),
- Instruction::JumpIfFalse(location) => Some(location),
- _ => None,
- }
- .unwrap();
- std::mem::replace(location, new_location);
- }
- }
- #[derive(Debug)]
- pub enum Value {
- Integer(Integer),
- Float(f32),
- Boolean(bool),
- }
- impl Value {
- fn as_integer(&self) -> Integer {
- match self {
- &Value::Integer(value) => value,
- _ => panic!("Value is not an integer"),
- }
- }
- fn as_boolean(&self) -> bool {
- match self {
- &Value::Boolean(value) => value,
- _ => panic!("Value is not an boolean"),
- }
- }
- }
- #[derive(Debug)]
- pub struct Function<'a> {
- name: &'a str,
- pub local_size: usize,
- instructions: Vec<Instruction<'a>>,
- labels: HashMap<String, usize>,
- }
- impl<'a> Function<'a> {
- pub fn new(name: &'a str) -> Self {
- Self::new_with_locals(name, 0)
- }
- pub fn new_with_locals(name: &'a str, local_size: usize) -> Self {
- Function {
- name,
- local_size,
- instructions: Vec::new(),
- labels: HashMap::new(),
- }
- }
- pub fn push_instruction(&mut self, instruction: Instruction<'a>) {
- self.instructions.push(instruction);
- }
- pub fn label(&mut self, label: String) {
- self.labels.insert(label, self.instructions.len());
- }
- pub fn label_str(&mut self, label: &str) {
- self.label(label.to_owned());
- }
- fn transform_jumps(&mut self) {
- for (index, instruction) in self.instructions.iter_mut().enumerate() {
- if let Some(location) = instruction.get_jump_location() {
- let jump_dest = self.labels[location.as_label()];
- let offset = (jump_dest as isize) - (index as isize + 1);
- instruction.set_jump_location(Location::Offset(offset));
- }
- }
- }
- }
- pub struct StackFrame<'a> {
- pub(crate) locals: Box<[Option<Value>]>,
- return_address: InstructionPointer<'a>,
- }
- impl<'a> StackFrame<'a> {
- fn new(local_size: usize, return_address: InstructionPointer<'a>) -> Self {
- let locals: Box<[Option<Value>]> = iter::repeat_with(|| None).take(local_size).collect();
- StackFrame {
- locals,
- return_address,
- }
- }
- fn get_local(&self, slot: LocalSlotIndex) -> &Value {
- self.locals[slot as usize]
- .as_ref()
- .expect("Variable is set to null")
- }
- fn set_local(&mut self, slot: LocalSlotIndex, value: Value) {
- self.locals[slot as usize].replace(value);
- }
- }
- pub struct VirtualMachine<'a> {
- pub(crate) functions: Vec<Function<'a>>,
- function_locations: HashMap<&'a str, usize>,
- instruction_pointer: InstructionPointer<'a>,
- pub(crate) call_stack: Stack<StackFrame<'a>>,
- pub(crate) operand_stack: Stack<Value>,
- }
- impl<'a> VirtualMachine<'a> {
- pub(crate) fn new() -> Self {
- Self::new_with_functions(Vec::new())
- }
- pub(crate) fn new_with_functions(functions: Vec<Function<'a>>) -> Self {
- VirtualMachine {
- functions,
- function_locations: HashMap::new(),
- instruction_pointer: InstructionPointer::from((Location::Address(0), 0)),
- call_stack: Stack::new(),
- operand_stack: Stack::new(),
- }
- }
- pub(crate) fn push_function(&mut self, function: Function<'a>) {
- let name = function.name;
- self.functions.push(function);
- self.function_locations
- .insert(name, self.functions.len() - 1);
- }
- pub(crate) fn run(&mut self) {
- self.functions
- .iter_mut()
- .for_each(|function| function.transform_jumps());
- for instruction in self
- .functions
- .iter_mut()
- .flat_map(|function| &mut function.instructions)
- {
- if let Instruction::CallFunction(location) = instruction {
- let address = self.function_locations[location.as_label()];
- std::mem::replace(location, Location::Address(address));
- }
- }
- let (main_index, main_fn) = &self
- .functions
- .iter()
- .enumerate()
- .find(|(_index, function)| function.name == "main")
- .expect("No entry point present (push main fn)");
- self.call_stack.push(StackFrame::new(
- main_fn.local_size,
- InstructionPointer::from((Location::Address(*main_index), main_fn.instructions.len())),
- ));
- loop {
- let function_index = self.instruction_pointer.0.as_address();
- let function = &self.functions[function_index];
- let instructions = &function.instructions;
- let instruction_count = instructions.len();
- let instruction_index = self.instruction_pointer.1;
- let instruction = &instructions[instruction_index];
- self.instruction_pointer.1 += 1;
- match instruction {
- Instruction::CallFunction(location) => self.inst_call_function(*location),
- Instruction::PushConstantInteger(val) => self.inst_load_constant_integer(*val),
- Instruction::StoreInteger(slot) => self.inst_store_integer(*slot),
- Instruction::LoadInteger(slot) => self.inst_load_integer(*slot),
- Instruction::AddInteger => self.inst_add_integer(),
- Instruction::MultiplyInteger => self.inst_multiply_integer(),
- Instruction::ModInteger => self.inst_mod_integer(),
- Instruction::CompareInteger => self.inst_compare_integer(),
- Instruction::IntegerGreaterThan => self.inst_integer_greater_than(),
- Instruction::IntegerSmallerThan => self.inst_integer_smaller_than(),
- Instruction::IncrementInteger(slot) => self.inst_increment_integer(*slot),
- Instruction::DecrementInteger(slot) => self.inst_decrement_integer(*slot),
- Instruction::JumpIfIntegerLessThan(location) => {
- self.inst_jump_if_integer_less_than(*location)
- }
- Instruction::JumpIfIntegerEquals(location) => {
- self.inst_jump_if_integer_equals(*location)
- }
- Instruction::CompareBoolean => self.inst_compare_boolean(),
- Instruction::Jump(location) => self.inst_jump(*location),
- Instruction::JumpIfTrue(location) => self.inst_jump_if_true(*location),
- Instruction::JumpIfFalse(location) => self.inst_jump_if_false(*location),
- Instruction::Return => self.inst_return(),
- }
- if self.call_stack.is_empty() {
- break;
- }
- }
- }
- fn jump(&mut self, location: Location) {
- let offset = location.as_offset();
- let current_pointer = self.instruction_pointer.1 as isize;
- self.instruction_pointer.1 = (current_pointer + offset) as usize;
- }
- fn inst_call_function(&mut self, location: Location<'a>) {
- let function = &self.functions[location.as_address()];
- self.call_stack.push(StackFrame::new(
- function.local_size,
- self.instruction_pointer.clone(),
- ));
- self.instruction_pointer = InstructionPointer::from((location, 0));
- }
- fn inst_load_constant_integer(&mut self, value: Integer) {
- self.operand_stack.push(Value::Integer(value));
- }
- fn inst_store_integer(&mut self, slot: LocalSlotIndex) {
- let current_frame = self.call_stack.peek_mut();
- let value = self.operand_stack.pop();
- current_frame.set_local(slot, value);
- }
- fn inst_load_integer(&mut self, slot: LocalSlotIndex) {
- let current_frame = self.call_stack.peek_mut();
- let value = current_frame.get_local(slot).as_integer();
- self.operand_stack.push(Value::Integer(value));
- }
- fn inst_increment_integer(&mut self, slot: LocalSlotIndex) {
- let current_frame = self.call_stack.peek_mut();
- current_frame.set_local(
- slot,
- Value::Integer(current_frame.get_local(slot).as_integer() + 1),
- );
- }
- fn inst_decrement_integer(&mut self, slot: LocalSlotIndex) {
- let current_frame = self.call_stack.peek_mut();
- current_frame.set_local(
- slot,
- Value::Integer(current_frame.get_local(slot).as_integer() - 1),
- );
- }
- fn inst_add_integer(&mut self) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- self.operand_stack.push(Value::Integer(lhs + rhs));
- }
- fn inst_multiply_integer(&mut self) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- self.operand_stack.push(Value::Integer(lhs * rhs));
- }
- fn inst_mod_integer(&mut self) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- self.operand_stack.push(Value::Integer(lhs % rhs));
- }
- fn inst_compare_integer(&mut self) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- self.operand_stack.push(Value::Boolean(lhs == rhs));
- }
- fn inst_integer_greater_than(&mut self) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- self.operand_stack.push(Value::Boolean(lhs > rhs));
- }
- fn inst_integer_smaller_than(&mut self) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- self.operand_stack.push(Value::Boolean(lhs < rhs));
- }
- fn inst_jump_if_integer_less_than(&mut self, location: Location) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- if lhs < rhs {
- self.jump(location);
- }
- }
- fn inst_jump_if_integer_equals(&mut self, location: Location) {
- let rhs = self.operand_stack.pop().as_integer();
- let lhs = self.operand_stack.pop().as_integer();
- if lhs == rhs {
- self.jump(location);
- }
- }
- fn inst_compare_boolean(&mut self) {
- let rhs = self.operand_stack.pop().as_boolean();
- let lhs = self.operand_stack.pop().as_boolean();
- self.operand_stack.push(Value::Boolean(lhs == rhs));
- }
- fn inst_jump_if_true(&mut self, location: Location) {
- let conditional = self.operand_stack.pop().as_boolean();
- if conditional {
- self.jump(location);
- }
- }
- fn inst_jump_if_false(&mut self, location: Location) {
- let conditional = self.operand_stack.pop().as_boolean();
- if !conditional {
- self.jump(location);
- }
- }
- fn inst_jump(&mut self, location: Location) {
- self.jump(location);
- }
- fn inst_return(&mut self) {
- let frame = self.call_stack.pop();
- self.instruction_pointer = frame.return_address;
- }
- }
- // vm_bench(42321, 20000000);
- #[test]
- fn test_vm_loops() {
- // locals:
- // 0 (index): integer
- //
- // PushConstantInteger(0) // Initialize index with integer 0
- // StoreInteger(0) // ^
- //
- // loopBegin:
- // LoadInteger(0) // Load integer index from stack
- // JumpIfIntegerLessThan(100000, "loopIteration") // If integer index less than 100000, jump to loopIteration
- //
- // Jump("programEnd") // We didn't switch branch, so condition must be false; jump to end
- //
- // loopIteration:
- // IncrementInteger(0) // Increment integer index
- //
- // Jump("loopBegin) // Jump back to loop condition
- //
- // programEnd:
- //
- // Return
- let mut virtual_machine = VirtualMachine::new();
- let mut main_fn = Function::new_with_locals("main", 1);
- main_fn.push_instruction(Instruction::PushConstantInteger(0));
- main_fn.push_instruction(Instruction::StoreInteger(0));
- main_fn.label_str("loopBegin");
- main_fn.push_instruction(Instruction::LoadInteger(0));
- main_fn.push_instruction(Instruction::PushConstantInteger(100000));
- main_fn.push_instruction(Instruction::JumpIfIntegerLessThan(Location::Label(
- "loopIteration",
- )));
- main_fn.push_instruction(Instruction::Jump(Location::Label("programEnd")));
- main_fn.label_str("loopIteration");
- main_fn.push_instruction(Instruction::IncrementInteger(0));
- main_fn.push_instruction(Instruction::Jump(Location::Label("loopBegin")));
- main_fn.label_str("programEnd");
- main_fn.push_instruction(Instruction::Return);
- virtual_machine.push_function(main_fn);
- virtual_machine.run();
- dbg!(virtual_machine.call_stack.pop().locals);
- }
- fn vm_bench(a: i64, b: i64) {
- // locals:
- // 0 a: integer
- // 1 b: integer
- // 2 index: integer
- // 3 result: integer
- //
- // StoreInteger(0) // Store first argument in a
- // StoreInteger(1) // Store second argument in b
- // PushConstantInteger(1) // Load constant 1 onto stack
- // StoreInteger(2) // Store constant 1 in index
- // LoadInteger(0) // Load a onto stack
- // LoadInteger(1) // Load b onto stack
- // ModInteger // a % b
- // StoreInteger(0) // Store mod result in a
- //
- // loopBegin:
- // LoadInteger(2) // Load integer index onto stack
- // LoadInteger(1) // Load integer b onto stack
- // JumpIfIntegerLessThan("loopIteration") // If integer index less than integer b, jump to loopIteration
- //
- // Jump("didntFind") // We didn't switch branch, so condition must be false; jump to end
- //
- // loopIteration:
- // IncrementInteger(2) // Increment integer index
- // LoadInteger(0) // Load a onto stack
- // LoadInteger(2) // Load index onto stack
- // MultiplyInteger // Push result of a * index onto stack
- // LoadInteger(1) // push b onto stack
- // ModInteger // push (a * index) % b onto stack
- // StoreInteger(3) // Store result in result
- // LoadInteger(3) // Load result back onto stack
- // PushConstantInteger(1) // Push integer 1 onto stack
- // JumpIfIntegerEquals("didFind") // Jump to didFind if result equals 1
- //
- // Jump("loopBegin") // Jump back to loop condition
- //
- // didFind:
- // LoadInteger(3) // Load result onto stack
- // Jump("end") // Jump to end
- //
- // didntFind:
- // PushConstantInteger(-1) // Push -1 onto stack to indicate failure
- //
- // end:
- //
- // Return
- let mut virtual_machine = VirtualMachine::new();
- let mut main_fn = Function::new_with_locals("main", 4);
- main_fn.push_instruction(Instruction::StoreInteger(0));
- main_fn.push_instruction(Instruction::StoreInteger(1));
- main_fn.push_instruction(Instruction::PushConstantInteger(1));
- main_fn.push_instruction(Instruction::StoreInteger(2));
- main_fn.push_instruction(Instruction::LoadInteger(0));
- main_fn.push_instruction(Instruction::LoadInteger(1));
- main_fn.push_instruction(Instruction::ModInteger);
- main_fn.push_instruction(Instruction::StoreInteger(0));
- main_fn.label_str("loopBegin");
- main_fn.push_instruction(Instruction::LoadInteger(2));
- main_fn.push_instruction(Instruction::LoadInteger(1));
- main_fn.push_instruction(Instruction::JumpIfIntegerLessThan(Location::Label(
- "loopIteration",
- )));
- main_fn.push_instruction(Instruction::Jump(Location::Label("didntFind")));
- main_fn.label_str("loopIteration");
- main_fn.push_instruction(Instruction::IncrementInteger(2));
- main_fn.push_instruction(Instruction::LoadInteger(0));
- main_fn.push_instruction(Instruction::LoadInteger(2));
- main_fn.push_instruction(Instruction::MultiplyInteger);
- main_fn.push_instruction(Instruction::LoadInteger(1));
- main_fn.push_instruction(Instruction::ModInteger);
- main_fn.push_instruction(Instruction::StoreInteger(3));
- main_fn.push_instruction(Instruction::LoadInteger(3));
- main_fn.push_instruction(Instruction::PushConstantInteger(1));
- main_fn.push_instruction(Instruction::JumpIfIntegerEquals(Location::Label("didFind")));
- main_fn.push_instruction(Instruction::Jump(Location::Label("loopBegin")));
- main_fn.label_str("didFind");
- main_fn.push_instruction(Instruction::LoadInteger(3));
- main_fn.push_instruction(Instruction::Jump(Location::Label("end")));
- main_fn.label_str("didntFind");
- main_fn.push_instruction(Instruction::PushConstantInteger(-1));
- main_fn.label_str("end");
- main_fn.push_instruction(Instruction::Return);
- virtual_machine.push_function(main_fn);
- virtual_machine.operand_stack.push(Value::Integer(b));
- virtual_machine.operand_stack.push(Value::Integer(a));
- virtual_machine.run();
- println!(
- "{}",
- virtual_machine.call_stack.pop().get_local(2).as_integer()
- );
- }
- #[test]
- fn test_vm_bench() {
- vm_bench(42, 2017);
- }
Advertisement