Guest User

Untitled

a guest
Jun 23rd, 2018
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.89 KB | None | 0 0
  1. def train_shared(self):
  2. '''
  3. Trains the network when the actor and critic share parameters
  4. '''
  5. if self.to_train == 1:
  6. # onpolicy a2c update
  7. a2c_loss = super(SIL, self).train_shared()
  8. # offpolicy sil update with random minibatch
  9. total_sil_loss = torch.tensor(0.0)
  10. for _ in range(self.training_epoch):
  11. batch = self.replay_sample()
  12. sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch)
  13. sil_loss = self.policy_loss_coef * sil_policy_loss + self.val_loss_coef * sil_val_loss
  14. self.net.training_step(loss=sil_loss)
  15. total_sil_loss += sil_loss
  16. sil_loss = total_sil_loss / self.training_epoch
  17. loss = a2c_loss + sil_loss
  18. self.last_loss = loss.item()
  19. return self.last_loss
  20.  
  21. def train_separate(self):
  22. '''
  23. Trains the network when the actor and critic are separate networks
  24. '''
  25. if self.to_train == 1:
  26. # onpolicy a2c update
  27. a2c_loss = super(SIL, self).train_separate()
  28. # offpolicy sil update with random minibatch
  29. total_sil_loss = torch.tensor(0.0)
  30. for _ in range(self.training_epoch):
  31. batch = self.replay_sample()
  32. sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch)
  33. sil_policy_loss = self.policy_loss_coef * sil_policy_loss
  34. sil_val_loss = self.val_loss_coef * sil_val_loss
  35. self.net.training_step(loss=sil_policy_loss, retain_graph=True)
  36. self.critic.training_step(loss=sil_val_loss)
  37. total_sil_loss += sil_policy_loss + sil_val_loss
  38. sil_loss = total_sil_loss / self.training_epoch
  39. loss = a2c_loss + sil_loss
  40. self.last_loss = loss.item()
  41. return self.last_loss
Add Comment
Please, Sign In to add comment