Advertisement
Chans

Untitled

May 22nd, 2019
513
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.42 KB | None | 0 0
  1.     # for training policy
  2.     def train_one_epoch():
  3.         # make some empty lists for logging.
  4.         batch_obs = []          # for observations
  5.         batch_acts = []         # for actions
  6.         batch_weights = []      # for R(tau) weighting in policy gradient
  7.         batch_rets = []         # for measuring episode returns
  8.         batch_lens = []         # for measuring episode lengths
  9.  
  10.         # reset episode-specific variables
  11.         obs = env.reset()       # first obs comes from starting distribution
  12.         done = False            # signal from environment that episode is over
  13.         ep_rews = []            # list for rewards accrued throughout ep
  14.  
  15.         # render first episode of each epoch
  16.         finished_rendering_this_epoch = False
  17.  
  18.         # collect experience by acting in the environment with current policy
  19.         while True:
  20.  
  21.             # rendering
  22.             if not(finished_rendering_this_epoch):
  23.                 env.render()
  24.  
  25.             # save obs
  26.             batch_obs.append(obs.copy())
  27.  
  28.             # act in the environment
  29.             act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
  30.             obs, rew, done, _ = env.step(act)
  31.  
  32.             # save action, reward
  33.             batch_acts.append(act)
  34.             ep_rews.append(rew)
  35.  
  36.             if done:
  37.                 # if episode is over, record info about episode
  38.                 ep_ret, ep_len = sum(ep_rews), len(ep_rews)
  39.                 batch_rets.append(ep_ret)
  40.                 batch_lens.append(ep_len)
  41.  
  42.                 # the weight for each logprob(a|s) is R(tau)
  43.                 batch_weights += [ep_ret] * ep_len
  44.  
  45.                 # reset episode-specific variables
  46.                 obs, done, ep_rews = env.reset(), False, []
  47.  
  48.                 # won't render again this epoch
  49.                 finished_rendering_this_epoch = True
  50.  
  51.                 # end experience loop if we have enough of it
  52.                 if len(batch_obs) > batch_size:
  53.                     break
  54.  
  55.         # take a single policy gradient update step
  56.         batch_loss, _ = sess.run([loss, train_op],
  57.                                  feed_dict={
  58.                                     obs_ph: np.array(batch_obs),
  59.                                     act_ph: np.array(batch_acts),
  60.                                     weights_ph: np.array(batch_weights)
  61.                                  })
  62.         return batch_loss, batch_rets, batch_lens
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement