Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def main():
- img_clf = ImageClassifier()
- # Feature cache files.
- # If you modify img_clf.extract_image_features(),
- # change the filename or delete the cache files.
- train_cache = 'train_data_v1.dat'
- test_cache = 'test_data_v1.dat'
- # load images
- print("Loading training images...")
- (train_raw, train_labels) = img_clf.load_data_from_folder('./train/')
- print("Loading testing images...")
- (test_raw, test_labels) = img_clf.load_data_from_folder('./test/')
- # convert images into features
- try:
- with open(train_cache, 'rb') as train_data_file:
- print("Loading cached training image features...")
- train_data = pickle.load(train_data_file)
- except FileNotFoundError:
- print("Extracting training image features...")
- train_data = img_clf.extract_image_features(train_raw)
- with open('train_data_v1.dat', 'wb') as train_data_file:
- print("Caching training image features...")
- pickle.dump(train_data, train_data_file)
- try:
- with open(test_cache, 'rb') as test_data_file:
- print("Loading cached testing image features...")
- test_data = pickle.load(test_data_file)
- except FileNotFoundError:
- print("Extracting testing image features...")
- test_data = img_clf.extract_image_features(test_raw)
- with open(test_cache, 'wb') as test_data_file:
- print("Caching testing image features...")
- pickle.dump(test_data, test_data_file)
- # train model and test on training data
- print("Training classifier...")
- img_clf.train_classifier(train_data, train_labels)
- print("Predicting training image labels...")
- predicted_labels = img_clf.predict_labels(train_data)
- print("\nTraining results")
- print("=============================")
- print("Confusion Matrix:\n", metrics.confusion_matrix(train_labels, predicted_labels))
- print("Accuracy: ", metrics.accuracy_score(train_labels, predicted_labels))
- print("F1 score: ", metrics.f1_score(train_labels, predicted_labels, average='micro'))
- # test model
- print("Predicting testing image labels...")
- predicted_labels = img_clf.predict_labels(test_data)
- print("\nTest results")
- print("=============================")
- print("Confusion Matrix:\n", metrics.confusion_matrix(test_labels, predicted_labels))
- print("Accuracy: ", metrics.accuracy_score(test_labels, predicted_labels))
- print("F1 score: ", metrics.f1_score(test_labels, predicted_labels, average='micro'))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement