Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def compute_loss(decoder_outputs, pad_target_seqs):
- batch_size = decoder_outputs.size(1)
- n = 0
- loss = torch.zeros([1]).to(device)
- for i in range(batch_size):
- for j, l in enumerate(decoder_outputs[:, i, :]):
- t = pad_target_seqs[j, i]
- if t != padding_value:
- n += 1
- loss += criterion(l.view(1, -1), t)
- return loss.squeeze().to(device), n
- if not skip_training:
- num_epochs = 30
- encoder_optimizer = optim.Adam(encoder.parameters(),lr=0.001)
- decoder_optimizer = optim.Adam(decoder.parameters(),lr=0.001)
- #criterion = nn.NLLLoss(ignore_index=padding_value,reduction='sum')
- criterion = nn.NLLLoss(ignore_index=padding_value,reduction='mean')
- encoder.train()
- decoder.train()
- n_words = 0
- running_loss = 0.0
- loss= 0
- for epoch in range(num_epochs):
- print(epoch)
- for i, batch in enumerate(trainloader):
- pad_input_seqs, input_seq_lengths, pad_target_seqs = batch
- batch_size = pad_input_seqs.size(1)
- pad_input_seqs, pad_target_seqs = pad_input_seqs.to(device), pad_target_seqs.to(device)
- # initialize
- encoder_hidden = encoder.init_hidden(batch_size)
- encoder_optimizer.zero_grad()
- decoder_optimizer.zero_grad()
- # Encode input sequence
- encoder_output, encoder_hidden = encoder(pad_input_seqs, input_seq_lengths, encoder_hidden)
- # make sure decoder hidden state has correct dimensions
- decoder_hidden = encoder_hidden
- # Decode using target sequence for teacher forcing
- teacher_forcing = True if random.random() < teacher_forcing_ratio else False
- decoder_outputs, _ = decoder(decoder_hidden, pad_target_seqs, teacher_forcing=teacher_forcing)
- decoder_outputs = decoder_outputs.permute(0,2,1)
- #print(decoder_outputs.shape)
- #print(pad_target_seqs.shape)
- loss = criterion (decoder_outputs,pad_target_seqs)
- n_words +=1
- #decoder_outputs = decoder_outputs
- #pad_target_seqs = pad_target_seqs.T.unsqueeze(0).T
- #loss, n = compute_loss(decoder_outputs, pad_target_seqs)
- #n_words += n
- # clip gradients to avoid gradient explosion
- loss.backward()
- #loss.detach_()
- #.detach_()
- #nn.utils.clip_grad_norm_(encoder.parameters(), 0.30)
- #nn.utils.clip_grad_norm_(decoder.parameters(), 0.30)
- # update model weights
- encoder_optimizer.step()
- decoder_optimizer.step()
- counter = len(trainloader) * epoch + i
- # print statistics
- running_loss += loss.item()
- if counter % 1 == 0 and counter > 0:
- print('[%d, %5d] loss: %.4f' % (epoch + 1, i, running_loss / n_words))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement