Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- (train_features, train_labels), (test_features, test_labels) = tf.keras.datasets.mnist.load_data()
- train_features = train_features.reshape(-1, 28, 28, 1)
- train_features = train_features.astype('float32')
- train_features = train_features / 255.
- test_features = test_features.reshape(-1, 28, 28, 1)
- test_features = test_features.astype('float32')
- test_features = test_features / 255.
- train_labels = tf.keras.utils.to_categorical(train_labels)
- test_labels = tf.keras.utils.to_categorical(test_labels)
- validation_features, test_features, validation_labels, test_labels = train_test_split(test_features,
- test_labels,
- test_size=0.50,
- stratify=test_labels)
- train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_labels))
- train_dataset = train_dataset.prefetch(BATCH_SIZE * 8)
- train_dataset = train_dataset.shuffle(train_features.shape[0])
- train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
- validation_dataset = tf.data.Dataset.from_tensor_slices((validation_features, validation_labels))
- validation_dataset = validation_dataset.batch((BATCH_SIZE // 4))
- test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_labels))
- test_dataset = test_dataset.batch((BATCH_SIZE // 4))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement