SHARE
TWEET

Untitled

a guest Jul 24th, 2019 73 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def cem(n_iterations=500, max_t=1000, gamma=1.0, print_every=10, pop_size=50, elite_frac=0.2, sigma=0.5):
  2.     """PyTorch implementation of the cross-entropy method.
  3.        
  4.     Params
  5.     ======
  6.         n_iterations (int): maximum number of training iterations
  7.         max_t (int): maximum number of timesteps per episode
  8.         gamma (float): discount rate
  9.         print_every (int): how often to print average score (over last 100 episodes)
  10.         pop_size (int): size of population at each iteration
  11.         elite_frac (float): percentage of top performers to use in update
  12.         sigma (float): standard deviation of additive noise
  13.     """
  14.     n_elite=int(pop_size*elite_frac)
  15.  
  16.     scores_deque = deque(maxlen=100)
  17.     scores = []
  18.     best_weight = sigma*np.random.randn(agent.get_weights_dim())
  19.  
  20.     for i_iteration in range(1, n_iterations+1):
  21.         weights_pop = [best_weight + (sigma*np.random.randn(agent.get_weights_dim())) for i in range(pop_size)]
  22.         rewards = np.array([agent.evaluate(weights, gamma, max_t) for weights in weights_pop])
  23.  
  24.         elite_idxs = rewards.argsort()[-n_elite:]
  25.         elite_weights = [weights_pop[i] for i in elite_idxs]
  26.         best_weight = np.array(elite_weights).mean(axis=0)
  27.         reward = agent.evaluate(best_weight, gamma=1.0)
  28.         scores_deque.append(reward)
  29.         scores.append(reward)
  30.        
  31.         torch.save(agent.state_dict(), 'checkpoint.pth')
  32.        
  33.         if i_iteration % print_every == 0:
  34.             print('Episode {}\tAverage Score: {:.2f}'.format(i_iteration, np.mean(scores_deque)))
  35.  
  36.         if np.mean(scores_deque)>=90.0:
  37.             print('\nEnvironment solved in {:d} iterations!\tAverage Score: {:.2f}'.format(i_iteration-100, np.mean(scores_deque)))
  38.             break
  39.     return scores
  40.  
  41. scores = cem()
  42.  
  43. # plot the scores
  44. fig = plt.figure()
  45. ax = fig.add_subplot(111)
  46. plt.plot(np.arange(1, len(scores)+1), scores)
  47. plt.ylabel('Score')
  48. plt.xlabel('Episode #')
  49. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top