Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def combined_loss(gan_model, **kwargs):
- """Wrapper function for combine adversarial loss, use as generator loss"""
- # Define non-adversarial loss - for example L1
- non_adversarial_loss = tf.losses.absolute_difference(
- gan_model.real_data, gan_model.generated_data)
- # Define generator loss
- generator_loss = tf.contrib.gan.losses.least_squares_generator_loss(
- gan_model,
- **kwargs)
- # The structure of kwargs changes between versions, better to add exception
- try:
- add_summaries = kwargs['add_summaries']
- except:
- add_summaries = True
- # Combine these losses - you can specify more parameters
- # Exactly one of weight_factor and gradient_ratio must be non-None
- combined_loss = tf.contrib.gan.losses.wargs.combine_adversarial_loss(
- non_adversarial_loss,
- generator_loss,
- weight_factor=1.0,
- gradient_ratio=None,
- variables=gan_model.generator_variables,
- scalar_summaries=add_summaries,
- gradient_summaries=add_summaries)
- return combined_loss
- gan_estimator = tf.contrib.gan.estimator.GANEstimator(
- model_dir,
- generator_fn=generator_fn,
- discriminator_fn=discriminator_fn,
- generator_loss_fn=combined_loss,
- discriminator_loss_fn=tf.contrib.gan.losses.least_squares_discriminator_loss,
- generator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5),
- discriminator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement