Advertisement
Guest User

Untitled

a guest
Mar 29th, 2020
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.78 KB | None | 0 0
  1. from torch.nn.utils.rnn import pad_sequence
  2.  
  3. def collate(list_of_samples):
  4.     """Merges a list of samples to form a mini-batch.
  5.  
  6.    Args:
  7.      list_of_samples is a list of tuples (src_seq, tgt_seq):
  8.          src_seq is of shape (src_seq_length,)
  9.          tgt_seq is of shape (tgt_seq_length,)
  10.  
  11.    Returns:
  12.      src_seqs of shape (max_src_seq_length, batch_size): Tensor of padded source sequences.
  13.          The sequences should be sorted by length in a decreasing order, that is src_seqs[:,0] should be
  14.          the longest sequence, and src_seqs[:,-1] should be the shortest.
  15.      src_seq_lengths: List of lengths of source sequences.
  16.      tgt_seqs of shape (max_tgt_seq_length, batch_size): Tensor of padded target sequences.
  17.    """
  18.     src_seq_lengths = list(map(lambda t : t[0].size()[0], list_of_samples)) # find source lengths
  19.     tgt_seq_lengths = list(map(lambda t : t[1].size()[0], list_of_samples)) # find target lengths
  20.     sort_ind = np.array(np.flip(np.argsort(src_seq_lengths))) # sort into decreasing order wrt source lengths
  21.    
  22.     src_seq_lengths = np.array(src_seq_lengths)[sort_ind]  # apply sort
  23.     tgt_seq_lengths = np.array(tgt_seq_lengths)[sort_ind]
  24.    
  25.     batch_size = len(src_seq_lengths)
  26.     max_src_seq_length = src_seq_lengths[0]  # take max length (already sorted -> first one)
  27.     max_tgt_seq_length = max(tgt_seq_lengths)# find max length
  28.    
  29.     src_seqs = torch.zeros(max_src_seq_length, batch_size, dtype=torch.long)
  30.     tgt_seqs = torch.zeros(max_tgt_seq_length, batch_size, dtype=torch.long)
  31.     for j in range(batch_size):
  32.         src_seqs[0:(src_seq_lengths[j]),j] = list_of_samples[sort_ind[j]][0]
  33.         tgt_seqs[0:(tgt_seq_lengths[j]),j] = list_of_samples[sort_ind[j]][1]
  34.    
  35.     return src_seqs, src_seq_lengths, tgt_seqs
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement