Guest User

Untitled

a guest
Feb 18th, 2018
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.46 KB | None | 0 0
  1. def _collate_fn(batch, padding=True):
  2. """ When creating minibatch in pytorch, this function is called.
  3. This function is one of the argument of torch.data.Dataloader.
  4. So, you don't use this function directly.
  5. """
  6.  
  7. error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
  8. elem_type = type(batch[0])
  9. if torch.is_tensor(batch[0]):
  10. out = None
  11. if _use_shared_memory:
  12. # If we're in a background process, concatenate directly into a
  13. # shared memory tensor to avoid an extra copy
  14. numel = sum([x.numel() for x in batch])
  15. storage = batch[0].storage()._new_shared(numel)
  16. out = batch[0].new(storage)
  17. return torch.stack(batch, 0, out=out)
  18. elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
  19. and elem_type.__name__ != 'string_':
  20. elem = batch[0]
  21. if elem_type.__name__ == 'ndarray':
  22. # array of string classes and object
  23. if re.search('[SaUO]', elem.dtype.str) is not None:
  24. raise TypeError(error_msg.format(elem.dtype))
  25. if pack_padded:
  26. inputs = [Variable(torch.from_numpy(b)) for b in batch]
  27. lengths = [b.shape[0] for b in batch]
  28. ## sorted decreasing order
  29. sorted_lengths_idxes = sorted(
  30. range(len(lengths)), key=lambda k: lengths[k], reverse=True)
  31. sorted_lengths = sorted(lengths, reverse=True)
  32. sorted_inputs = [inputs[i] for i in sorted_lengths_idxes]
  33. return pad_sequence(sorted_inputs, batch_first=True
  34. else:
  35. return torch.stack([torch.from_numpy(b) for b in batch], 0)
  36. if elem.shape == (): # scalars
  37. py_type = float if elem.dtype.name.startswith('float') else int
  38. return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
  39. elif isinstance(batch[0], int_classes):
  40. return torch.LongTensor(batch)
  41. elif isinstance(batch[0], float):
  42. return torch.DoubleTensor(batch)
  43. elif isinstance(batch[0], string_classes):
  44. return batch
  45. elif isinstance(batch[0], collections.Mapping):
  46. return {key: _collate_fn([d[key] for d in batch]) for key in batch[0]}
  47. elif isinstance(batch[0], collections.Sequence):
  48. transposed = zip(*batch)
  49. return [_collate_fn(samples) for samples in transposed]
  50.  
  51. raise TypeError((error_msg.format(type(batch[0]))))
Add Comment
Please, Sign In to add comment