Guest User

Untitled

a guest
Jan 23rd, 2019
122
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.05 KB | None | 0 0
  1. dataset = dataset.repeat()
  2. dataset = dataset.shuffle(buffer_size=100)
  3. dataset = dataset.batch(self.batch_size, drop_remainder=True)
  4. dataset = dataset.prefetch(100)
  5.  
  6. MirroredStrategy: This does in-graph replication with synchronous
  7. training on many GPUs on one machine. Essentially, we create copies of all
  8. variables in the model's layers on each device. We then use all-reduce
  9. to combine gradients across the devices before applying them
  10. to the variables to keep them in sync.
  11.  
  12. CollectiveAllReduceStrategy: This is a version of MirroredStrategy
  13. for multi-worker training.
  14.  
  15. def create_dataset():
  16. ...
  17. dataset = dataset.repeat()
  18. dataset = dataset.shuffle(buffer_size=100)
  19. dataset = dataset.batch(self.batch_size, drop_remainder=True)
  20. dataset = dataset.prefetch(100)
  21. return dataset
  22.  
  23.  
  24.  
  25. NUM_GPUS = 4
  26. strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)
  27.  
  28. optimizer = tf.train.RMSPropOptimizer(learning_rate=0.01, use_locking=True)
  29. optimizer_d = tf.train.RMSPropOptimizer(learning_rate=0.01, use_locking=True)
  30.  
  31. config = tf.estimator.RunConfig(save_checkpoints_steps=100,
  32. save_summary_steps=1, keep_checkpoint_max=50,
  33. train_distribute=strategy)
  34.  
  35. # I have more hooks here, just simplified to show
  36. def get_hooks_fn(GANTrainOps):
  37.  
  38. disjoint_train_hook_func = tfgan.get_sequential_train_hooks(
  39. train_steps=tfgan.GANTrainSteps(10, 1)
  40. ) # g steps, d steps
  41. disjoint_train_hooks = disjoint_train_hook_func(GANTrainOps)
  42. return [update_hook, summary_hook] + disjoint_train_hooks
  43.  
  44.  
  45. # Create GAN estimator.
  46. gan_estimator = tfgan.estimator.GANEstimator(
  47. model_dir = '/data/checkpoints/estimator_model',
  48. generator_fn = generator_fn,
  49. discriminator_fn = discriminator_fn,
  50. generator_loss_fn = generator_loss_fn,
  51. discriminator_loss_fn = discriminator_loss_fn,
  52. generator_optimizer = optimizer,
  53. discriminator_optimizer = optimizer_d,
  54. use_loss_summaries=True,
  55. config=config,
  56. get_hooks_fn=get_hooks_fn)
  57.  
  58.  
  59. gan_estimator.train(input_fn=create_dataset steps=10000)
Add Comment
Please, Sign In to add comment