Advertisement
Guest User

Untitled

a guest
Sep 18th, 2019
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.83 KB | None | 0 0
  1. @tf.function
  2. def train_step(inp, targ, enc_hidden, s2i=s2i_train):
  3. loss = 0
  4.  
  5. with tf.GradientTape() as tape:
  6. enc_output, enc_hidden = encoder(inp, enc_hidden)
  7.  
  8. dec_hidden = enc_hidden
  9.  
  10. dec_input = tf.expand_dims([s2i['BOS']] * BATCH_SIZE, 1)
  11.  
  12. # Teacher forcing - feeding the target as the next input
  13. for t in range(1, targ.shape[1]):
  14. # passing enc_output to the decoder
  15. predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
  16.  
  17. loss += loss_function(targ[:, t], predictions)
  18.  
  19. # using teacher forcing
  20. dec_input = tf.expand_dims(targ[:, t], 1)
  21.  
  22. batch_loss = (loss / int(targ.shape[1]))
  23.  
  24. variables = encoder.trainable_variables + decoder.trainable_variables
  25.  
  26. gradients = tape.gradient(loss, variables)
  27.  
  28. optimizer.apply_gradients(zip(gradients, variables))
  29.  
  30. return batch_loss
  31.  
  32. EPOCHS = 20
  33. epoch_loss = []
  34. for epoch in range(EPOCHS):
  35. start = time.time()
  36.  
  37. enc_hidden = encoder.initialize_hidden_state()
  38. total_loss = 0
  39.  
  40. for (batch, (inp, targ)) in enumerate(dataset.take(N_BATCH)):
  41. batch_loss = train_step(inp, targ, enc_hidden)
  42. total_loss += batch_loss
  43.  
  44. if batch % 50 == 0:
  45. print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
  46. batch,
  47. batch_loss.numpy()))
  48. # saving (checkpoint) the model every 2 epochs
  49. if (epoch + 1) % 2 == 0:
  50. checkpoint.save(file_prefix = checkpoint_prefix)
  51.  
  52. epoch_loss.append(total_loss / N_BATCH)
  53. print('Epoch {} Loss {:.4f}'.format(epoch + 1,
  54. total_loss / N_BATCH))
  55. print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
  56.  
  57. plt.plot(range(EPOCHS), epoch_loss)
  58. plt.title('intent slot filling - training loss - seq2seq with attention')
  59. plt.xlabel('epoch')
  60. plt.ylabel('loss');
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement