Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def post_body_init(self):
- '''Initializes the part of algorithm needing a body to exist first.'''
- self.body = self.agent.nanflat_body_a[0] # single-body algo
- # create the extra replay memory for SIL
- memory_name = self.memory_spec['sil_replay_name']
- MemoryClass = getattr(memory, memory_name)
- self.body.replay_memory = MemoryClass(self.memory_spec, self, self.body)
- self.init_algorithm_params()
- self.init_nets()
- logger.info(util.self_desc(self))
- # ...
- def sample(self):
- '''Modify the onpolicy sample to also append to replay'''
- batches = [body.memory.sample() for body in self.agent.nanflat_body_a]
- batch = util.concat_batches(batches)
- data_keys = self.body.replay_memory.data_keys
- for idx in range(len(batch['dones'])):
- tuples = [batch[k][idx] for k in data_keys]
- self.body.replay_memory.add_experience(*tuples)
- batch = util.to_torch_batch(batch, self.net.gpu)
- return batch
- def replay_sample(self):
- '''Samples a batch from memory'''
- batches = [body.replay_memory.sample() for body in self.agent.nanflat_body_a]
- batch = util.concat_batches(batches)
- batch = util.to_torch_batch(batch, self.net.gpu)
- assert not torch.isnan(batch['states']).any()
- return batch
Add Comment
Please, Sign In to add comment