Advertisement
Guest User

Untitled

a guest
Jan 18th, 2017
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.74 KB | None | 0 0
  1. import numpy as np
  2. import tensorflow as tf
  3.  
  4. def unpool(input_images, argmax, output_shape, name='unpooling'):
  5. os = output_shape.as_list()
  6. output_sz = np.prod(os)
  7.  
  8. b = os[0]
  9. output_hwc = np.prod(os[1:])
  10. input_hwc = np.prod(argmax.get_shape().as_list()[1:])
  11. offset = tf.tile(tf.reshape(tf.range(b, dtype=tf.int64), [b, 1]), [1, input_hwc]) * output_hwc
  12. reshaped_argmax = tf.reshape(argmax, [b, input_hwc])
  13.  
  14. indices = tf.reshape(reshaped_argmax + offset, [-1, 1])
  15. updates = tf.reshape(input_images, [-1])
  16. scatter = tf.scatter_nd(indices=indices,
  17. updates=updates,
  18. shape=tf.constant([output_sz], dtype=tf.int64))
  19. return tf.reshape(scatter, output_shape, name=name)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement