Advertisement
Guest User

Untitled

a guest
Jul 24th, 2019
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.91 KB | None | 0 0
  1. def cem(n_iterations=500, max_t=1000, gamma=1.0, print_every=10, pop_size=50, elite_frac=0.2, sigma=0.5):
  2. """PyTorch implementation of the cross-entropy method.
  3.  
  4. Params
  5. ======
  6. n_iterations (int): maximum number of training iterations
  7. max_t (int): maximum number of timesteps per episode
  8. gamma (float): discount rate
  9. print_every (int): how often to print average score (over last 100 episodes)
  10. pop_size (int): size of population at each iteration
  11. elite_frac (float): percentage of top performers to use in update
  12. sigma (float): standard deviation of additive noise
  13. """
  14. n_elite=int(pop_size*elite_frac)
  15.  
  16. scores_deque = deque(maxlen=100)
  17. scores = []
  18. best_weight = sigma*np.random.randn(agent.get_weights_dim())
  19.  
  20. for i_iteration in range(1, n_iterations+1):
  21. weights_pop = [best_weight + (sigma*np.random.randn(agent.get_weights_dim())) for i in range(pop_size)]
  22. rewards = np.array([agent.evaluate(weights, gamma, max_t) for weights in weights_pop])
  23.  
  24. elite_idxs = rewards.argsort()[-n_elite:]
  25. elite_weights = [weights_pop[i] for i in elite_idxs]
  26. best_weight = np.array(elite_weights).mean(axis=0)
  27. reward = agent.evaluate(best_weight, gamma=1.0)
  28. scores_deque.append(reward)
  29. scores.append(reward)
  30.  
  31. torch.save(agent.state_dict(), 'checkpoint.pth')
  32.  
  33. if i_iteration % print_every == 0:
  34. print('Episode {}\tAverage Score: {:.2f}'.format(i_iteration, np.mean(scores_deque)))
  35.  
  36. if np.mean(scores_deque)>=90.0:
  37. print('\nEnvironment solved in {:d} iterations!\tAverage Score: {:.2f}'.format(i_iteration-100, np.mean(scores_deque)))
  38. break
  39. return scores
  40.  
  41. scores = cem()
  42.  
  43. # plot the scores
  44. fig = plt.figure()
  45. ax = fig.add_subplot(111)
  46. plt.plot(np.arange(1, len(scores)+1), scores)
  47. plt.ylabel('Score')
  48. plt.xlabel('Episode #')
  49. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement