Advertisement
mayankjoin3

multi-class-claude

Feb 7th, 2025
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.32 KB | None | 0 0
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. import seaborn as sns
  6. import time
  7. from scipy import stats
  8.  
  9. from sklearn.model_selection import StratifiedKFold, train_test_split
  10. from sklearn.ensemble import RandomForestClassifier
  11. from sklearn.preprocessing import StandardScaler, LabelEncoder
  12. from sklearn.metrics import (
  13.     accuracy_score, precision_score, recall_score, f1_score,
  14.     confusion_matrix, roc_curve, roc_auc_score, precision_recall_curve,
  15.     average_precision_score, matthews_corrcoef, cohen_kappa_score,
  16.     mean_squared_error, mean_absolute_error, log_loss,
  17.     hamming_loss, jaccard_score, balanced_accuracy_score
  18. )
  19. from sklearn.preprocessing import label_binarize
  20.  
  21. class UNSWRandomForestClassifier:
  22.     def __init__(self, n_estimators=100, random_state=42):
  23.         self.n_estimators = n_estimators
  24.         self.random_state = random_state
  25.         self.model = RandomForestClassifier(
  26.             n_estimators=n_estimators,
  27.             random_state=random_state,
  28.             n_jobs=-1
  29.         )
  30.        
  31.     def load_and_preprocess_data(self, file_path):
  32.         data = pd.read_csv(file_path)
  33.        
  34.         X = data.drop('label', axis=1)
  35.         y = data['label']
  36.        
  37.         le = LabelEncoder()
  38.         y = le.fit_transform(y)
  39.        
  40.         scaler = StandardScaler()
  41.         X = scaler.fit_transform(X)
  42.        
  43.         return X, y
  44.    
  45.     def compute_comprehensive_metrics(self, y_true, y_pred, y_prob, fold_number):
  46.         n_classes = len(np.unique(y_true))
  47.         metrics = {'Fold': fold_number}
  48.        
  49.         # Global Metrics
  50.         metrics.update({
  51.             'Accuracy': accuracy_score(y_true, y_pred),
  52.             'Balanced Accuracy': balanced_accuracy_score(y_true, y_pred),
  53.             'Matthews Correlation Coefficient': matthews_corrcoef(y_true, y_pred),
  54.             'Cohen Kappa Score': cohen_kappa_score(y_true, y_pred)
  55.         })
  56.        
  57.         # Per-Class Metrics
  58.         for i in range(n_classes):
  59.             y_true_binary = (y_true == i)
  60.             y_pred_binary = (y_pred == i)
  61.            
  62.             tn, fp, fn, tp = confusion_matrix(y_true_binary, y_pred_binary).ravel()
  63.            
  64.             # Per-Class Metrics
  65.             metrics[f'Class_{i}_Precision'] = precision_score(y_true_binary, y_pred_binary)
  66.             metrics[f'Class_{i}_Recall'] = recall_score(y_true_binary, y_pred_binary)
  67.             metrics[f'Class_{i}_F1_Score'] = f1_score(y_true_binary, y_pred_binary)
  68.             metrics[f'Class_{i}_TPR'] = tp / (tp + fn)
  69.             metrics[f'Class_{i}_TNR'] = tn / (tn + fp)
  70.             metrics[f'Class_{i}_FPR'] = fp / (fp + tn)
  71.             metrics[f'Class_{i}_FNR'] = fn / (fn + tp)
  72.             metrics[f'Class_{i}_PPV'] = tp / (tp + fp)
  73.             metrics[f'Class_{i}_NPV'] = tn / (tn + fn)
  74.        
  75.         # ROC and PR Metrics
  76.         y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
  77.         metrics.update({
  78.             'AUC (Macro)': roc_auc_score(y_true_bin, y_prob, multi_class='ovr', average='macro'),
  79.             'AUC-PR (Macro)': average_precision_score(y_true_bin, y_prob, average='macro')
  80.         })
  81.        
  82.         # Error Metrics
  83.         metrics.update({
  84.             'Mean Squared Error': mean_squared_error(y_true, y_pred),
  85.             'Mean Absolute Error': mean_absolute_error(y_true, y_pred),
  86.             'Root Mean Squared Error': np.sqrt(mean_squared_error(y_true, y_pred)),
  87.             'Log Loss': log_loss(y_true, y_prob)
  88.         })
  89.        
  90.         # Other Metrics
  91.         metrics.update({
  92.             'Hamming Loss': hamming_loss(y_true, y_pred),
  93.             'Jaccard Score (Macro)': jaccard_score(y_true, y_pred, average='macro')
  94.         })
  95.        
  96.         return pd.DataFrame([metrics])
  97.    
  98.     def train_and_evaluate(self, X, y, n_splits=5, output_csv_path='metrics_output.csv'):
  99.         skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
  100.        
  101.         fold_metrics_list = []
  102.         fold_feature_importances = []
  103.         total_training_time = 0
  104.         total_testing_time = 0
  105.        
  106.         os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
  107.        
  108.         for fold, (train_index, val_index) in enumerate(skf.split(X, y), 1):
  109.             X_train, X_val = X[train_index], X[val_index]
  110.             y_train, y_val = y[train_index], y[val_index]
  111.            
  112.             start_train_time = time.time()
  113.             self.model.fit(X_train, y_train)
  114.             training_time = time.time() - start_train_time
  115.             total_training_time += training_time
  116.            
  117.             start_test_time = time.time()
  118.             y_pred = self.model.predict(X_val)
  119.             y_prob = self.model.predict_proba(X_val)
  120.             testing_time = time.time() - start_test_time
  121.             total_testing_time += testing_time
  122.            
  123.             fold_metrics = self.compute_comprehensive_metrics(y_val, y_pred, y_prob, fold)
  124.             fold_metrics_list.append(fold_metrics)
  125.            
  126.             fold_metrics.to_csv(output_csv_path, mode='a', header=not os.path.exists(output_csv_path), index=False)
  127.            
  128.             fold_feature_importances.append(self.model.feature_importances_)
  129.        
  130.         all_fold_metrics = pd.concat(fold_metrics_list, ignore_index=True)
  131.         avg_metrics = all_fold_metrics.mean()
  132.        
  133.         avg_feature_importance = np.mean(fold_feature_importances, axis=0)
  134.        
  135.         return avg_metrics, avg_feature_importance
  136.    
  137.     def plot_one_vs_all_auc_roc(self, X, y):
  138.         y_bin = label_binarize(y, classes=np.unique(y))
  139.         n_classes = y_bin.shape[1]
  140.        
  141.         X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
  142.        
  143.         self.model.fit(X_train, y_train)
  144.         y_score = self.model.predict_proba(X_test)
  145.        
  146.         plt.figure(figsize=(10, 8))
  147.         colors = ['blue', 'red', 'green', 'orange', 'purple']
  148.        
  149.         for i, color in zip(range(n_classes), colors[:n_classes]):
  150.             fpr, tpr, _ = roc_curve(y_test == i, y_score[:, i])
  151.             roc_auc = auc(fpr, tpr)
  152.            
  153.             plt.plot(fpr, tpr, color=color,
  154.                      label=f'ROC curve (class {i}, AUC = {roc_auc:.2f})')
  155.        
  156.         plt.plot([0, 1], [0, 1], 'k--')
  157.         plt.xlim([0.0, 1.0])
  158.         plt.ylim([0.0, 1.05])
  159.         plt.xlabel('False Positive Rate')
  160.         plt.ylabel('True Positive Rate')
  161.         plt.title('Receiver Operating Characteristic (ROC) - One-vs-All')
  162.         plt.legend(loc="lower right")
  163.         plt.tight_layout()
  164.         plt.show()
  165.  
  166. def main():
  167.     file_path = 'path/to/unsw_nb15_dataset.csv'
  168.     output_csv_path = 'results/metrics_output.csv'
  169.    
  170.     rf_classifier = UNSWRandomForestClassifier(n_estimators=100)
  171.    
  172.     X, y = rf_classifier.load_and_preprocess_data(file_path)
  173.    
  174.     feature_names = pd.read_csv(file_path).drop('label', axis=1).columns
  175.    
  176.     avg_metrics, feature_importance = rf_classifier.train_and_evaluate(
  177.         X, y,
  178.         n_splits=5,
  179.         output_csv_path=output_csv_path
  180.     )
  181.    
  182.     print("Average Metrics across 5-Fold Cross-Validation:")
  183.     for metric, value in avg_metrics.items():
  184.         print(f"{metric}: {value}")
  185.    
  186.     rf_classifier.plot_one_vs_all_auc_roc(X, y)
  187.  
  188. if __name__ == "__main__":
  189.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement