Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import json
- class MDP(object):
- def __init__(self, gamma = 0.75):
- self.gamma = gamma
- self.__id = 0
- self.states = []
- def add_state(self, state):
- state.id = self.__id if state.id < 0 else state.id
- self.__id += 1
- self.states.append(state)
- return self
- def dump(self):
- data = {}
- states = data["states"] = []
- data["gamma"] = self.gamma
- #
- for state in self.states:
- state_dict = {}
- state_dict["id"] = state.id
- actions = state_dict["actions"] = []
- for action in state.actions:
- action_dict = {}
- action_dict["id"] = action.id
- transitions = action_dict["transitions"] = []
- for transition in action.transitions:
- trans_dict = {}
- trans_dict["id"] = transition.id
- trans_dict["probability"] = transition.prob
- trans_dict["reward"] = transition.reward
- trans_dict["to"] = transition.to
- transitions.append(trans_dict)
- actions.append(action_dict)
- states.append(state_dict)
- print json.dumps(data, sort_keys=True, indent=4, separators=(',', ': '))
- class MDPState:
- def __init__(self):
- self.id = 0
- self.__action_id = 0
- self.actions = []
- def add_action(self, action):
- action.id = self.__action_id
- self.actions.append(action)
- self.__action_id += 1
- return self
- class MPDAction:
- def __init__(self):
- self.id = 0
- self.__transition_id = 0
- self.transitions = []
- def add_transition(self, transition):
- transition.id = self.__transition_id
- self.transitions.append(transition)
- self.__transition_id += 1
- return self
- class MPDTransition:
- def __init__(self, prob, reward, to):
- self.id = 0
- self.prob = prob
- self.reward = reward
- self.to = to
- # State, State', Reward, Prob
- data = [[0, 2, 1, .5],
- [0, 1, 0, .5],
- [0, 0, 0, 1.0],
- [1, 1, 0, 1.0],
- [1, 0, 0, 1.0],
- [2, 5, 0, .5],
- [2, 3, 1, .5],
- [2, 2, 1, 1.0],
- [5, 0, 0, 1.],
- [5, 5, 0, 1.0],
- [3, 4, 0, .5],
- [3, 6, 1, .5],
- [3, 3, 1, 1.0],
- [4, 4, 0, 1.],
- [4, 6, 0, .2],
- [4, 5, 1, .8],
- [6, 7, 1, 1.0],
- [6, 6, 1, 1.0],
- [7, 7, 0, 1.],
- [7, 4, 0, 0.9],
- [7, 6, 0, 0.1]]
- states = []
- states_dict = {}
- actions = []
- for ary in data:
- s = ary[0]
- s2 = ary[1]
- same = s == s2
- if s not in states_dict:
- state = states_dict[s] = MDPState()
- state.id = s
- states.append(state)
- actions = [MPDAction()]
- state.add_action(actions[0])
- if len(actions) == 1 and same:
- actions.append(MPDAction())
- states_dict[s].add_action(actions[1])
- action_index = 1 if same else 0
- actions[action_index].add_transition(MPDTransition(ary[3], ary[2], s2))
- mdp = MDP()
- for state in states:
- mdp.add_state(state)
- mdp.dump()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement