Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "from sklearn.datasets import load_iris\n",
- "from sklearn.model_selection import StratifiedShuffleSplit"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "random_state = 42"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "X, y = load_iris(return_X_y=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(150, 4)"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "X.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "StratifiedShuffleSplit(n_splits=1, random_state=0, test_size=0.4,\n",
- " train_size=None)"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "sss = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=0)\n",
- "sss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(array([128, 101, 52, 28, 94, 38, 76, 141, 148, 75, 126, 87, 69,\n",
- " 45, 8, 115, 4, 127, 79, 84, 108, 82, 140, 59, 131, 10,\n",
- " 22, 97, 13, 95, 63, 135, 33, 15, 56, 105, 16, 27, 32,\n",
- " 78, 104, 26, 92, 60, 41, 58, 119, 93, 112, 11, 146, 72,\n",
- " 83, 116, 62, 91, 120, 48, 57, 7, 133, 106, 31, 132, 80,\n",
- " 73, 66, 111, 107, 20, 30, 25, 42, 14, 70, 138, 35, 137,\n",
- " 2, 18, 124, 122, 74, 143, 43, 117, 29, 125, 96, 34]),\n",
- " array([121, 109, 36, 144, 1, 9, 39, 147, 98, 89, 23, 149, 118,\n",
- " 44, 61, 100, 65, 37, 113, 142, 64, 24, 145, 46, 99, 53,\n",
- " 102, 19, 54, 139, 40, 130, 71, 86, 110, 47, 136, 51, 81,\n",
- " 123, 50, 49, 68, 103, 129, 85, 88, 0, 17, 6, 3, 134,\n",
- " 90, 21, 5, 55, 114, 12, 67, 77]))"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_index, test_index = next(sss.split(X, y))\n",
- "train_index, test_index"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(90, 4)"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "X[train_index].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(60, 4)"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "X[test_index].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(array([0, 1, 2]), array([20, 20, 20]))"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.unique(y[test_index], return_counts=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(array([0, 1, 2]), array([30, 30, 30]))"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.unique(y[train_index], return_counts=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "def recursive_stratified_shuffle_split(sizes, random_state=None):\n",
- "\n",
- " head, *tail = sizes\n",
- " sss = StratifiedShuffleSplit(n_splits=1, test_size=head, random_state=random_state)\n",
- "\n",
- " def split(X, y):\n",
- "\n",
- " a_index, b_index = next(sss.split(X, y))\n",
- "\n",
- " yield a_index\n",
- "\n",
- " if tail:\n",
- "\n",
- " split_tail = recursive_stratified_shuffle_split(sizes=tail, random_state=random_state)\n",
- " \n",
- " for ind in split_tail(X[b_index], y[b_index]):\n",
- " \n",
- " yield b_index[ind]\n",
- "\n",
- " else:\n",
- "\n",
- " yield b_index\n",
- " \n",
- " return split"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "# first split 70/80 and split the remainder 60/20\n",
- "split = recursive_stratified_shuffle_split(sizes=[80, 20])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[array([139, 138, 7, 34, 109, 128, 24, 132, 76, 96, 22, 101, 83,\n",
- " 140, 146, 46, 67, 8, 61, 44, 88, 85, 1, 9, 35, 74,\n",
- " 145, 0, 65, 6, 57, 136, 73, 4, 54, 43, 69, 55, 75,\n",
- " 131, 99, 60, 18, 79, 125, 5, 111, 63, 12, 149, 13, 89,\n",
- " 106, 25, 122, 113, 119, 49, 80, 11, 59, 52, 115, 142, 38,\n",
- " 45, 20, 118, 130, 123]),\n",
- " array([ 64, 40, 15, 114, 36, 124, 50, 2, 107, 53, 141, 30, 87,\n",
- " 62, 17, 39, 134, 105, 19, 70, 66, 42, 129, 116, 86, 37,\n",
- " 21, 94, 72, 41, 71, 84, 68, 110, 148, 82, 98, 137, 31,\n",
- " 48, 47, 102, 127, 23, 133, 27, 51, 95, 121, 77, 120, 32,\n",
- " 104, 16, 58, 147, 33, 103, 92, 135]),\n",
- " array([ 56, 14, 112, 143, 93, 26, 108, 78, 144, 100, 117, 29, 126,\n",
- " 97, 28, 10, 81, 3, 90, 91])]"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "list(split(X, y))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[(70, 4), (60, 4), (20, 4)]"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "[X[index].shape for index in split(X, y)]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[(array([0, 1, 2]), array([24, 23, 23])),\n",
- " (array([0, 1, 2]), array([20, 20, 20])),\n",
- " (array([0, 1, 2]), array([6, 7, 7]))]"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "[np.unique(y[index], return_counts=True) for index in split(X, y)]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "# first split 40/60 and split the remainder 50/50\n",
- "split = recursive_stratified_shuffle_split(sizes=[0.4, 0.5])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[(90, 4), (30, 4), (30, 4)]"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "[X[index].shape for index in split(X, y)]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[(array([0, 1, 2]), array([30, 30, 30])),\n",
- " (array([0, 1, 2]), array([10, 10, 10])),\n",
- " (array([0, 1, 2]), array([10, 10, 10]))]"
- ]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "[np.unique(y[index], return_counts=True) for index in split(X, y)]"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment