Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def training_step(self, batch, batch_idx, optimizer_idx=0):
- # batch is [audio_ref, audio_corrupted]
- cut_batch = [self.generator.cut_tensor(speech) for speech in batch]
- reference_speech, corrupted_speech = cut_batch
- enhanced_speech, decomposed_enhanced_speech = self.generator(corrupted_speech)
- decomposed_reference_speech = self.generator.pqmf.forward(
- reference_speech, "analysis"
- )
- enhanced_embeddings = self.discriminator(
- bands=decomposed_enhanced_speech[:, 1:, :].detach(), audio=enhanced_speech.detach()
- )
- reference_embeddings = self.discriminator(
- bands=decomposed_reference_speech[:, 1:, :].detach(), audio=reference_speech.detach()
- )
- outs = {
- "reference": reference_speech,
- "corrupted": corrupted_speech,
- "enhanced": enhanced_speech,
- }
- opt_g, opt_d = self.optimizers()
- # train discriminator
- self.toggle_optimizer(opt_d)
- # optimizer_idx =1
- # if optimizer_idx == 1:
- # valid_loss
- adv_loss_valid = 0
- for scale in range(len(reference_embeddings)): # across embeddings
- certainties = reference_embeddings[scale][-1]
- adv_loss_valid += self.relu(1 - certainties).mean() # across time
- adv_loss_valid /= len(reference_embeddings)
- # fake_loss
- adv_loss_fake = 0
- for scale in range(len(enhanced_embeddings)): # across embeddings
- certainties = enhanced_embeddings[scale][-1]
- adv_loss_fake += self.relu(1 + certainties).mean() # across time
- adv_loss_fake /= len(enhanced_embeddings)
- # loss to backprop on
- dis_loss = adv_loss_valid + adv_loss_fake
- # total_loss = ∑ losses
- outs.update({"loss": dis_loss})
- opt_d.zero_grad()
- #self.manual_backward(dis_loss, retain_graph=True)
- self.manual_backward(dis_loss)
- opt_d.step()
- self.untoggle_optimizer(opt_d)
- # train generator
- self.toggle_optimizer(opt_g)
- # optimizer_idx =0
- # if optimizer_idx == 0:
- # ftr_loss
- ftr_loss = 0
- for scale in range(len(reference_embeddings)): # across scales
- for layer in range(
- 1, len(reference_embeddings[scale]) - 1
- ): # across layers
- a = reference_embeddings[scale][layer]
- b = enhanced_embeddings[scale][layer]
- ftr_loss += self.l1(a, b) / (len(reference_embeddings[scale]) - 2)
- ftr_loss /= len(reference_embeddings)
- # loss_adv_gen
- adv_loss = 0
- for scale in range(len(enhanced_embeddings)): # across embeddings
- certainties = enhanced_embeddings[scale][-1]
- adv_loss += self.relu(1 - certainties).mean() # across time
- adv_loss /= len(enhanced_embeddings)
- gen_loss = adv_loss + 100 * ftr_loss
- outs.update({"loss": gen_loss})
- opt_g.zero_grad()
- self.manual_backward(gen_loss)
- opt_g.step()
- self.untoggle_optimizer(opt_g)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement