Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train(num_iteration):
- global total_iterations
- for i in range(total_iterations,
- total_iterations + num_iteration):
- x_batch, y_true_batch, _, cls_batch = data.train.next_batch(batch_size)
- x_valid_batch, y_valid_batch, _, valid_cls_batch = data.valid.next_batch(batch_size)
- feed_dict_tr = {x: x_batch,
- y_true: y_true_batch}
- feed_dict_val = {x: x_valid_batch,
- y_true: y_valid_batch}
- session.run(optimizer, feed_dict=feed_dict_tr)
- if i % int(data.train.num_examples/batch_size) == 0:
- val_loss = session.run(cost, feed_dict=feed_dict_val)
- epoch = int(i / int(data.train.num_examples/batch_size))
- show_progress(epoch, feed_dict_tr, feed_dict_val, val_loss)
- saver.save(session, 'dogs-cats-model')
- total_iterations += num_iteration
Add Comment
Please, Sign In to add comment