Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import tensorflow as tf
- def unpool(input_images, argmax, output_shape, name='unpooling'):
- os = output_shape.as_list()
- output_sz = np.prod(os)
- b = os[0]
- output_hwc = np.prod(os[1:])
- input_hwc = np.prod(argmax.get_shape().as_list()[1:])
- offset = tf.tile(tf.reshape(tf.range(b, dtype=tf.int64), [b, 1]), [1, input_hwc]) * output_hwc
- reshaped_argmax = tf.reshape(argmax, [b, input_hwc])
- indices = tf.reshape(reshaped_argmax + offset, [-1, 1])
- updates = tf.reshape(input_images, [-1])
- scatter = tf.scatter_nd(indices=indices,
- updates=updates,
- shape=tf.constant([output_sz], dtype=tf.int64))
- return tf.reshape(scatter, output_shape, name=name)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement