Advertisement
Guest User

Untitled

a guest
Mar 30th, 2020
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.09 KB | None | 0 0
  1. def compute_loss(decoder_outputs, pad_target_seqs):
  2. batch_size = decoder_outputs.size(1)
  3.  
  4. n = 0
  5. loss = torch.zeros([1]).to(device)
  6. for i in range(batch_size):
  7. for j, l in enumerate(decoder_outputs[:, i, :]):
  8. t = pad_target_seqs[j, i]
  9. if t != padding_value:
  10. n += 1
  11. loss += criterion(l.view(1, -1), t)
  12. return loss.squeeze().to(device), n
  13.  
  14.  
  15. if not skip_training:
  16. num_epochs = 30
  17. encoder_optimizer = optim.Adam(encoder.parameters(),lr=0.001)
  18. decoder_optimizer = optim.Adam(decoder.parameters(),lr=0.001)
  19. #criterion = nn.NLLLoss(ignore_index=padding_value,reduction='sum')
  20. criterion = nn.NLLLoss(ignore_index=padding_value,reduction='mean')
  21. encoder.train()
  22. decoder.train()
  23.  
  24. n_words = 0
  25. running_loss = 0.0
  26. loss= 0
  27.  
  28.  
  29. for epoch in range(num_epochs):
  30. print(epoch)
  31. for i, batch in enumerate(trainloader):
  32. pad_input_seqs, input_seq_lengths, pad_target_seqs = batch
  33. batch_size = pad_input_seqs.size(1)
  34. pad_input_seqs, pad_target_seqs = pad_input_seqs.to(device), pad_target_seqs.to(device)
  35.  
  36. # initialize
  37. encoder_hidden = encoder.init_hidden(batch_size)
  38. encoder_optimizer.zero_grad()
  39. decoder_optimizer.zero_grad()
  40.  
  41. # Encode input sequence
  42. encoder_output, encoder_hidden = encoder(pad_input_seqs, input_seq_lengths, encoder_hidden)
  43.  
  44. # make sure decoder hidden state has correct dimensions
  45. decoder_hidden = encoder_hidden
  46.  
  47. # Decode using target sequence for teacher forcing
  48. teacher_forcing = True if random.random() < teacher_forcing_ratio else False
  49.  
  50.  
  51.  
  52. decoder_outputs, _ = decoder(decoder_hidden, pad_target_seqs, teacher_forcing=teacher_forcing)
  53.  
  54. decoder_outputs = decoder_outputs.permute(0,2,1)
  55. #print(decoder_outputs.shape)
  56. #print(pad_target_seqs.shape)
  57. loss = criterion (decoder_outputs,pad_target_seqs)
  58. n_words +=1
  59. #decoder_outputs = decoder_outputs
  60. #pad_target_seqs = pad_target_seqs.T.unsqueeze(0).T
  61.  
  62.  
  63.  
  64. #loss, n = compute_loss(decoder_outputs, pad_target_seqs)
  65. #n_words += n
  66.  
  67.  
  68.  
  69.  
  70.  
  71. # clip gradients to avoid gradient explosion
  72.  
  73. loss.backward()
  74. #loss.detach_()
  75. #.detach_()
  76.  
  77. #nn.utils.clip_grad_norm_(encoder.parameters(), 0.30)
  78. #nn.utils.clip_grad_norm_(decoder.parameters(), 0.30)
  79. # update model weights
  80. encoder_optimizer.step()
  81. decoder_optimizer.step()
  82.  
  83. counter = len(trainloader) * epoch + i
  84. # print statistics
  85. running_loss += loss.item()
  86.  
  87. if counter % 1 == 0 and counter > 0:
  88. print('[%d, %5d] loss: %.4f' % (epoch + 1, i, running_loss / n_words))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement