Guest User

Untitled

a guest
May 23rd, 2018
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.15 KB | None | 0 0
  1. import numpy as np
  2. from gridworld import GridworldEnv
  3.  
  4. env = GridworldEnv()
  5.  
  6. def policy_eval(policy, env, discount_factor=1.0, epsilon=0.00001):
  7. """
  8. Evaluate a policy given an environment and a full description of the environment's dynamics.
  9.  
  10. Args:
  11. policy: [S, A] shaped matrix representing the policy.
  12. env: OpenAI env. env.P represents the transition probabilities of the environment.
  13. env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).
  14. env.nS is a number of states in the environment.
  15. env.nA is a number of actions in the environment.
  16. theta: We stop evaluation once our value function change is less than theta for all states.
  17. discount_factor: Gamma discount factor.
  18.  
  19. Returns:
  20. Vector of length env.nS representing the value function.
  21. """
  22. # Start with a random (all 0) value function
  23. V = np.zeros(env.nS)
  24.  
  25. while True:
  26.  
  27. #old value function
  28. V_old = np.zeros(env.nS)
  29. #stopping condition
  30. delta = 0
  31.  
  32. #loop over state space
  33. for s in range(env.nS):
  34.  
  35. #To accumelate bellman expectation eqn
  36. Q = 0
  37. #get probability distribution over actions
  38. action_probs = policy[s]
  39.  
  40. #loop over possible actions
  41. for a in range(env.nA):
  42.  
  43. #get transitions
  44. [(prob, next_state, reward, done)] = env.P[s][a]
  45. #apply bellman expectatoin eqn
  46. Q += action_probs[a] * (reward + discount_factor * V[next_state])
  47.  
  48. #get the biggest difference over state space
  49. delta = max(delta, abs(Q - V[s]))
  50.  
  51. #update state-value
  52. V_old[s] = Q
  53.  
  54. #the new value function
  55. V = V_old
  56.  
  57. #if true value function
  58. if(delta < epsilon):
  59. break
  60.  
  61. return np.array(V)
  62.  
  63.  
  64. random_policy = np.ones([env.nS, env.nA]) / env.nA
  65. v = policy_eval(random_policy, env)
  66.  
  67. expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])
  68. np.testing.assert_array_almost_equal(v, expected_v, decimal=2)
  69.  
  70. print(v)
  71. print(expected_v)
Add Comment
Please, Sign In to add comment