Advertisement
Guest User

Untitled

a guest
Sep 22nd, 2019
153
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.44 KB | None | 0 0
  1. !pip install cmake 'gym[atari]' scipy
  2. !pip install git+https://github.com/jsh9/python-plot-utilities@v0.6.1
  3. import sys
  4. import gym
  5. import numpy as np
  6. import random
  7. import math
  8. from collections import defaultdict, deque
  9. import matplotlib.pyplot as plt
  10. #from plot_utils import plot_values
  11. #from plot_utils import plot_values
  12. %matplotlib inline
  13.  
  14. env = gym.make('CliffWalking-v0')
  15.  
  16. def update_Q(alpha, gamma, Q, state, action, reward, next_state=None):
  17. """Returns updated Q-value for the most recent experience."""
  18. current = Q[state][action] # estimate in Q-table (for current state, action pair)
  19. Qsa_next = np.max(Q[next_state]) if next_state is not None else 0 # value of next state
  20. target = reward + (gamma * Qsa_next) # construct TD target
  21. new_value = current + (alpha * (target - current)) # get updated value
  22. return new_value
  23.  
  24.  
  25. def epsilon_greedy(Q, state, nA, eps):
  26. """Selects epsilon-greedy action for supplied state.
  27.  
  28. Params
  29. ======
  30. Q (dictionary): action-value function
  31. state (int): current state
  32. nA (int): number actions in the environment
  33. eps (float): epsilon
  34. """
  35. if random.random() > eps: # select greedy action with probability epsilon
  36. return np.argmax(Q[state])
  37. else: # otherwise, select an action randomly
  38. return random.choice(np.arange(env.action_space.n))
  39.  
  40.  
  41.  
  42.  
  43. def q_learning(env, num_episodes, alpha, gamma=1.0, plot_every=100):
  44. """Q-Learning - TD Control
  45.  
  46. Params
  47. ======
  48. num_episodes (int): number of episodes to run the algorithm
  49. alpha (float): learning rate
  50. gamma (float): discount factor
  51. plot_every (int): number of episodes to use when calculating average score
  52. """
  53. nA = env.action_space.n # number of actions
  54. Q = defaultdict(lambda: np.zeros(nA)) # initialize empty dictionary of arrays
  55.  
  56. # monitor performance
  57. tmp_scores = deque(maxlen=plot_every) # deque for keeping track of scores
  58. avg_scores = deque(maxlen=num_episodes) # average scores over every plot_every episodes
  59.  
  60. for i_episode in range(1, num_episodes+1):
  61. # monitor progress
  62. if i_episode % 100 == 0:
  63. print("\rEpisode {}/{}".format(i_episode, num_episodes), end="")
  64. sys.stdout.flush()
  65. score = 0 # initialize score
  66. state = env.reset() # start episode
  67. eps = 1.0 / i_episode # set value of epsilon
  68.  
  69. while True:
  70. action = epsilon_greedy(Q, state, nA, eps) # epsilon-greedy action selection
  71. next_state, reward, done, info = env.step(action) # take action A, observe R, S'
  72. score += reward # add reward to agent's score
  73. Q[state][action] = update_Q(alpha, gamma, Q, \
  74. state, action, reward, next_state)
  75. state = next_state # S <- S'
  76. if done:
  77. tmp_scores.append(score) # append score
  78. break
  79. if (i_episode % plot_every == 0):
  80. avg_scores.append(np.mean(tmp_scores))
  81.  
  82. # plot performance
  83. plt.plot(np.linspace(0,num_episodes,len(avg_scores),endpoint=False), np.asarray(avg_scores))
  84. plt.xlabel('Episode Number')
  85. plt.ylabel('Average Reward (Over Next %d Episodes)' % plot_every)
  86. plt.show()
  87. # print best 100-episode performance
  88. print(('Best Average Reward over %d Episodes: ' % plot_every), np.max(avg_scores))
  89. return Q
  90.  
  91. # obtain the estimated optimal policy and corresponding action-value function
  92. # obtain the estimated optimal policy and corresponding action-value function
  93. Q_sarsamax = q_learning(env, 5000, .01)
  94.  
  95. # print the estimated optimal policy
  96. policy_sarsamax = np.array([np.argmax(Q_sarsamax[key]) if key in Q_sarsamax else -1 for key in np.arange(48)]).reshape((4,12))
  97. print("\nEstimated Optimal Policy (UP = 0, RIGHT = 1, DOWN = 2, LEFT = 3, N/A = -1):")
  98. print(policy_sarsamax)
  99.  
  100. # plot the estimated optimal state-value function
  101. somethingnew = np.array([np.max(Q_sarsamax[key]) if key in Q_sarsamax else 0 for key in np.arange(48)]).reshape(4,12)
  102. #print(somethingnew)
  103.  
  104. for line in somethingnew:
  105. print(*line)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement