Advertisement
Guest User

Untitled

a guest
Dec 9th, 2019
219
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.83 KB | None | 0 0
  1. from sklearn.model_selection import train_test_split
  2. import tensorflow as tf
  3. import os
  4.  
  5. def _get_label(file_path):
  6.  
  7. # Convert the path to a list of path components
  8. parts = tf.strings.split(file_path, os.path.sep)
  9.  
  10. # The second to last is the class-directory
  11. return parts[-2].numpy().decode('utf8')
  12.  
  13. def split(pattern, test_size=0.25, val_size=None, random_state=None, shuffle=True, stratify=False):
  14. """Split a dataset of all files matching one or more glob patterns
  15.  
  16. Arguments:
  17.  
  18. pattern: A string, a list of strings, or a tf.Tensor of string
  19. type (scalar or vector), representing the filename glob (i.e. shell wildcard)
  20. pattern(s) that will be matched.
  21.  
  22. test_size: If float, should be between 0.0 and 1.0 and represent the proportion of the
  23. dataset to include in the test split. If int, represents the absolute number
  24. of test samples.
  25.  
  26. val_size: If float, should be between 0.0 and 1.0 and represent the proportion of the
  27. train dataset to include in the test split. If int, represents the absolute
  28. number of test samples.
  29.  
  30. random_state: If int, random_state is the seed used by the random number generator;
  31. If RandomState instance, random_state is the random number generator;
  32. If None, the random number generator is the RandomState instance used by np.random.
  33.  
  34. shuffle: Whether or not to shuffle the data before splitting.
  35. If shuffle=False then stratify must be None.
  36. """
  37.  
  38. # Search the files
  39. matching = tf.io.matching_files(pattern)
  40.  
  41. # X and y to split
  42. X = []
  43. y = []
  44.  
  45. for m in matching:
  46.  
  47. # Filter out directories
  48. if not tf.io.gfile.isdir(m.numpy()):
  49.  
  50. xValue = m.numpy().decode('utf8')
  51.  
  52. X.append(xValue)
  53. y.append(_get_label(xValue))
  54.  
  55. # Verify if stratify X_test
  56. _stratify = y if stratify else None
  57.  
  58. # Generate train and test datasets
  59. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, shuffle=shuffle, stratify=_stratify)
  60.  
  61. # Verify if stratify X_val
  62. _stratify = y_train if stratify else None
  63.  
  64. # Generate validation dataset
  65. if val_size is not None:
  66.  
  67. X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_size, random_state=random_state, shuffle=shuffle, stratify=_stratify)
  68.  
  69. # Returns Train / Test / Val datasets
  70. # return X_train, y_train, X_test, y_test, X_val, y_val
  71. return (
  72. tf.data.Dataset.from_tensor_slices((X_train, y_train)),
  73. tf.data.Dataset.from_tensor_slices((X_test, y_test)),
  74. tf.data.Dataset.from_tensor_slices((X_val, y_val))
  75. )
  76.  
  77. # Returns Train / Test datasets
  78. # return (X_train, y_train, X_test, y_test)
  79. return (
  80. tf.data.Dataset.from_tensor_slices((X_train, y_train)),
  81. tf.data.Dataset.from_tensor_slices((X_test, y_test))
  82. )
  83.  
  84. def main():
  85.  
  86. (trainDS, testDS, valDS) = split('dataset/*/*', val_size=0.1, stratify=True)
  87.  
  88. train = {}
  89. test = {}
  90. val = {}
  91.  
  92. for d in trainDS:
  93.  
  94. value = d[1].numpy().decode('utf8')
  95.  
  96. train[value] = 1 if train.get(value) is None else train[value] + 1
  97.  
  98.  
  99. for d in testDS:
  100.  
  101. value = d[1].numpy().decode('utf8')
  102.  
  103. test[value] = 1 if test.get(value) is None else test[value] + 1
  104.  
  105.  
  106. for d in valDS:
  107.  
  108. value = d[1].numpy().decode('utf8')
  109.  
  110. val[value] = 1 if val.get(value) is None else val[value] + 1
  111.  
  112.  
  113. print(' ')
  114. print(f'{train}')
  115. print(' ')
  116. print(f'{test}')
  117. print(' ')
  118. print(f'{val}')
  119. print(' ')
  120.  
  121.  
  122. if __name__ == "__main__":
  123. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement