visoft

Out-of-core ItemList backed up by numpy memory mapped files

Mar 1st, 2019
236
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch.nn as nn
  2. from fastai import basic_train, data_block
  3. import numpy as np
  4. from torch import Tensor
  5. import torch
  6. import torch.optim
  7. import os
  8. import time
  9.  
  10. from pathlib import Path
  11.  
  12.  
  13. class MemMapItemList(data_block.ItemList):
  14.     def __init__(self, items, path, data_shape, dtype = np.float32, **kwargs):
  15.         super().__init__(items, path, **kwargs)
  16.         self.data_shape = data_shape
  17.         self.copy_new.append("data_shape")
  18.         self._file_process_dict = {}  # Deleting this structure might cause grief when the main thread is killing the workers
  19.         self._dtype = dtype
  20.  
  21.     def get(self, i):
  22.         pid = os.getpid()
  23.         mem_file = self._file_process_dict.get(pid, None)  # each process owns its handler.
  24.         if mem_file is None:
  25.             mem_file = np.memmap(self.path, self._dtype, mode='r+', shape=self.data_shape)
  26.             self._file_process_dict[pid] = mem_file
  27.         idx = self.items[i]
  28.         item_data = np.copy(mem_file[idx, :])
  29.         if self._dtype == np.float32:
  30.             item = data_block.FloatItem(item_data)
  31.         else:
  32.             item = data_block.Category(item_data, item_data)
  33.         return item
  34.  
  35.     def reconstruct(self, t: Tensor, x: Tensor = None):
  36.         return data_block.FloatItem(t.cpu().numpy())
  37.  
  38.     def labels_from_memmap(self, npy_memfile, data_shape, dtype=np.float32, **kwargs):
  39.         y = MemMapItemList(self.items, npy_memfile, data_shape, dtype=dtype)
  40.         res = self._label_list(x=self, y=y)
  41.         return res
  42.  
  43.     @classmethod
  44.     def from_memfile(cls, path, data_shape):
  45.         "Constructs a MemMapItemList from a numpy mem mapped file"
  46.         items = np.arange(0, data_shape[0])
  47.         return MemMapItemList(items, path, data_shape)
  48.  
  49.  
  50. def gen_some_data_for_io(folder, N, lx, ly):
  51.     feat = np.random.rand(N, lx)
  52.     feat[:, 0] = np.arange(N)
  53.     target = np.random.rand(N, ly)
  54.     target[:, 0] = np.arange(N)
  55.  
  56.     fx = folder / "x.npy"
  57.     fy = folder / "y.npy"
  58.  
  59.     npfx = np.memmap(fx, np.float32, "w+", shape=feat.shape)
  60.     npfx[:] = feat[:]
  61.     npfx.flush()
  62.  
  63.     npfy = np.memmap(fy, np.float32, "w+", shape=target.shape)
  64.     npfy[:] = target[:]
  65.     npfy.flush()
  66.  
  67.     del npfx
  68.     del npfy
  69.  
  70.  
  71. class Validation_Net(nn.Module):
  72.     "Dummy learner. It passes the first feature from input to the output"
  73.  
  74.     def __init__(self, input_size=5, output_size=3):
  75.         super().__init__()
  76.         self.last = nn.Linear(input_size, output_size)
  77.  
  78.     def forward(self, x):
  79.         out = self.last(x)
  80.         out[:, 0] = x[:, 0]
  81.         return out
  82.  
  83.  
  84. class Validation_Loss(torch.nn.Module):
  85.     "Just makes sure that the first column from the input is identical with the target"
  86.  
  87.     def __init__(self):
  88.         super().__init__()
  89.  
  90.     def forward(self, x, y):
  91.         diff = x[:, 0] - y[:, 0]
  92.         abs_diff = torch.abs(diff)
  93.         abs_sum = torch.sum(abs_diff)
  94.         if abs_sum > 0.000001:
  95.             raise Exception("Input and lables are misalligned. Maybe the batch reading is wrong")
  96.         dls = x - y
  97.         dls = torch.sum(torch.pow(dls, 2))
  98.         return dls
  99.  
  100.  
  101. def train_network(folder, N, lx, ly):
  102.     train_data_shape = (N, lx)
  103.     test_data_shape = (N, ly)
  104.  
  105.     item_list = MemMapItemList.from_memfile(folder / "x.npy", data_shape=train_data_shape)
  106.     splitted = item_list.random_split_by_pct(valid_pct=0.1)
  107.     labeled = splitted.labels_from_memmap(folder / "y.npy", data_shape=test_data_shape)
  108.     data_bunch = labeled.databunch(bs=512, num_workers=4)  # Test few values to see what's best for your hw+data stack
  109.  
  110.     model = Validation_Net()
  111.     learner = basic_train.Learner(data=data_bunch, model=model, true_wd=True, wd=0.0001,
  112.                                   loss_func=Validation_Loss(), path=folder)
  113.  
  114.     learner.fit(3, lr=0.001)
  115.     t0 = time.time()
  116.     learner.fit(3, lr=0.001)
  117.     t1 = time.time()
  118.     print("Time {}".format(t1 - t0))
  119.  
  120.  
  121. if __name__ == "__main__":
  122.     N = 100000
  123.     lx = 5
  124.     ly = 3
  125.     folder = Path(".")
  126.     gen_some_data_for_io(folder, N, lx, ly)
  127.     train_network(folder, N, lx, ly)
RAW Paste Data