Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def sparse_truncate(sp_tensor, truncate_lengths):
- row_indices = sp_tensor.indices[:, 0]
- sequence_lengths = tf.sparse_segment_sum(tf.ones_like(row_indices), tf.range(
- tf.size(row_indices)), row_indices, num_segments=sp_tensor.dense_shape[0])
- max_length = tf.reduce_max(sequence_lengths)
- base_mask = tf.to_int32(tf.sequence_mask(sequence_lengths, max_length))
- trunc_mask = tf.to_int32(tf.sequence_mask(truncate_lengths, max_length))
- mask = base_mask + trunc_mask
- mask_mask = tf.greater(mask, 0)
- mask = tf.equal(tf.boolean_mask(mask, mask_mask) - 1, 1)
- new_indices = tf.boolean_mask(sp_tensor.indices, mask)
- new_values = tf.boolean_mask(sp_tensor.values, mask)
- new_dense_shape = sp_tensor.dense_shape
- return tf.SparseTensor(new_indices, new_values, new_dense_shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement