Advertisement
Guest User

Untitled

a guest
Jul 27th, 2017
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.82 KB | None | 0 0
  1. import gym
  2. import time
  3. import random
  4.  
  5.  
  6. def get_states(dst_observation):
  7. deg = dst_observation[2] * 180 / 3.14 / 1.2
  8. v1 = dst_observation[1] * 5
  9. v2 = dst_observation[3] * 5
  10.  
  11. if deg > 10.0:
  12. deg = 10.0
  13. elif deg < -10.0:
  14. deg = -10.0
  15.  
  16. if v1 > 10.0:
  17. v1 = 10.0
  18. elif v1 < -10.0:
  19. v1 = -10.0
  20.  
  21. if v2 > 10.0:
  22. v2 = 10.0
  23. elif v2 < -10.0:
  24. v2 = -10.0
  25.  
  26. return int(deg + 10.0) * 21 * 21 \
  27. + int(v1 + 10.0) * 21 \
  28. + int(v2 + 10.0)
  29.  
  30. env = gym.make('CartPole-v0')
  31. Q = [[random.random(), random.random()] for _ in range(21 * 21 * 21)]
  32. alpha = 0.05 # learning rate
  33. gamma = 0.99 # saving rate
  34. epsilon = 0.1 # rate for epsilon greedy method
  35.  
  36. for episode in range(10001):
  37. observation = env.reset()
  38.  
  39. R = [0 for _ in range(200)]
  40. H = [[0, 0] for _ in range(200)]
  41. i_end = 0
  42.  
  43. for i in range(200):
  44. states = get_states(observation)
  45.  
  46. # epsilon greedy
  47. action = 0
  48. if random.random() < epsilon:
  49. action = random.randint(0, 1)
  50. else:
  51. action = Q[states].index(max(Q[states]))
  52. observation, reward, done, info = env.step(action)
  53.  
  54. # logging history
  55. H[i] = [states, action]
  56.  
  57. # calculate income
  58. rate = 1.0
  59. for j in reversed(range(0, i + 1)):
  60. R[j] += reward * rate
  61. rate *= gamma
  62.  
  63. if episode % 2000 == 0:
  64. env.render()
  65. time.sleep(1.0 / 20)
  66.  
  67. if done and not episode % 1000 == 0:
  68. i_end = i
  69. break
  70.  
  71. # update Q
  72. for i in range(i_end - 1):
  73. S_t = H[i][0]
  74. S_t_1 = H[i + 1][0]
  75. A_t = H[i][1]
  76. R_t_1 = R[i + 1]
  77. Q_max = max(Q[S_t_1])
  78. Q[S_t][A_t] += alpha * (gamma * Q_max - Q[S_t][A_t] + R_t_1)
  79.  
  80. print("episode: ", episode, i_end)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement