Advertisement
Guest User

Untitled

a guest
Oct 15th, 2019
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.04 KB | None | 0 0
  1. epochs = 60
  2.  
  3. def loss(actual, predicted):
  4. crossentropy_loss = tf.losses.categorical_crossentropy(actual, predicted)
  5. average_loss = tf.reduce_mean(crossentropy_loss)
  6. return average_loss
  7.  
  8.  
  9. def train(train_dataset, validation_dataset, epochs, learning_rate=1e-1, momentum=9e-1, decay=1e-6):
  10.  
  11. optimizer = tf.optimizers.SGD(learning_rate=learning_rate=,
  12. momentum=momentum,
  13. decay=decay)
  14.  
  15. for epoch in range(epochs):
  16. epoch_loss = []
  17. train_accuracy = []
  18. validation_accuracy = []
  19. for train_batch, validation_batch in zip(train_dataset, validation_dataset):
  20.  
  21. train_batch_features, train_batch_labels = train_batch
  22. validation_batch_features, validation_batch_labels = validation_batch
  23.  
  24. with tf.GradientTape() as tape:
  25. predictions = model(train_batch_features)
  26. train_loss = loss(train_batch_labels, predictions)
  27. gradients = tape.gradient(train_loss, model.trainable_variables)
  28. optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  29.  
  30. epoch_loss.append(train_loss)
  31.  
  32. accuracy = tf.metrics.Accuracy()
  33. accuracy(tf.argmax(train_batch_labels, 1),
  34. tf.argmax(predictions, 1))
  35. train_accuracy.append(accuracy.result())
  36.  
  37. validation_predictions = model(validation_batch_features)
  38. accuracy = tf.metrics.Accuracy()
  39. accuracy(tf.argmax(validation_batch_labels, 1),
  40. tf.argmax(validation_predictions, 1))
  41. validation_accuracy.append(accuracy.result())
  42. epoch_loss = tf.reduce_mean(epoch_loss)
  43. train_accuracy = tf.reduce_mean(train_accuracy)
  44. validation_accuracy = tf.reduce_mean(validation_accuracy)
  45.  
  46. print('Epoch {} / {} : train loss = {}, train accuracy = {}, validation accuracy = {}'.format(epoch_loss,
  47. train_accuracy,
  48. validation_accuracy))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement