Advertisement
Guest User

Sparse Truncate in python

a guest
Aug 17th, 2019
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.78 KB | None | 0 0
  1. def sparse_truncate(sp_tensor, truncate_lengths):
  2. row_indices = sp_tensor.indices[:, 0]
  3. sequence_lengths = tf.sparse_segment_sum(tf.ones_like(row_indices), tf.range(
  4. tf.size(row_indices)), row_indices, num_segments=sp_tensor.dense_shape[0])
  5. max_length = tf.reduce_max(sequence_lengths)
  6. base_mask = tf.to_int32(tf.sequence_mask(sequence_lengths, max_length))
  7. trunc_mask = tf.to_int32(tf.sequence_mask(truncate_lengths, max_length))
  8. mask = base_mask + trunc_mask
  9. mask_mask = tf.greater(mask, 0)
  10. mask = tf.equal(tf.boolean_mask(mask, mask_mask) - 1, 1)
  11. new_indices = tf.boolean_mask(sp_tensor.indices, mask)
  12. new_values = tf.boolean_mask(sp_tensor.values, mask)
  13. new_dense_shape = sp_tensor.dense_shape
  14. return tf.SparseTensor(new_indices, new_values, new_dense_shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement