Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import matplotlib.pyplot as plt
- import numpy as np
- import tensorflow as tf
- from sklearn.model_selection import train_test_split
- from tensorflow.keras import datasets, layers, models
- from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
- from tensorflow.keras.metrics import Precision
- from tensorflow.keras.utils import to_categorical
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
- # Wczytanie danych
- (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
- # Normalizacja wartości pikseli
- train_images, test_images = train_images / 255.0, test_images / 255.0
- # Podział danych na zbiory treningowy i walidacyjny
- train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.2, random_state=42)
- # Konwersja etykiet na one-hot encoding
- num_classes = 10
- train_labels_categorical = to_categorical(train_labels, num_classes=num_classes)
- val_labels_categorical = to_categorical(val_labels, num_classes=num_classes)
- data_generator = ImageDataGenerator(
- horizontal_flip=True,
- rotation_range=15,
- width_shift_range=0.1,
- height_shift_range=0.1,
- zoom_range=0.1
- )
- # Budowanie modelu
- model = models.Sequential()
- model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
- model.add(layers.MaxPooling2D((2, 2)))
- model.add(layers.Conv2D(64, (3, 3), activation='relu'))
- model.add(layers.MaxPooling2D((2, 2)))
- model.add(layers.Conv2D(64, (3, 3), activation='relu'))
- model.add(layers.Dropout(0.5))
- model.add(Flatten())
- model.add(Dense(64, activation='relu'))
- model.add(Dense(10))
- # Kompilacja modelu
- model.compile(optimizer='adam',
- loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
- metrics=['accuracy'])
- # Inicjalizacja list do przechowywania metryk
- train_accuracy = []
- val_accuracy = []
- train_precision = []
- val_precision = []
- # Uczenie modelu i obliczanie metryk dla każdej epoki
- epochs = 30
- for epoch in range(epochs):
- print(f"Epoch {epoch + 1}/{epochs}")
- # Uczenie modelu na danych treningowych
- history = model.fit(
- data_generator.flow(train_images, train_labels_categorical, batch_size=32),
- validation_data=(val_images, val_labels_categorical),
- steps_per_epoch=len(train_images) // 32,
- epochs=1,
- verbose=1
- )
- # Dodawanie metryk do list
- train_accuracy.append(history.history['accuracy'][0])
- val_accuracy.append(history.history['val_accuracy'][0])
- # Obliczanie precyzji dla danych treningowych i walidacyjnych
- train_preds = model.predict(train_images)
- train_preds = np.argmax(train_preds, axis=1)
- train_prec = Precision()
- train_prec.update_state(train_labels, train_preds)
- train_precision.append(train_prec.result().numpy())
- val_preds = model.predict(val_images)
- val_preds = np.argmax(val_preds, axis=1)
- val_prec = Precision()
- val_prec.update_state(val_labels, val_preds)
- val_precision.append(val_prec.result().numpy())
- print(f"Precision (train): {train_precision[-1]:.4f}")
- print(f"Precision (validation): {val_precision[-1]:.4f}")
- # Rysowanie wykresów
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
- # Wykres zmiany# dokładności ('accuracy') i dokładności walidacyjnej ('val_accuracy') w czasie
- ax1.plot(train_accuracy, label='accuracy')
- ax1.plot(val_accuracy, label='val_accuracy')
- ax1.set_xlabel('Epoch')
- ax1.set_ylabel('Accuracy')
- ax1.set_ylim([0.5, 1])
- ax1.legend(loc='lower right')
- ax1.set_title('Accuracy and Validation Accuracy')
- # Wykres zmiany precyzji ('precision') i precyzji walidacyjnej ('val_precision') w czasie
- ax2.plot(train_precision, label='precision')
- ax2.plot(val_precision, label='val_precision')
- ax2.set_xlabel('Epoch')
- ax2.set_ylabel('Precision')
- ax2.set_ylim([0.5, 1])
- ax2.legend(loc='lower right')
- ax2.set_title('Precision and Validation Precision')
- plt.show()
- # Ewaluacja modelu na danych testowych
- test_loss, test_acc = model.evaluate(test_images, to_categorical(test_labels, num_classes=num_classes), verbose=2)
- print(f"Test accuracy: {test_acc}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement