Advertisement
Arham-4

MDP Value Iteration

Nov 27th, 2021
949
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.17 KB | None | 0 0
  1. import sys
  2.  
  3. def update_state_dict(state_dict, raw_line, num_states, num_actions):
  4.     split = raw_line.split(' ')
  5.     state = split[0]
  6.     if int(state.replace('s', '')) > num_states:
  7.         return
  8.     reward = int(split[1])
  9.     transitions = {}
  10.     for i in range(2, len(split), 3):
  11.         action = split[i].replace('(', '')
  12.         if action == '':
  13.             break
  14.         if int(action.replace('a', '')) > num_actions:
  15.             continue
  16.         to_state = split[i + 1]
  17.         probability = float(split[i + 2].replace(')', ''))
  18.         if action in transitions:
  19.             transitions[action][to_state] = probability
  20.         else:
  21.             transitions[action] = {to_state: probability}
  22.     state_dict[state] = (reward, transitions)
  23.  
  24. def init_dp(state_dict):
  25.     init = [{}]
  26.     for state in state_dict.keys():
  27.         init[0][state] = (0, 0)
  28.     return init
  29.    
  30. def compute_max_action(action_results):
  31.     max_action = (0, -100000)
  32.     for result in action_results:
  33.         if result[1] >= max_action[1]:
  34.             max_action = result
  35.     return max_action
  36.  
  37. def sorted_states(list_of_states):
  38.     states = []
  39.     sorted_states = []
  40.     for state in list_of_states:
  41.         states.append(int(state.replace('s', '')))
  42.     for state in sorted(states):
  43.         sorted_states.append('s' + str(state))
  44.     return sorted_states
  45.  
  46. def compute_j(current_index, value_iteration_dp, state_dict, discount_factor):
  47.     for state in state_dict.keys():
  48.         state_num = int(state.replace('s', ''))
  49.         reward = state_dict[state][0]
  50.         transitions = state_dict[state][1]
  51.         action_results = []
  52.         for action, probabilities in transitions.items():
  53.             summation = 0
  54.             for other_state in state_dict.keys():
  55.                 probability = 0
  56.                 if other_state in probabilities:
  57.                     probability = probabilities[other_state]
  58.                 previous_j = value_iteration_dp[current_index - 1][other_state][1]
  59.                 summation += probability * previous_j
  60.             action_results.append((action, reward + discount_factor * summation))
  61.         max_action = compute_max_action(action_results)
  62.         value_iteration_dp[current_index][state] = max_action
  63.    
  64. def print_j(i, value_iteration_dp):
  65.     print('After iteration ' + str(i) + ':')
  66.     for state in sorted_states(value_iteration_dp[i].keys()):
  67.         print('(' + state + ' ' + value_iteration_dp[i][state][0] + ' %.4f) ' % value_iteration_dp[i][state][1], end='')
  68.     print()
  69.  
  70. if len(sys.argv) != 5:
  71.     print('The program must be run with the number of states, number of actions, the test file, and discount factor as inputs.')
  72. else:
  73.     num_states = int(sys.argv[1])
  74.     num_actions = int(sys.argv[2])
  75.     test_file = open(sys.argv[3], 'r')
  76.     discount_factor = float(sys.argv[4])
  77.     state_dict = {}
  78.     for line in test_file.read().splitlines():
  79.         update_state_dict(state_dict, line, num_states, num_actions)
  80.     value_iteration_dp = init_dp(state_dict)
  81.     for i in range(1, 101):
  82.         value_iteration_dp.append({})
  83.         compute_j(i, value_iteration_dp, state_dict, discount_factor)
  84.         print_j(i, value_iteration_dp)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement