daily pastebin goal
53%
SHARE
TWEET

Untitled

a guest Feb 23rd, 2019 66 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # Set up training process.
  2. from collections import deque
  3.  
  4. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  5.  
  6. agent_a2c = A2CModel().to(device)
  7. optimizer = optim.Adam(agent_a2c.parameters(), lr=0.00015)
  8.  
  9. env_info = env.reset(train_mode=True)[brain_name]
  10. states = env_info.vector_observations
  11. init_states = states
  12.  
  13. n_episodes = 1
  14. n_steps = 10
  15. episode_end = False
  16. a2c_ep_rewards_list = []
  17. ep_rewards_deque = deque([0], maxlen=100) # initialize with 0
  18. ep_rewards = 0
  19.  
  20. while True:
  21.     batch_s, batch_a, batch_v_t, accu_rewards, init_states, episode_end = collect_trajectories(
  22.         agent_a2c, env, brain_name, init_states, episode_end, n_steps)
  23.  
  24.     loss, mus, stds = learn(batch_s, batch_a, batch_v_t, agent_a2c, optimizer)
  25.     ep_rewards += accu_rewards
  26.     print('\rEpisode {:>4}\tEpisodic Score {:>7.3f}\tLoss {:>12.6f}'.format(
  27.         n_episodes, np.mean(ep_rewards_deque), float(loss)), end="")
  28.  
  29.     if episode_end == True:
  30.         if n_episodes % 100 == 0:
  31.             print('\rEpisode {:>4}\tEpisodic Score {:>7.3f}\tLoss {:>12.6f}'.format(
  32.                 n_episodes, np.mean(ep_rewards_deque), float(loss)))
  33.  
  34.         if np.mean(ep_rewards_deque) >= 34:
  35.             break
  36.         a2c_ep_rewards_list.append(ep_rewards/num_agents)
  37.         ep_rewards_deque.append(ep_rewards/num_agents)
  38.         ep_rewards = 0
  39.         n_episodes += 1
  40.         episode_end = False
  41.  
  42.  
  43. # save a2c model
  44. pth = './checkpoint/a2c_checkpoint.pth'
  45. torch.save(agent_a2c.state_dict(), pth)
  46.  
  47. a2c_ep_rewards_list = np.array(a2c_ep_rewards_list)
  48. np.save('./data/a2c_ep_rewards_list.npy', a2c_ep_rewards_list)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top