Advertisement
naren_paste

Hyperparameter tuning

Dec 5th, 2023
759
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.85 KB | Source Code | 0 0
  1. from sklearn.model_selection import GridSearchCV
  2.  
  3. def evaluate_model(model, X_train, y_train, X_test, y_test, hyperparameters=None):
  4.     if hyperparameters is not None:
  5.         # Perform hyperparameter tuning using GridSearchCV
  6.         grid_search = GridSearchCV(model, hyperparameters, scoring='accuracy', cv=5)
  7.         grid_search.fit(X_train, y_train)
  8.         best_model = grid_search.best_estimator_
  9.     else:
  10.         best_model = model
  11.  
  12.     # Train the best model on the entire training set
  13.     best_model.fit(X_train, y_train)
  14.  
  15.     # Evaluate the model on the test set
  16.     accuracy = best_model.score(X_test, y_test)
  17.  
  18.     return accuracy, best_model
  19.  
  20. # Example usage
  21. from sklearn.ensemble import RandomForestClassifier
  22. from sklearn.svm import SVC
  23. from sklearn.datasets import load_iris
  24. from sklearn.model_selection import train_test_split
  25.  
  26. # Load dataset
  27. iris = load_iris()
  28. X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
  29.  
  30. # Define models
  31. rf_model = RandomForestClassifier(random_state=42)
  32. svc_model = SVC(random_state=42)
  33.  
  34. # Define hyperparameters for tuning (customize based on your needs)
  35. rf_hyperparameters = {'n_estimators': [50, 100, 200], 'max_depth': [None, 10, 20]}
  36. svc_hyperparameters = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}
  37.  
  38. # Evaluate Random Forest with hyperparameter tuning
  39. rf_accuracy, best_rf_model = evaluate_model(rf_model, X_train, y_train, X_test, y_test, hyperparameters=rf_hyperparameters)
  40.  
  41. # Evaluate SVC with hyperparameter tuning
  42. svc_accuracy, best_svc_model = evaluate_model(svc_model, X_train, y_train, X_test, y_test, hyperparameters=svc_hyperparameters)
  43.  
  44. print(f"Random Forest Accuracy: {rf_accuracy}")
  45. print(f"Tuned Random Forest Model: {best_rf_model}")
  46.  
  47. print(f"SVC Accuracy: {svc_accuracy}")
  48. print(f"Tuned SVC Model: {best_svc_model}")
  49.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement