Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import numpy as np
- input_data = np.random.rand(10, 10, 64)
- positions = [0,9,8,2,3,5,1,2,4,6]
- input_tensor = tf.placeholder(shape = (None,10,64) , dtype = tf.float32)
- positions_tensor = tf.placeholder(shape = (None,) , dtype = tf.int32)
- def gather_indexes(input_tensor, positions):
- """Gathers the vectors at the specific positions over a minibatch."""
- batch_size = tf.shape(input_tensor)[0]
- seq_length = tf.shape(input_tensor)[1]
- dim = tf.shape(input_tensor)[2]
- flat_offsets = tf.reshape(tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
- flat_positions = tf.reshape(tf.expand_dims(positions, axis=-1) + flat_offsets, [-1])
- flat_sequence_tensor = tf.reshape(input_tensor,
- [batch_size * seq_length, dim])
- output_tensor = tf.gather(flat_sequence_tensor, flat_positions) ## slices tensor by positions
- return output_tensor
- init = tf.global_variables_initializer()
- with tf.Session() as sess:
- sess.run(init)
- result = sess.run(gather_indexes(input_tensor, positions_tensor),
- feed_dict = {input_tensor: input_data,
- positions_tensor: positions})
- ## check result
- print(result[0])
- print(input_data[0][positions[0]])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement