Advertisement
Guest User

Untitled

a guest
Jun 17th, 2019
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.40 KB | None | 0 0
  1. # Train and evaluate CNN with Early Stopping procedure defined at the very top
  2. with tf.Session() as sess:
  3. init.run()
  4. for epoch in range(n_epochs):
  5. for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
  6. iteration += 1
  7. sess.run(training_op, feed_dict={X: X_batch, y: y_batch, training: True})
  8. if iteration % check_interval == 0:
  9. loss_val = loss.eval(feed_dict={X: X_valid, y: y_valid})
  10. if loss_val < best_loss_val:
  11. best_loss_val = loss_val
  12. checks_since_last_progress = 0
  13. best_model_params = get_model_params()
  14. else:
  15. checks_since_last_progress += 1
  16. acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
  17. acc_val = accuracy.eval(feed_dict={X: X_valid, y: y_valid})
  18. print("Epoch {}, last batch accuracy: {:.4f}%, valid. accuracy: {:.4f}%, valid. best loss: {:.6f}".format(
  19. epoch, acc_batch * 100, acc_val * 100, best_loss_val))
  20. if checks_since_last_progress > max_checks_without_progress:
  21. print("Early stopping!")
  22. break
  23.  
  24. if best_model_params:
  25. restore_model_params(best_model_params)
  26. acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
  27. print("Final accuracy on test set:", acc_test)
  28. save_path = saver.save(sess, "./my_mnist_model")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement