Guest User

Untitled

a guest
Apr 23rd, 2018
59
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.70 KB | None | 0 0
  1. def train_nn(sess, epochs, batch_size, get_batches_fn, train_op,
  2. cross_entropy_loss, input_image,
  3. correct_label, keep_prob, learning_rate):
  4.  
  5. keep_prob_value = 0.5
  6. learning_rate_value = 0.001
  7. for epoch in range(epochs):
  8. # Create function to get batches
  9. total_loss = 0
  10. for X_batch, gt_batch in get_batches_fn(batch_size):
  11.  
  12. loss, _ = sess.run([cross_entropy_loss, train_op],
  13. feed_dict={input_image: X_batch, correct_label: gt_batch,
  14. keep_prob: keep_prob_value, learning_rate:learning_rate_value})
  15.  
  16. total_loss += loss;
  17.  
  18. print("EPOCH {} ...".format(epoch + 1))
  19. print("Loss = {:.3f}".format(total_loss))
  20. print()
Add Comment
Please, Sign In to add comment