Guest User

Untitled

a guest
Dec 11th, 2018
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.73 KB | None | 0 0
  1. # Initial Tensor for concatenating the tensors to after each while loop
  2. tf_padded_final = tf.zeros(shape=[1,max_sent_seq_len,output_size * 2])
  3.  
  4. # 1-D vector of size = document batch size which is the number of sentences in a document (used for the tf.bidirectional_dynamic_rnn sequence lengths argument)
  5. sentence_batch_len = tf.placeholder(shape=[None],dtype=tf.int32,name="sentence_batch_len")
  6.  
  7. # This is a 2-D array, with the first index being the start sentence index of the document to gather, and the second index is the last sentence index of the document to gather.
  8. **sentence_index_offsets = tf.placeholder(shape=[None,2],dtype=tf.int32,name="sentence_index_offsets")**
  9.  
  10. i = tf.constant(0)
  11.  
  12. def while_cond (i, tf_padded_final):
  13.  
  14. mb = tf.constant(mini_batch_size)
  15. return tf.less(i,mb)
  16.  
  17.  
  18. # Loop through the mini batch of Documents (not sentences) one at a time
  19. # And rollup the sentence vectors into a single row (1 row per document)
  20. # While loop was used as there can be multiple sentence counts across documents
  21. def body(i,tf_padded_final):
  22.  
  23. #tf.print(i,[i])
  24. end_idx = sentence_index_offsets[i,1]
  25. st_idx = sentence_index_offsets[i,0]
  26. tf_range = tf.range(start=st_idx,limit=end_idx)
  27. pad_len = max_sent_seq_len - sentence_batch_len[i]
  28.  
  29. tf_slice = tf.gather(outputs,tf_range)
  30. tf_slice_padding = [[0, pad_len], [0, 0]]
  31. tf_slice_padded = tf.pad(tf_slice, tf_slice_padding, 'CONSTANT')
  32. tf_slice_padded_3D = tf.expand_dims(tf_slice_padded, axis=0)
  33.  
  34. tf_padded_final = tf.concat([tf_padded_final,tf_slice_padded_3D],axis=0)
  35.  
  36. i = tf.add(i,1)
  37.  
  38. return i, tf_padded_final
  39.  
  40. _, tf_padded_final_2 = tf.while_loop(while_cond, body, [i, tf_padded_final],shape_invariants=[i.get_shape(),tf.TensorShape([None,12,20])])
Add Comment
Please, Sign In to add comment