Advertisement
Guest User

Untitled

a guest
Mar 19th, 2019
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 14.31 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 65,
  6. "metadata": {
  7. "scrolled": false
  8. },
  9. "outputs": [],
  10. "source": [
  11. "# code is written for readibility not performance\n",
  12. "def sort_print_importances(importances):\n",
  13. " imps_df = pd.DataFrame(list(importances.values()), list(importances.keys()))\n",
  14. " imps_df.columns = [\"Importance\"]\n",
  15. " return imps_df.sort_values(\"Importance\", ascending = False)\n",
  16. " \n",
  17. "def get_auc(true_labels, predictions, positive_label):\n",
  18. " fpr, tpr, thresholds = sklearn.metrics.roc_curve(true_labels, predictions, pos_label = positive_label)\n",
  19. " return sklearn.metrics.auc(fpr, tpr)\n",
  20. " \n",
  21. "def reduce_importances(dummy_importances, original_columns):\n",
  22. " \"\"\" Helper function to combine the importances obtained from 1-hot encoded data.\"\"\"\n",
  23. " \n",
  24. " # d2 loop optimize for wider data\n",
  25. " # loop through all the importances and aggregate for original column\n",
  26. " importances = {}\n",
  27. " for col in original_columns: \n",
  28. " importances[col] = 0\n",
  29. " for dummy_col, val in dummy_importances.items():\n",
  30. " if str(dummy_col).startswith(col + \"_\") or dummy_col == col:\n",
  31. " importances[col] = importances[col] + val\n",
  32. " return importances\n",
  33. "\n",
  34. "# sorts by first passed in..\n",
  35. "def plot_importances(importances, names):\n",
  36. " assert(len(importances) >= 1)\n",
  37. " assert(len(importances) == len(names))\n",
  38. " import plotly.graph_objs as go\n",
  39. " cols = sorted(importances[0], key = importances[0].get, reverse = True)\n",
  40. " fig = go.FigureWidget()\n",
  41. " fig.layout.title = \"Importances\"\n",
  42. " fig.layout.showlegend = True\n",
  43. " for i in range(len(importances)):\n",
  44. " fig.add_bar(y=[importances[i][col] for col in cols], x = cols, name= names[i])\n",
  45. " return fig"
  46. ]
  47. },
  48. {
  49. "cell_type": "code",
  50. "execution_count": 34,
  51. "metadata": {
  52. "scrolled": false
  53. },
  54. "outputs": [
  55. {
  56. "name": "stdout",
  57. "output_type": "stream",
  58. "text": [
  59. "Train df size is 21815, test df size is 10746\n"
  60. ]
  61. }
  62. ],
  63. "source": [
  64. "target_column = \"yearly-income\"\n",
  65. "positive_label = \" >50K\"\n",
  66. "# Make dummy values for train a Gradient Boosted Tree\n",
  67. "categorical_cols = [col for col in df.columns if col != target_column and df.dtypes[col] == object]\n",
  68. "categorical_cols.append('education-num')\n",
  69. "dummy_df = pd.get_dummies(df, columns = categorical_cols)\n",
  70. "# Split into train and test\n",
  71. "train_df, test_df = train_test_split(dummy_df, test_size=0.33, random_state=42)\n",
  72. "print(\"Train df size is {}, test df size is {}\".format(len(train_df), len(test_df)))\n",
  73. "columns = list(train_df.columns)\n",
  74. "columns.remove(target_column)\n"
  75. ]
  76. },
  77. {
  78. "cell_type": "code",
  79. "execution_count": 35,
  80. "metadata": {
  81. "scrolled": false
  82. },
  83. "outputs": [],
  84. "source": [
  85. "# Do a basic grid search over based parameters with early stopping\n",
  86. "gbm = sklearn.model_selection.GridSearchCV(\n",
  87. " ensemble.GradientBoostingClassifier(\n",
  88. " n_estimators=500, \n",
  89. " n_iter_no_change = 20, \n",
  90. " validation_fraction = 0.2), {\n",
  91. " \"learning_rate\": np.arange(0.05,0.35,0.05),\n",
  92. " \"max_depth\":[1,2,3,4,5]},\n",
  93. " cv = 5)\n",
  94. "gbm.fit(train_df.drop([target_column], axis = 1), \n",
  95. " train_df[target_column])\n",
  96. "tuned_gbm = gbm"
  97. ]
  98. },
  99. {
  100. "cell_type": "code",
  101. "execution_count": 47,
  102. "metadata": {
  103. "scrolled": false
  104. },
  105. "outputs": [
  106. {
  107. "name": "stdout",
  108. "output_type": "stream",
  109. "text": [
  110. "Best parameters found {'criterion': 'friedman_mse', 'init': None, 'learning_rate': 0.05, 'loss': 'deviance', 'max_depth': 4, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_impurity_split': None, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'n_estimators': 500, 'n_iter_no_change': 20, 'presort': 'auto', 'random_state': None, 'subsample': 1.0, 'tol': 0.0001, 'validation_fraction': 0.2, 'verbose': 0, 'warm_start': False}\n",
  111. "\n",
  112. "Stopped improving after: 500\n",
  113. "\n",
  114. "AUC on test is: 92.55 %\n"
  115. ]
  116. }
  117. ],
  118. "source": [
  119. "print(\"Best parameters found\", gbm.best_estimator_.get_params())\n",
  120. "print(\"\\nStopped improving after:\", gbm.best_estimator_.n_estimators_)\n",
  121. "model = tuned_gbm.best_estimator_\n",
  122. "# predict and score on test dataset\n",
  123. "predictions = model.predict_proba(test_df.drop(target_column, axis = 1))[:,list(model.classes_).index(positive_label)]\n",
  124. "score = get_auc(test_df[target_column], predictions, positive_label)\n",
  125. "print(\"\\nAUC on test is: \", round(score * 100, 2), \"%\") "
  126. ]
  127. },
  128. {
  129. "cell_type": "code",
  130. "execution_count": 67,
  131. "metadata": {
  132. "scrolled": false
  133. },
  134. "outputs": [
  135. {
  136. "data": {
  137. "text/html": [
  138. "<div>\n",
  139. "<style scoped>\n",
  140. " .dataframe tbody tr th:only-of-type {\n",
  141. " vertical-align: middle;\n",
  142. " }\n",
  143. "\n",
  144. " .dataframe tbody tr th {\n",
  145. " vertical-align: top;\n",
  146. " }\n",
  147. "\n",
  148. " .dataframe thead th {\n",
  149. " text-align: right;\n",
  150. " }\n",
  151. "</style>\n",
  152. "<table border=\"1\" class=\"dataframe\">\n",
  153. " <thead>\n",
  154. " <tr style=\"text-align: right;\">\n",
  155. " <th></th>\n",
  156. " <th>Importance</th>\n",
  157. " </tr>\n",
  158. " </thead>\n",
  159. " <tbody>\n",
  160. " <tr>\n",
  161. " <th>marital-status_ Married-civ-spouse</th>\n",
  162. " <td>0.362117</td>\n",
  163. " </tr>\n",
  164. " <tr>\n",
  165. " <th>capital-gain</th>\n",
  166. " <td>0.213057</td>\n",
  167. " </tr>\n",
  168. " <tr>\n",
  169. " <th>capital-loss</th>\n",
  170. " <td>0.072038</td>\n",
  171. " </tr>\n",
  172. " <tr>\n",
  173. " <th>age</th>\n",
  174. " <td>0.063056</td>\n",
  175. " </tr>\n",
  176. " <tr>\n",
  177. " <th>hours-per-week</th>\n",
  178. " <td>0.041892</td>\n",
  179. " </tr>\n",
  180. " <tr>\n",
  181. " <th>occupation_ Prof-specialty</th>\n",
  182. " <td>0.034868</td>\n",
  183. " </tr>\n",
  184. " <tr>\n",
  185. " <th>occupation_ Exec-managerial</th>\n",
  186. " <td>0.031057</td>\n",
  187. " </tr>\n",
  188. " <tr>\n",
  189. " <th>fnlwgt</th>\n",
  190. " <td>0.024566</td>\n",
  191. " </tr>\n",
  192. " <tr>\n",
  193. " <th>education-num_13</th>\n",
  194. " <td>0.017286</td>\n",
  195. " </tr>\n",
  196. " <tr>\n",
  197. " <th>education-num_14</th>\n",
  198. " <td>0.012526</td>\n",
  199. " </tr>\n",
  200. " <tr>\n",
  201. " <th>education_ Bachelors</th>\n",
  202. " <td>0.008913</td>\n",
  203. " </tr>\n",
  204. " <tr>\n",
  205. " <th>workclass_ Self-emp-not-inc</th>\n",
  206. " <td>0.008253</td>\n",
  207. " </tr>\n",
  208. " <tr>\n",
  209. " <th>education-num_16</th>\n",
  210. " <td>0.006599</td>\n",
  211. " </tr>\n",
  212. " <tr>\n",
  213. " <th>occupation_ Tech-support</th>\n",
  214. " <td>0.006491</td>\n",
  215. " </tr>\n",
  216. " <tr>\n",
  217. " <th>education_ Prof-school</th>\n",
  218. " <td>0.005582</td>\n",
  219. " </tr>\n",
  220. " <tr>\n",
  221. " <th>education-num_9</th>\n",
  222. " <td>0.004995</td>\n",
  223. " </tr>\n",
  224. " <tr>\n",
  225. " <th>relationship_ Wife</th>\n",
  226. " <td>0.004767</td>\n",
  227. " </tr>\n",
  228. " <tr>\n",
  229. " <th>occupation_ Other-service</th>\n",
  230. " <td>0.004670</td>\n",
  231. " </tr>\n",
  232. " <tr>\n",
  233. " <th>occupation_ Farming-fishing</th>\n",
  234. " <td>0.004615</td>\n",
  235. " </tr>\n",
  236. " <tr>\n",
  237. " <th>occupation_ Sales</th>\n",
  238. " <td>0.004305</td>\n",
  239. " </tr>\n",
  240. " </tbody>\n",
  241. "</table>\n",
  242. "</div>"
  243. ],
  244. "text/plain": [
  245. " Importance\n",
  246. "marital-status_ Married-civ-spouse 0.362117\n",
  247. "capital-gain 0.213057\n",
  248. "capital-loss 0.072038\n",
  249. "age 0.063056\n",
  250. "hours-per-week 0.041892\n",
  251. "occupation_ Prof-specialty 0.034868\n",
  252. "occupation_ Exec-managerial 0.031057\n",
  253. "fnlwgt 0.024566\n",
  254. "education-num_13 0.017286\n",
  255. "education-num_14 0.012526\n",
  256. "education_ Bachelors 0.008913\n",
  257. "workclass_ Self-emp-not-inc 0.008253\n",
  258. "education-num_16 0.006599\n",
  259. "occupation_ Tech-support 0.006491\n",
  260. "education_ Prof-school 0.005582\n",
  261. "education-num_9 0.004995\n",
  262. "relationship_ Wife 0.004767\n",
  263. "occupation_ Other-service 0.004670\n",
  264. "occupation_ Farming-fishing 0.004615\n",
  265. "occupation_ Sales 0.004305"
  266. ]
  267. },
  268. "execution_count": 67,
  269. "metadata": {},
  270. "output_type": "execute_result"
  271. }
  272. ],
  273. "source": [
  274. "# Print raw values for the importances\n",
  275. "dummy_importances_1 = {}\n",
  276. "for i in range(len(columns)):\n",
  277. " dummy_importances_1[columns[i]] = tuned_gbm.best_estimator_.feature_importances_[i]\n",
  278. "sort_print_importances(dummy_importances_1).head(20)\n",
  279. " \n"
  280. ]
  281. },
  282. {
  283. "cell_type": "code",
  284. "execution_count": null,
  285. "metadata": {
  286. "scrolled": false
  287. },
  288. "outputs": [],
  289. "source": [
  290. "plot_importances([dummy_importances_1],[\"Scikit Importances (Dummy)\"])"
  291. ]
  292. },
  293. {
  294. "cell_type": "code",
  295. "execution_count": 70,
  296. "metadata": {
  297. "scrolled": false
  298. },
  299. "outputs": [
  300. {
  301. "data": {
  302. "text/html": [
  303. "<div>\n",
  304. "<style scoped>\n",
  305. " .dataframe tbody tr th:only-of-type {\n",
  306. " vertical-align: middle;\n",
  307. " }\n",
  308. "\n",
  309. " .dataframe tbody tr th {\n",
  310. " vertical-align: top;\n",
  311. " }\n",
  312. "\n",
  313. " .dataframe thead th {\n",
  314. " text-align: right;\n",
  315. " }\n",
  316. "</style>\n",
  317. "<table border=\"1\" class=\"dataframe\">\n",
  318. " <thead>\n",
  319. " <tr style=\"text-align: right;\">\n",
  320. " <th></th>\n",
  321. " <th>Importance</th>\n",
  322. " </tr>\n",
  323. " </thead>\n",
  324. " <tbody>\n",
  325. " <tr>\n",
  326. " <th>marital-status</th>\n",
  327. " <td>0.364811</td>\n",
  328. " </tr>\n",
  329. " <tr>\n",
  330. " <th>capital-gain</th>\n",
  331. " <td>0.213057</td>\n",
  332. " </tr>\n",
  333. " <tr>\n",
  334. " <th>occupation</th>\n",
  335. " <td>0.094029</td>\n",
  336. " </tr>\n",
  337. " <tr>\n",
  338. " <th>capital-loss</th>\n",
  339. " <td>0.072038</td>\n",
  340. " </tr>\n",
  341. " <tr>\n",
  342. " <th>age</th>\n",
  343. " <td>0.063056</td>\n",
  344. " </tr>\n",
  345. " <tr>\n",
  346. " <th>education-num</th>\n",
  347. " <td>0.051340</td>\n",
  348. " </tr>\n",
  349. " <tr>\n",
  350. " <th>hours-per-week</th>\n",
  351. " <td>0.041892</td>\n",
  352. " </tr>\n",
  353. " <tr>\n",
  354. " <th>education</th>\n",
  355. " <td>0.032389</td>\n",
  356. " </tr>\n",
  357. " <tr>\n",
  358. " <th>fnlwgt</th>\n",
  359. " <td>0.024566</td>\n",
  360. " </tr>\n",
  361. " <tr>\n",
  362. " <th>workclass</th>\n",
  363. " <td>0.016797</td>\n",
  364. " </tr>\n",
  365. " <tr>\n",
  366. " <th>native-country</th>\n",
  367. " <td>0.011466</td>\n",
  368. " </tr>\n",
  369. " <tr>\n",
  370. " <th>relationship</th>\n",
  371. " <td>0.010971</td>\n",
  372. " </tr>\n",
  373. " <tr>\n",
  374. " <th>race</th>\n",
  375. " <td>0.001832</td>\n",
  376. " </tr>\n",
  377. " <tr>\n",
  378. " <th>sex</th>\n",
  379. " <td>0.001756</td>\n",
  380. " </tr>\n",
  381. " </tbody>\n",
  382. "</table>\n",
  383. "</div>"
  384. ],
  385. "text/plain": [
  386. " Importance\n",
  387. "marital-status 0.364811\n",
  388. "capital-gain 0.213057\n",
  389. "occupation 0.094029\n",
  390. "capital-loss 0.072038\n",
  391. "age 0.063056\n",
  392. "education-num 0.051340\n",
  393. "hours-per-week 0.041892\n",
  394. "education 0.032389\n",
  395. "fnlwgt 0.024566\n",
  396. "workclass 0.016797\n",
  397. "native-country 0.011466\n",
  398. "relationship 0.010971\n",
  399. "race 0.001832\n",
  400. "sex 0.001756"
  401. ]
  402. },
  403. "execution_count": 70,
  404. "metadata": {},
  405. "output_type": "execute_result"
  406. }
  407. ],
  408. "source": [
  409. "# Reduce the importances to the original columns by simple adding up\n",
  410. "assert(len(columns) == len(tuned_gbm.best_estimator_.feature_importances_))\n",
  411. "original_columns = list(df.columns)\n",
  412. "original_columns.remove(target_column)\n",
  413. "reduced_importances_1 = reduce_importances(dummy_importances_1, original_columns)\n",
  414. "sort_print_importances(reduced_importances_1)\n"
  415. ]
  416. },
  417. {
  418. "cell_type": "code",
  419. "execution_count": null,
  420. "metadata": {
  421. "scrolled": false
  422. },
  423. "outputs": [],
  424. "source": [
  425. "plot_importances([reduced_importances_1], [\"Reduced Importances\"])"
  426. ]
  427. }
  428. ],
  429. "metadata": {
  430. "kernelspec": {
  431. "display_name": "Python 3",
  432. "language": "python",
  433. "name": "python3"
  434. },
  435. "language_info": {
  436. "codemirror_mode": {
  437. "name": "ipython",
  438. "version": 3
  439. },
  440. "file_extension": ".py",
  441. "mimetype": "text/x-python",
  442. "name": "python",
  443. "nbconvert_exporter": "python",
  444. "pygments_lexer": "ipython3",
  445. "version": "3.6.5"
  446. }
  447. },
  448. "nbformat": 4,
  449. "nbformat_minor": 2
  450. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement