Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Train and evaluate CNN with Early Stopping procedure defined at the very top
- with tf.Session() as sess:
- init.run()
- for epoch in range(n_epochs):
- for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
- iteration += 1
- sess.run(training_op, feed_dict={X: X_batch, y: y_batch, training: True})
- if iteration % check_interval == 0:
- loss_val = loss.eval(feed_dict={X: X_valid, y: y_valid})
- if loss_val < best_loss_val:
- best_loss_val = loss_val
- checks_since_last_progress = 0
- best_model_params = get_model_params()
- else:
- checks_since_last_progress += 1
- acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
- acc_val = accuracy.eval(feed_dict={X: X_valid, y: y_valid})
- print("Epoch {}, last batch accuracy: {:.4f}%, valid. accuracy: {:.4f}%, valid. best loss: {:.6f}".format(
- epoch, acc_batch * 100, acc_val * 100, best_loss_val))
- if checks_since_last_progress > max_checks_without_progress:
- print("Early stopping!")
- break
- if best_model_params:
- restore_model_params(best_model_params)
- acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
- print("Final accuracy on test set:", acc_test)
- save_path = saver.save(sess, "./my_mnist_model")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement