Advertisement
cjxd

image_test.py

May 18th, 2017
860
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.93 KB | None | 0 0
  1. import tensorflow as tf
  2. import sys, os, os.path
  3.  
  4. # Data parameters
  5. shuffle = False
  6. input_dtype=tf.uint8
  7. dtype=tf.float32
  8.  
  9. input_h = input_w = 1024
  10. input_ch = 4
  11. patch_h = patch_w = 32
  12. image_patch_ratio = patch_h * patch_w / (input_h * input_w)
  13. input_noise = 0
  14.  
  15. usage = "Usage: image_test.py <path to input list> <path to outputs>\n"
  16. usage+= "where <path to input list> is a list of images to process separated by newlines\n"
  17. usage+= "and <path to outputs> is the directory to store saved images."
  18.  
  19. def read_files(image_list):
  20.     filename_queue = tf.train.string_input_producer(image_list, shuffle=shuffle)
  21.  
  22.     reader = tf.WholeFileReader()
  23.     _, image_file = reader.read(filename_queue)
  24.     image = tf.image.decode_png(image_file, channels=input_ch, dtype=input_dtype)
  25.     image = tf.image.convert_image_dtype(image, dtype)
  26.     image.set_shape((input_h, input_w, input_ch))
  27.  
  28.     return image
  29.  
  30. def add_noise(image, mean=0.0, stddev=0.5):
  31.     noise = tf.random_normal(shape=image.shape,
  32.               mean=0.0, stddev=stddev,
  33.               dtype=dtype)
  34.  
  35.     return image + noise
  36.  
  37. def generate_patches(image):
  38.     patch_size = [1, patch_h, patch_w, 1]
  39.     patches = tf.extract_image_patches([image],
  40.         patch_size, patch_size, [1, 1, 1, 1], 'VALID')
  41.     patches = tf.reshape(patches, [-1, patch_h, patch_w, input_ch])
  42.  
  43.     return patches
  44.  
  45. def make_image(data):
  46.     converted = tf.image.convert_image_dtype(data, input_dtype)
  47.     encoded = tf.image.encode_png(converted)
  48.  
  49.     return encoded
  50.  
  51. def make_images(data):
  52.     data_queue = tf.train.batch([data],
  53.             batch_size=1,
  54.             enqueue_many=True,
  55.             capacity=10000)
  56.  
  57.     return make_image(data_queue[0])
  58.  
  59. def reconstruct_image(patches):
  60.     image = tf.reshape(patches, [1, input_h, input_w, input_ch])
  61.  
  62.     return make_image(image[0])
  63.  
  64. def main(args):
  65.     if len(args) != 2:
  66.         print(usage)
  67.         sys.exit(1)
  68.  
  69.     input_list = args[0]
  70.     image_dir = args[1]
  71.  
  72.     with open(input_list, 'r') as input_set:
  73.         inputs = input_set.read().splitlines()
  74.  
  75.     n_examples = len(inputs)
  76.     n_patches = n_examples // image_patch_ratio
  77.    
  78.     # Load, patch and reconstruct images
  79.     input_data = read_files(inputs)
  80.     input_img = make_image(input_data)
  81.     input_patches = generate_patches(input_data)
  82.     patch_imgs = make_images(input_patches)
  83.     output_img = reconstruct_image(input_patches)
  84.  
  85.     # Initialize session and graph
  86.     with tf.Session() as sess:
  87.         sess.run(tf.global_variables_initializer())  
  88.  
  89.         # Start input enqueue threads
  90.         coord = tf.train.Coordinator()
  91.         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  92.  
  93.         # Main loop
  94.         try:
  95.             i = 0
  96.             while not coord.should_stop():
  97.                 print("Saving image %d/%d" % (i, n_patches))
  98.                 tag = str(i)
  99.  
  100.                 # Generate files for input, patches, and output
  101.                 input_name = tf.constant(os.path.join(image_dir, tag + '_in.png'))
  102.                 patch_name = tf.constant(os.path.join(image_dir, tag + '_patch.png'))
  103.                 output_name = tf.constant(os.path.join(image_dir, tag + '_out.png'))
  104.  
  105.                 input_fwrite = tf.write_file(input_name, input_img)
  106.                 patch_fwrite = tf.write_file(patch_name, patch_imgs)
  107.                 output_fwrite = tf.write_file(output_name, output_img)
  108.  
  109.                 # Run only patch_fwrite if you want to quickly save lots of patches
  110.                 # The input and output images will be the same every time, so don't
  111.                 # waste your breath.
  112.                 sess.run([input_fwrite, patch_fwrite, output_fwrite])
  113.  
  114.                 i += 1
  115.  
  116.         except tf.errors.OutOfRangeError:
  117.             pass
  118.         finally:
  119.             coord.request_stop()
  120.  
  121.         # Wait for threads to finish.
  122.         coord.join(threads)
  123.  
  124. if __name__ == '__main__':
  125.     argv = sys.argv[1:]
  126.     main(argv)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement