Advertisement
adwas33

Untitled

May 4th, 2023
16
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.12 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import tensorflow as tf
  4. from sklearn.model_selection import train_test_split
  5. from tensorflow.keras import datasets, layers, models
  6. from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
  7. from tensorflow.keras.metrics import Precision
  8. from tensorflow.keras.utils import to_categorical
  9. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  10.  
  11.  
  12. # Wczytanie danych
  13. (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
  14.  
  15. # Normalizacja wartości pikseli
  16. train_images, test_images = train_images / 255.0, test_images / 255.0
  17.  
  18. # Podział danych na zbiory treningowy i walidacyjny
  19. train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.2, random_state=42)
  20.  
  21. # Konwersja etykiet na one-hot encoding
  22. num_classes = 10
  23. train_labels_categorical = to_categorical(train_labels, num_classes=num_classes)
  24. val_labels_categorical = to_categorical(val_labels, num_classes=num_classes)
  25.  
  26. data_generator = ImageDataGenerator(
  27. horizontal_flip=True,
  28. rotation_range=15,
  29. width_shift_range=0.1,
  30. height_shift_range=0.1,
  31. zoom_range=0.1
  32. )
  33.  
  34. # Budowanie modelu
  35. model = models.Sequential()
  36. model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
  37. model.add(layers.MaxPooling2D((2, 2)))
  38. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  39. model.add(layers.MaxPooling2D((2, 2)))
  40. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  41. model.add(layers.Dropout(0.5))
  42. model.add(Flatten())
  43. model.add(Dense(64, activation='relu'))
  44. model.add(Dense(10))
  45.  
  46. # Kompilacja modelu
  47. model.compile(optimizer='adam',
  48. loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
  49. metrics=['accuracy'])
  50.  
  51. # Inicjalizacja list do przechowywania metryk
  52. train_accuracy = []
  53. val_accuracy = []
  54. train_precision = []
  55. val_precision = []
  56.  
  57. # Uczenie modelu i obliczanie metryk dla każdej epoki
  58. epochs = 30
  59. for epoch in range(epochs):
  60. print(f"Epoch {epoch + 1}/{epochs}")
  61.  
  62. # Uczenie modelu na danych treningowych
  63. history = model.fit(
  64. data_generator.flow(train_images, train_labels_categorical, batch_size=32),
  65. validation_data=(val_images, val_labels_categorical),
  66. steps_per_epoch=len(train_images) // 32,
  67. epochs=1,
  68. verbose=1
  69. )
  70.  
  71.  
  72. # Dodawanie metryk do list
  73. train_accuracy.append(history.history['accuracy'][0])
  74. val_accuracy.append(history.history['val_accuracy'][0])
  75.  
  76. # Obliczanie precyzji dla danych treningowych i walidacyjnych
  77. train_preds = model.predict(train_images)
  78. train_preds = np.argmax(train_preds, axis=1)
  79. train_prec = Precision()
  80. train_prec.update_state(train_labels, train_preds)
  81. train_precision.append(train_prec.result().numpy())
  82.  
  83. val_preds = model.predict(val_images)
  84. val_preds = np.argmax(val_preds, axis=1)
  85. val_prec = Precision()
  86. val_prec.update_state(val_labels, val_preds)
  87. val_precision.append(val_prec.result().numpy())
  88.  
  89. print(f"Precision (train): {train_precision[-1]:.4f}")
  90. print(f"Precision (validation): {val_precision[-1]:.4f}")
  91.  
  92.  
  93. # Rysowanie wykresów
  94. fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
  95.  
  96. # Wykres zmiany# dokładności ('accuracy') i dokładności walidacyjnej ('val_accuracy') w czasie
  97. ax1.plot(train_accuracy, label='accuracy')
  98. ax1.plot(val_accuracy, label='val_accuracy')
  99. ax1.set_xlabel('Epoch')
  100. ax1.set_ylabel('Accuracy')
  101. ax1.set_ylim([0.5, 1])
  102. ax1.legend(loc='lower right')
  103. ax1.set_title('Accuracy and Validation Accuracy')
  104.  
  105. # Wykres zmiany precyzji ('precision') i precyzji walidacyjnej ('val_precision') w czasie
  106. ax2.plot(train_precision, label='precision')
  107. ax2.plot(val_precision, label='val_precision')
  108. ax2.set_xlabel('Epoch')
  109. ax2.set_ylabel('Precision')
  110. ax2.set_ylim([0.5, 1])
  111. ax2.legend(loc='lower right')
  112. ax2.set_title('Precision and Validation Precision')
  113.  
  114. plt.show()
  115.  
  116. # Ewaluacja modelu na danych testowych
  117. test_loss, test_acc = model.evaluate(test_images, to_categorical(test_labels, num_classes=num_classes), verbose=2)
  118. print(f"Test accuracy: {test_acc}")
  119.  
  120.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement