Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train_shared(self):
- '''
- Trains the network when the actor and critic share parameters
- '''
- if self.to_train == 1:
- # onpolicy a2c update
- a2c_loss = super(SIL, self).train_shared()
- # offpolicy sil update with random minibatch
- total_sil_loss = torch.tensor(0.0)
- for _ in range(self.training_epoch):
- batch = self.replay_sample()
- sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch)
- sil_loss = self.policy_loss_coef * sil_policy_loss + self.val_loss_coef * sil_val_loss
- self.net.training_step(loss=sil_loss)
- total_sil_loss += sil_loss
- sil_loss = total_sil_loss / self.training_epoch
- loss = a2c_loss + sil_loss
- self.last_loss = loss.item()
- return self.last_loss
- def train_separate(self):
- '''
- Trains the network when the actor and critic are separate networks
- '''
- if self.to_train == 1:
- # onpolicy a2c update
- a2c_loss = super(SIL, self).train_separate()
- # offpolicy sil update with random minibatch
- total_sil_loss = torch.tensor(0.0)
- for _ in range(self.training_epoch):
- batch = self.replay_sample()
- sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch)
- sil_policy_loss = self.policy_loss_coef * sil_policy_loss
- sil_val_loss = self.val_loss_coef * sil_val_loss
- self.net.training_step(loss=sil_policy_loss, retain_graph=True)
- self.critic.training_step(loss=sil_val_loss)
- total_sil_loss += sil_policy_loss + sil_val_loss
- sil_loss = total_sil_loss / self.training_epoch
- loss = a2c_loss + sil_loss
- self.last_loss = loss.item()
- return self.last_loss
Add Comment
Please, Sign In to add comment