SHARE
TWEET

Untitled

a guest Jan 26th, 2020 75 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class ModelTrainer:
  2.  
  3.     _classInstances = {}
  4.  
  5.     def __init__(self, dataset_name, test_fraction):
  6.         self.dataset_name = dataset_name
  7.         self.test_fraction = test_fraction
  8.         if dataset_name == 'boston':
  9.             self.X, self.y = 100 * test_fraction, 1
  10.         elif dataset_name == 'iris':
  11.             self.X, self.y = 200 * test_fraction, 2
  12.         print('In init!!!')
  13. #         if isinstance(dataset, tuple):
  14. #             X, y = dataset
  15. #         else:
  16. #             X, y = shuffle(dataset.data, dataset.target, random_state=13)
  17.  
  18. #         offset = int(X.shape[0] * (1 - test_fraction))
  19. #         X_train, y_train = X[:offset], y[:offset]
  20. #         X_test = X[offset:]
  21.  
  22. #         fitted_estimator = estimator.fit(X_train, y_train)
  23.  
  24. #         if isinstance(estimator, (LinearClassifierMixin, SVC, NuSVC)):
  25. #             y_pred = estimator.decision_function(X_test)
  26. #         elif isinstance(estimator, DecisionTreeClassifier):
  27. #             y_pred = estimator.predict_proba(X_test.astype(np.float32))
  28. #         elif isinstance(
  29. #                 estimator,
  30. #                 (forest.ForestClassifier, XGBClassifier, LGBMClassifier)):
  31. #             y_pred = estimator.predict_proba(X_test)
  32. #         else:
  33. #             y_pred = estimator.predict(X_test)
  34.  
  35. #         return X_test, y_pred, fitted_estimator
  36.  
  37.     @classmethod
  38.     def get_instance(cls, dataset_name, test_fraction=0.07):
  39.         print('In get_instance!!!')
  40.         key = dataset_name + " {}".format(test_fraction)
  41.         if key not in cls._classInstances:
  42.             cls._classInstances[key] = ModelTrainer(dataset_name, test_fraction)
  43.         return cls._classInstances[key]
  44.    
  45.     def __call__(self, estimator):
  46.         print('In call!!!')
  47.         a = self.X
  48.         return a
  49.  
  50.  
  51. from functools import partial
  52. train_model_regression = partial(ModelTrainer.get_instance, "boston")
  53.  
  54.  
  55. def train_model_classification(estimator, test_fraction=0.02):
  56.     return ModelTrainer.get_instance(estimator, "iris", test_fraction)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Top