1. def cem(n_iterations=500, max_t=1000, gamma=1.0, print_every=10, pop_size=50, elite_frac=0.2, sigma=0.5):
2.     """PyTorch implementation of the cross-entropy method.
3.
4.     Params
5.     ======
6.         n_iterations (int): maximum number of training iterations
7.         max_t (int): maximum number of timesteps per episode
8.         gamma (float): discount rate
9.         print_every (int): how often to print average score (over last 100 episodes)
10.         pop_size (int): size of population at each iteration
11.         elite_frac (float): percentage of top performers to use in update
12.         sigma (float): standard deviation of additive noise
13.     """
14.     n_elite=int(pop_size*elite_frac)
15.
16.     scores_deque = deque(maxlen=100)
17.     scores = []
18.     best_weight = sigma*np.random.randn(agent.get_weights_dim())
19.
20.     for i_iteration in range(1, n_iterations+1):
21.         weights_pop = [best_weight + (sigma*np.random.randn(agent.get_weights_dim())) for i in range(pop_size)]
22.         rewards = np.array([agent.evaluate(weights, gamma, max_t) for weights in weights_pop])
23.
24.         elite_idxs = rewards.argsort()[-n_elite:]
25.         elite_weights = [weights_pop[i] for i in elite_idxs]
26.         best_weight = np.array(elite_weights).mean(axis=0)
27.         reward = agent.evaluate(best_weight, gamma=1.0)
28.         scores_deque.append(reward)
29.         scores.append(reward)
30.
31.         torch.save(agent.state_dict(), 'checkpoint.pth')
32.
33.         if i_iteration % print_every == 0:
34.             print('Episode {}\tAverage Score: {:.2f}'.format(i_iteration, np.mean(scores_deque)))
35.
36.         if np.mean(scores_deque)>=90.0:
37.             print('\nEnvironment solved in {:d} iterations!\tAverage Score: {:.2f}'.format(i_iteration-100, np.mean(scores_deque)))
38.             break
39.     return scores
40.
41. scores = cem()
42.
43. # plot the scores
44. fig = plt.figure()
45. ax = fig.add_subplot(111)
46. plt.plot(np.arange(1, len(scores)+1), scores)
47. plt.ylabel('Score')
48. plt.xlabel('Episode #')
49. plt.show()
