Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- -def _transform_imagnet_image(image, target_image_shape, crop_method, seed, random_flip):
- +def _transform_imagnet_image(image, target_image_shape, crop_method, seed):
- """Preprocesses ImageNet images to have a target image shape.
- Args:
- @@ -473,28 +469,10 @@ def _transform_imagnet_image(image, target_image_shape, crop_method, seed, rando
- begin = tf.cast(begin, tf.int32)
- begin = tf.concat([begin, [0]], axis=0) # Add channel dimension.
- image = tf.slice(image, begin, [size, size, 3])
- - elif crop_method == "zoomed":
- - shape = tf.shape(image)
- - h, w = shape[0], shape[1]
- - size = tf.minimum(h, w)
- - begin, size, _ = tf.image.sample_distorted_bounding_box(
- - [size, size, shape[2]],
- - tf.zeros([0, 0, 4], tf.float32),
- - aspect_ratio_range=[1.0, 1.0],
- - area_range=[0.5, 1.0],
- - use_image_if_no_bounding_boxes=True,
- - seed=seed)
- - begin = tf.cast([h - size[0], w - size[1], 0], tf.int32) // 2
- - image = tf.slice(image, begin, size)
- - # Unfortunately, the above operation loses the depth-dimension. So we need
- - # to restore it the manual way.
- - image.set_shape([None, None, target_image_shape[-1]])
- elif crop_method != "none":
- raise ValueError("Unsupported crop method: {}".format(crop_method))
- image = tf.image.resize_images(
- image, [target_image_shape[0], target_image_shape[1]])
- - if random_flip:
- - image = tf.image.random_flip_left_right(image, seed=seed)
- image.set_shape(target_image_shape)
- return image
- diff --git a/DiffAugment-biggan-imagenet/compare_gan/eval_gan_lib.py b/DiffAugment-biggan-imagenet/compare_gan/eval_gan_lib.py
- index cb7af7d..aa84e59 100644
- --- a/DiffAugment-biggan-imagenet/compare_gan/eval_gan_lib.py
- +++ b/DiffAugment-biggan-imagenet/compare_gan/eval_gan_lib.py
- @@ -93,7 +93,7 @@ def _update_bn_accumulators(sess, generated, num_accu_examples):
- def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu,
- - num_averaging_runs, update_bn_accumulators=True, use_tags=True):
- + num_averaging_runs):
- """Evaluate model at given checkpoint_path.
- Args:
- @@ -129,7 +129,7 @@ def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu,
- generator = hub.Module(
- module_spec,
- name="gen_module",
- - tags={"gen", "bs{}".format(batch_size)} if use_tags else None)
- + tags={"gen", "bs{}".format(batch_size)})
- logging.info("Generator inputs: %s", generator.get_input_info_dict())
- z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value
- z = z_generator(shape=[batch_size, z_dim])
- @@ -153,7 +153,7 @@ def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu,
- tf.global_variables_initializer().run()
- - if update_bn_accumulators and _update_bn_accumulators(sess, generated, num_accu_examples=204800):
- + if _update_bn_accumulators(sess, generated, num_accu_examples=204800):
- saver = tf.train.Saver()
- save_path = os.path.join(module_spec, "model-with-accu.ckpt")
- checkpoint_path = saver.save(
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement