Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def _collate_fn(batch, padding=True):
- """ When creating minibatch in pytorch, this function is called.
- This function is one of the argument of torch.data.Dataloader.
- So, you don't use this function directly.
- """
- error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
- elem_type = type(batch[0])
- if torch.is_tensor(batch[0]):
- out = None
- if _use_shared_memory:
- # If we're in a background process, concatenate directly into a
- # shared memory tensor to avoid an extra copy
- numel = sum([x.numel() for x in batch])
- storage = batch[0].storage()._new_shared(numel)
- out = batch[0].new(storage)
- return torch.stack(batch, 0, out=out)
- elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
- and elem_type.__name__ != 'string_':
- elem = batch[0]
- if elem_type.__name__ == 'ndarray':
- # array of string classes and object
- if re.search('[SaUO]', elem.dtype.str) is not None:
- raise TypeError(error_msg.format(elem.dtype))
- if pack_padded:
- inputs = [Variable(torch.from_numpy(b)) for b in batch]
- lengths = [b.shape[0] for b in batch]
- ## sorted decreasing order
- sorted_lengths_idxes = sorted(
- range(len(lengths)), key=lambda k: lengths[k], reverse=True)
- sorted_lengths = sorted(lengths, reverse=True)
- sorted_inputs = [inputs[i] for i in sorted_lengths_idxes]
- return pad_sequence(sorted_inputs, batch_first=True
- else:
- return torch.stack([torch.from_numpy(b) for b in batch], 0)
- if elem.shape == (): # scalars
- py_type = float if elem.dtype.name.startswith('float') else int
- return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
- elif isinstance(batch[0], int_classes):
- return torch.LongTensor(batch)
- elif isinstance(batch[0], float):
- return torch.DoubleTensor(batch)
- elif isinstance(batch[0], string_classes):
- return batch
- elif isinstance(batch[0], collections.Mapping):
- return {key: _collate_fn([d[key] for d in batch]) for key in batch[0]}
- elif isinstance(batch[0], collections.Sequence):
- transposed = zip(*batch)
- return [_collate_fn(samples) for samples in transposed]
- raise TypeError((error_msg.format(type(batch[0]))))
Add Comment
Please, Sign In to add comment