Guest User

Untitled

a guest
Jun 23rd, 2018
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.33 KB | None | 0 0
  1. def post_body_init(self):
  2. '''Initializes the part of algorithm needing a body to exist first.'''
  3. self.body = self.agent.nanflat_body_a[0] # single-body algo
  4. # create the extra replay memory for SIL
  5. memory_name = self.memory_spec['sil_replay_name']
  6. MemoryClass = getattr(memory, memory_name)
  7. self.body.replay_memory = MemoryClass(self.memory_spec, self, self.body)
  8. self.init_algorithm_params()
  9. self.init_nets()
  10. logger.info(util.self_desc(self))
  11.  
  12. # ...
  13. def sample(self):
  14. '''Modify the onpolicy sample to also append to replay'''
  15. batches = [body.memory.sample() for body in self.agent.nanflat_body_a]
  16. batch = util.concat_batches(batches)
  17. data_keys = self.body.replay_memory.data_keys
  18. for idx in range(len(batch['dones'])):
  19. tuples = [batch[k][idx] for k in data_keys]
  20. self.body.replay_memory.add_experience(*tuples)
  21. batch = util.to_torch_batch(batch, self.net.gpu)
  22. return batch
  23.  
  24. def replay_sample(self):
  25. '''Samples a batch from memory'''
  26. batches = [body.replay_memory.sample() for body in self.agent.nanflat_body_a]
  27. batch = util.concat_batches(batches)
  28. batch = util.to_torch_batch(batch, self.net.gpu)
  29. assert not torch.isnan(batch['states']).any()
  30. return batch
Add Comment
Please, Sign In to add comment