Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import sys
- def update_state_dict(state_dict, raw_line, num_states, num_actions):
- split = raw_line.split(' ')
- state = split[0]
- if int(state.replace('s', '')) > num_states:
- return
- reward = int(split[1])
- transitions = {}
- for i in range(2, len(split), 3):
- action = split[i].replace('(', '')
- if action == '':
- break
- if int(action.replace('a', '')) > num_actions:
- continue
- to_state = split[i + 1]
- probability = float(split[i + 2].replace(')', ''))
- if action in transitions:
- transitions[action][to_state] = probability
- else:
- transitions[action] = {to_state: probability}
- state_dict[state] = (reward, transitions)
- def init_dp(state_dict):
- init = [{}]
- for state in state_dict.keys():
- init[0][state] = (0, 0)
- return init
- def compute_max_action(action_results):
- max_action = (0, -100000)
- for result in action_results:
- if result[1] >= max_action[1]:
- max_action = result
- return max_action
- def sorted_states(list_of_states):
- states = []
- sorted_states = []
- for state in list_of_states:
- states.append(int(state.replace('s', '')))
- for state in sorted(states):
- sorted_states.append('s' + str(state))
- return sorted_states
- def compute_j(current_index, value_iteration_dp, state_dict, discount_factor):
- for state in state_dict.keys():
- state_num = int(state.replace('s', ''))
- reward = state_dict[state][0]
- transitions = state_dict[state][1]
- action_results = []
- for action, probabilities in transitions.items():
- summation = 0
- for other_state in state_dict.keys():
- probability = 0
- if other_state in probabilities:
- probability = probabilities[other_state]
- previous_j = value_iteration_dp[current_index - 1][other_state][1]
- summation += probability * previous_j
- action_results.append((action, reward + discount_factor * summation))
- max_action = compute_max_action(action_results)
- value_iteration_dp[current_index][state] = max_action
- def print_j(i, value_iteration_dp):
- print('After iteration ' + str(i) + ':')
- for state in sorted_states(value_iteration_dp[i].keys()):
- print('(' + state + ' ' + value_iteration_dp[i][state][0] + ' %.4f) ' % value_iteration_dp[i][state][1], end='')
- print()
- if len(sys.argv) != 5:
- print('The program must be run with the number of states, number of actions, the test file, and discount factor as inputs.')
- else:
- num_states = int(sys.argv[1])
- num_actions = int(sys.argv[2])
- test_file = open(sys.argv[3], 'r')
- discount_factor = float(sys.argv[4])
- state_dict = {}
- for line in test_file.read().splitlines():
- update_state_dict(state_dict, line, num_states, num_actions)
- value_iteration_dp = init_dp(state_dict)
- for i in range(1, 101):
- value_iteration_dp.append({})
- compute_j(i, value_iteration_dp, state_dict, discount_factor)
- print_j(i, value_iteration_dp)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement