Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def gb_train(self,X_train,y_train,X_val,y_val):
- self.logger.info("..creating the datasets for gb")
- param_grid = {
- 'num_leaves': [20],
- 'min_data_in_leaf': [10],
- 'max_depth': [40],
- 'bagging_fraction': [0.85],
- 'colsample_bytree': [0.75],
- 'subsample': [0.7],
- 'reg_alpha': [1.2]}
- std_params = {'boosting_type':'gbdt',
- 'objective':'regression',
- 'bagging_freq': 5,
- 'learning_rate':0.05,
- 'metric':'rmse',
- 'num_leaves': 10,
- 'min_data_in_leaf': 1,
- 'max_depth': 15,
- 'bagging_fraction': 0.95}
- lgb_estimator = lgb.LGBMRegressor(boosting_type='gbdt',
- objective='regression',
- bagging_freq=5,
- num_boost_round= 10000,
- learning_rate= 0.075,
- metric='rmse',
- early_stopping_rounds=5,
- verbose=-2)
- gsearch = GridSearchCV(estimator=lgb_estimator,
- param_grid=param_grid,
- cv=4,
- refit=True)
- lgb_model = gsearch.fit(X = X_train,
- y = y_train,
- eval_set = (X_val, y_val),
- eval_metric = 'rmse',
- verbose=-3)
- self.logger.info("..cross validation has finished correctly")
- msg = "..the best score over cross validation is {}"
- msg = msg.format(lgb_model.best_score_)
- self.logger.info(msg)
- self.logger.info("..showing best params")
- params = lgb_model.best_params_
- self.logger.info("..{}".format(params))
- for key in std_params:
- if key in params:
- std_params[key]=params[key]
- lgb_X_train = lgb.Dataset(X_train,y_train)
- lgb_X_eval = lgb.Dataset(X_val,y_val)
- gbm = lgb.train(std_params,
- num_boost_round=10000,
- early_stopping_rounds=5,
- train_set=lgb_X_train,
- valid_sets=lgb_X_eval)
- self.logger.info("..plotting features importance")
- lgb.plot_importance(gbm)
- #plt.show()
- gbm.save_model('gbm_model.txt')
- self.gb_model = gbm
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement