Advertisement
Guest User

GAN PL issue

a guest
Oct 12th, 2023
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.19 KB | None | 0 0
  1. def training_step(self, batch, batch_idx, optimizer_idx=0):
  2. # batch is [audio_ref, audio_corrupted]
  3. cut_batch = [self.generator.cut_tensor(speech) for speech in batch]
  4.  
  5. reference_speech, corrupted_speech = cut_batch
  6.  
  7. enhanced_speech, decomposed_enhanced_speech = self.generator(corrupted_speech)
  8. decomposed_reference_speech = self.generator.pqmf.forward(
  9. reference_speech, "analysis"
  10. )
  11. enhanced_embeddings = self.discriminator(
  12. bands=decomposed_enhanced_speech[:, 1:, :].detach(), audio=enhanced_speech.detach()
  13. )
  14. reference_embeddings = self.discriminator(
  15. bands=decomposed_reference_speech[:, 1:, :].detach(), audio=reference_speech.detach()
  16. )
  17.  
  18. outs = {
  19. "reference": reference_speech,
  20. "corrupted": corrupted_speech,
  21. "enhanced": enhanced_speech,
  22. }
  23.  
  24. opt_g, opt_d = self.optimizers()
  25.  
  26. # train discriminator
  27.  
  28. self.toggle_optimizer(opt_d)
  29.  
  30.  
  31. # optimizer_idx =1
  32.  
  33. # if optimizer_idx == 1:
  34. # valid_loss
  35. adv_loss_valid = 0
  36. for scale in range(len(reference_embeddings)): # across embeddings
  37. certainties = reference_embeddings[scale][-1]
  38. adv_loss_valid += self.relu(1 - certainties).mean() # across time
  39. adv_loss_valid /= len(reference_embeddings)
  40.  
  41. # fake_loss
  42. adv_loss_fake = 0
  43. for scale in range(len(enhanced_embeddings)): # across embeddings
  44. certainties = enhanced_embeddings[scale][-1]
  45. adv_loss_fake += self.relu(1 + certainties).mean() # across time
  46. adv_loss_fake /= len(enhanced_embeddings)
  47.  
  48. # loss to backprop on
  49. dis_loss = adv_loss_valid + adv_loss_fake
  50.  
  51. # total_loss = ∑ losses
  52. outs.update({"loss": dis_loss})
  53.  
  54. opt_d.zero_grad()
  55. #self.manual_backward(dis_loss, retain_graph=True)
  56. self.manual_backward(dis_loss)
  57. opt_d.step()
  58.  
  59. self.untoggle_optimizer(opt_d)
  60.  
  61.  
  62. # train generator
  63. self.toggle_optimizer(opt_g)
  64. # optimizer_idx =0
  65. # if optimizer_idx == 0:
  66. # ftr_loss
  67. ftr_loss = 0
  68. for scale in range(len(reference_embeddings)): # across scales
  69. for layer in range(
  70. 1, len(reference_embeddings[scale]) - 1
  71. ): # across layers
  72. a = reference_embeddings[scale][layer]
  73. b = enhanced_embeddings[scale][layer]
  74. ftr_loss += self.l1(a, b) / (len(reference_embeddings[scale]) - 2)
  75. ftr_loss /= len(reference_embeddings)
  76.  
  77. # loss_adv_gen
  78. adv_loss = 0
  79. for scale in range(len(enhanced_embeddings)): # across embeddings
  80. certainties = enhanced_embeddings[scale][-1]
  81. adv_loss += self.relu(1 - certainties).mean() # across time
  82. adv_loss /= len(enhanced_embeddings)
  83.  
  84. gen_loss = adv_loss + 100 * ftr_loss
  85.  
  86. outs.update({"loss": gen_loss})
  87.  
  88. opt_g.zero_grad()
  89. self.manual_backward(gen_loss)
  90. opt_g.step()
  91.  
  92. self.untoggle_optimizer(opt_g)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement