Advertisement
Guest User

Untitled

a guest
Jan 24th, 2017
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.79 KB | None | 0 0
  1. #!/usr/local/bin/python
  2. """
  3. Q-learning with value fucntion approximation
  4. """
  5.  
  6. import argparse
  7. import numpy as np
  8. import matplotlib
  9. from matplotlib import pyplot as plt
  10. from mpl_toolkits.mplot3d import Axes3D
  11. from collections import defaultdict
  12. import gym
  13. from gym import wrappers
  14. import pdb
  15.  
  16. from sklearn.pipeline import Pipeline, FeatureUnion
  17. from sklearn.preprocessing import StandardScaler
  18. from sklearn.linear_model import SGDRegressor
  19. from sklearn.kernel_approximation import RBFSampler
  20.  
  21.  
  22. EXP_NAME_PREFIX = 'exp/q_learning_vfa'
  23. API_KEY = 'sk_ARsYZ2eRsGoeANVhUgrQ'
  24. ENVS = {
  25. 'mountaincar': 'MountainCar-v0', # --env mountaincar --gamma 0.99 --eps 0.3 --goal -110 --upload --max_episodes 10000 --eps_schedule 500
  26. }
  27.  
  28. class FeaturesMaker(object):
  29. def __init__(self):
  30. self.pipeline = Pipeline([
  31. ('scale', StandardScaler()),
  32. ('rbf', FeatureUnion([
  33. ('rbf5', RBFSampler(gamma=5.0, n_components=100)),
  34. ('rbf2', RBFSampler(gamma=2.0, n_components=100)),
  35. ('rbf1', RBFSampler(gamma=1.0, n_components=100)),
  36. ('rbf05', RBFSampler(gamma=0.5, n_components=100)),
  37. ]))
  38. ])
  39.  
  40. def fit(self, X):
  41. return self.pipeline.fit(X)
  42.  
  43. def transform(self, X):
  44. return self.pipeline.transform(X)
  45.  
  46.  
  47. class ValueFunction(object):
  48. def __init__(self, F, nA):
  49. self.F = F
  50. self.models = [SGDRegressor(learning_rate='constant') for _ in xrange(nA)]
  51.  
  52. def predict(self, s):
  53. f = self.F.transform([s])
  54. return np.array([m.predict(f)[0] for m in self.models])
  55.  
  56. def update(self, s, a, t):
  57. f = self.F.transform([s])
  58. self.models[a].partial_fit(f, [t])
  59.  
  60.  
  61. def q_learning(env, V, max_episodes, gamma, eps, eps_schedule, goal):
  62. nA = env.action_space.n
  63.  
  64. # init
  65. for a in xrange(nA):
  66. V.update(env.observation_space.sample(), a, 0)
  67.  
  68. P = np.zeros(nA, np.float32)
  69.  
  70. tR = np.ones(100, np.float32) * (-1000)
  71. for e in xrange(max_episodes):
  72. if e % eps_schedule == 0 and e > 0:
  73. eps /= 2
  74.  
  75. s = env.reset()
  76. done = False
  77. tR[e % tR.size] = 0.
  78. nit = 0
  79. while not done:
  80. nit += 1
  81. P.fill(eps / nA)
  82. P[np.argmax(V.predict(s))] += 1 - eps
  83. a = np.random.choice(xrange(nA), p=P)
  84. ns, r, done, _ = env.step(a)
  85. t = r + gamma * np.max(V.predict(ns))
  86. V.update(s, a, t)
  87. s = ns
  88. tR[e % tR.size] += r
  89.  
  90. print 'episode %d, iterations %d, average reward: %.3f' % (e, nit, np.mean(tR))
  91. if np.mean(tR) > goal:
  92. return e
  93.  
  94. return max_episodes
  95.  
  96.  
  97. def main():
  98. parser = argparse.ArgumentParser(description='Q-learning with VF approximation')
  99. parser.add_argument('--env', choices=ENVS.keys())
  100. parser.add_argument('--max_episodes', type=int, default=10000)
  101. parser.add_argument('--gamma', type=float, default=1.0)
  102. parser.add_argument('--eps', type=float, default=0.0)
  103. parser.add_argument('--eps_schedule', type=int, default=10000)
  104. parser.add_argument('--goal', type=float, default=1.0)
  105. parser.add_argument('--upload', action='store_true', default=False)
  106. args = parser.parse_args()
  107.  
  108. exp_name = '%s_%s' % (EXP_NAME_PREFIX, args.env)
  109.  
  110. env = gym.make(ENVS[args.env])
  111. env.seed(0)
  112. np.random.seed(0)
  113. if args.upload:
  114. env = wrappers.Monitor(env, exp_name, force=True)
  115.  
  116. F = FeaturesMaker()
  117. X = np.array([env.observation_space.sample() for _ in xrange(10000)])
  118. F.fit(X)
  119. V = ValueFunction(F, env.action_space.n)
  120.  
  121. res = q_learning(env, V, args.max_episodes, args.gamma,
  122. args.eps, args.eps_schedule, args.goal)
  123. print 'result -> %d' % res
  124.  
  125. env.close()
  126. if args.upload:
  127. gym.upload(exp_name, api_key=API_KEY)
  128.  
  129.  
  130. if __name__ == '__main__':
  131. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement