Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from torch.nn.utils.rnn import pad_sequence
- def collate(list_of_samples):
- """Merges a list of samples to form a mini-batch.
- Args:
- list_of_samples is a list of tuples (src_seq, tgt_seq):
- src_seq is of shape (src_seq_length,)
- tgt_seq is of shape (tgt_seq_length,)
- Returns:
- src_seqs of shape (max_src_seq_length, batch_size): Tensor of padded source sequences.
- The sequences should be sorted by length in a decreasing order, that is src_seqs[:,0] should be
- the longest sequence, and src_seqs[:,-1] should be the shortest.
- src_seq_lengths: List of lengths of source sequences.
- tgt_seqs of shape (max_tgt_seq_length, batch_size): Tensor of padded target sequences.
- """
- src_seq_lengths = list(map(lambda t : t[0].size()[0], list_of_samples)) # find source lengths
- tgt_seq_lengths = list(map(lambda t : t[1].size()[0], list_of_samples)) # find target lengths
- sort_ind = np.array(np.flip(np.argsort(src_seq_lengths))) # sort into decreasing order wrt source lengths
- src_seq_lengths = np.array(src_seq_lengths)[sort_ind] # apply sort
- tgt_seq_lengths = np.array(tgt_seq_lengths)[sort_ind]
- batch_size = len(src_seq_lengths)
- max_src_seq_length = src_seq_lengths[0] # take max length (already sorted -> first one)
- max_tgt_seq_length = max(tgt_seq_lengths)# find max length
- src_seqs = torch.zeros(max_src_seq_length, batch_size, dtype=torch.long)
- tgt_seqs = torch.zeros(max_tgt_seq_length, batch_size, dtype=torch.long)
- for j in range(batch_size):
- src_seqs[0:(src_seq_lengths[j]),j] = list_of_samples[sort_ind[j]][0]
- tgt_seqs[0:(tgt_seq_lengths[j]),j] = list_of_samples[sort_ind[j]][1]
- return src_seqs, src_seq_lengths, tgt_seqs
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement