Advertisement
Guest User

Untitled

a guest
May 5th, 2016
52
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.06 KB | None | 0 0
  1. # shape: C x F/2
  2. # output = self.permutations: [num_copies x cell_size]
  3. permutations = []
  4. indices = numpy.arange(self._dim / 2) #[1 ,2 ,3 ...64]
  5. for i in range(self._num_copies):
  6. numpy.random.shuffle(indices) #[4, 48, 32, ...64]
  7. permutations.append(numpy.concatenate(
  8. [indices,
  9. [ind + self._dim / 2 for ind in indices]]))
  10. #you're appending a row with two columns -- a permutation in the first column, and the same permutation + dim/2 for imaginary
  11. # C x F (numpy)
  12. self.permutations = tf.constant(numpy.vstack(permutations), dtype = tf.int32) #This is a permutation tensor that has the stored permutations
  13. # output = self.permutations: [num_copies x cell_size]
  14.  
  15. def permute(complex_tensor): #complex tensor is [batch_size x cell_size]
  16. gather_tensor = tf.gather_nd(complex_tensor, self.permutations)
  17. return gather_tensor
  18.  
  19. def permute(self, complex_tensor):
  20. inputs_permuted = []
  21. for i in range(self.permutations.get_shape()[0].value):
  22. inputs_permuted.append(
  23. tf.gather(complex_tensor, self.permutations[i]))
  24. return tf.concat(0, inputs_permuted)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement