Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def write_log(callback, names, logs, batch_no):
- for name, value in zip(names, logs):
- summary = tf.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value
- summary_value.tag = name
- callback.writer.add_summary(summary, batch_no)
- callback.writer.flush()
- callback = callbacks[0]
- callback.set_model(self.keras_model)
- train_names = ['train_loss', 'train_mae']
- val_names = ['val_loss', 'val_mae']
- for epoch in range(epochs):
- print("epoch: {}".format(epoch))
- for step in range(self.config.STEPS_PER_EPOCH):
- # print("step: {}".format(step))
- if step % int(self.config.STEPS_PER_EPOCH / 100.) == 0:
- print('.', end='', flush=True)
- train_inputs = next(train_generator)
- logs = self.keras_model.train_on_batch(train_inputs[0], train_inputs[1])
- write_log(callback, train_names, logs, step + epoch * self.config.STEPS_PER_EPOCH)
- if step % 50 == 0:
- val_inputs = next(val_generator)
- logs = self.keras_model.train_on_batch(val_inputs[0], val_inputs[1])
- write_log(callback, val_names, logs, step + epoch * self.config.STEPS_PER_EPOCH)
- print('') # newline
- self.keras_model.save(os.path.join(self.log_dir, 'tiny_{}.h5'.format(epoch)))
- self.epoch = max(self.epoch, epochs)
Add Comment
Please, Sign In to add comment