Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def get_comb_list(start=4, finish = 40):
- sim_pairs, diff_pairs, triplets = get_pairs_cor_labels(start, finish)
- comb_list = sim_pairs + diff_pairs
- random.shuffle(comb_list)
- return (comb_list)
- ##########################################################################################################################
- #Getting input tensors for the above pairs along with labels
- #_cl:contrastive loss
- def get_input_tensors_cl(cur_batch):
- from torchvision.transforms import ToTensor
- batch_imgs_anc = torch.Tensor()
- batch_imgs_pn = torch.Tensor()
- #batch_imgs_neg = torch.Tensor()
- labels = []
- root_path = 'C:\\Users\\..\\atnt_faces\\'
- for i in range(len(cur_batch)):
- file_path_anc = cur_batch[i][0]
- file_path_pn = cur_batch[i][1]
- new_tensor_anc = ToTensor()(Image.open(root_path + file_path_anc)).unsqueeze(0)
- new_tensor_pn = ToTensor()(Image.open(root_path + file_path_pn)).unsqueeze(0)
- batch_imgs_anc = torch.cat((batch_imgs_anc, new_tensor_anc), dim =0)
- batch_imgs_pn = torch.cat((batch_imgs_pn, new_tensor_pn), dim =0)
- labels.append(cur_batch[i][2])
- return(batch_imgs_anc, batch_imgs_pn, labels)
- ####################################################################################################################
- #generate nn, note - 2 networks are used in contrast to triplet loss
- conv_net_sig, dense_net_sig = gen_nn()
- conv_net_sig1, dense_net_sig1 = gen_nn()
- LR = 0.0001
- optimizer_sig = torch.optim.Adam(([p for p in conv_net_sig.parameters()] + [p for p in dense_net_sig.parameters()]
- + [p for p in conv_net_sig1.parameters()] + [p for p in dense_net_sig1.parameters()]), lr=LR)
- criterion = nn.BCELoss()
- #####################################################################################################################
- #Train
- epochs = 20
- print_every = 1
- LR = 0.0001
- batch_size = 50
- comb_list = get_comb_list(start=4, finish = 41)
- if len(comb_list) % batch_size ==0:
- num_batches = len(comb_list)//batch_size
- else:
- num_batches = len(comb_list)//batch_size + 1
- #or use:
- #num_batches = get_num_batches(the_list, batch_size)
- losses_sig = []
- for e in range(epochs):
- comb_list = get_comb_list(start=4, finish = 41)
- for i in range(num_batches):
- cur_batch = get_batch(i, batch_size, num_batches, comb_list)
- batch_imgs_anc, batch_imgs_pn, labels = get_input_tensors(cur_batch)
- anc_fc_out = forward(batch_imgs_anc, conv_net_sig, dense_net_sig)
- pn_fc_out = forward(batch_imgs_pn, conv_net_sig1, dense_net_sig1)
- #margin is not used here but
- dot_inp_tar = torch.sum(torch.mul(anc_fc_out, pn_fc_out), dim =1).reshape(-1, 1)
- #sigmoid activation squashes the scores to 1 or 0
- sig_logits = nn.Sigmoid()(dot_inp_tar)
- optimizer_sig.zero_grad()
- loss = criterion(sig_logits, torch.Tensor(labels).view(sig_logits.shape[0], 1))
- loss.backward()
- optimizer_sig.step()
- losses_sig.append(loss.item())
- if e % print_every == 0:
- print(loss.item())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement