Advertisement
Guest User

Untitled

a guest
Apr 20th, 2018
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.65 KB | None | 0 0
  1.     def gb_train(self,X_train,y_train,X_val,y_val):
  2.         self.logger.info("..creating the datasets for gb")
  3.         param_grid = {
  4.             'num_leaves': [20],
  5.             'min_data_in_leaf': [10],
  6.             'max_depth': [40],
  7.             'bagging_fraction': [0.85],
  8.             'colsample_bytree': [0.75],
  9.             'subsample': [0.7],
  10.             'reg_alpha': [1.2]}
  11.         std_params = {'boosting_type':'gbdt',
  12.                       'objective':'regression',
  13.                       'bagging_freq': 5,
  14.                       'learning_rate':0.05,
  15.                       'metric':'rmse',
  16.                       'num_leaves': 10,
  17.                       'min_data_in_leaf': 1,
  18.                       'max_depth': 15,
  19.                       'bagging_fraction': 0.95}
  20.         lgb_estimator = lgb.LGBMRegressor(boosting_type='gbdt',
  21.                                           objective='regression',
  22.                                           bagging_freq=5,
  23.                                           num_boost_round= 10000,
  24.                                           learning_rate= 0.075,
  25.                                           metric='rmse',
  26.                                           early_stopping_rounds=5,
  27.                                           verbose=-2)
  28.         gsearch = GridSearchCV(estimator=lgb_estimator,
  29.                                param_grid=param_grid,
  30.                                cv=4,
  31.                                refit=True)
  32.         lgb_model = gsearch.fit(X = X_train,
  33.                                 y = y_train,
  34.                                 eval_set = (X_val, y_val),
  35.                                 eval_metric = 'rmse',
  36.                                 verbose=-3)
  37.         self.logger.info("..cross validation has finished correctly")
  38.         msg = "..the best score over cross validation is {}"
  39.         msg = msg.format(lgb_model.best_score_)
  40.         self.logger.info(msg)
  41.         self.logger.info("..showing best params")
  42.         params = lgb_model.best_params_
  43.         self.logger.info("..{}".format(params))
  44.         for key in std_params:
  45.             if key in params:
  46.                 std_params[key]=params[key]
  47.         lgb_X_train = lgb.Dataset(X_train,y_train)
  48.         lgb_X_eval = lgb.Dataset(X_val,y_val)
  49.         gbm = lgb.train(std_params,
  50.                         num_boost_round=10000,
  51.                         early_stopping_rounds=5,
  52.                         train_set=lgb_X_train,
  53.                         valid_sets=lgb_X_eval)
  54.         self.logger.info("..plotting features importance")
  55.         lgb.plot_importance(gbm)
  56.         #plt.show()
  57.         gbm.save_model('gbm_model.txt')
  58.         self.gb_model = gbm
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement