Advertisement
Guest User

Untitled

a guest
Mar 29th, 2020
196
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.95 KB | None | 0 0
  1. class A2C:
  2. def __init__(self, policy, optimizer, value_loss_coef=0.25,
  3. entropy_coef=0.01, max_grad_norm=0.5):
  4.  
  5. self.policy = policy
  6. self.optimizer = optimizer
  7. self.value_loss_coef = value_loss_coef
  8. self.entropy_coef = entropy_coef
  9. self.max_grad_norm = max_grad_norm
  10.  
  11. self.reward_history = []
  12. self.log_counter = 0
  13.  
  14. self.logger = {
  15. 'entropy': [],
  16. 'value_loss': [],
  17. 'average_reward': [],
  18. 'policy_loss': [],
  19. 'value_targets': [],
  20. 'value_predictions': [],
  21. 'gradient_norm': [],
  22. 'advantages': [],
  23. 'A2C_loss': []
  24. }
  25.  
  26. def _log(self, rewards=False, **kwargs):
  27.  
  28. for key, value in kwargs.items():
  29. self.logger[key].append(value)
  30.  
  31. if rewards:
  32. self.log_counter += 1
  33. if self.log_counter == 100:
  34. self.logger['average_reward'].append(np.mean(self.reward_history))
  35. self.reward_history = []
  36. self.log_counter = 0
  37.  
  38. def policy_loss(self, trajectory):
  39.  
  40. advantages = trajectory['value_targets'] - trajectory['values'].squeeze()
  41. actions = torch.tensor(trajectory['actions'], device=device).unsqueeze(-1)
  42. log_probs = torch.gather(trajectory['log_probs'], dim=1, index=actions)
  43.  
  44. loss = torch.mean(advantages.detach() * log_probs.squeeze())
  45.  
  46. self._log(**{
  47. 'advantages': advantages.cpu().data.numpy().mean(),
  48. 'policy_loss': loss.cpu().item()
  49. })
  50.  
  51. return loss
  52.  
  53. def value_loss(self, trajectory):
  54.  
  55. value_targets = trajectory['value_targets']
  56. values = trajectory['values'].squeeze()
  57. loss = torch.nn.MSELoss()(values, value_targets.detach())
  58.  
  59. self._log(**{
  60. 'value_targets': value_targets.cpu().data.numpy().mean(),
  61. 'value_predictions': values.cpu().data.numpy().mean(),
  62. 'value_loss': loss.cpu().item()
  63. })
  64.  
  65. return loss
  66.  
  67. def loss(self, trajectory):
  68.  
  69. entropy = -torch.sum(trajectory['log_probs'] * trajectory['probs'], dim=-1)
  70. entropy = torch.mean(entropy)
  71.  
  72. a2c_loss = -self.policy_loss(trajectory) +\
  73. self.value_loss_coef * self.value_loss(trajectory) -\
  74. self.entropy_coef * entropy
  75.  
  76. self._log(**{
  77. 'entropy': entropy.cpu().item(),
  78. 'A2C_loss': a2c_loss.cpu().item()
  79. })
  80.  
  81. return a2c_loss
  82.  
  83. def step(self, trajectory):
  84.  
  85. loss = self.loss(trajectory)
  86.  
  87. grad_norm = nn.utils.clip_grad_norm_(model.parameters(), self.max_grad_norm)
  88.  
  89. loss.backward()
  90.  
  91. opt.step()
  92. opt.zero_grad()
  93.  
  94. self.reward_history.extend(trajectory['rewards'])
  95. self._log(rewards=True, **{'gradient_norm': grad_norm})
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement