Advertisement
Guest User

Untitled

a guest
Mar 25th, 2019
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.07 KB | None | 0 0
  1. """
  2. openml_rerf_test.py
  3. """
  4.  
  5. import sys
  6. import openml
  7. import argparse
  8. import numpy as np
  9. import pandas as pd
  10.  
  11. import sklearn
  12. from sklearn import compose, impute, feature_selection
  13. from sklearn.model_selection import train_test_split
  14. from sklearn.ensemble import RandomForestClassifier
  15.  
  16. from RerF import fastRerF, fastPredict
  17.  
  18. # --
  19. # CLI
  20.  
  21. def parse_args():
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument('--task-id', type=int, default=3)
  24. parser.add_argument('--num-cores', type=int, default=16)
  25. parser.add_argument('--num-trees', type=int, default=500)
  26. parser.add_argument('--seed', type=int, default=123)
  27. return parser.parse_args()
  28.  
  29. args = parse_args()
  30. np.random.seed(args.seed)
  31.  
  32. # --
  33. # Load dataset
  34.  
  35. task = openml.tasks.get_task(args.task_id)
  36. X, y = task.get_X_and_y()
  37.  
  38. # Use first split (for now)
  39. train_idx, test_idx = task.get_train_test_split_indices()
  40. X_train, X_test = X[train_idx], X[test_idx]
  41. y_train, y_test = y[train_idx], y[test_idx]
  42.  
  43. # --
  44. # Preprocess data
  45.  
  46. dataset = task.get_dataset()
  47. nominal_indices = dataset.get_features_by_type(data_type='nominal', exclude=[task.target_name])
  48. numeric_indices = dataset.get_features_by_type(data_type='numeric', exclude=[task.target_name])
  49.  
  50. prep = sklearn.pipeline.make_pipeline(
  51. sklearn.compose.ColumnTransformer(
  52. transformers=[
  53. ('numeric', sklearn.pipeline.make_pipeline(
  54. sklearn.preprocessing.Imputer(),
  55. sklearn.preprocessing.StandardScaler(),
  56. ), numeric_indices),
  57. ('nominal', sklearn.pipeline.make_pipeline(
  58. sklearn.impute.SimpleImputer(strategy='constant', fill_value=-1),
  59. sklearn.preprocessing.OneHotEncoder(handle_unknown='ignore'),
  60. ), nominal_indices)
  61. ],
  62. remainder='passthrough',
  63. ),
  64. sklearn.feature_selection.VarianceThreshold(),
  65. )
  66.  
  67. Xf_train = prep.fit_transform(X_train)
  68. Xf_test = prep.transform(X_test)
  69.  
  70. # --
  71. # Train models
  72.  
  73. def fit_rerf(Xf_train, Xf_test, y_train, y_test, num_trees, num_cores):
  74. rerf_forest = fastRerF(
  75. X=Xf_train,
  76. Y=y_train,
  77. forestType="binnedBaseRerF",
  78. trees=num_trees,
  79. numCores=num_cores,
  80. )
  81.  
  82. return fastPredict(X=Xf_test, forest=rerf_forest)
  83.  
  84.  
  85. def fit_sklearn(Xf_train, Xf_test, y_train, y_test, num_trees, num_cores):
  86. sk_forest = RandomForestClassifier(n_estimators=num_trees, n_jobs=num_cores)
  87. sk_forest = sk_forest.fit(Xf_train, y_train)
  88. return sk_forest.predict(Xf_test)
  89.  
  90. kwargs = {
  91. "Xf_train" : Xf_train,
  92. "Xf_test" : Xf_test,
  93. "y_train" : y_train,
  94. "y_test" : y_test,
  95. "num_trees" : args.num_trees,
  96. "num_cores" : args.num_cores
  97. }
  98.  
  99. print('-' * 50, file=sys.stderr)
  100. print('fit rerf', file=sys.stderr)
  101. rerf_pred = [fit_rerf(**kwargs) for _ in range(10)]
  102. rerf_accs = [(y_test == p).mean() for p in rerf_pred]
  103.  
  104. print('-' * 50, file=sys.stderr)
  105. print('fit sklearn', file=sys.stderr)
  106. sk_pred = [fit_sklearn(**kwargs) for _ in range(10)]
  107. sk_accs = [(y_test == p).mean() for p in sk_pred]
  108.  
  109. print('np.mean(rerf_accs)', np.mean(rerf_accs))
  110. print('np.mean(sk_accs)', np.mean(sk_accs))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement