Advertisement
cjxd

Tensorflow patch and reconstruct functions

May 19th, 2017
844
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.50 KB | None | 0 0
  1. def generate_patches(image, patch_h, patch_w):
  2.     '''Splits an image into patches of size patch_h x patch_w
  3.    Input: image of shape [image_h, image_w, image_ch]
  4.    Output: batch of patches shape [n, patch_h, patch_w, image_ch]
  5.    '''
  6.     assert image.shape.ndims == 3
  7.  
  8.     pad = [[0, 0], [0, 0]]
  9.     image_h = image.shape[0].value
  10.     image_w = image.shape[1].value
  11.     image_ch = image.shape[2].value
  12.     p_area = patch_h * patch_w
  13.  
  14.     patches = tf.space_to_batch_nd([image], [patch_h, patch_w], pad)
  15.     patches = tf.split(patches, p_area, 0)
  16.     patches = tf.stack(patches, 3)
  17.     patches = tf.reshape(patches, [-1, patch_h, patch_w, image_ch])
  18.  
  19.     return patches
  20.  
  21. def reconstruct_image(patches, image_h, image_w):
  22.     '''Reconstructs an image from patches of size patch_h x patch_w
  23.    Input: batch of patches shape [n, patch_h, patch_w, patch_ch]
  24.    Output: image of shape [image_h, image_w, patch_ch]
  25.    '''
  26.     assert patches.shape.ndims == 4
  27.  
  28.     pad = [[0, 0], [0, 0]]
  29.     patch_h = patches.shape[1].value
  30.     patch_w = patches.shape[2].value
  31.     patch_ch = patches.shape[3].value
  32.     p_area = patch_h * patch_w
  33.     h_ratio = image_h // patch_h
  34.     w_ratio = image_w // patch_w
  35.  
  36.     image = tf.reshape(patches, [1, h_ratio, w_ratio, p_area, patch_ch])
  37.     image = tf.split(image, p_area, 3)
  38.     image = tf.stack(image, 0)
  39.     image = tf.reshape(image, [p_area, h_ratio, w_ratio, patch_ch])
  40.     image = tf.batch_to_space_nd(image, [patch_h, patch_w], pad)
  41.  
  42.     return image[0]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement