Guest User

Untitled

a guest
Jan 16th, 2019
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.70 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import numpy as np\n",
  10. "\n",
  11. "from sklearn.datasets import load_iris\n",
  12. "from sklearn.model_selection import StratifiedShuffleSplit"
  13. ]
  14. },
  15. {
  16. "cell_type": "code",
  17. "execution_count": 2,
  18. "metadata": {},
  19. "outputs": [],
  20. "source": [
  21. "random_state = 42"
  22. ]
  23. },
  24. {
  25. "cell_type": "code",
  26. "execution_count": 3,
  27. "metadata": {},
  28. "outputs": [],
  29. "source": [
  30. "X, y = load_iris(return_X_y=True)"
  31. ]
  32. },
  33. {
  34. "cell_type": "code",
  35. "execution_count": 4,
  36. "metadata": {},
  37. "outputs": [
  38. {
  39. "data": {
  40. "text/plain": [
  41. "(150, 4)"
  42. ]
  43. },
  44. "execution_count": 4,
  45. "metadata": {},
  46. "output_type": "execute_result"
  47. }
  48. ],
  49. "source": [
  50. "X.shape"
  51. ]
  52. },
  53. {
  54. "cell_type": "code",
  55. "execution_count": 5,
  56. "metadata": {},
  57. "outputs": [
  58. {
  59. "data": {
  60. "text/plain": [
  61. "StratifiedShuffleSplit(n_splits=1, random_state=0, test_size=0.4,\n",
  62. " train_size=None)"
  63. ]
  64. },
  65. "execution_count": 5,
  66. "metadata": {},
  67. "output_type": "execute_result"
  68. }
  69. ],
  70. "source": [
  71. "sss = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=0)\n",
  72. "sss"
  73. ]
  74. },
  75. {
  76. "cell_type": "code",
  77. "execution_count": 6,
  78. "metadata": {},
  79. "outputs": [
  80. {
  81. "data": {
  82. "text/plain": [
  83. "(array([128, 101, 52, 28, 94, 38, 76, 141, 148, 75, 126, 87, 69,\n",
  84. " 45, 8, 115, 4, 127, 79, 84, 108, 82, 140, 59, 131, 10,\n",
  85. " 22, 97, 13, 95, 63, 135, 33, 15, 56, 105, 16, 27, 32,\n",
  86. " 78, 104, 26, 92, 60, 41, 58, 119, 93, 112, 11, 146, 72,\n",
  87. " 83, 116, 62, 91, 120, 48, 57, 7, 133, 106, 31, 132, 80,\n",
  88. " 73, 66, 111, 107, 20, 30, 25, 42, 14, 70, 138, 35, 137,\n",
  89. " 2, 18, 124, 122, 74, 143, 43, 117, 29, 125, 96, 34]),\n",
  90. " array([121, 109, 36, 144, 1, 9, 39, 147, 98, 89, 23, 149, 118,\n",
  91. " 44, 61, 100, 65, 37, 113, 142, 64, 24, 145, 46, 99, 53,\n",
  92. " 102, 19, 54, 139, 40, 130, 71, 86, 110, 47, 136, 51, 81,\n",
  93. " 123, 50, 49, 68, 103, 129, 85, 88, 0, 17, 6, 3, 134,\n",
  94. " 90, 21, 5, 55, 114, 12, 67, 77]))"
  95. ]
  96. },
  97. "execution_count": 6,
  98. "metadata": {},
  99. "output_type": "execute_result"
  100. }
  101. ],
  102. "source": [
  103. "train_index, test_index = next(sss.split(X, y))\n",
  104. "train_index, test_index"
  105. ]
  106. },
  107. {
  108. "cell_type": "code",
  109. "execution_count": 7,
  110. "metadata": {},
  111. "outputs": [
  112. {
  113. "data": {
  114. "text/plain": [
  115. "(90, 4)"
  116. ]
  117. },
  118. "execution_count": 7,
  119. "metadata": {},
  120. "output_type": "execute_result"
  121. }
  122. ],
  123. "source": [
  124. "X[train_index].shape"
  125. ]
  126. },
  127. {
  128. "cell_type": "code",
  129. "execution_count": 8,
  130. "metadata": {},
  131. "outputs": [
  132. {
  133. "data": {
  134. "text/plain": [
  135. "(60, 4)"
  136. ]
  137. },
  138. "execution_count": 8,
  139. "metadata": {},
  140. "output_type": "execute_result"
  141. }
  142. ],
  143. "source": [
  144. "X[test_index].shape"
  145. ]
  146. },
  147. {
  148. "cell_type": "code",
  149. "execution_count": 9,
  150. "metadata": {},
  151. "outputs": [
  152. {
  153. "data": {
  154. "text/plain": [
  155. "(array([0, 1, 2]), array([20, 20, 20]))"
  156. ]
  157. },
  158. "execution_count": 9,
  159. "metadata": {},
  160. "output_type": "execute_result"
  161. }
  162. ],
  163. "source": [
  164. "np.unique(y[test_index], return_counts=True)"
  165. ]
  166. },
  167. {
  168. "cell_type": "code",
  169. "execution_count": 10,
  170. "metadata": {},
  171. "outputs": [
  172. {
  173. "data": {
  174. "text/plain": [
  175. "(array([0, 1, 2]), array([30, 30, 30]))"
  176. ]
  177. },
  178. "execution_count": 10,
  179. "metadata": {},
  180. "output_type": "execute_result"
  181. }
  182. ],
  183. "source": [
  184. "np.unique(y[train_index], return_counts=True)"
  185. ]
  186. },
  187. {
  188. "cell_type": "code",
  189. "execution_count": 11,
  190. "metadata": {},
  191. "outputs": [],
  192. "source": [
  193. "def recursive_stratified_shuffle_split(sizes, random_state=None):\n",
  194. "\n",
  195. " head, *tail = sizes\n",
  196. " sss = StratifiedShuffleSplit(n_splits=1, test_size=head, random_state=random_state)\n",
  197. "\n",
  198. " def split(X, y):\n",
  199. "\n",
  200. " a_index, b_index = next(sss.split(X, y))\n",
  201. "\n",
  202. " yield a_index\n",
  203. "\n",
  204. " if tail:\n",
  205. "\n",
  206. " split_tail = recursive_stratified_shuffle_split(sizes=tail, random_state=random_state)\n",
  207. " \n",
  208. " for ind in split_tail(X[b_index], y[b_index]):\n",
  209. " \n",
  210. " yield b_index[ind]\n",
  211. "\n",
  212. " else:\n",
  213. "\n",
  214. " yield b_index\n",
  215. " \n",
  216. " return split"
  217. ]
  218. },
  219. {
  220. "cell_type": "code",
  221. "execution_count": 12,
  222. "metadata": {},
  223. "outputs": [],
  224. "source": [
  225. "# first split 70/80 and split the remainder 60/20\n",
  226. "split = recursive_stratified_shuffle_split(sizes=[80, 20])"
  227. ]
  228. },
  229. {
  230. "cell_type": "code",
  231. "execution_count": 13,
  232. "metadata": {},
  233. "outputs": [
  234. {
  235. "data": {
  236. "text/plain": [
  237. "[array([139, 138, 7, 34, 109, 128, 24, 132, 76, 96, 22, 101, 83,\n",
  238. " 140, 146, 46, 67, 8, 61, 44, 88, 85, 1, 9, 35, 74,\n",
  239. " 145, 0, 65, 6, 57, 136, 73, 4, 54, 43, 69, 55, 75,\n",
  240. " 131, 99, 60, 18, 79, 125, 5, 111, 63, 12, 149, 13, 89,\n",
  241. " 106, 25, 122, 113, 119, 49, 80, 11, 59, 52, 115, 142, 38,\n",
  242. " 45, 20, 118, 130, 123]),\n",
  243. " array([ 64, 40, 15, 114, 36, 124, 50, 2, 107, 53, 141, 30, 87,\n",
  244. " 62, 17, 39, 134, 105, 19, 70, 66, 42, 129, 116, 86, 37,\n",
  245. " 21, 94, 72, 41, 71, 84, 68, 110, 148, 82, 98, 137, 31,\n",
  246. " 48, 47, 102, 127, 23, 133, 27, 51, 95, 121, 77, 120, 32,\n",
  247. " 104, 16, 58, 147, 33, 103, 92, 135]),\n",
  248. " array([ 56, 14, 112, 143, 93, 26, 108, 78, 144, 100, 117, 29, 126,\n",
  249. " 97, 28, 10, 81, 3, 90, 91])]"
  250. ]
  251. },
  252. "execution_count": 13,
  253. "metadata": {},
  254. "output_type": "execute_result"
  255. }
  256. ],
  257. "source": [
  258. "list(split(X, y))"
  259. ]
  260. },
  261. {
  262. "cell_type": "code",
  263. "execution_count": 14,
  264. "metadata": {},
  265. "outputs": [
  266. {
  267. "data": {
  268. "text/plain": [
  269. "[(70, 4), (60, 4), (20, 4)]"
  270. ]
  271. },
  272. "execution_count": 14,
  273. "metadata": {},
  274. "output_type": "execute_result"
  275. }
  276. ],
  277. "source": [
  278. "[X[index].shape for index in split(X, y)]"
  279. ]
  280. },
  281. {
  282. "cell_type": "code",
  283. "execution_count": 15,
  284. "metadata": {},
  285. "outputs": [
  286. {
  287. "data": {
  288. "text/plain": [
  289. "[(array([0, 1, 2]), array([24, 23, 23])),\n",
  290. " (array([0, 1, 2]), array([20, 20, 20])),\n",
  291. " (array([0, 1, 2]), array([6, 7, 7]))]"
  292. ]
  293. },
  294. "execution_count": 15,
  295. "metadata": {},
  296. "output_type": "execute_result"
  297. }
  298. ],
  299. "source": [
  300. "[np.unique(y[index], return_counts=True) for index in split(X, y)]"
  301. ]
  302. },
  303. {
  304. "cell_type": "code",
  305. "execution_count": 16,
  306. "metadata": {},
  307. "outputs": [],
  308. "source": [
  309. "# first split 40/60 and split the remainder 50/50\n",
  310. "split = recursive_stratified_shuffle_split(sizes=[0.4, 0.5])"
  311. ]
  312. },
  313. {
  314. "cell_type": "code",
  315. "execution_count": 17,
  316. "metadata": {},
  317. "outputs": [
  318. {
  319. "data": {
  320. "text/plain": [
  321. "[(90, 4), (30, 4), (30, 4)]"
  322. ]
  323. },
  324. "execution_count": 17,
  325. "metadata": {},
  326. "output_type": "execute_result"
  327. }
  328. ],
  329. "source": [
  330. "[X[index].shape for index in split(X, y)]"
  331. ]
  332. },
  333. {
  334. "cell_type": "code",
  335. "execution_count": 18,
  336. "metadata": {},
  337. "outputs": [
  338. {
  339. "data": {
  340. "text/plain": [
  341. "[(array([0, 1, 2]), array([30, 30, 30])),\n",
  342. " (array([0, 1, 2]), array([10, 10, 10])),\n",
  343. " (array([0, 1, 2]), array([10, 10, 10]))]"
  344. ]
  345. },
  346. "execution_count": 18,
  347. "metadata": {},
  348. "output_type": "execute_result"
  349. }
  350. ],
  351. "source": [
  352. "[np.unique(y[index], return_counts=True) for index in split(X, y)]"
  353. ]
  354. }
  355. ],
  356. "metadata": {
  357. "kernelspec": {
  358. "display_name": "Python 3",
  359. "language": "python",
  360. "name": "python3"
  361. },
  362. "language_info": {
  363. "codemirror_mode": {
  364. "name": "ipython",
  365. "version": 3
  366. },
  367. "file_extension": ".py",
  368. "mimetype": "text/x-python",
  369. "name": "python",
  370. "nbconvert_exporter": "python",
  371. "pygments_lexer": "ipython3",
  372. "version": "3.5.2"
  373. }
  374. },
  375. "nbformat": 4,
  376. "nbformat_minor": 2
  377. }
Add Comment
Please, Sign In to add comment