Advertisement
Guest User

Untitled

a guest
Mar 4th, 2016
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.12 KB | None | 0 0
  1. class ModelEmbedder :
  2.  
  3. def __init__(self, model, rare_threshold) :
  4. self.model = model
  5. self.means = {}
  6. self.rare_threshold = rare_threshold
  7. self.train = None
  8. self.origin_train = None
  9. self.average = 0
  10.  
  11.  
  12. def fit(self,train,target):
  13. self.origin_train = train.copy().fillna(-1)
  14. self.train = train.copy()
  15. self.train = self.train.fillna(-1)
  16. self.train['target'] = target
  17. self.average = target.mean()
  18.  
  19. for feat in train.columns:
  20. if feat != 'target' :
  21. if self.train[feat].dtype=='object' :
  22. self.train.loc[self.train[feat].value_counts()[self.train[feat]].values < self.rare_threshold, feat] = "RARE"
  23. self.origin_train.loc[self.origin_train[feat].value_counts()[self.origin_train[feat]].values < self.rare_threshold, feat] = "RARE"
  24. self.means[feat] = self.train.groupby([feat])['target'].mean()
  25. self.means[feat]["RARE"] = self.average
  26.  
  27. self.train[feat] = self.train[feat].replace(self.means[feat], inplace=False)
  28.  
  29. del self.train['target']
  30.  
  31. self.model.fit(self.train,target)
  32.  
  33. def _pre_treat_test(self,test) :
  34. test = test.copy()
  35. test = test.fillna(-1)
  36.  
  37. for feat in self.origin_train.columns:
  38. if self.origin_train[feat].dtype=='object' :
  39. test.loc[self.origin_train[feat].value_counts()[self.origin_train[feat]].values < self.rare_threshold, feat] = "RARE"
  40.  
  41. criterion = ~test[feat].isin(set(self.origin_train[feat]))
  42.  
  43. test.loc[criterion,feat] = self.average
  44.  
  45.  
  46. test[feat] = test[feat].replace(self.means[feat], inplace=False)
  47.  
  48. return test
  49.  
  50. def predict_proba(self,test) :
  51. test = self._pre_treat_test(test)
  52. return self.model.predict_proba(test)
  53.  
  54. def get_params(self, deep = True):
  55. return self.model.get_params(deep)
  56.  
  57. rf = ensemble.ExtraTreesClassifier(n_jobs=7,
  58. n_estimators = n_estimators,
  59. random_state = 11)
  60.  
  61. rf_embedded = model_embedder.ModelEmbedder(rf,10)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement