Advertisement
Guest User

Untitled

a guest
Jun 26th, 2019
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.62 KB | None | 0 0
  1. # Use a fixed noise vector to see how the GAN Images transition through time on a fixed noise.
  2. fixed_noise = gen_noise(16,noise_shape)
  3.  
  4. # To keep Track of losses
  5. avg_disc_fake_loss = []
  6. avg_disc_real_loss = []
  7. avg_GAN_loss = []
  8.  
  9. # We will run for num_steps iterations
  10. for step in range(num_steps):
  11. tot_step = step
  12. print("Begin step: ", tot_step)
  13. # to keep track of time per step
  14. step_begin_time = time.time()
  15.  
  16. # sample a batch of normalized images from the dataset
  17. real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir)
  18.  
  19. # Genearate noise to send as input to the generator
  20. noise = gen_noise(batch_size,noise_shape)
  21.  
  22. # Use generator to create(predict) images
  23. fake_data_X = generator.predict(noise)
  24.  
  25. # Save predicted images from the generator every 10th step
  26. if (tot_step % 100) == 0:
  27. step_num = str(tot_step).zfill(4)
  28. save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png")
  29.  
  30. # Create the labels for real and fake data. We don't give exact ones and zeros but add a small amount of noise. This is an important GAN training trick
  31. real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
  32. fake_data_Y = np.random.random_sample(batch_size)*0.2
  33.  
  34. # train the discriminator using data and labels
  35.  
  36. discriminator.trainable = True
  37. generator.trainable = False
  38.  
  39. # Training Discriminator seperately on real data
  40. dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y)
  41. # training Discriminator seperately on fake data
  42. dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y)
  43.  
  44. print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0]))
  45.  
  46. # Save the losses to plot later
  47. avg_disc_fake_loss.append(dis_metrics_fake[0])
  48. avg_disc_real_loss.append(dis_metrics_real[0])
  49.  
  50. # Train the generator using a random vector of noise and its labels (1's with noise)
  51. generator.trainable = True
  52. discriminator.trainable = False
  53.  
  54. GAN_X = gen_noise(batch_size,noise_shape)
  55. GAN_Y = real_data_Y
  56.  
  57. gan_metrics = gan.train_on_batch(GAN_X,GAN_Y)
  58. print("GAN loss: %f" % (gan_metrics[0]))
  59.  
  60. # Log results by opening a file in append mode
  61. text_file = open(log_dir+"\\training_log.txt", "a")
  62. text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f\n" % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0]))
  63. text_file.close()
  64.  
  65. # save GAN loss to plot later
  66. avg_GAN_loss.append(gan_metrics[0])
  67.  
  68. end_time = time.time()
  69. diff_time = int(end_time - step_begin_time)
  70. print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))
  71.  
  72. # save model at every 500 steps
  73. if ((tot_step+1) % 500) == 0:
  74. print("-----------------------------------------------------------------")
  75. print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss)))
  76. print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss)))
  77. print("Average GAN loss: %f" % (np.mean(avg_GAN_loss)))
  78. print("-----------------------------------------------------------------")
  79. discriminator.trainable = False
  80. generator.trainable = False
  81. # predict on fixed_noise
  82. fixed_noise_generate = generator.predict(noise)
  83. step_num = str(tot_step).zfill(4)
  84. save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png")
  85. generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5")
  86. discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement