Advertisement
Guest User

Untitled

a guest
Jun 25th, 2019
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.25 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3.  
  4.  
  5. input_data = np.random.rand(10, 10, 64)
  6. positions = [0,9,8,2,3,5,1,2,4,6]
  7.  
  8. input_tensor = tf.placeholder(shape = (None,10,64) , dtype = tf.float32)
  9. positions_tensor = tf.placeholder(shape = (None,) , dtype = tf.int32)
  10.  
  11.  
  12. def gather_indexes(input_tensor, positions):
  13. """Gathers the vectors at the specific positions over a minibatch."""
  14.  
  15. batch_size = tf.shape(input_tensor)[0]
  16. seq_length = tf.shape(input_tensor)[1]
  17. dim = tf.shape(input_tensor)[2]
  18.  
  19. flat_offsets = tf.reshape(tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
  20. flat_positions = tf.reshape(tf.expand_dims(positions, axis=-1) + flat_offsets, [-1])
  21. flat_sequence_tensor = tf.reshape(input_tensor,
  22. [batch_size * seq_length, dim])
  23. output_tensor = tf.gather(flat_sequence_tensor, flat_positions) ## slices tensor by positions
  24. return output_tensor
  25.  
  26.  
  27. init = tf.global_variables_initializer()
  28.  
  29. with tf.Session() as sess:
  30.  
  31. sess.run(init)
  32. result = sess.run(gather_indexes(input_tensor, positions_tensor),
  33. feed_dict = {input_tensor: input_data,
  34. positions_tensor: positions})
  35.  
  36. ## check result
  37. print(result[0])
  38. print(input_data[0][positions[0]])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement