SHARE
TWEET

Untitled

a guest Mar 22nd, 2019 71 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. {
  2.  "cells": [
  3.   {
  4.    "cell_type": "code",
  5.    "execution_count": 1,
  6.    "metadata": {},
  7.    "outputs": [],
  8.    "source": [
  9.     "%matplotlib inline"
  10.    ]
  11.   },
  12.   {
  13.    "cell_type": "code",
  14.    "execution_count": 2,
  15.    "metadata": {},
  16.    "outputs": [],
  17.    "source": [
  18.     "import numpy as np\n",
  19.     "\n",
  20.     "import matplotlib.pyplot as plt\n",
  21.     "\n",
  22.     "from sklearn.model_selection import StratifiedShuffleSplit\n",
  23.     "from sklearn.datasets import load_iris\n",
  24.     "from matplotlib.patches import Patch"
  25.    ]
  26.   },
  27.   {
  28.    "cell_type": "code",
  29.    "execution_count": 3,
  30.    "metadata": {},
  31.    "outputs": [],
  32.    "source": [
  33.     "plt.style.use('seaborn-colorblind')"
  34.    ]
  35.   },
  36.   {
  37.    "cell_type": "code",
  38.    "execution_count": 4,
  39.    "metadata": {},
  40.    "outputs": [],
  41.    "source": [
  42.     "random_state = 42\n",
  43.     "cmap_data = plt.cm.Paired\n",
  44.     "cmap_cv = plt.cm.coolwarm"
  45.    ]
  46.   },
  47.   {
  48.    "cell_type": "code",
  49.    "execution_count": 5,
  50.    "metadata": {},
  51.    "outputs": [],
  52.    "source": [
  53.     "X, y = load_iris(return_X_y=True)"
  54.    ]
  55.   },
  56.   {
  57.    "cell_type": "code",
  58.    "execution_count": 6,
  59.    "metadata": {},
  60.    "outputs": [
  61.     {
  62.      "data": {
  63.       "text/plain": [
  64.        "(150, 4)"
  65.       ]
  66.      },
  67.      "execution_count": 6,
  68.      "metadata": {},
  69.      "output_type": "execute_result"
  70.     }
  71.    ],
  72.    "source": [
  73.     "X.shape"
  74.    ]
  75.   },
  76.   {
  77.    "cell_type": "code",
  78.    "execution_count": 7,
  79.    "metadata": {},
  80.    "outputs": [
  81.     {
  82.      "data": {
  83.       "text/plain": [
  84.        "(150,)"
  85.       ]
  86.      },
  87.      "execution_count": 7,
  88.      "metadata": {},
  89.      "output_type": "execute_result"
  90.     }
  91.    ],
  92.    "source": [
  93.     "y.shape"
  94.    ]
  95.   },
  96.   {
  97.    "cell_type": "code",
  98.    "execution_count": 8,
  99.    "metadata": {},
  100.    "outputs": [],
  101.    "source": [
  102.     "def plot_cv_indices(X, y, cv, ax=None):\n",
  103.     "    \"\"\"Create a sample plot for indices of a cross-validation object.\"\"\"\n",
  104.     "\n",
  105.     "    num_samples = len(X)\n",
  106.     "\n",
  107.     "    # Generate the training/testing visualizations for each CV split\n",
  108.     "    for i, partition_inds in enumerate(cv.split(X, y)):\n",
  109.     "\n",
  110.     "        # Fill in indices with the training/test groups\n",
  111.     "        indices = np.empty_like(y)\n",
  112.     "\n",
  113.     "        for j, ind in enumerate(partition_inds):\n",
  114.     "\n",
  115.     "            indices[ind] = j\n",
  116.     "\n",
  117.     "            # Visualize the results\n",
  118.     "            ax.scatter(range(len(indices)), [i + 0.5] * len(indices),\n",
  119.     "                       c=indices, marker='_', lw=10, cmap=cmap_cv)\n",
  120.     "#                        vmin=-0.2, vmax=1.2)\n",
  121.     "\n",
  122.     "    # Plot the data classes and groups at the end\n",
  123.     "    ax.scatter(range(num_samples), [i + 1.5] * num_samples,\n",
  124.     "               c=y, marker='_', lw=10, cmap=cmap_data)\n",
  125.     "\n",
  126.     "    n_splits = i + 1\n",
  127.     "\n",
  128.     "    # Formatting\n",
  129.     "\n",
  130.     "    yticklabels = list(range(n_splits))\n",
  131.     "    yticklabels.append('class')\n",
  132.     "\n",
  133.     "    ax.set(yticks=np.arange(n_splits+1) + 0.5, yticklabels=yticklabels,\n",
  134.     "           xlabel='Sample index', ylabel=\"CV iteration\",\n",
  135.     "           ylim=[n_splits + 1.2, -0.2], xlim=[0, num_samples])\n",
  136.     "    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)\n",
  137.     "\n",
  138.     "    return ax"
  139.    ]
  140.   },
  141.   {
  142.    "cell_type": "code",
  143.    "execution_count": 9,
  144.    "metadata": {},
  145.    "outputs": [
  146.     {
  147.      "data": {
  148.       "text/plain": [
  149.        "StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2,\n",
  150.        "            train_size=None)"
  151.       ]
  152.      },
  153.      "execution_count": 9,
  154.      "metadata": {},
  155.      "output_type": "execute_result"
  156.     }
  157.    ],
  158.    "source": [
  159.     "sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=random_state)\n",
  160.     "sss"
  161.    ]
  162.   },
  163.   {
  164.    "cell_type": "code",
  165.    "execution_count": 10,
  166.    "metadata": {},
  167.    "outputs": [
  168.     {
  169.      "data": {
  170.       "text/plain": [
  171.        "1"
  172.       ]
  173.      },
  174.      "execution_count": 10,
  175.      "metadata": {},
  176.      "output_type": "execute_result"
  177.     }
  178.    ],
  179.    "source": [
  180.     "sss.get_n_splits(X, y)"
  181.    ]
  182.   },
  183.   {
  184.    "cell_type": "code",
  185.    "execution_count": 11,
  186.    "metadata": {},
  187.    "outputs": [
  188.     {
  189.      "data": {
  190.       "text/plain": [
  191.        "(array([  8, 106,  76,   9,  89, 146,  94, 133, 135, 117, 105,  78,  60,\n",
  192.        "         67,  92,  29,  19, 108, 137,  31,  72,  46, 131,  25,  81, 103,\n",
  193.        "        111,  26, 121,  40,  21,  71,  91,  15, 139, 148,  79,  59, 144,\n",
  194.        "         70,   6,  50,  16, 130,   1,  17, 101,  35,  41,  12,  45,  64,\n",
  195.        "        100,  55,   0, 149,  99, 136,  47, 142,  36,  53, 113,  24,  83,\n",
  196.        "         90, 122,  66,  54, 115,  39,  23,   4, 119,  82, 129,  80, 145,\n",
  197.        "        123,  85,  34, 114,  68,  43, 120,  32, 109,  98,  86,  30,  97,\n",
  198.        "        110,  44,  13, 124, 118, 112,  87, 126,   5, 143,  96, 125, 102,\n",
  199.        "         48,  74,  73,  95,  88,  65,  27, 128,  62,  61,  11,  37,   2,\n",
  200.        "         33,  52,   3]),\n",
  201.        " array([ 38, 127,  57,  93,  42,  56,  22,  20, 147,  84, 107, 141, 104,\n",
  202.        "         51,   7,  49,  14,  69,  63, 138,  10, 140,  58, 134, 132,  77,\n",
  203.        "         75,  18, 116,  28]))"
  204.       ]
  205.      },
  206.      "execution_count": 11,
  207.      "metadata": {},
  208.      "output_type": "execute_result"
  209.     }
  210.    ],
  211.    "source": [
  212.     "train_index, test_index = next(sss.split(X, y))\n",
  213.     "train_index, test_index"
  214.    ]
  215.   },
  216.   {
  217.    "cell_type": "code",
  218.    "execution_count": 12,
  219.    "metadata": {},
  220.    "outputs": [
  221.     {
  222.      "data": {
  223.       "image/png": "\n",
  224.       "text/plain": [
  225.        "<Figure size 720x72 with 1 Axes>"
  226.       ]
  227.      },
  228.      "metadata": {
  229.       "needs_background": "light"
  230.      },
  231.      "output_type": "display_data"
  232.     }
  233.    ],
  234.    "source": [
  235.     "fig, ax = plt.subplots(figsize=(10, 1))\n",
  236.     "\n",
  237.     "plot_cv_indices(X, y, sss, ax=ax)\n",
  238.     "\n",
  239.     "plt.show()"
  240.    ]
  241.   },
  242.   {
  243.    "cell_type": "markdown",
  244.    "metadata": {},
  245.    "source": [
  246.     "#### Train-test Proportion\n",
  247.     "\n",
  248.     "Verify that the requested train-test proportion has been satisfied:"
  249.    ]
  250.   },
  251.   {
  252.    "cell_type": "code",
  253.    "execution_count": 13,
  254.    "metadata": {},
  255.    "outputs": [
  256.     {
  257.      "data": {
  258.       "text/plain": [
  259.        "(120, 4)"
  260.       ]
  261.      },
  262.      "execution_count": 13,
  263.      "metadata": {},
  264.      "output_type": "execute_result"
  265.     }
  266.    ],
  267.    "source": [
  268.     "X[train_index].shape"
  269.    ]
  270.   },
  271.   {
  272.    "cell_type": "code",
  273.    "execution_count": 14,
  274.    "metadata": {},
  275.    "outputs": [
  276.     {
  277.      "data": {
  278.       "text/plain": [
  279.        "(30, 4)"
  280.       ]
  281.      },
  282.      "execution_count": 14,
  283.      "metadata": {},
  284.      "output_type": "execute_result"
  285.     }
  286.    ],
  287.    "source": [
  288.     "X[test_index].shape"
  289.    ]
  290.   },
  291.   {
  292.    "cell_type": "code",
  293.    "execution_count": 15,
  294.    "metadata": {},
  295.    "outputs": [
  296.     {
  297.      "data": {
  298.       "text/plain": [
  299.        "0.2"
  300.       ]
  301.      },
  302.      "execution_count": 15,
  303.      "metadata": {},
  304.      "output_type": "execute_result"
  305.     }
  306.    ],
  307.    "source": [
  308.     "len(test_index) / len(np.hstack((train_index, test_index)))"
  309.    ]
  310.   },
  311.   {
  312.    "cell_type": "markdown",
  313.    "metadata": {},
  314.    "source": [
  315.     "#### Stratified\n",
  316.     "\n",
  317.     "Verify that the proportion of class labels are stratified across splits."
  318.    ]
  319.   },
  320.   {
  321.    "cell_type": "code",
  322.    "execution_count": 16,
  323.    "metadata": {},
  324.    "outputs": [
  325.     {
  326.      "data": {
  327.       "text/plain": [
  328.        "(array([0, 1, 2]), array([40, 40, 40]))"
  329.       ]
  330.      },
  331.      "execution_count": 16,
  332.      "metadata": {},
  333.      "output_type": "execute_result"
  334.     }
  335.    ],
  336.    "source": [
  337.     "np.unique(y[train_index], return_counts=True)"
  338.    ]
  339.   },
  340.   {
  341.    "cell_type": "code",
  342.    "execution_count": 17,
  343.    "metadata": {
  344.     "scrolled": true
  345.    },
  346.    "outputs": [
  347.     {
  348.      "data": {
  349.       "text/plain": [
  350.        "(array([0, 1, 2]), array([10, 10, 10]))"
  351.       ]
  352.      },
  353.      "execution_count": 17,
  354.      "metadata": {},
  355.      "output_type": "execute_result"
  356.     }
  357.    ],
  358.    "source": [
  359.     "np.unique(y[test_index], return_counts=True)"
  360.    ]
  361.   },
  362.   {
  363.    "cell_type": "markdown",
  364.    "metadata": {},
  365.    "source": [
  366.     "### Visualization\n",
  367.     "\n",
  368.     "Visualize for a large number of CV splits (10)."
  369.    ]
  370.   },
  371.   {
  372.    "cell_type": "code",
  373.    "execution_count": 18,
  374.    "metadata": {},
  375.    "outputs": [
  376.     {
  377.      "data": {
  378.       "text/plain": [
  379.        "StratifiedShuffleSplit(n_splits=10, random_state=42, test_size=0.1,\n",
  380.        "            train_size=None)"
  381.       ]
  382.      },
  383.      "execution_count": 18,
  384.      "metadata": {},
  385.      "output_type": "execute_result"
  386.     }
  387.    ],
  388.    "source": [
  389.     "sss = StratifiedShuffleSplit(test_size=0.1, random_state=random_state)\n",
  390.     "sss"
  391.    ]
  392.   },
  393.   {
  394.    "cell_type": "code",
  395.    "execution_count": 19,
  396.    "metadata": {},
  397.    "outputs": [
  398.     {
  399.      "data": {
  400.       "image/png": "\n",
  401.       "text/plain": [
  402.        "<Figure size 720x360 with 1 Axes>"
  403.       ]
  404.      },
  405.      "metadata": {
  406.       "needs_background": "light"
  407.      },
  408.      "output_type": "display_data"
  409.     }
  410.    ],
  411.    "source": [
  412.     "fig, ax = plt.subplots(figsize=(10, 5))\n",
  413.     "\n",
  414.     "plot_cv_indices(X, y, sss, ax=ax)\n",
  415.     "\n",
  416.     "plt.show()"
  417.    ]
  418.   },
  419.   {
  420.    "cell_type": "markdown",
  421.    "metadata": {},
  422.    "source": [
  423.     "## Generalizing to Multiple Partitions"
  424.    ]
  425.   },
  426.   {
  427.    "cell_type": "code",
  428.    "execution_count": 20,
  429.    "metadata": {},
  430.    "outputs": [],
  431.    "source": [
  432.     "class RecursiveStratifiedShuffleSplit(StratifiedShuffleSplit):\n",
  433.     "    \n",
  434.     "    def __init__(self, n_splits=10, partition_sizes=None, random_state=None):\n",
  435.     "\n",
  436.     "        if partition_sizes is None:\n",
  437.     "\n",
  438.     "            partition_sizes = ['default']\n",
  439.     "\n",
  440.     "        head_size, *tail_sizes = partition_sizes\n",
  441.     "\n",
  442.     "        if tail_sizes:\n",
  443.     "\n",
  444.     "            self.rsss_tail = RecursiveStratifiedShuffleSplit(n_splits=1,\n",
  445.     "                                                             partition_sizes=tail_sizes,\n",
  446.     "                                                             random_state=random_state)\n",
  447.     "        else:\n",
  448.     "\n",
  449.     "            self.rsss_tail = None\n",
  450.     "\n",
  451.     "        super(RecursiveStratifiedShuffleSplit, self).__init__(n_splits=n_splits, \n",
  452.     "                                                              test_size=head_size, \n",
  453.     "                                                              random_state=random_state)    \n",
  454.     "\n",
  455.     "    def split(self, X, y, groups=None):\n",
  456.     "        \n",
  457.     "        # iterate over `n_splits` splits\n",
  458.     "        for a_ind, b_ind in super(RecursiveStratifiedShuffleSplit, self).split(X, y):\n",
  459.     "\n",
  460.     "            inds = [a_ind]\n",
  461.     "\n",
  462.     "            if self.rsss_tail is None:\n",
  463.     "\n",
  464.     "                inds.append(b_ind)\n",
  465.     "\n",
  466.     "            else:\n",
  467.     "\n",
  468.     "                # generator yields only `n_splits=1` split by definition\n",
  469.     "                tail_inds = next(self.rsss_tail.split(X[b_ind], y[b_ind]))\n",
  470.     "\n",
  471.     "                # iterate through `len(partition_sizes) + 1` indices\n",
  472.     "                for ind in tail_inds:\n",
  473.     "\n",
  474.     "                    inds.append(b_ind[ind])\n",
  475.     "  \n",
  476.     "            yield tuple(inds)"
  477.    ]
  478.   },
  479.   {
  480.    "cell_type": "code",
  481.    "execution_count": 21,
  482.    "metadata": {},
  483.    "outputs": [],
  484.    "source": [
  485.     "rsss = RecursiveStratifiedShuffleSplit(n_splits=1, partition_sizes=[0.2, 0.5], random_state=random_state)"
  486.    ]
  487.   },
  488.   {
  489.    "cell_type": "code",
  490.    "execution_count": 22,
  491.    "metadata": {},
  492.    "outputs": [
  493.     {
  494.      "data": {
  495.       "text/plain": [
  496.        "1"
  497.       ]
  498.      },
  499.      "execution_count": 22,
  500.      "metadata": {},
  501.      "output_type": "execute_result"
  502.     }
  503.    ],
  504.    "source": [
  505.     "rsss.get_n_splits()"
  506.    ]
  507.   },
  508.   {
  509.    "cell_type": "code",
  510.    "execution_count": 23,
  511.    "metadata": {},
  512.    "outputs": [
  513.     {
  514.      "data": {
  515.       "text/plain": [
  516.        "(array([  8, 106,  76,   9,  89, 146,  94, 133, 135, 117, 105,  78,  60,\n",
  517.        "         67,  92,  29,  19, 108, 137,  31,  72,  46, 131,  25,  81, 103,\n",
  518.        "        111,  26, 121,  40,  21,  71,  91,  15, 139, 148,  79,  59, 144,\n",
  519.        "         70,   6,  50,  16, 130,   1,  17, 101,  35,  41,  12,  45,  64,\n",
  520.        "        100,  55,   0, 149,  99, 136,  47, 142,  36,  53, 113,  24,  83,\n",
  521.        "         90, 122,  66,  54, 115,  39,  23,   4, 119,  82, 129,  80, 145,\n",
  522.        "        123,  85,  34, 114,  68,  43, 120,  32, 109,  98,  86,  30,  97,\n",
  523.        "        110,  44,  13, 124, 118, 112,  87, 126,   5, 143,  96, 125, 102,\n",
  524.        "         48,  74,  73,  95,  88,  65,  27, 128,  62,  61,  11,  37,   2,\n",
  525.        "         33,  52,   3]),\n",
  526.        " array([ 42, 127,  10, 116, 140,  77,  93,  38,  49,  18,  69, 132,  57,\n",
  527.        "         84, 107]),\n",
  528.        " array([  7,  75, 147,  22, 141, 138,  51,  56, 134,  20,  28, 104,  63,\n",
  529.        "         14,  58]))"
  530.       ]
  531.      },
  532.      "execution_count": 23,
  533.      "metadata": {},
  534.      "output_type": "execute_result"
  535.     }
  536.    ],
  537.    "source": [
  538.     "train_index, val_index, test_index = next(rsss.split(X, y))\n",
  539.     "train_index, val_index, test_index"
  540.    ]
  541.   },
  542.   {
  543.    "cell_type": "code",
  544.    "execution_count": 24,
  545.    "metadata": {},
  546.    "outputs": [
  547.     {
  548.      "data": {
  549.       "image/png": "\n",
  550.       "text/plain": [
  551.        "<Figure size 720x72 with 1 Axes>"
  552.       ]
  553.      },
  554.      "metadata": {
  555.       "needs_background": "light"
  556.      },
  557.      "output_type": "display_data"
  558.     }
  559.    ],
  560.    "source": [
  561.     "fig, ax = plt.subplots(figsize=(10, 1))\n",
  562.     "\n",
  563.     "plot_cv_indices(X, y, rsss, ax=ax)\n",
  564.     "\n",
  565.     "plt.show()"
  566.    ]
  567.   },
  568.   {
  569.    "cell_type": "markdown",
  570.    "metadata": {},
  571.    "source": [
  572.     "#### Train-val-test Proportion\n",
  573.     "\n",
  574.     "Verify that the requested train-val-test proportions have been satisfied:"
  575.    ]
  576.   },
  577.   {
  578.    "cell_type": "code",
  579.    "execution_count": 25,
  580.    "metadata": {},
  581.    "outputs": [
  582.     {
  583.      "data": {
  584.       "text/plain": [
  585.        "(120, 4)"
  586.       ]
  587.      },
  588.      "execution_count": 25,
  589.      "metadata": {},
  590.      "output_type": "execute_result"
  591.     }
  592.    ],
  593.    "source": [
  594.     "X[train_index].shape"
  595.    ]
  596.   },
  597.   {
  598.    "cell_type": "code",
  599.    "execution_count": 26,
  600.    "metadata": {},
  601.    "outputs": [
  602.     {
  603.      "data": {
  604.       "text/plain": [
  605.        "(15, 4)"
  606.       ]
  607.      },
  608.      "execution_count": 26,
  609.      "metadata": {},
  610.      "output_type": "execute_result"
  611.     }
  612.    ],
  613.    "source": [
  614.     "X[val_index].shape"
  615.    ]
  616.   },
  617.   {
  618.    "cell_type": "code",
  619.    "execution_count": 27,
  620.    "metadata": {},
  621.    "outputs": [
  622.     {
  623.      "data": {
  624.       "text/plain": [
  625.        "(15, 4)"
  626.       ]
  627.      },
  628.      "execution_count": 27,
  629.      "metadata": {},
  630.      "output_type": "execute_result"
  631.     }
  632.    ],
  633.    "source": [
  634.     "X[test_index].shape"
  635.    ]
  636.   },
  637.   {
  638.    "cell_type": "code",
  639.    "execution_count": 28,
  640.    "metadata": {},
  641.    "outputs": [
  642.     {
  643.      "data": {
  644.       "text/plain": [
  645.        "0.2"
  646.       ]
  647.      },
  648.      "execution_count": 28,
  649.      "metadata": {},
  650.      "output_type": "execute_result"
  651.     }
  652.    ],
  653.    "source": [
  654.     "len(np.hstack((val_index, test_index))) / len(np.hstack((train_index, val_index, test_index)))"
  655.    ]
  656.   },
  657.   {
  658.    "cell_type": "code",
  659.    "execution_count": 29,
  660.    "metadata": {},
  661.    "outputs": [
  662.     {
  663.      "data": {
  664.       "text/plain": [
  665.        "0.5"
  666.       ]
  667.      },
  668.      "execution_count": 29,
  669.      "metadata": {},
  670.      "output_type": "execute_result"
  671.     }
  672.    ],
  673.    "source": [
  674.     "len(test_index) / len(np.hstack((val_index, test_index)))"
  675.    ]
  676.   },
  677.   {
  678.    "cell_type": "markdown",
  679.    "metadata": {},
  680.    "source": [
  681.     "#### Stratified\n",
  682.     "\n",
  683.     "Verify that the proportion of class labels are stratified across splits."
  684.    ]
  685.   },
  686.   {
  687.    "cell_type": "code",
  688.    "execution_count": 30,
  689.    "metadata": {},
  690.    "outputs": [
  691.     {
  692.      "data": {
  693.       "text/plain": [
  694.        "(array([0, 1, 2]), array([40, 40, 40]))"
  695.       ]
  696.      },
  697.      "execution_count": 30,
  698.      "metadata": {},
  699.      "output_type": "execute_result"
  700.     }
  701.    ],
  702.    "source": [
  703.     "np.unique(y[train_index], return_counts=True)"
  704.    ]
  705.   },
  706.   {
  707.    "cell_type": "code",
  708.    "execution_count": 31,
  709.    "metadata": {},
  710.    "outputs": [
  711.     {
  712.      "data": {
  713.       "text/plain": [
  714.        "(array([0, 1, 2]), array([5, 5, 5]))"
  715.       ]
  716.      },
  717.      "execution_count": 31,
  718.      "metadata": {},
  719.      "output_type": "execute_result"
  720.     }
  721.    ],
  722.    "source": [
  723.     "np.unique(y[val_index], return_counts=True)"
  724.    ]
  725.   },
  726.   {
  727.    "cell_type": "code",
  728.    "execution_count": 32,
  729.    "metadata": {
  730.     "scrolled": true
  731.    },
  732.    "outputs": [
  733.     {
  734.      "data": {
  735.       "text/plain": [
  736.        "(array([0, 1, 2]), array([5, 5, 5]))"
  737.       ]
  738.      },
  739.      "execution_count": 32,
  740.      "metadata": {},
  741.      "output_type": "execute_result"
  742.     }
  743.    ],
  744.    "source": [
  745.     "np.unique(y[test_index], return_counts=True)"
  746.    ]
  747.   },
  748.   {
  749.    "cell_type": "markdown",
  750.    "metadata": {},
  751.    "source": [
  752.     "### Visualization\n",
  753.     "\n",
  754.     "Visualize for a large number of CV splits (10). Each partition should satisfy the properties we tested above."
  755.    ]
  756.   },
  757.   {
  758.    "cell_type": "code",
  759.    "execution_count": 33,
  760.    "metadata": {},
  761.    "outputs": [
  762.     {
  763.      "data": {
  764.       "text/plain": [
  765.        "RecursiveStratifiedShuffleSplit(n_splits=10, partition_sizes=None,\n",
  766.        "                random_state=42)"
  767.       ]
  768.      },
  769.      "execution_count": 33,
  770.      "metadata": {},
  771.      "output_type": "execute_result"
  772.     }
  773.    ],
  774.    "source": [
  775.     "rsss = RecursiveStratifiedShuffleSplit(partition_sizes=[0.2, 0.5], random_state=random_state)\n",
  776.     "rsss"
  777.    ]
  778.   },
  779.   {
  780.    "cell_type": "code",
  781.    "execution_count": 34,
  782.    "metadata": {},
  783.    "outputs": [
  784.     {
  785.      "data": {
  786.       "image/png": "\n",
  787.       "text/plain": [
  788.        "<Figure size 720x360 with 1 Axes>"
  789.       ]
  790.      },
  791.      "metadata": {
  792.       "needs_background": "light"
  793.      },
  794.      "output_type": "display_data"
  795.     }
  796.    ],
  797.    "source": [
  798.     "fig, ax = plt.subplots(figsize=(10, 5))\n",
  799.     "\n",
  800.     "plot_cv_indices(X, y, rsss, ax=ax)\n",
  801.     "\n",
  802.     "plt.show()"
  803.    ]
  804.   },
  805.   {
  806.    "cell_type": "markdown",
  807.    "metadata": {},
  808.    "source": [
  809.     "### Even more partitions\n",
  810.     "\n",
  811.     "We can approximately half the dataset, and further approximately half one of the remaining halves, and so on until this process can no longer be repeated."
  812.    ]
  813.   },
  814.   {
  815.    "cell_type": "code",
  816.    "execution_count": 35,
  817.    "metadata": {},
  818.    "outputs": [
  819.     {
  820.      "data": {
  821.       "text/plain": [
  822.        "RecursiveStratifiedShuffleSplit(n_splits=10, partition_sizes=None,\n",
  823.        "                random_state=42)"
  824.       ]
  825.      },
  826.      "execution_count": 35,
  827.      "metadata": {},
  828.      "output_type": "execute_result"
  829.     }
  830.    ],
  831.    "source": [
  832.     "rsss = RecursiveStratifiedShuffleSplit(partition_sizes=[0.5, 0.5, 0.5, 0.5], random_state=random_state)\n",
  833.     "rsss"
  834.    ]
  835.   },
  836.   {
  837.    "cell_type": "code",
  838.    "execution_count": 36,
  839.    "metadata": {},
  840.    "outputs": [
  841.     {
  842.      "data": {
  843.       "text/plain": [
  844.        "[75, 37, 19, 9, 10]"
  845.       ]
  846.      },
  847.      "execution_count": 36,
  848.      "metadata": {},
  849.      "output_type": "execute_result"
  850.     }
  851.    ],
  852.    "source": [
  853.     "[len(ind) for ind in next(rsss.split(X, y))]"
  854.    ]
  855.   },
  856.   {
  857.    "cell_type": "code",
  858.    "execution_count": 37,
  859.    "metadata": {},
  860.    "outputs": [
  861.     {
  862.      "data": {
  863.       "image/png": "\n",
  864.       "text/plain": [
  865.        "<Figure size 720x360 with 1 Axes>"
  866.       ]
  867.      },
  868.      "metadata": {
  869.       "needs_background": "light"
  870.      },
  871.      "output_type": "display_data"
  872.     }
  873.    ],
  874.    "source": [
  875.     "fig, ax = plt.subplots(figsize=(10, 5))\n",
  876.     "\n",
  877.     "plot_cv_indices(X, y, rsss, ax=ax)\n",
  878.     "\n",
  879.     "plt.show()"
  880.    ]
  881.   }
  882.  ],
  883.  "metadata": {
  884.   "kernelspec": {
  885.    "display_name": "Python 3",
  886.    "language": "python",
  887.    "name": "python3"
  888.   },
  889.   "language_info": {
  890.    "codemirror_mode": {
  891.     "name": "ipython",
  892.     "version": 3
  893.    },
  894.    "file_extension": ".py",
  895.    "mimetype": "text/x-python",
  896.    "name": "python",
  897.    "nbconvert_exporter": "python",
  898.    "pygments_lexer": "ipython3",
  899.    "version": "3.5.2"
  900.   }
  901.  },
  902.  "nbformat": 4,
  903.  "nbformat_minor": 2
  904. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top