Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # for training policy
- def train_one_epoch():
- # make some empty lists for logging.
- batch_obs = [] # for observations
- batch_acts = [] # for actions
- batch_weights = [] # for R(tau) weighting in policy gradient
- batch_rets = [] # for measuring episode returns
- batch_lens = [] # for measuring episode lengths
- # reset episode-specific variables
- obs = env.reset() # first obs comes from starting distribution
- done = False # signal from environment that episode is over
- ep_rews = [] # list for rewards accrued throughout ep
- # render first episode of each epoch
- finished_rendering_this_epoch = False
- # collect experience by acting in the environment with current policy
- while True:
- # rendering
- if not(finished_rendering_this_epoch):
- env.render()
- # save obs
- batch_obs.append(obs.copy())
- # act in the environment
- act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
- obs, rew, done, _ = env.step(act)
- # save action, reward
- batch_acts.append(act)
- ep_rews.append(rew)
- if done:
- # if episode is over, record info about episode
- ep_ret, ep_len = sum(ep_rews), len(ep_rews)
- batch_rets.append(ep_ret)
- batch_lens.append(ep_len)
- # the weight for each logprob(a|s) is R(tau)
- batch_weights += [ep_ret] * ep_len
- # reset episode-specific variables
- obs, done, ep_rews = env.reset(), False, []
- # won't render again this epoch
- finished_rendering_this_epoch = True
- # end experience loop if we have enough of it
- if len(batch_obs) > batch_size:
- break
- # take a single policy gradient update step
- batch_loss, _ = sess.run([loss, train_op],
- feed_dict={
- obs_ph: np.array(batch_obs),
- act_ph: np.array(batch_acts),
- weights_ph: np.array(batch_weights)
- })
- return batch_loss, batch_rets, batch_lens
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement