Advertisement
Guest User

Untitled

a guest
Jul 27th, 2017
46
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.28 KB | None | 0 0
  1. def train_epoch(flag=False, initial=True):
  2. diter = 5
  3. count = 0
  4. large_iter = 100
  5. if flag :
  6. final_iter = large_iter
  7. else:
  8. final_iter = diter
  9. run=0
  10. start_time = time.time()
  11. loss_val = [0,0]
  12. while run <= num_examples:
  13. for t in range(final_iter):
  14. feed_list = gan.generate_batch()
  15. run += batch_size
  16. feed_dict = {
  17. placeholders['image_input'] : feed_list[0],
  18. placeholders['x'] : feed_list[1],
  19. placeholders['image_class_input'] : feed_list[2],
  20. }
  21. _, loss_val[0] = session.run([optimizers["discriminator"],losses["disc_image_discriminator"]], feed_dict=feed_dict)
  22.  
  23. for _ in range(2*diter):
  24. feed_list = gan.generate_batch()
  25. run += batch_size
  26. feed_dict = {
  27. placeholders['image_input'] : feed_list[0],
  28. placeholders['image_class_input'] : feed_list[2],
  29. }
  30. if initial :
  31. _, loss_val[6] = session.run([optimizers["generator"], losses["generator_image"]], feed_dict=feed_dict)
  32. else:
  33. _, loss_val[6] = session.run([optimizers["generator_gan"], losses["generator_image"]], feed_dict=feed_dict)
  34.  
  35. # z_c = session.run(z_hat_c, feed_dict=feed_dict)
  36. count += 1
  37. if count % 10 == 0 or flag:
  38. print("%d:%d : "%(ep+1,run) + " : ".join(map(lambda x : str(x),loss_val)) + " " + str(time.time() - start_time))
  39. # print(z_c)
  40. start_time = time.time()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement