SHARE
TWEET

Untitled

a guest Jun 25th, 2019 48 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import tensorflow as tf
  2. import numpy as np
  3.  
  4.  
  5. input_data = np.random.rand(10, 10, 64)
  6. positions = [0,9,8,2,3,5,1,2,4,6]
  7.  
  8. input_tensor = tf.placeholder(shape = (None,10,64) , dtype = tf.float32)
  9. positions_tensor = tf.placeholder(shape = (None,) , dtype = tf.int32)
  10.  
  11.  
  12. def gather_indexes(input_tensor, positions):
  13.   """Gathers the vectors at the specific positions over a minibatch."""
  14.  
  15.   batch_size = tf.shape(input_tensor)[0]
  16.   seq_length = tf.shape(input_tensor)[1]
  17.   dim = tf.shape(input_tensor)[2]
  18.  
  19.   flat_offsets = tf.reshape(tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
  20.   flat_positions = tf.reshape(tf.expand_dims(positions, axis=-1) + flat_offsets, [-1])
  21.   flat_sequence_tensor = tf.reshape(input_tensor,
  22.                                     [batch_size * seq_length, dim])
  23.   output_tensor = tf.gather(flat_sequence_tensor, flat_positions)  ## slices tensor by positions
  24.   return output_tensor
  25.  
  26.  
  27. init = tf.global_variables_initializer()
  28.  
  29. with tf.Session() as sess:
  30.    
  31.     sess.run(init)
  32.     result = sess.run(gather_indexes(input_tensor, positions_tensor),
  33.                       feed_dict = {input_tensor: input_data,
  34.                                    positions_tensor: positions})
  35.  
  36. ## check result
  37. print(result[0])
  38. print(input_data[0][positions[0]])
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top