Emania

distance train

May 14th, 2020
173
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.53 KB | None | 0 0
  1.  
  2. class MyDataset(Dataset):
  3. """
  4. dataset spocte vsechna mozna data a pro kazdou tridu si udrzuje vlastni listy a indexy,
  5. pri kazdem __getitem__ bere s pravdepodobnosti 0.5 z kazdeho seznamu -> ignoruje argument idx
  6.  
  7. """
  8.     def __init__(self, pickle_path, train=True):
  9.  
  10.         # load the pickle
  11.         with open(pickle_path, 'rb') as f:
  12.             self.pickle_dict = pickle.load(f)
  13.         print('keys ', self.pickle_dict.keys())
  14.         self.descriptors = np.array([self.pickle_dict[key]['trn_descriptors'] for key in self.pickle_dict['keys']]).squeeze(-1)
  15.  
  16.         # labels
  17.         self.train = train
  18.         labels_name = 'trn_labels' if self.train else 'tst_labels'
  19.         self.labels = self.pickle_dict[256][labels_name]
  20.         self.size = len(self.labels)
  21.         y = self.labels[:, None]
  22.         z = np.repeat(y, len(self.labels), -1)
  23.         self.indexes = (y.T == z).flatten()
  24.  
  25.         # same and different idxs
  26.         self.same_i, self.different_i, self.same_turn  = 0, 0, True
  27.         mask = [x == True and self.labels[i // self.size] != -1 for i, x in enumerate(self.indexes)]
  28.         self.different = np.where(np.logical_not(mask))
  29.         self.same = np.where(mask)
  30.         print('same length ', len(self.same))
  31.         print('different length ', len(self.different))
  32.  
  33.  
  34.     def __len__(self):
  35.         return self.size**2
  36.  
  37.     def __getitem__(self, idx):
  38.  
  39.         if torch.round(torch.rand(1)).item():
  40.             i = self.same_i
  41.             self.same_i = (self.same_i +1) % len(self.same)
  42.             true_x = self.same[i] // self.size
  43.             true_y = self.same[i] % self.size
  44.         else:
  45.             i = self.different_i
  46.             true_x = self.different[i] // self.size
  47.             true_y = self.different[i] % self.size
  48.             self.different_i = (self.different_i + 1) % len(self.different)
  49.  
  50.  
  51.         data = self.descriptors[:,true_x].dot(self.descriptors[:, true_y].T).flatten()
  52.         label = int(self.labels[true_x] == self.labels[true_y])
  53.  
  54.         self.same_turn = not self.same_turn
  55.         return data, label
  56.  
  57.  
  58. ## main
  59. def moving_average(a, n=3) :
  60.     ret = np.cumsum(a, dtype=float)
  61.     ret[n:] = ret[n:] - ret[:-n]
  62.     return ret[n - 1:] / n
  63.  
  64. def plot_history(losses, accuracy, epoch, window):
  65.  
  66.     plt_loss = moving_average(np.array(losses), window)
  67.     plt_accuracy = moving_average(np.array(accuracy), window)
  68.  
  69.     plt.figure()
  70.     plt.subplot(121)
  71.     plt.plot(np.linspace(0,epoch, len(plt_loss)), plt_loss)
  72.     plt.title('loss')
  73.  
  74.     plt.subplot(122)
  75.     plt.plot(np.linspace(0,epoch, len(plt_accuracy)), plt_accuracy)
  76.     plt.title('accuracy')
  77.     plt.savefig('models/learning_state.png')
  78.  
  79.  
  80. def main():
  81.     epochs = 30
  82.     bs = 20
  83.     dev = torch.device('cuda')  if torch.cuda.is_available() else torch.device('cpu')
  84.  
  85.     # Prepare model
  86.     model = nn.Sequential(nn.Linear(25, 100), nn.Linear(100,1))
  87.     model.to(dev)
  88.     optimizer = torch.optim.SGD(model.parameters(), lr=0.00005, momentum=0.9)
  89.     model.train()
  90.     loss = nn.BCEWithLogitsLoss()
  91.  
  92.     # Prepare Dataset
  93.     dataset = MyDataset('output/gl18-tl-resnet101-gem-w')
  94.     loader = torch.utils.data.DataLoader(dataset, bs, shuffle=True)
  95.     loader_len = len(loader)
  96.  
  97.     losses = []
  98.     accuracy = []
  99.     ewa_a = .5
  100.     q = 0.1
  101.     ewa_l = .5
  102.     last_p = 0 # lepsi nazev by byl best_p
  103.     print('starting ')
  104.  
  105.     # trenovani site
  106.     for epoch in range(1, epochs + 1):
  107.         for i, (x,y) in enumerate(loader):
  108.  
  109.             # data preparation
  110.             x = x.to(dev)
  111.             y = y.to(dev)
  112.             y = y.float()
  113.  
  114.             # step
  115.             optimizer.zero_grad()
  116.             out = model.forward(x)
  117.             l = loss(out.flatten().float(), y.flatten().float())
  118.             l.backward()
  119.             optimizer.step()
  120.  
  121.             # loss
  122.             ewa_l = (1-q) * ewa_l + q*l # Tohle konvertovat na float!!
  123.             losses.append(ewa_l)
  124.  
  125.             # accuracy
  126.             classify = torch.round(torch.sigmoid(out)).float().flatten()
  127.             acc = (classify == y).float().mean()
  128.             ewa_a  = (1-q) * ewa_a  + q*acc # Tohle konvertovat na float!!
  129.             accuracy.append(ewa_a)
  130.            
  131.             if i%100 == 0:
  132.                 print(f'Epoch: {epoch}/{epochs}, batch {i}/{loader_len} EWA loss {ewa_l}, EWA accuracy {ewa_a}')
  133.  
  134.         if ewa_a  > last_p:
  135.             last_p = ewa_a
  136.             torch.save(model.state_dict(), 'models/best')
  137.  
  138.         plot_history(losses, accuracy, epoch, len(loader)//10)
  139.  
  140. if __name__ == '__main__':
  141.     main()
Add Comment
Please, Sign In to add comment