SHARE
TWEET

Untitled

a guest Jul 17th, 2019 62 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def get_comb_list(start=4, finish = 40):
  2.    
  3.     sim_pairs, diff_pairs, triplets = get_pairs_cor_labels(start, finish)
  4.     comb_list = sim_pairs + diff_pairs
  5.     random.shuffle(comb_list)
  6.     return (comb_list)
  7.  
  8. ##########################################################################################################################
  9. #Getting input tensors for the above pairs along with labels
  10. #_cl:contrastive loss
  11. def get_input_tensors_cl(cur_batch):
  12.    
  13.     from torchvision.transforms import ToTensor
  14.     batch_imgs_anc = torch.Tensor()
  15.     batch_imgs_pn = torch.Tensor()
  16.     #batch_imgs_neg = torch.Tensor()
  17.     labels = []
  18.     root_path = 'C:\\Users\\..\\atnt_faces\\'
  19.  
  20.  
  21.  
  22.     for i in range(len(cur_batch)):
  23.         file_path_anc = cur_batch[i][0]
  24.         file_path_pn = cur_batch[i][1]
  25.        
  26.         new_tensor_anc = ToTensor()(Image.open(root_path + file_path_anc)).unsqueeze(0)
  27.         new_tensor_pn = ToTensor()(Image.open(root_path + file_path_pn)).unsqueeze(0)
  28.        
  29.         batch_imgs_anc = torch.cat((batch_imgs_anc, new_tensor_anc), dim =0)
  30.         batch_imgs_pn = torch.cat((batch_imgs_pn, new_tensor_pn), dim =0)
  31.        
  32.         labels.append(cur_batch[i][2])
  33.        
  34.     return(batch_imgs_anc, batch_imgs_pn, labels)
  35.  
  36. ####################################################################################################################
  37.  
  38. #generate nn, note - 2 networks are used in contrast to triplet loss
  39. conv_net_sig, dense_net_sig = gen_nn()
  40. conv_net_sig1, dense_net_sig1 = gen_nn()
  41. LR = 0.0001
  42. optimizer_sig = torch.optim.Adam(([p for p in conv_net_sig.parameters()] + [p for p in dense_net_sig.parameters()]
  43.                                   + [p for p in conv_net_sig1.parameters()] + [p for p in dense_net_sig1.parameters()]), lr=LR)
  44. criterion = nn.BCELoss()
  45.  
  46. #####################################################################################################################
  47. #Train
  48. epochs = 20
  49. print_every = 1
  50. LR = 0.0001
  51. batch_size = 50
  52.  
  53. comb_list = get_comb_list(start=4, finish = 41)
  54.    
  55. if len(comb_list) % batch_size ==0:
  56.     num_batches = len(comb_list)//batch_size
  57. else:
  58.     num_batches = len(comb_list)//batch_size + 1
  59. #or use:  
  60. #num_batches = get_num_batches(the_list, batch_size)
  61.  
  62. losses_sig = []
  63.  
  64. for e in range(epochs):
  65.     comb_list = get_comb_list(start=4, finish = 41)
  66.     for i in range(num_batches):
  67.         cur_batch = get_batch(i, batch_size, num_batches, comb_list)
  68.        
  69.         batch_imgs_anc, batch_imgs_pn, labels = get_input_tensors(cur_batch)
  70.        
  71.  
  72.         anc_fc_out = forward(batch_imgs_anc, conv_net_sig, dense_net_sig)
  73.        
  74.         pn_fc_out = forward(batch_imgs_pn, conv_net_sig1, dense_net_sig1)
  75.         #margin is not used here but
  76.        
  77.        
  78.         dot_inp_tar = torch.sum(torch.mul(anc_fc_out, pn_fc_out), dim =1).reshape(-1, 1)
  79.  
  80.         #sigmoid activation squashes the scores to 1 or 0
  81.         sig_logits = nn.Sigmoid()(dot_inp_tar)
  82.  
  83.         optimizer_sig.zero_grad()
  84.         loss = criterion(sig_logits, torch.Tensor(labels).view(sig_logits.shape[0], 1))
  85.         loss.backward()
  86.         optimizer_sig.step()      
  87.              
  88.         losses_sig.append(loss.item())
  89.        
  90.        
  91.        
  92.        
  93.     if e % print_every == 0:
  94.         print(loss.item())
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top