Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- dataset = dataset.repeat()
- dataset = dataset.shuffle(buffer_size=100)
- dataset = dataset.batch(self.batch_size, drop_remainder=True)
- dataset = dataset.prefetch(100)
- MirroredStrategy: This does in-graph replication with synchronous
- training on many GPUs on one machine. Essentially, we create copies of all
- variables in the model's layers on each device. We then use all-reduce
- to combine gradients across the devices before applying them
- to the variables to keep them in sync.
- CollectiveAllReduceStrategy: This is a version of MirroredStrategy
- for multi-worker training.
- def create_dataset():
- ...
- dataset = dataset.repeat()
- dataset = dataset.shuffle(buffer_size=100)
- dataset = dataset.batch(self.batch_size, drop_remainder=True)
- dataset = dataset.prefetch(100)
- return dataset
- NUM_GPUS = 4
- strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)
- optimizer = tf.train.RMSPropOptimizer(learning_rate=0.01, use_locking=True)
- optimizer_d = tf.train.RMSPropOptimizer(learning_rate=0.01, use_locking=True)
- config = tf.estimator.RunConfig(save_checkpoints_steps=100,
- save_summary_steps=1, keep_checkpoint_max=50,
- train_distribute=strategy)
- # I have more hooks here, just simplified to show
- def get_hooks_fn(GANTrainOps):
- disjoint_train_hook_func = tfgan.get_sequential_train_hooks(
- train_steps=tfgan.GANTrainSteps(10, 1)
- ) # g steps, d steps
- disjoint_train_hooks = disjoint_train_hook_func(GANTrainOps)
- return [update_hook, summary_hook] + disjoint_train_hooks
- # Create GAN estimator.
- gan_estimator = tfgan.estimator.GANEstimator(
- model_dir = '/data/checkpoints/estimator_model',
- generator_fn = generator_fn,
- discriminator_fn = discriminator_fn,
- generator_loss_fn = generator_loss_fn,
- discriminator_loss_fn = discriminator_loss_fn,
- generator_optimizer = optimizer,
- discriminator_optimizer = optimizer_d,
- use_loss_summaries=True,
- config=config,
- get_hooks_fn=get_hooks_fn)
- gan_estimator.train(input_fn=create_dataset steps=10000)
Add Comment
Please, Sign In to add comment