Advertisement
Guest User

Untitled

a guest
Apr 23rd, 2019
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.76 KB | None | 0 0
  1. from __future__ import absolute_import, division, print_function
  2.  
  3. import argparse
  4. import os
  5. from random import randint
  6.  
  7. import matplotlib.pyplot as plt
  8. import tensorflow as tf
  9. from tensorflow import keras
  10.  
  11. CIFAR100_LABELS_LIST = [
  12.     'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
  13.     'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
  14.     'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
  15.     'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
  16.     'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
  17.     'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
  18.     'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
  19.     'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
  20.     'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
  21.     'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
  22.     'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
  23.     'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
  24.     'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
  25.     'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
  26. ]
  27.  
  28. CIFAR100_SUPERCLASS_LABELS_LIST = [
  29.     'aquatic_mammals', 'fish', 'flowers', 'food_containers',
  30.     'fruit_and_vegetables', 'household_electrical_devices',
  31.     'household_furniture', 'insects', 'large_carnivores',
  32.     'large_man-made_outdoor_things', 'large_natural_outdoor_scenes',
  33.     'large_omnivores_and_herbivores', 'medium_mammals',
  34.     'non-insect_invertebrates', 'people', 'reptiles', 'small_mammals',
  35.     'trees', 'vehicles_1', 'vehicles_2'
  36. ]
  37.  
  38.  
  39. def file_type(file_path):
  40.     if not os.path.exists(file_path):
  41.         raise argparse.ArgumentTypeError('File does not exist')
  42.     return file_path
  43.  
  44.  
  45. def main():
  46.     parser = argparse.ArgumentParser('Process image')
  47.     parser.add_argument('-f', dest="filename", required=True, type=file_type,
  48.                         help="input image file path")
  49.  
  50.     file_path = parser.parse_args().filename
  51.  
  52.     loaded_model = tf.keras.models.load_model('model')
  53.     if not loaded_model:
  54.         train_model()
  55.  
  56.  
  57.     # for i in range(10):
  58.     #     traf = randint(1, 9999)
  59.     #     predictions = new_model.predict(test_images)
  60.     #     print("\n\nnumer z tablicy:   " + str(np.argmax(predictions[traf])))
  61.     #     print("kategoria:   " + str(CIFAR100_LABELS_LIST[np.argmax(predictions[traf])]))
  62.     #     print("\n\nobrazek:")
  63.     #     plt.imshow(test_images[traf])
  64.     #     plt.show()
  65.     #     print("-" * 49)
  66.  
  67.  
  68. def plot_random_picture(x_images, y_labels):
  69.     random_picture_index = randint(0, 1000)
  70.     plt.figure()
  71.     plt.imshow(x_images[random_picture_index])
  72.     plt.colorbar()
  73.     plt.grid(False)
  74.     plt.suptitle(CIFAR100_LABELS_LIST[int(y_labels[random_picture_index])])
  75.     plt.show()
  76.  
  77.  
  78. def train_model():
  79.     (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data(label_mode='fine')
  80.     train_images = train_images / 255.0
  81.     test_images = test_images / 255.0
  82.  
  83.     model = keras.Sequential([
  84.         keras.layers.Flatten(input_shape=(32, 32, 3)),
  85.         keras.layers.Dense(256, activation=tf.nn.relu),
  86.         keras.layers.Dense(128, activation=tf.nn.relu),
  87.         keras.layers.Dense(100, activation=tf.nn.softmax)
  88.  
  89.     ])
  90.  
  91.     model.compile(optimizer='adam',
  92.                   loss='sparse_categorical_crossentropy',
  93.                   metrics=['accuracy'])
  94.  
  95.     model.fit(train_images, train_labels, epochs=100, batch_size=100, validation_data=(test_images, test_labels))
  96.     model.save('model')
  97.  
  98.  
  99. if __name__ == '__main__':
  100.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement