Advertisement
Guest User

Untitled

a guest
Apr 18th, 2019
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.21 KB | None | 0 0
  1. """
  2. Arrange files for binary classification. The files are shuffled before being copied.
  3. """
  4. import os
  5. import numpy as np
  6. from shutil import copyfile
  7. from os import listdir
  8. import argparse
  9.  
  10. def get_args():
  11. parser = argparse.ArgumentParser('python')
  12.  
  13. parser.add_argument('-src_dir',
  14. default='./',
  15. required=False,
  16. help='directory containing folders of images')
  17.  
  18. parser.add_argument('-dst_dir',
  19. default='./cats_and_dogs/',
  20. required=False,
  21. help='destination directory')
  22.  
  23. parser.add_argument('-target_name',
  24. default='cats',
  25. required=False,
  26. help='directory containing folders of images')
  27.  
  28. parser.add_argument('-target_num',
  29. default='0',
  30. required=False,
  31. help='number of the target class')
  32.  
  33. parser.add_argument('-val_frac',
  34. type=float,
  35. default=0.2,
  36. required=False,
  37. help='the fraction of validation data')
  38. return parser.parse_args()
  39.  
  40. # save the splitted datasets to folders
  41. def save_to_folder(fileCollection, sourceFolder, targetFolder):
  42. m = len(fileCollection)
  43. for i in range(m):
  44. source = sourceFolder + '/' + fileCollection[i]
  45. target = targetFolder + '/' + fileCollection[i]
  46. copyfile(source, target)
  47.  
  48. def split_bc(src_dir, dst_dir, val_frac, target_name='cats', target_num='0'):
  49. src_target = src_dir + target_name
  50.  
  51. # a list of files in target class
  52. list_target = [f for f in listdir(src_target)]
  53. num_target = len(list_target)
  54. # shuffle the collection to create a new one
  55. idx_target = np.arange(num_target)
  56. permu_target = np.random.permutation(idx_target)
  57. shuffled_target = [list_target[i] for i in permu_target]
  58.  
  59. # calculate the length of each new folder of target
  60. len_target_test = int(num_target*val_frac)
  61. print('lenth of test data of {}:{}'.format(target_name, len_target_test))
  62. len_target_train = num_target - len_target_test
  63. print('length of train data of {}:{}'.format(target_name, len_target_train))
  64.  
  65. # create 2 new lists of images of target
  66. target_train_set = shuffled_target[0:len_target_train]
  67. target_test_set = shuffled_target[len_target_train : num_target]
  68.  
  69. # create sub-directory in dst_dir of nucleotide
  70. target_train_dir = dst_dir + 'train/' + target_num + '-' + target_name + '/'
  71. if not os.path.exists(target_train_dir):
  72. os.makedirs(target_train_dir)
  73. target_test_dir = dst_dir + 'val/' + target_num + '-' + target_name + '/'
  74. if not os.path.exists(target_test_dir):
  75. os.makedirs(target_test_dir)
  76.  
  77. # copy files to corresponding new folder
  78. save_to_folder(target_train_set, src_target, target_train_dir)
  79. save_to_folder(target_test_set, src_target, target_test_dir)
  80.  
  81. if __name__ == "__main__":
  82. args = get_args()
  83. src_dir = args.src_dir
  84. dst_dir = args.dst_dir
  85. target_name = args.target_name
  86. target_num = args.target_num
  87. val_frac = args.val_frac
  88.  
  89. split_bc(src_dir, dst_dir, val_frac, target_name, target_num)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement