Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from sklearn.model_selection import train_test_split
- import tensorflow as tf
- import os
- def _get_label(file_path):
- # Convert the path to a list of path components
- parts = tf.strings.split(file_path, os.path.sep)
- # The second to last is the class-directory
- return parts[-2].numpy().decode('utf8')
- def split(pattern, test_size=0.25, val_size=None, random_state=None, shuffle=True, stratify=False):
- """Split a dataset of all files matching one or more glob patterns
- Arguments:
- pattern: A string, a list of strings, or a tf.Tensor of string
- type (scalar or vector), representing the filename glob (i.e. shell wildcard)
- pattern(s) that will be matched.
- test_size: If float, should be between 0.0 and 1.0 and represent the proportion of the
- dataset to include in the test split. If int, represents the absolute number
- of test samples.
- val_size: If float, should be between 0.0 and 1.0 and represent the proportion of the
- train dataset to include in the test split. If int, represents the absolute
- number of test samples.
- random_state: If int, random_state is the seed used by the random number generator;
- If RandomState instance, random_state is the random number generator;
- If None, the random number generator is the RandomState instance used by np.random.
- shuffle: Whether or not to shuffle the data before splitting.
- If shuffle=False then stratify must be None.
- """
- # Search the files
- matching = tf.io.matching_files(pattern)
- # X and y to split
- X = []
- y = []
- for m in matching:
- # Filter out directories
- if not tf.io.gfile.isdir(m.numpy()):
- xValue = m.numpy().decode('utf8')
- X.append(xValue)
- y.append(_get_label(xValue))
- # Verify if stratify X_test
- _stratify = y if stratify else None
- # Generate train and test datasets
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, shuffle=shuffle, stratify=_stratify)
- # Verify if stratify X_val
- _stratify = y_train if stratify else None
- # Generate validation dataset
- if val_size is not None:
- X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_size, random_state=random_state, shuffle=shuffle, stratify=_stratify)
- # Returns Train / Test / Val datasets
- # return X_train, y_train, X_test, y_test, X_val, y_val
- return (
- tf.data.Dataset.from_tensor_slices((X_train, y_train)),
- tf.data.Dataset.from_tensor_slices((X_test, y_test)),
- tf.data.Dataset.from_tensor_slices((X_val, y_val))
- )
- # Returns Train / Test datasets
- # return (X_train, y_train, X_test, y_test)
- return (
- tf.data.Dataset.from_tensor_slices((X_train, y_train)),
- tf.data.Dataset.from_tensor_slices((X_test, y_test))
- )
- def main():
- (trainDS, testDS, valDS) = split('dataset/*/*', val_size=0.1, stratify=True)
- train = {}
- test = {}
- val = {}
- for d in trainDS:
- value = d[1].numpy().decode('utf8')
- train[value] = 1 if train.get(value) is None else train[value] + 1
- for d in testDS:
- value = d[1].numpy().decode('utf8')
- test[value] = 1 if test.get(value) is None else test[value] + 1
- for d in valDS:
- value = d[1].numpy().decode('utf8')
- val[value] = 1 if val.get(value) is None else val[value] + 1
- print(' ')
- print(f'{train}')
- print(' ')
- print(f'{test}')
- print(' ')
- print(f'{val}')
- print(' ')
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement