daily pastebin goal
38%
SHARE
TWEET

Untitled

a guest Jan 23rd, 2019 68 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. dataset = dataset.repeat()
  2. dataset = dataset.shuffle(buffer_size=100)
  3. dataset = dataset.batch(self.batch_size, drop_remainder=True)
  4. dataset = dataset.prefetch(100)
  5.    
  6. MirroredStrategy: This does in-graph replication with synchronous
  7. training on many GPUs on one machine. Essentially, we create copies of all
  8. variables in the model's layers on each device. We then use all-reduce
  9. to combine gradients across the devices before applying them
  10. to the variables to keep them in sync.
  11.  
  12. CollectiveAllReduceStrategy: This is a version of MirroredStrategy
  13. for multi-worker training.
  14.    
  15. def create_dataset():
  16.     ...
  17.     dataset = dataset.repeat()
  18.     dataset = dataset.shuffle(buffer_size=100)
  19.     dataset = dataset.batch(self.batch_size, drop_remainder=True)
  20.     dataset = dataset.prefetch(100)
  21.     return dataset
  22.  
  23.  
  24.  
  25. NUM_GPUS = 4
  26. strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)
  27.  
  28. optimizer = tf.train.RMSPropOptimizer(learning_rate=0.01, use_locking=True)
  29. optimizer_d = tf.train.RMSPropOptimizer(learning_rate=0.01, use_locking=True)
  30.  
  31. config = tf.estimator.RunConfig(save_checkpoints_steps=100,
  32.           save_summary_steps=1, keep_checkpoint_max=50,
  33.           train_distribute=strategy)
  34.  
  35. # I have more hooks here, just simplified to show
  36. def get_hooks_fn(GANTrainOps):
  37.  
  38.     disjoint_train_hook_func = tfgan.get_sequential_train_hooks(
  39.                  train_steps=tfgan.GANTrainSteps(10, 1)
  40.                  ) # g steps, d steps
  41.     disjoint_train_hooks = disjoint_train_hook_func(GANTrainOps)
  42.     return [update_hook, summary_hook] + disjoint_train_hooks
  43.  
  44.  
  45. # Create GAN estimator.
  46. gan_estimator = tfgan.estimator.GANEstimator(
  47.     model_dir = '/data/checkpoints/estimator_model',
  48.     generator_fn = generator_fn,
  49.     discriminator_fn = discriminator_fn,
  50.     generator_loss_fn = generator_loss_fn,
  51.     discriminator_loss_fn = discriminator_loss_fn,
  52.     generator_optimizer = optimizer,
  53.     discriminator_optimizer = optimizer_d,
  54.     use_loss_summaries=True,
  55.     config=config,
  56.     get_hooks_fn=get_hooks_fn)
  57.  
  58.  
  59. gan_estimator.train(input_fn=create_dataset  steps=10000)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top