Advertisement
Guest User

Untitled

a guest
Jan 17th, 2017
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.28 KB | None | 0 0
  1. def get_minibatch(lines, index, batch_size, word2ind, max_len, add_start=False, add_end=True):
  2. """Prepare minibatch."""
  3. if add_start and add_end:
  4. lines = [
  5. ['<s>'] + line + ['</s>']
  6. for line in lines[index:index + batch_size]
  7. ]
  8. elif add_start and not add_end:
  9. lines = [
  10. ['<s>'] + line
  11. for line in lines[index:index + batch_size]
  12. ]
  13. elif not add_start and add_end:
  14. lines = [
  15. line + ['</s>']
  16. for line in lines[index:index + batch_size]
  17. ]
  18. lines = [line[:max_len] for line in lines]
  19.  
  20. lens = [len(line) for line in lines]
  21. max_len = max(lens)
  22.  
  23. input_lines = np.array([
  24. [word2ind[w] if w in word2ind else word2ind['<unk>'] for w in line[:-1]] +
  25. [word2ind['<pad>']] * (max_len - len(line))
  26. for line in lines
  27. ]).astype(np.int32)
  28.  
  29. output_lines = np.array([
  30. [word2ind[w] if w in word2ind else word2ind['<unk>'] for w in line[1:]] +
  31. [word2ind['<pad>']] * (max_len - len(line))
  32. for line in lines
  33. ]).astype(np.int32)
  34.  
  35. mask = np.array(
  36. [
  37. ([1] * (l - 1)) + ([0] * (max_len - l))
  38. for l in lens
  39. ]
  40. ).astype(np.float32)
  41.  
  42. return input_lines, output_lines, lens, mask
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement