SHARE
TWEET

Sparse Truncate in python

a guest Aug 17th, 2019 75 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def sparse_truncate(sp_tensor, truncate_lengths):
  2.     row_indices = sp_tensor.indices[:, 0]
  3.     sequence_lengths = tf.sparse_segment_sum(tf.ones_like(row_indices), tf.range(
  4.                 tf.size(row_indices)), row_indices, num_segments=sp_tensor.dense_shape[0])
  5.     max_length = tf.reduce_max(sequence_lengths)
  6.     base_mask = tf.to_int32(tf.sequence_mask(sequence_lengths, max_length))
  7.     trunc_mask = tf.to_int32(tf.sequence_mask(truncate_lengths, max_length))
  8.     mask = base_mask + trunc_mask
  9.     mask_mask = tf.greater(mask,  0)
  10.     mask = tf.equal(tf.boolean_mask(mask, mask_mask) - 1, 1)
  11.     new_indices = tf.boolean_mask(sp_tensor.indices, mask)
  12.     new_values = tf.boolean_mask(sp_tensor.values, mask)
  13.     new_dense_shape = sp_tensor.dense_shape
  14.     return tf.SparseTensor(new_indices, new_values, new_dense_shape)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top