Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import absolute_import, division, print_function
- import argparse
- import os
- from random import randint
- import matplotlib.pyplot as plt
- import tensorflow as tf
- from tensorflow import keras
- CIFAR100_LABELS_LIST = [
- 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
- 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
- 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
- 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
- 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
- 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
- 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
- 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
- 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
- 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
- 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
- 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
- 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
- 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
- ]
- CIFAR100_SUPERCLASS_LABELS_LIST = [
- 'aquatic_mammals', 'fish', 'flowers', 'food_containers',
- 'fruit_and_vegetables', 'household_electrical_devices',
- 'household_furniture', 'insects', 'large_carnivores',
- 'large_man-made_outdoor_things', 'large_natural_outdoor_scenes',
- 'large_omnivores_and_herbivores', 'medium_mammals',
- 'non-insect_invertebrates', 'people', 'reptiles', 'small_mammals',
- 'trees', 'vehicles_1', 'vehicles_2'
- ]
- def file_type(file_path):
- if not os.path.exists(file_path):
- raise argparse.ArgumentTypeError('File does not exist')
- return file_path
- def main():
- parser = argparse.ArgumentParser('Process image')
- parser.add_argument('-f', dest="filename", required=True, type=file_type,
- help="input image file path")
- file_path = parser.parse_args().filename
- loaded_model = tf.keras.models.load_model('model')
- if not loaded_model:
- train_model()
- # for i in range(10):
- # traf = randint(1, 9999)
- # predictions = new_model.predict(test_images)
- # print("\n\nnumer z tablicy: " + str(np.argmax(predictions[traf])))
- # print("kategoria: " + str(CIFAR100_LABELS_LIST[np.argmax(predictions[traf])]))
- # print("\n\nobrazek:")
- # plt.imshow(test_images[traf])
- # plt.show()
- # print("-" * 49)
- def plot_random_picture(x_images, y_labels):
- random_picture_index = randint(0, 1000)
- plt.figure()
- plt.imshow(x_images[random_picture_index])
- plt.colorbar()
- plt.grid(False)
- plt.suptitle(CIFAR100_LABELS_LIST[int(y_labels[random_picture_index])])
- plt.show()
- def train_model():
- (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data(label_mode='fine')
- train_images = train_images / 255.0
- test_images = test_images / 255.0
- model = keras.Sequential([
- keras.layers.Flatten(input_shape=(32, 32, 3)),
- keras.layers.Dense(256, activation=tf.nn.relu),
- keras.layers.Dense(128, activation=tf.nn.relu),
- keras.layers.Dense(100, activation=tf.nn.softmax)
- ])
- model.compile(optimizer='adam',
- loss='sparse_categorical_crossentropy',
- metrics=['accuracy'])
- model.fit(train_images, train_labels, epochs=100, batch_size=100, validation_data=(test_images, test_labels))
- model.save('model')
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement