SHARE
TWEET

Untitled

Chans May 22nd, 2019 72 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1.     # for training policy
  2.     def train_one_epoch():
  3.         # make some empty lists for logging.
  4.         batch_obs = []          # for observations
  5.         batch_acts = []         # for actions
  6.         batch_weights = []      # for R(tau) weighting in policy gradient
  7.         batch_rets = []         # for measuring episode returns
  8.         batch_lens = []         # for measuring episode lengths
  9.  
  10.         # reset episode-specific variables
  11.         obs = env.reset()       # first obs comes from starting distribution
  12.         done = False            # signal from environment that episode is over
  13.         ep_rews = []            # list for rewards accrued throughout ep
  14.  
  15.         # render first episode of each epoch
  16.         finished_rendering_this_epoch = False
  17.  
  18.         # collect experience by acting in the environment with current policy
  19.         while True:
  20.  
  21.             # rendering
  22.             if not(finished_rendering_this_epoch):
  23.                 env.render()
  24.  
  25.             # save obs
  26.             batch_obs.append(obs.copy())
  27.  
  28.             # act in the environment
  29.             act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
  30.             obs, rew, done, _ = env.step(act)
  31.  
  32.             # save action, reward
  33.             batch_acts.append(act)
  34.             ep_rews.append(rew)
  35.  
  36.             if done:
  37.                 # if episode is over, record info about episode
  38.                 ep_ret, ep_len = sum(ep_rews), len(ep_rews)
  39.                 batch_rets.append(ep_ret)
  40.                 batch_lens.append(ep_len)
  41.  
  42.                 # the weight for each logprob(a|s) is R(tau)
  43.                 batch_weights += [ep_ret] * ep_len
  44.  
  45.                 # reset episode-specific variables
  46.                 obs, done, ep_rews = env.reset(), False, []
  47.  
  48.                 # won't render again this epoch
  49.                 finished_rendering_this_epoch = True
  50.  
  51.                 # end experience loop if we have enough of it
  52.                 if len(batch_obs) > batch_size:
  53.                     break
  54.  
  55.         # take a single policy gradient update step
  56.         batch_loss, _ = sess.run([loss, train_op],
  57.                                  feed_dict={
  58.                                     obs_ph: np.array(batch_obs),
  59.                                     act_ph: np.array(batch_acts),
  60.                                     weights_ph: np.array(batch_weights)
  61.                                  })
  62.         return batch_loss, batch_rets, batch_lens
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top