Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from sklearn.naive_bayes import MultinomialNB
- import matplotlib.pyplot as plt
- # indexarea fiecarui interval
- def value_to_bins(train_images, num_bins):
- return np.digitize(train_images, num_bins)
- #calcularea scorului
- def naive_bayes_score(training_data, training_labels, testing_data, testing_labels):
- naive_bayes_model = MultinomialNB()
- naive_bayes_model.fit(training_data, training_labels)
- naive_bayes_model.predict(testing_data)
- return naive_bayes_model.score(testing_data, testing_labels)
- #returnarea y_pred
- def naive_bayes_predict(training_data, training_labels, testing_data):
- naive_bayes_model = MultinomialNB()
- naive_bayes_model.fit(training_data, training_labels)
- return naive_bayes_model.predict(testing_data)
- #afisarea imaginilor misclasate
- def print_missclassified_images(y_pred, y_true, testing_data):
- i = 0
- for (x, y) in zip(y_pred, y_true):
- if(x != y):
- image = testing_data[i, :]
- image = np.reshape(image, (28, 28))
- plt.title("Aceasta imagine a fost clasificata ca " + str(x))
- plt.imshow(image.astype(np.uint8), cmap = 'gray')
- plt.show()
- i = i + 1
- #afisarea matricei de confuzie
- def confusion_matrix(y_true, y_pred):
- matrix = np.zeros((10, 10))
- for (x, y) in zip(y_pred, y_true):
- matrix[y.astype(int), x.astype(int)] = matrix[y.astype(int), x.astype(int)] + 1
- print(matrix)
- train_images = np.loadtxt("data/train_images.txt")
- test_images = np.loadtxt("data/test_images.txt")
- train_labels = np.loadtxt("data/train_labels.txt")
- test_labels = np.loadtxt("data/test_labels.txt")
- num_bins = np.linspace(0, 255, 5)
- #num_bins = [3, 5, 7, 9, 11]
- digitized_train_images = value_to_bins(train_images, num_bins) - 1
- digitized_test_images = value_to_bins(test_images, num_bins) - 1
- print(naive_bayes_score(digitized_train_images, train_labels, digitized_test_images, test_labels))
- y_pred = naive_bayes_predict(digitized_train_images, train_labels, digitized_test_images)
- #print_missclassified_images(y_pred, test_labels, test_images)
- confusion_matrix(test_labels, y_pred)
Advertisement
Add Comment
Please, Sign In to add comment