Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def calculate_classification_metrics(test_targets, predictions) -> dict:
- """Calculation of accuracy score, F1 micro and F1 macro"""
- results = {
- 'accuracy': accuracy_score(test_targets, predictions),
- 'f1-micro': f1_score(test_targets, predictions, average="micro"), # good for unbalanced classes
- 'f1-macro': f1_score(test_targets, predictions, average="macro"),
- 'precision': precision_score(test_targets, predictions, average="weighted"),
- 'recall': recall_score(test_targets, predictions, average="weighted"),
- }
- # 'accuracy', 'f1-micro', 'f1-macro', 'precision', 'recall'
- print(f"\tAccuracy score: {results['accuracy']:.3f}")
- print(f"\tF1-micro: {results['f1-micro']:.3f}")
- print(f"\tF1-macro: {results['f1-macro']:.3f}")
- print(f"\tPrecision score: {results['precision']:.3f}")
- print(f"\tRecall score: {results['recall']:.3f}")
- return results
- def feature_importances(model, feature_list: list, importance_limit: float = 0.05, display_top_n: int = 20,
- display_graph=True):
- importances = list(model.feature_importances_) # List of tuples with variable and importance
- # Sort the feature importances by most important first
- feature_importances = [(feature, round(importance, 2)) for feature, importance in
- zip(feature_list, importances)]
- feature_importances = sorted(feature_importances, key=lambda x: x[1],
- reverse=True) # Print out the feature and importances
- features_to_keep = []
- for i, pair in zip(range(display_top_n), feature_importances):
- feature, importance = pair
- if importance < importance_limit:
- break
- features_to_keep.append(feature)
- print(f'{i + 1:3}. Feature: {feature:20} Importance: {importance}')
- print(features_to_keep)
- if display_graph:
- features = list()
- importances = list()
- for feature, importance in feature_importances[:display_top_n]:
- features.append(feature)
- importances.append(importance)
- # convert to pd.DataFrame
- sns.set(style='darkgrid')
- plt.plot(features, importances)
- plt.xticks(rotation=90)
- plt.show()
- return features_to_keep
- def visualize_dict(D, sort=True):
- if sort:
- D = {k: v for k, v in sorted(D.items(),
- key=lambda item: item[1])}
- df = pd.DataFrame(data=D, index=list(D.keys()))
- ax = sns.barplot(data=df)
- # adding the text labels
- rects = ax.patches
- labels = list(D.values())
- for rect, label in zip(rects, labels):
- height = rect.get_height()
- ax.text(rect.get_x() + rect.get_width() / 2,
- height,
- label,
- ha='center',
- va='bottom')
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement