Advertisement
Guest User

Untitled

a guest
Mar 19th, 2019
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.44 KB | None | 0 0
  1. def combined_loss(gan_model, **kwargs):
  2. """Wrapper function for combine adversarial loss, use as generator loss"""
  3. # Define non-adversarial loss - for example L1
  4. non_adversarial_loss = tf.losses.absolute_difference(
  5. gan_model.real_data, gan_model.generated_data)
  6.  
  7. # Define generator loss
  8. generator_loss = tf.contrib.gan.losses.least_squares_generator_loss(
  9. gan_model,
  10. **kwargs)
  11.  
  12. # The structure of kwargs changes between versions, better to add exception
  13. try:
  14. add_summaries = kwargs['add_summaries']
  15. except:
  16. add_summaries = True
  17.  
  18. # Combine these losses - you can specify more parameters
  19. # Exactly one of weight_factor and gradient_ratio must be non-None
  20. combined_loss = tf.contrib.gan.losses.wargs.combine_adversarial_loss(
  21. non_adversarial_loss,
  22. generator_loss,
  23. weight_factor=1.0,
  24. gradient_ratio=None,
  25. variables=gan_model.generator_variables,
  26. scalar_summaries=add_summaries,
  27. gradient_summaries=add_summaries)
  28. return combined_loss
  29.  
  30.  
  31. gan_estimator = tf.contrib.gan.estimator.GANEstimator(
  32. model_dir,
  33. generator_fn=generator_fn,
  34. discriminator_fn=discriminator_fn,
  35. generator_loss_fn=combined_loss,
  36. discriminator_loss_fn=tf.contrib.gan.losses.least_squares_discriminator_loss,
  37. generator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5),
  38. discriminator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement