DigitMagazine

Complete Training File for Devworx August 2018 Tensorflow

Jul 25th, 2018
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.95 KB | None | 0 0
  1.  
  2. def train(num_iteration):
  3. global total_iterations
  4.  
  5. for i in range(total_iterations,
  6. total_iterations + num_iteration):
  7.  
  8. x_batch, y_true_batch, _, cls_batch = data.train.next_batch(batch_size)
  9. x_valid_batch, y_valid_batch, _, valid_cls_batch = data.valid.next_batch(batch_size)
  10.  
  11.  
  12. feed_dict_tr = {x: x_batch,
  13. y_true: y_true_batch}
  14. feed_dict_val = {x: x_valid_batch,
  15. y_true: y_valid_batch}
  16.  
  17. session.run(optimizer, feed_dict=feed_dict_tr)
  18.  
  19. if i % int(data.train.num_examples/batch_size) == 0:
  20. val_loss = session.run(cost, feed_dict=feed_dict_val)
  21. epoch = int(i / int(data.train.num_examples/batch_size))
  22.  
  23. show_progress(epoch, feed_dict_tr, feed_dict_val, val_loss)
  24. saver.save(session, 'n-shapes-model')
  25.  
  26.  
  27. total_iterations += num_iteration
Add Comment
Please, Sign In to add comment