Guest User

Untitled

a guest
Jan 23rd, 2019
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.52 KB | None | 0 0
  1. def write_log(callback, names, logs, batch_no):
  2. for name, value in zip(names, logs):
  3. summary = tf.Summary()
  4. summary_value = summary.value.add()
  5. summary_value.simple_value = value
  6. summary_value.tag = name
  7. callback.writer.add_summary(summary, batch_no)
  8. callback.writer.flush()
  9.  
  10. callback = callbacks[0]
  11. callback.set_model(self.keras_model)
  12. train_names = ['train_loss', 'train_mae']
  13. val_names = ['val_loss', 'val_mae']
  14.  
  15. for epoch in range(epochs):
  16. print("epoch: {}".format(epoch))
  17. for step in range(self.config.STEPS_PER_EPOCH):
  18. # print("step: {}".format(step))
  19. if step % int(self.config.STEPS_PER_EPOCH / 100.) == 0:
  20. print('.', end='', flush=True)
  21. train_inputs = next(train_generator)
  22. logs = self.keras_model.train_on_batch(train_inputs[0], train_inputs[1])
  23. write_log(callback, train_names, logs, step + epoch * self.config.STEPS_PER_EPOCH)
  24.  
  25. if step % 50 == 0:
  26. val_inputs = next(val_generator)
  27. logs = self.keras_model.train_on_batch(val_inputs[0], val_inputs[1])
  28. write_log(callback, val_names, logs, step + epoch * self.config.STEPS_PER_EPOCH)
  29. print('') # newline
  30. self.keras_model.save(os.path.join(self.log_dir, 'tiny_{}.h5'.format(epoch)))
  31.  
  32. self.epoch = max(self.epoch, epochs)
Add Comment
Please, Sign In to add comment