Advertisement
Guest User

Untitled

a guest
Mar 26th, 2019
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.68 KB | None | 0 0
  1. import random
  2. import torch
  3. import pathlib
  4. import os
  5. from torch.utils import data
  6. import PIL
  7. import PIL.Image
  8.  
  9. class SiameseDataset(data.Dataset):
  10. def __init__(self, root, ext, transform=None, pair_transform=None, target_transform=None):
  11. super(SiameseDataset, self).__init__()
  12. self.transform = transform
  13. self.pair_transform = pair_transform
  14. self.target_transform = target_transform
  15. self.root = root
  16.  
  17. self.base_path = pathlib.Path(root)
  18. self.files = sorted(list(path.glob("*/*."+ext)))
  19. self.files_map = self._files_mapping()
  20. self.pair_files = self._pair_files()
  21.  
  22. def __len__(self):
  23. return len(self.pair_files)
  24.  
  25. def __getitem__(self, idx):
  26. (imp1, imp2), sim = self.pair_files[idx]
  27. im1 = PIL.Image.open(imp1)
  28. im2 = PIL.Image.open(imp2)
  29.  
  30. if self.transform:
  31. im1 = self.transform(im1)
  32. im2 = self.transform(im2)
  33.  
  34. if self.pair_transform:
  35. im1,im2 = self.transform_pair(im1,im2)
  36.  
  37. if self.target_transform:
  38. sim = self.target_transform(sim)
  39. return im1, im2, sim
  40.  
  41.  
  42. def _files_mapping(self):
  43. dirname = []
  44. filename = []
  45. dct = {}
  46. for f in self.files:
  47. spl = str(f).split('/')
  48. dirname = spl[-2]
  49. filename = spl[-1]
  50.  
  51. if dirname not in dct.keys():
  52. dct.update({dirname:[]})
  53. else:
  54. dct[dirname].append(filename)
  55. dct[dirname]=sorted(dct[dirname])
  56. return dct
  57.  
  58.  
  59. def _similar_pair(self):
  60. fmap = self.files_map
  61. atp = {}
  62. c = 0
  63. for key in fmap.keys():
  64. atp.update({key:[]})
  65. n = len(fmap[key])
  66. ctp = ((n-1)*n)+n
  67. for i in range(n):
  68. for j in range(n):
  69. fp = os.path.join(key, fmap[key][i])
  70. fo = os.path.join(key, fmap[key][j])
  71. atp[key].append(((fp,fo),0))
  72. return atp
  73.  
  74.  
  75. def _len_similar_pair(self):
  76. fmap = self.files_map
  77. dct = {}
  78. spair = self._similar_pair()
  79. for key in fmap.keys():
  80. dd = {key:len(spair[key])}
  81. dct.update(dd)
  82. return dct
  83.  
  84.  
  85. def _diff_pair_dircomp(self):
  86. fmap = self.files_map
  87. dirname = list(fmap.keys())
  88. pair_dircomp=[]
  89. for idx in range(len(dirname)):
  90. dirtmp = dirname.copy()
  91. dirtmp.pop(idx)
  92. odir = dirtmp
  93. pdir = dirname[idx]
  94. pdc = (pdir, odir)
  95. pair_dircomp.append(pdc)
  96. return pair_dircomp
  97.  
  98.  
  99. def _different_pair(self):
  100. fmap = self.files_map
  101. pair_sampled = {}
  102. pair_dircomp = self._diff_pair_dircomp()
  103. len_spair = self._len_similar_pair()
  104. for idx, (kp,kvo) in enumerate(pair_dircomp):
  105. val_pri = fmap[kp]
  106. num_sample = len(val_pri)//4
  107.  
  108. pair_sampled.update({kp:[]})
  109. for vp in val_pri:
  110. #get filename file primary
  111. fp = os.path.join(kp,vp)
  112. for ko in kvo:
  113. vov = fmap[ko]
  114. pair=[]
  115. for vo in vov:
  116. fo = os.path.join(ko,vo)
  117. pair.append(((fp, fo),1))
  118. mout = random.sample(pair,num_sample)
  119. pair_sampled[kp].append(mout)
  120.  
  121. for key in pair_sampled.keys():
  122. val = pair_sampled[key]
  123. num_sample =len_spair[key]
  124. tmp_val = []
  125. for va in val:
  126. for v in va:
  127. tmp_val.append(v)
  128. pair_sampled[key] = random.sample(tmp_val,num_sample)
  129.  
  130. return pair_sampled
  131.  
  132.  
  133. def _pair_files(self):
  134. fmap = self.files_map
  135. base_path = self.root
  136. sim_pair = self._similar_pair()
  137. diff_pair = self._different_pair()
  138. files_list = []
  139. for key in fmap.keys():
  140. spair = sim_pair[key]
  141. dpair = diff_pair[key]
  142. n = len(spair)
  143. for i in range(n):
  144. spair_p = os.path.join(base_path,spair[i][0][0])
  145. spair_o = os.path.join(base_path,spair[i][0][1])
  146. spair[i] = ((spair_p, spair_o), 0)
  147.  
  148. dpair_p = os.path.join(base_path, dpair[i][0][0])
  149. dpair_o = os.path.join(base_path, dpair[i][0][1])
  150. dpair[i] = ((dpair_p, dpair_o), 1)
  151.  
  152. files_list.append(spair[i])
  153. files_list.append(dpair[i])
  154.  
  155. return files_list
  156.  
  157.  
  158.  
  159. root='/data/att_faces'
  160. sd = SiameseDataset(root, ext="pgm")
  161. sd.__getitem__(3)[0]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement