Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train_epoch(flag=False, initial=True):
- diter = 5
- count = 0
- large_iter = 100
- if flag :
- final_iter = large_iter
- else:
- final_iter = diter
- run=0
- start_time = time.time()
- loss_val = [0,0]
- while run <= num_examples:
- for t in range(final_iter):
- feed_list = gan.generate_batch()
- run += batch_size
- feed_dict = {
- placeholders['image_input'] : feed_list[0],
- placeholders['x'] : feed_list[1],
- placeholders['image_class_input'] : feed_list[2],
- }
- _, loss_val[0] = session.run([optimizers["discriminator"],losses["disc_image_discriminator"]], feed_dict=feed_dict)
- for _ in range(2*diter):
- feed_list = gan.generate_batch()
- run += batch_size
- feed_dict = {
- placeholders['image_input'] : feed_list[0],
- placeholders['image_class_input'] : feed_list[2],
- }
- if initial :
- _, loss_val[6] = session.run([optimizers["generator"], losses["generator_image"]], feed_dict=feed_dict)
- else:
- _, loss_val[6] = session.run([optimizers["generator_gan"], losses["generator_image"]], feed_dict=feed_dict)
- # z_c = session.run(z_hat_c, feed_dict=feed_dict)
- count += 1
- if count % 10 == 0 or flag:
- print("%d:%d : "%(ep+1,run) + " : ".join(map(lambda x : str(x),loss_val)) + " " + str(time.time() - start_time))
- # print(z_c)
- start_time = time.time()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement