Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from sklearn.datasets import make_multilabel_classification
- from sklearn.ensemble import GradientBoostingClassifier
- from sklearn.model_selection import train_test_split
- from sklearn.metrics import confusion_matrix
- # Dataset init
- x, y = make_multilabel_classification(n_samples=1000, n_features=10, n_classes=3, n_labels=1, random_state=0)
- y = y.sum(axis=1)
- x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=0, test_size=0.33)
- # Classification
- classifier = GradientBoostingClassifier()
- classifier.fit(x_train, y_train)
- y_score = classifier.predict(x_test)
- cm = confusion_matrix(y_test, y_score)
- def calculate_tpr_tnr(cm):
- """
- Sensitivity (TPR) and specificity (TNR) calculation
- per class for scikit-learn machine learning algorithms.
- -------
- cm : ndarray
- Confusion matrix obtained with `sklearn.metrics.confusion_matrix`
- method.
- Returns
- -------
- sensitivities : ndarray
- Array of sensitivity values per each class.
- specificities : ndarray
- Array of specificity values per each class.
- """
- # Sensitivity = TP/(TP + FN)
- # TP of a class is a diagonal element
- # Sum of all values in a row is TP + FN
- # So, we can vectorize it this way:
- sensitivities = np.diag(cm) / np.sum(cm, axis=1)
- # Specificity = TN/(TN + FP)
- # FP is the sum of all values in a column excluding TP (diagonal element)
- # TN of a class is the sum of all cols and rows excluding this class' col and row
- # A bit harder case...
- # TN + FP
- cm_sp = np.tile(cm, (cm.shape[0], 1, 1))
- z = np.zeros(cm_sp.shape)
- ids = np.arange(cm_sp.shape[0])
- # Placing a row mask
- # That will be our TN + FP vectorized calculation
- z[ids, ids, :] = 1
- tnfp = np.ma.array(cm_sp, mask=z).sum(axis=(1, 2))
- # TN
- # Now adding a column mask
- z[ids, :, ids] = 1
- tn = np.ma.array(cm_sp, mask=z).sum(axis=(1, 2))
- # Finally, calculating specificities per each class
- specificities = (tn / tnfp).filled()
- return sensitivities, specificities
- calculate_tpr_tnr(cm)
Add Comment
Please, Sign In to add comment