Advertisement
Guest User

Untitled

a guest
Jul 17th, 2019
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.18 KB | None | 0 0
  1. def get_comb_list(start=4, finish = 40):
  2.  
  3. sim_pairs, diff_pairs, triplets = get_pairs_cor_labels(start, finish)
  4. comb_list = sim_pairs + diff_pairs
  5. random.shuffle(comb_list)
  6. return (comb_list)
  7.  
  8. ##########################################################################################################################
  9. #Getting input tensors for the above pairs along with labels
  10. #_cl:contrastive loss
  11. def get_input_tensors_cl(cur_batch):
  12.  
  13. from torchvision.transforms import ToTensor
  14. batch_imgs_anc = torch.Tensor()
  15. batch_imgs_pn = torch.Tensor()
  16. #batch_imgs_neg = torch.Tensor()
  17. labels = []
  18. root_path = 'C:\\Users\\..\\atnt_faces\\'
  19.  
  20.  
  21.  
  22. for i in range(len(cur_batch)):
  23. file_path_anc = cur_batch[i][0]
  24. file_path_pn = cur_batch[i][1]
  25.  
  26. new_tensor_anc = ToTensor()(Image.open(root_path + file_path_anc)).unsqueeze(0)
  27. new_tensor_pn = ToTensor()(Image.open(root_path + file_path_pn)).unsqueeze(0)
  28.  
  29. batch_imgs_anc = torch.cat((batch_imgs_anc, new_tensor_anc), dim =0)
  30. batch_imgs_pn = torch.cat((batch_imgs_pn, new_tensor_pn), dim =0)
  31.  
  32. labels.append(cur_batch[i][2])
  33.  
  34. return(batch_imgs_anc, batch_imgs_pn, labels)
  35.  
  36. ####################################################################################################################
  37.  
  38. #generate nn, note - 2 networks are used in contrast to triplet loss
  39. conv_net_sig, dense_net_sig = gen_nn()
  40. conv_net_sig1, dense_net_sig1 = gen_nn()
  41. LR = 0.0001
  42. optimizer_sig = torch.optim.Adam(([p for p in conv_net_sig.parameters()] + [p for p in dense_net_sig.parameters()]
  43. + [p for p in conv_net_sig1.parameters()] + [p for p in dense_net_sig1.parameters()]), lr=LR)
  44. criterion = nn.BCELoss()
  45.  
  46. #####################################################################################################################
  47. #Train
  48. epochs = 20
  49. print_every = 1
  50. LR = 0.0001
  51. batch_size = 50
  52.  
  53. comb_list = get_comb_list(start=4, finish = 41)
  54.  
  55. if len(comb_list) % batch_size ==0:
  56. num_batches = len(comb_list)//batch_size
  57. else:
  58. num_batches = len(comb_list)//batch_size + 1
  59. #or use:
  60. #num_batches = get_num_batches(the_list, batch_size)
  61.  
  62. losses_sig = []
  63.  
  64. for e in range(epochs):
  65. comb_list = get_comb_list(start=4, finish = 41)
  66. for i in range(num_batches):
  67. cur_batch = get_batch(i, batch_size, num_batches, comb_list)
  68.  
  69. batch_imgs_anc, batch_imgs_pn, labels = get_input_tensors(cur_batch)
  70.  
  71.  
  72. anc_fc_out = forward(batch_imgs_anc, conv_net_sig, dense_net_sig)
  73.  
  74. pn_fc_out = forward(batch_imgs_pn, conv_net_sig1, dense_net_sig1)
  75. #margin is not used here but
  76.  
  77.  
  78. dot_inp_tar = torch.sum(torch.mul(anc_fc_out, pn_fc_out), dim =1).reshape(-1, 1)
  79.  
  80. #sigmoid activation squashes the scores to 1 or 0
  81. sig_logits = nn.Sigmoid()(dot_inp_tar)
  82.  
  83. optimizer_sig.zero_grad()
  84. loss = criterion(sig_logits, torch.Tensor(labels).view(sig_logits.shape[0], 1))
  85. loss.backward()
  86. optimizer_sig.step()
  87.  
  88. losses_sig.append(loss.item())
  89.  
  90.  
  91.  
  92.  
  93. if e % print_every == 0:
  94. print(loss.item())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement