Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import gym
- import MDP
- env = gym.make('CartPole-v0')
- A = env.action_space.n;
- S = env.observation_space.shape[0] * 20 * 20 * 20 * 20;
- expBig = MDP.SparseExperience(S, A);
- modelBig = MDP.SparseRLModel(expBig, 0.9);
- expSmall = MDP.SparseExperience(S, A);
- modelSmall = MDP.SparseRLModel(expSmall, 0.9);
- def isStateBig(o, thresh):
- for i in range(len(o)):
- if o[i] > thresh or o[i] < -thresh:
- return True;
- return False;
- def observationToState(o, thresh):
- s = int(0);
- for i in range(len(o)):
- s *= 20;
- ox = (min(thresh, max(-thresh, o[i])) + thresh) * 20;
- s += int(ox // 20);
- return s;
- solverBig = MDP.PrioritizedSweepingSparseRLModel(modelBig);
- policyBig = MDP.QGreedyPolicy(solverBig.getQFunction());
- solverSmall = MDP.PrioritizedSweepingSparseRLModel(modelSmall);
- policySmall = MDP.QGreedyPolicy(solverSmall.getQFunction());
- threshBig = 1.0;
- threshSmall = 0.2;
- for i_episode in xrange(100):
- o = env.reset()
- for t in xrange(300):
- env.render()
- sBig = observationToState(o, threshBig);
- sSmall = observationToState(o, threshSmall)
- stateBig = isStateBig(o, threshSmall);
- if stateBig:
- a = policyBig.sampleAction(sBig);
- else:
- a = policySmall.sampleAction(sSmall);
- o1, rew, done, info = env.step(a);
- s1Big = observationToState(o1, threshBig);
- s1Small = observationToState(o1, threshSmall)
- state1Big = isStateBig(o1, threshSmall);
- if done:
- env.render()
- print o1;
- if state1Big == False or t >= 199:
- rew = 10;
- else:
- rew = -10;
- if stateBig == False:
- expSmall.record(sSmall, a, s1Small, rew);
- modelSmall.sync(sSmall, a, s1Small);
- solverSmall.stepUpdateQ(sSmall, a);
- solverSmall.batchUpdateQ();
- expBig.record(sBig, a, s1Big, rew);
- modelBig.sync(sBig, a, s1Big);
- solverBig.stepUpdateQ(sBig, a);
- solverBig.batchUpdateQ();
- print state1Big, sBig, a, s1Big, rew
- print "Episode {} finished after {} timesteps".format(i_episode, t+1)
- break
- if stateBig == False and state1Big:
- rew = -10;
- elif stateBig == True and state1Big == False:
- rew = 10;
- if stateBig == False:
- expSmall.record(sSmall, a, s1Small, rew);
- modelSmall.sync(sSmall, a, s1Small);
- solverSmall.stepUpdateQ(sSmall, a);
- solverSmall.batchUpdateQ();
- expBig.record(sBig, a, s1Big, rew);
- modelBig.sync(sBig, a, s1Big);
- solverBig.stepUpdateQ(sBig, a);
- solverBig.batchUpdateQ();
- o = o1;
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement