Advertisement
Guest User

Tensorflow Fashion_mnist

a guest
Jul 17th, 2020
322
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.25 KB | None | 0 0
  1. #Import
  2. import matplotlib.pyplot as plt
  3. import tensorflow as tf
  4.  
  5. #Verifie si tensorflow est en version 2.0
  6. assert hasattr(tf, 'function')
  7.  
  8. #Charge le dataset
  9. fashion_mnist = tf.keras.datasets.fashion_mnist
  10. (images, targets), (_, _) = fashion_mnist.load_data() # Target correspond à la catégorie de l'image
  11.  
  12. #Associe un nom à la catégorie
  13. targets_names = ["T-shirt", "Pantalons", "Pull", "Dress", "Veste", "Sandale", "Haut", "Chaussure", "Sac", "Bottes"]
  14.  
  15. # On créer le modèle
  16. model = tf.keras.models.Sequential() # Model sequentiel : l'instruction d'après prend en entrée le résultat de l'instruction d'avant
  17. # Applatie l'image : (28,28) -> (1,784)
  18. model.add(tf.keras.layers.Flatten(input_shape=[28,28]))
  19.  
  20. #Layers
  21. model.add(tf.keras.layers.Dense(256, activation="relu"))
  22. model.add(tf.keras.layers.Dense(128, activation="relu"))
  23. model.add(tf.keras.layers.Dense(10, activation="softmax"))
  24.  
  25. model.compile(
  26.     loss="sparse_categorical_crossentropy",
  27.     optimizer="sgd",
  28.     metrics=["accuracy"]
  29. )
  30.  
  31. history = model.fit(images,targets, epochs=10)
  32.  
  33. loss_curve = history.history["loss"]
  34. acc_curve = history.history["accuracy"]
  35.  
  36. plt.plot(loss_curve)
  37. plt.title("Loss")
  38. plt.show()
  39.  
  40. plt.plot(acc_curve)
  41. plt.title("Accuracy")
  42. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement