limmen

pentest_mdp_notebook

Dec 4th, 2020
775
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. {
  2.  "cells": [
  3.   {
  4.    "cell_type": "markdown",
  5.    "metadata": {},
  6.    "source": [
  7.     "# A Markov Decision Process of a Penetration Tester\n",
  8.     "\n",
  9.     "This notebook includes a simple MDP model of a penetration tester and code for solving the MDP using value iteration.\n",
  10.     "\n",
  11.     "Kim Hammar, 13 February 2020, KTH Royal Institute of Technology"
  12.    ]
  13.   },
  14.   {
  15.    "cell_type": "markdown",
  16.    "metadata": {},
  17.    "source": [
  18.     "## Background Theory\n",
  19.     "\n",
  20.     "This section includes some background theory on MDPs"
  21.    ]
  22.   },
  23.   {
  24.    "cell_type": "markdown",
  25.    "metadata": {},
  26.    "source": [
  27.     "### Markov Process\n",
  28.     "A Markov process is a discrete stochastic process (sequence of random variables ${X_1, X_2,..., X_n}$ that satisfies the Markov property.\n",
  29.     "\n",
  30.     "**Markov Property**: A stochastic process ${X_1, X_2,..., X_n}$ is Markovian iff $\\mathbb{P}[X_{t+1}|X_t, X_{t+1},..., X_t] = \\mathbb{P}[X_{t+1}|X_t]$\n",
  31.     "\n",
  32.     "A Markov process is defined by two elements: $\\langle S, \\mathcal{P}$, where $S$ denotes the set of states and $\\mathcal{P}$ denotes the transition probabilities."
  33.    ]
  34.   },
  35.   {
  36.    "cell_type": "markdown",
  37.    "metadata": {},
  38.    "source": [
  39.     "### Markov Reward Process\n",
  40.     "A markov **reward** process (MRP) extends Markov processes to also include a reward function $\\mathcal{R}_{ss^{\\prime}}$ that defines a scalar reward for each state transition in the Markov process. Thus, a Markov reward process is defined by a tuple of three elements: $\\langle S, \\mathcal{P}_{s,s^{\\prime}}, \\mathcal{R}_{s,s^{\\prime}}\\rangle$ where $S$ denotes the set of states, $\\mathcal{P}_{s,s^{\\prime}}$ denotes the transition probabilities, and $\\mathcal{R}_{s,s^{\\prime}}$ denotes the reward function."
  41.    ]
  42.   },
  43.   {
  44.    "cell_type": "markdown",
  45.    "metadata": {},
  46.    "source": [
  47.     "### Markov Decision Process\n",
  48.     "\n",
  49.     "A Markov **decision** process if a further extension to Markov processes and MRPs that adds a set of actions $\\mathcal{A}$ to the process. A Markov Decision process is defined by a tuple of five elements $\\langle S, \\mathcal{P}_{s,s^{\\prime}}^a, \\mathcal{R}_{s,s^{\\prime}}^a, \\mathcal{A}, \\gamma \\rangle$ where $S$ denotes the set of states, $\\mathcal{P}_{s,s^{\\prime}}^a$ denotes the transition probabilities, $\\mathcal{R}_{s,s^{\\prime}}^a$ denotes the reward function, $\\mathcal{A}$ denotes the set of actions, and $\\gamma$ denotes the discount factor.\n",
  50.     "\n",
  51.     "**Fundamental Theorem of MDPs:**: For any Markov Decision Process,\n",
  52.     "- There exists an optimal policy $\\pi$ that is better than or equal to all other policies $\\pi^{*} \\geq \\pi, \\forall \\pi$\n",
  53.     "- All optimal policies achieve the optimal value function $V^{\\pi^{*}} = V^{*}(s)$\n",
  54.     "- All optimal policies achieve the optimal action-value function $Q^{\\pi^{*}} = Q^{*}$"
  55.    ]
  56.   },
  57.   {
  58.    "cell_type": "markdown",
  59.    "metadata": {},
  60.    "source": [
  61.     "### Solving a Markov Decision Process with Policy Iteration\n",
  62.     "\n",
  63.     "If the agent has perfect knowledge about the MDP it is acting in—the tuple $\\langle S, \\mathcal{P}_{s,s^{\\prime}}^a, \\mathcal{R}_{s,s^{\\prime}}^a, \\mathcal{A}, \\gamma \\rangle$ is given—the problem of computing the value function for a policy $\\pi$ can be solved using computation and **planning**, without requiring the agent to take a single action in the environment. \n",
  64.     "\n",
  65.     "In this notebook we'll use a dynamic programming algorithm called **value iteration** for solving the MDP."
  66.    ]
  67.   },
  68.   {
  69.    "cell_type": "markdown",
  70.    "metadata": {},
  71.    "source": [
  72.     "#### Value Iteration\n",
  73.     "\n",
  74.     "Value iteration is an alternative dynamic programming algorithm to policy iteration. Iterating policy evaluation and policy improvement, as is done in policy iteration, is guaranteed to converge. However, it entails doing\n",
  75.     "a substantial amount of work in each iteration, and it can take a long time to converge for large MDPs. Is it possible to redefine the algorithm to do less work per iteration and converge faster while still pertaining the convergence guarantees?"
  76.    ]
  77.   },
  78.   {
  79.    "cell_type": "markdown",
  80.    "metadata": {},
  81.    "source": [
  82.     "Recall that policy iteration uses what is called *``full backup''* to iteratively compute $V^{\\pi}(s)$ under a fixed policy $\\pi$. Full backup refers to the fact that policy iteration computes a sum over all possible actions $\\mathcal{A}(s)$ and successor states $s^{\\prime}$ and then backs up the summation to the value estimate of the current state. Value iteration on the other hand is a dynamic programming algorithm that truncates the full backup in policy iteration to just consider the dominant action instead.\n",
  83.     "\n",
  84.     "The value iteration update is defined below:"
  85.    ]
  86.   },
  87.   {
  88.    "cell_type": "markdown",
  89.    "metadata": {},
  90.    "source": [
  91.     "\\begin{align}\\label{eq:value_iter}\n",
  92.     "  V_{k+1}(s) &= \\max_a\\mathbb{E}_{\\pi}\\left[r_{t+1} + \\gamma V_k(s_{t+1}) | s_t = s, a_t=a\\right] && \\text{Value iteration}\\\\\n",
  93.     "               &= \\displaystyle\\max_a \\sum_{s^{\\prime}} \\mathcal{P}_{ss^{\\prime}}^a\\left[\\mathcal{R}_{ss^{\\prime}}^{a} + \\gamma V_k{s^{\\prime}}\\right]\n",
  94.     "\\end{align}"
  95.    ]
  96.   },
  97.   {
  98.    "cell_type": "markdown",
  99.    "metadata": {},
  100.    "source": [
  101.     "In value iteration, a single backup is performed in each iteration rather than a full backup|. Moreover, value iteration incorporates the policy-improvement step of policy iteration in the value iteration update, effectively updating the policy at the same time as updating the values. Value iteration has the same convergence guarantees as policy iteration (will converge to $V^{*}$ in the limit, but is generally more computationally efficient. In fact, value iteration is just a reformulation of the Bellman optimality equation as an update rule."
  102.    ]
  103.   },
  104.   {
  105.    "cell_type": "markdown",
  106.    "metadata": {},
  107.    "source": [
  108.     "## Imports"
  109.    ]
  110.   },
  111.   {
  112.    "cell_type": "code",
  113.    "execution_count": null,
  114.    "metadata": {},
  115.    "outputs": [],
  116.    "source": [
  117.     "import numpy as np\n",
  118.     "import pandas as pd"
  119.    ]
  120.   },
  121.   {
  122.    "cell_type": "markdown",
  123.    "metadata": {},
  124.    "source": [
  125.     "## Define the MDP\n",
  126.     "\n",
  127.     "In this section the MDP model under study is defined. Recall that a Markov Decision Process (MDP) is defined by the tuple $\\langle S, \\mathcal{P}_{s,s^{\\prime}}^a, \\mathcal{R}_{s,s^{\\prime}}^a, \\mathcal{A}, \\gamma \\rangle$."
  128.    ]
  129.   },
  130.   {
  131.    "cell_type": "markdown",
  132.    "metadata": {},
  133.    "source": [
  134.     "### Transition Diagram\n",
  135.     "![Transition Diagram](./hacker_transition_graph.png \"Transition Graph\")"
  136.    ]
  137.   },
  138.   {
  139.    "cell_type": "markdown",
  140.    "metadata": {},
  141.    "source": [
  142.     "### Constants"
  143.    ]
  144.   },
  145.   {
  146.    "cell_type": "code",
  147.    "execution_count": null,
  148.    "metadata": {},
  149.    "outputs": [],
  150.    "source": [
  151.     "NUM_ACTIONS = 7\n",
  152.     "NUM_STATES = 6\n",
  153.     "STATES = [\"Reconnaissance\", \"Vulnerability\", \"Exploit\", \"Caught\", \"Post-Exploit\", \"Quit\"]\n",
  154.     "ACTIONS = [\"network-sniffing\", \"port-scan\", \"exit\", \"exploit-detection\", \"exploit\", \"install-backdoor\",\n",
  155.     "          \"stay\"]"
  156.    ]
  157.   },
  158.   {
  159.    "cell_type": "markdown",
  160.    "metadata": {},
  161.    "source": [
  162.     "### Actions $\\mathcal{A}$\n",
  163.     "\n",
  164.     "The traditional theory on reinforcement learning assumes that the set of actions, $\\mathcal{A}$, are discrete,\n",
  165.     "while in the more general case actions could be either discrete or continuous. In this notebook,\n",
  166.     "actions are assumed to be discrete.\n",
  167.     "\n",
  168.     "\n",
  169.     "The MDP under study in this notebook has 7 actions"
  170.    ]
  171.   },
  172.   {
  173.    "cell_type": "code",
  174.    "execution_count": null,
  175.    "metadata": {},
  176.    "outputs": [],
  177.    "source": [
  178.     "actions = pd.DataFrame(ACTIONS, columns=['Action'])\n",
  179.     "actions.head(NUM_ACTIONS)"
  180.    ]
  181.   },
  182.   {
  183.    "cell_type": "markdown",
  184.    "metadata": {},
  185.    "source": [
  186.     "### States $S$\n",
  187.     "\n",
  188.     "The collection of environment states, $S$, can either be discrete or continuous. In most of the foundational work on reinforcement learning it is assumed that states are discrete and that is the assumption in this notebook as well. The set of states, $S$, is the set of possible states of the Markov Decision\n",
  189.     "Process. The states are assumed to have the Markov property, that is, the states capture all the\n",
  190.     "necessary information in order to select the next action.\n",
  191.     "\n",
  192.     "The MDP under study in this notebook has 6 states"
  193.    ]
  194.   },
  195.   {
  196.    "cell_type": "code",
  197.    "execution_count": null,
  198.    "metadata": {},
  199.    "outputs": [],
  200.    "source": [
  201.     "states = pd.DataFrame(STATES, columns=['State'])\n",
  202.     "states.head(NUM_STATES)"
  203.    ]
  204.   },
  205.   {
  206.    "cell_type": "markdown",
  207.    "metadata": {},
  208.    "source": [
  209.     "### Transition Probabilities $\\mathcal{P}_{ss^{\\prime}}^a$\n",
  210.     "\n",
  211.     "The transition probabilities, $\\mathcal{P}_{ss^{\\prime}}^a$ define the probability of transitioning to state $s^{\\prime}$ when taking action $a$ in state $s$. The transition probabilities can be encoded in a tensor, $\\mathcal{P}_{ss^{\\prime}}^a \\in \\mathbb{R}^{|\\mathcal{A}|\\times |S|\\times |S|}$ with\n",
  212.     "one matrix of state transition probabilities for each action."
  213.    ]
  214.   },
  215.   {
  216.    "cell_type": "code",
  217.    "execution_count": null,
  218.    "metadata": {},
  219.    "outputs": [],
  220.    "source": [
  221.     "P = np.zeros((NUM_ACTIONS,NUM_STATES,NUM_STATES))\n",
  222.     "P[0,0,0] = 0.9 # network sniffing\n",
  223.     "P[0,0,1] = 0.1 # network sniffing\n",
  224.     "P[1,0,0] = 0.8 # port-scan\n",
  225.     "P[1,0,1] = 0.2 # port-scan\n",
  226.     "P[2,0,5] = 1   # exit\n",
  227.     "P[2,4,5] = 1   # exit\n",
  228.     "P[3,1,0] = 0.6 # exploit-detection\n",
  229.     "P[3,1,2] = 0.4 # exploit-detection\n",
  230.     "P[4,2,1] = 0.2 # exploit\n",
  231.     "P[4,2,3] = 0.1 # exploit\n",
  232.     "P[4,2,4] = 0.7 # exploit\n",
  233.     "P[5,4,3] = 0.1 # install-backdoor\n",
  234.     "P[5,4,4] = 0.9 # install-backdoor\n",
  235.     "P[6,3,3] = 1   # stay\n",
  236.     "P[6,5,5] = 1   # stay\n",
  237.     "# Matrix should be stochastic for all rows where the action is feasible\n",
  238.     "for i in range(0, NUM_ACTIONS):\n",
  239.     "    for j in range(0, NUM_STATES):\n",
  240.     "        assert(sum(P[i,j]) == 1.0 or sum(P[i,j]) == 0.0)\n",
  241.     "P.shape"
  242.    ]
  243.   },
  244.   {
  245.    "cell_type": "code",
  246.    "execution_count": null,
  247.    "metadata": {},
  248.    "outputs": [],
  249.    "source": [
  250.     "pd.DataFrame(P[0], columns=STATES, index=pd.Index(STATES, name=\"State transition for action 0\")).head(NUM_STATES)"
  251.    ]
  252.   },
  253.   {
  254.    "cell_type": "code",
  255.    "execution_count": null,
  256.    "metadata": {},
  257.    "outputs": [],
  258.    "source": [
  259.     "pd.DataFrame(P[1], columns=STATES, index=pd.Index(STATES, name=\"State transition for action 1\")).head(NUM_STATES)"
  260.    ]
  261.   },
  262.   {
  263.    "cell_type": "code",
  264.    "execution_count": null,
  265.    "metadata": {},
  266.    "outputs": [],
  267.    "source": [
  268.     "pd.DataFrame(P[2], columns=STATES, index=pd.Index(STATES, name=\"State transition for action 2\")).head(NUM_STATES)"
  269.    ]
  270.   },
  271.   {
  272.    "cell_type": "code",
  273.    "execution_count": null,
  274.    "metadata": {},
  275.    "outputs": [],
  276.    "source": [
  277.     "pd.DataFrame(P[3], columns=STATES, index=pd.Index(STATES, name=\"State transition for action 3\")).head(NUM_STATES)"
  278.    ]
  279.   },
  280.   {
  281.    "cell_type": "code",
  282.    "execution_count": null,
  283.    "metadata": {},
  284.    "outputs": [],
  285.    "source": [
  286.     "pd.DataFrame(P[4], columns=STATES, index=pd.Index(STATES, name=\"State transition for action 4\")).head(NUM_STATES)"
  287.    ]
  288.   },
  289.   {
  290.    "cell_type": "code",
  291.    "execution_count": null,
  292.    "metadata": {},
  293.    "outputs": [],
  294.    "source": [
  295.     "pd.DataFrame(P[5], columns=STATES, index=pd.Index(STATES, name=\"State transition for action 5\")).head(NUM_STATES)"
  296.    ]
  297.   },
  298.   {
  299.    "cell_type": "code",
  300.    "execution_count": null,
  301.    "metadata": {},
  302.    "outputs": [],
  303.    "source": [
  304.     "pd.DataFrame(P[6], columns=STATES, index=pd.Index(STATES, name=\"State transition for action 6\")).head(NUM_STATES)"
  305.    ]
  306.   },
  307.   {
  308.    "cell_type": "markdown",
  309.    "metadata": {},
  310.    "source": [
  311.     "### Reward Function $\\mathcal{R}_{ss^{\\prime}}^a$\n",
  312.     "\n",
  313.     "The reward function instructs the agent how to act in the environment, different reward functions\n",
  314.     "will result in different policies learned by the agent. To make reinforcement learning algorithms run\n",
  315.     "in a reasonable amount of time, it is frequently necessary to use a well-chosen reward function that\n",
  316.     "gives appropriate “hints” to the learning algorithm. Certain choices of rewards may allow an\n",
  317.     "agent to learn orders of magnitude faster; other choices may cause the agent to learn highly suboptimal solutions. Thus, even though it is generally easier to describe a reward function than it is to\n",
  318.     "describe an optimal policy, it is frequently difficult to manually describe good reward functions as well. \n",
  319.     "\n",
  320.     "The transition probabilities can be encoded in a tensor $\\mathcal{R}_{ss^{\\prime}}^a \\in \\mathbb{R}^{|A|\\times |S| \\times |S|}$."
  321.    ]
  322.   },
  323.   {
  324.    "cell_type": "code",
  325.    "execution_count": null,
  326.    "metadata": {},
  327.    "outputs": [],
  328.    "source": [
  329.     "R = np.zeros((NUM_ACTIONS,NUM_STATES,NUM_STATES))\n",
  330.     "R[0,0,1] = 5   # network-sniffing\n",
  331.     "R[1,0,1] = 5   # port-scane\n",
  332.     "R[3,1,0] = -1  # exploit-detection\n",
  333.     "R[3,1,2] = 3   # exploit-detection\n",
  334.     "R[4,2,1] = -1  # exploit\n",
  335.     "R[4,2,3] = -10 # exploit\n",
  336.     "R[4,2,4] = 10  # exploit\n",
  337.     "R[5,4,3] = -10 # install-backdoor\n",
  338.     "R[5,4,4] = 2   # install-backdoor"
  339.    ]
  340.   },
  341.   {
  342.    "cell_type": "code",
  343.    "execution_count": null,
  344.    "metadata": {},
  345.    "outputs": [],
  346.    "source": [
  347.     "pd.DataFrame(R[0], columns=STATES, index=pd.Index(STATES, name=\"Reward function for action 0\")).head(NUM_STATES)"
  348.    ]
  349.   },
  350.   {
  351.    "cell_type": "code",
  352.    "execution_count": null,
  353.    "metadata": {},
  354.    "outputs": [],
  355.    "source": [
  356.     "pd.DataFrame(R[1], columns=STATES, index=pd.Index(STATES, name=\"Reward function for action 1\")).head(NUM_STATES)"
  357.    ]
  358.   },
  359.   {
  360.    "cell_type": "code",
  361.    "execution_count": null,
  362.    "metadata": {},
  363.    "outputs": [],
  364.    "source": [
  365.     "pd.DataFrame(R[2], columns=STATES, index=pd.Index(STATES, name=\"Reward function for action 2\")).head(NUM_STATES)"
  366.    ]
  367.   },
  368.   {
  369.    "cell_type": "code",
  370.    "execution_count": null,
  371.    "metadata": {},
  372.    "outputs": [],
  373.    "source": [
  374.     "pd.DataFrame(R[3], columns=STATES, index=pd.Index(STATES, name=\"Reward function for action 3\")).head(NUM_STATES)"
  375.    ]
  376.   },
  377.   {
  378.    "cell_type": "code",
  379.    "execution_count": null,
  380.    "metadata": {},
  381.    "outputs": [],
  382.    "source": [
  383.     "pd.DataFrame(R[4], columns=STATES, index=pd.Index(STATES, name=\"Reward function for action 4\")).head(NUM_STATES)"
  384.    ]
  385.   },
  386.   {
  387.    "cell_type": "code",
  388.    "execution_count": null,
  389.    "metadata": {},
  390.    "outputs": [],
  391.    "source": [
  392.     "pd.DataFrame(R[5], columns=STATES, index=pd.Index(STATES, name=\"Reward function for action 5\")).head(NUM_STATES)"
  393.    ]
  394.   },
  395.   {
  396.    "cell_type": "code",
  397.    "execution_count": null,
  398.    "metadata": {},
  399.    "outputs": [],
  400.    "source": [
  401.     "pd.DataFrame(R[6], columns=STATES, index=pd.Index(STATES, name=\"Reward function for action 6\")).head(NUM_STATES)"
  402.    ]
  403.   },
  404.   {
  405.    "cell_type": "markdown",
  406.    "metadata": {},
  407.    "source": [
  408.     "### Discount Factor $\\gamma$\n",
  409.     "\n",
  410.     "The discount factor control the agent’s horizon, a more aggressive discounting yields a shorter\n",
  411.     "horizon than a soft discounting strategy, which can affect the learned policy. To discount rewards is\n",
  412.     "merely a heuristic to modulate the rewards over time and to make the math work out (avoiding\n",
  413.     "infinite sums). However, to discount reward is also intuitively useful in many domains. For example,\n",
  414.     "in the finance domain it is considered more worth to have assets now rather than later. Moreover,\n",
  415.     "whenever there is uncertainty in the environment, rewards close in time are valued higher than farsighted rewards. A probabilistic interpretation of the discount factor is that it models a small\n",
  416.     "probability that the interaction between the agent and the environment might end after each timestep; unpredictable events can occur and prevent the future reward (e.g. in the most dramatic case,\n",
  417.     "death could happen). In summary, the discount factor can be interpreted in several ways—it can be\n",
  418.     "seen as an interest rate, a probability of living another step, or as a mathematical trick to bound the\n",
  419.     "infinite sum.\n",
  420.     "\n",
  421.     "For the example studied in this notebook, the discount factor is $\\gamma = 0.8$."
  422.    ]
  423.   },
  424.   {
  425.    "cell_type": "code",
  426.    "execution_count": null,
  427.    "metadata": {},
  428.    "outputs": [],
  429.    "source": [
  430.     "gamma = 0.8"
  431.    ]
  432.   },
  433.   {
  434.    "cell_type": "markdown",
  435.    "metadata": {},
  436.    "source": [
  437.     "## Initialize Value Iteration Variables"
  438.    ]
  439.   },
  440.   {
  441.    "cell_type": "markdown",
  442.    "metadata": {},
  443.    "source": [
  444.     "### Policy $\\pi$"
  445.    ]
  446.   },
  447.   {
  448.    "cell_type": "code",
  449.    "execution_count": null,
  450.    "metadata": {},
  451.    "outputs": [],
  452.    "source": [
  453.     "policy0 = np.array([\n",
  454.     "    [1,0,0,0,0,0,0], # Reconnaissance -> network-sniffing\n",
  455.     "    [0,0,0,1,0,0,0], # Vulnerability -> exploit-detection\n",
  456.     "    [0,0,0,0,1,0,0], # Exploit -> exploit\n",
  457.     "    [0,0,0,0,0,0,1], # Caught -> stay\n",
  458.     "    [0,0,1,0,0,0,0], # Post-exploit -> exit\n",
  459.     "    [0,0,0,0,0,0,1]  # Exit -> stay\n",
  460.     "])\n",
  461.     "pd.DataFrame(policy0, columns=ACTIONS, index=pd.Index(STATES)).head(NUM_STATES)"
  462.    ]
  463.   },
  464.   {
  465.    "cell_type": "markdown",
  466.    "metadata": {},
  467.    "source": [
  468.     "### Initial State Values $\\vec{v}^{\\pi}$\n",
  469.     "\\begin{align}\n",
  470.     "\\vec{v}^{\\pi} &=\n",
  471.     "\\begin{bmatrix}\n",
  472.     "\\displaystyle\\sum_a \\pi(s_0,a) \\sum_{s^{\\prime}}\\mathcal{P}_{s_0s^{\\prime}}^a \\left[\\mathcal{R}_{s_0s^{\\prime}}^a + \\gamma V^{\\pi}(s^{\\prime})\\right] \\\\\n",
  473.     "\\vdots\\\\\n",
  474.     "\\displaystyle\\sum_a \\pi(s_{|\\mathcal{S}|},a) \\sum_{s^{\\prime}}\\mathcal{P}_{s_{|\\mathcal{S}|}s^{\\prime}}^a \\left[\\mathcal{R}_{s_{|\\mathcal{S}|}s^{\\prime}}^a + \\gamma V^{\\pi}(s^{\\prime})\\right]\n",
  475.     "\\end{bmatrix}\\label{eq:closed_form_linear_3}\\\\\n",
  476.     "\\end{align}"
  477.    ]
  478.   },
  479.   {
  480.    "cell_type": "code",
  481.    "execution_count": null,
  482.    "metadata": {},
  483.    "outputs": [],
  484.    "source": [
  485.     "v0 = np.ones((NUM_STATES))\n",
  486.     "pd.DataFrame(v0, columns=[\"State-Value\"], index=pd.Index(STATES)).head(NUM_STATES)"
  487.    ]
  488.   },
  489.   {
  490.    "cell_type": "markdown",
  491.    "metadata": {},
  492.    "source": [
  493.     "## Define Value Iteration Algorithm\n",
  494.     "\n",
  495.     "\\begin{align}\\label{eq:value_iter}\n",
  496.     "  V_{k+1}(s) &= \\max_a\\mathbb{E}_{\\pi}\\left[r_{t+1} + \\gamma V_k(s_{t+1}) | s_t = s, a_t=a\\right] && \\text{Value iteration}\\\\\n",
  497.     "               &= \\displaystyle\\max_a \\sum_{s^{\\prime}} \\mathcal{P}_{ss^{\\prime}}^a\\left[\\mathcal{R}_{ss^{\\prime}}^{a} + \\gamma V_k{s^{\\prime}}\\right]\n",
  498.     "\\end{align}"
  499.    ]
  500.   },
  501.   {
  502.    "cell_type": "code",
  503.    "execution_count": null,
  504.    "metadata": {},
  505.    "outputs": [],
  506.    "source": [
  507.     "def compute_value_vector(P, R, gamma, v):\n",
  508.     "    for s in range(NUM_STATES):\n",
  509.     "        action_values = np.zeros(NUM_ACTIONS)\n",
  510.     "        for a in range(NUM_ACTIONS):\n",
  511.     "            for s_prime in range(NUM_STATES):\n",
  512.     "                action_values[a] += P[a,s,s_prime]*(R[a,s,s_prime] + gamma*v[s_prime])\n",
  513.     "        best_action_value = np.max(action_values)\n",
  514.     "        v[s] = best_action_value\n",
  515.     "    return v"
  516.    ]
  517.   },
  518.   {
  519.    "cell_type": "code",
  520.    "execution_count": null,
  521.    "metadata": {},
  522.    "outputs": [],
  523.    "source": [
  524.     "def greedy_policy(P, R, gamma, v):\n",
  525.     "    \"\"\"\n",
  526.     "    Returns the greedy policy according to the value function\n",
  527.     "    \n",
  528.     "    Args:\n",
  529.     "        :P: the state transition probabilities for all actions in the MDP (tensor num_actions x num_states x num_states)\n",
  530.     "        :R: the reward function in the MDP (tensor num_actions x num_states x num_states)\n",
  531.     "        :gamma: the discount factor\n",
  532.     "        :v: the state values (dimension NUM_STATES)\n",
  533.     "        \n",
  534.     "    Returns:\n",
  535.     "           :pi_prime: a new updated policy (dimensions num_states x num_actions)\n",
  536.     "    \"\"\"\n",
  537.     "    pi_prime = np.zeros((NUM_STATES,NUM_ACTIONS))\n",
  538.     "    for s in range(0, NUM_STATES):\n",
  539.     "        action_values = np.zeros(NUM_ACTIONS)\n",
  540.     "        for a in range(0, NUM_ACTIONS):\n",
  541.     "            for s_prime in range(0, NUM_STATES):\n",
  542.     "                action_values[a] += P[a,s,s_prime]*(R[a,s,s_prime] + gamma*v[s_prime])\n",
  543.     "        if(max(action_values) == 0.0):\n",
  544.     "            best_action = np.argmax(sum(P[:,s,s_prime] for s_prime in range(NUM_STATES)))\n",
  545.     "            pi_prime[s, best_action] = 1\n",
  546.     "        else:\n",
  547.     "            best_action = np.argmax(action_values)\n",
  548.     "            pi_prime[s,best_action] = 1\n",
  549.     "    return pi_prime"
  550.    ]
  551.   },
  552.   {
  553.    "cell_type": "code",
  554.    "execution_count": null,
  555.    "metadata": {},
  556.    "outputs": [],
  557.    "source": [
  558.     "def value_iteration(P, R, gamma ,v, N):\n",
  559.     "    for i in range(0, N):\n",
  560.     "        v = compute_value_vector(P,R,gamma,v)\n",
  561.     "    pi = greedy_policy(P, R, gamma ,v, )\n",
  562.     "    return pi, v"
  563.    ]
  564.   },
  565.   {
  566.    "cell_type": "markdown",
  567.    "metadata": {},
  568.    "source": [
  569.     "## Running the Algorithm\n",
  570.     "\n",
  571.     "To get a better understanding we'll inspect the algorithm's progress for each iteration."
  572.    ]
  573.   },
  574.   {
  575.    "cell_type": "markdown",
  576.    "metadata": {},
  577.    "source": [
  578.     "### Iteration 1"
  579.    ]
  580.   },
  581.   {
  582.    "cell_type": "code",
  583.    "execution_count": null,
  584.    "metadata": {},
  585.    "outputs": [],
  586.    "source": [
  587.     "policy1, v1 = value_iteration(P, R, gamma, v0, 1)"
  588.    ]
  589.   },
  590.   {
  591.    "cell_type": "markdown",
  592.    "metadata": {},
  593.    "source": [
  594.     "#### Compare Old State Values with New Ones:"
  595.    ]
  596.   },
  597.   {
  598.    "cell_type": "markdown",
  599.    "metadata": {},
  600.    "source": [
  601.     "##### Old $\\vec{v}^{\\pi}_0$"
  602.    ]
  603.   },
  604.   {
  605.    "cell_type": "code",
  606.    "execution_count": null,
  607.    "metadata": {},
  608.    "outputs": [],
  609.    "source": [
  610.     "pd.DataFrame(v0, columns=[\"State-Value\"], index=pd.Index(STATES)).head(NUM_STATES)"
  611.    ]
  612.   },
  613.   {
  614.    "cell_type": "markdown",
  615.    "metadata": {},
  616.    "source": [
  617.     "##### New $\\vec{v}^{\\pi}_1$"
  618.    ]
  619.   },
  620.   {
  621.    "cell_type": "code",
  622.    "execution_count": null,
  623.    "metadata": {},
  624.    "outputs": [],
  625.    "source": [
  626.     "pd.DataFrame(v1, columns=[\"State-Value\"], index=pd.Index(STATES)).head(NUM_STATES)"
  627.    ]
  628.   },
  629.   {
  630.    "cell_type": "markdown",
  631.    "metadata": {},
  632.    "source": [
  633.     "#### Compare Old Policy with New Policy:"
  634.    ]
  635.   },
  636.   {
  637.    "cell_type": "markdown",
  638.    "metadata": {},
  639.    "source": [
  640.     "##### Old $\\pi_0$"
  641.    ]
  642.   },
  643.   {
  644.    "cell_type": "code",
  645.    "execution_count": null,
  646.    "metadata": {},
  647.    "outputs": [],
  648.    "source": [
  649.     "pd.DataFrame(policy0, columns=ACTIONS, index=pd.Index(STATES)).head(NUM_STATES)"
  650.    ]
  651.   },
  652.   {
  653.    "cell_type": "markdown",
  654.    "metadata": {},
  655.    "source": [
  656.     "##### New $\\pi_1$"
  657.    ]
  658.   },
  659.   {
  660.    "cell_type": "code",
  661.    "execution_count": null,
  662.    "metadata": {},
  663.    "outputs": [],
  664.    "source": [
  665.     "pd.DataFrame(policy1, columns=ACTIONS, index=pd.Index(STATES)).head(NUM_STATES)"
  666.    ]
  667.   },
  668.   {
  669.    "cell_type": "markdown",
  670.    "metadata": {},
  671.    "source": [
  672.     "### Iteration 2"
  673.    ]
  674.   },
  675.   {
  676.    "cell_type": "code",
  677.    "execution_count": null,
  678.    "metadata": {},
  679.    "outputs": [],
  680.    "source": [
  681.     "policy2, v2 = value_iteration(P, R, gamma, v1, 1)"
  682.    ]
  683.   },
  684.   {
  685.    "cell_type": "markdown",
  686.    "metadata": {},
  687.    "source": [
  688.     "#### Compare Old State Values with New Ones:"
  689.    ]
  690.   },
  691.   {
  692.    "cell_type": "markdown",
  693.    "metadata": {},
  694.    "source": [
  695.     "##### Old $\\vec{v}^{\\pi}_1$"
  696.    ]
  697.   },
  698.   {
  699.    "cell_type": "code",
  700.    "execution_count": null,
  701.    "metadata": {},
  702.    "outputs": [],
  703.    "source": [
  704.     "pd.DataFrame(v1, columns=[\"State-Value\"], index=pd.Index(STATES)).head(NUM_STATES)"
  705.    ]
  706.   },
  707.   {
  708.    "cell_type": "markdown",
  709.    "metadata": {},
  710.    "source": [
  711.     "##### New $\\vec{v}^{\\pi}_2$"
  712.    ]
  713.   },
  714.   {
  715.    "cell_type": "code",
  716.    "execution_count": null,
  717.    "metadata": {},
  718.    "outputs": [],
  719.    "source": [
  720.     "pd.DataFrame(v2, columns=[\"State-Value\"], index=pd.Index(STATES)).head(NUM_STATES)"
  721.    ]
  722.   },
  723.   {
  724.    "cell_type": "markdown",
  725.    "metadata": {},
  726.    "source": [
  727.     "#### Compare Old Policy with New Policy:"
  728.    ]
  729.   },
  730.   {
  731.    "cell_type": "markdown",
  732.    "metadata": {},
  733.    "source": [
  734.     "##### Old $\\pi_1$"
  735.    ]
  736.   },
  737.   {
  738.    "cell_type": "code",
  739.    "execution_count": null,
  740.    "metadata": {},
  741.    "outputs": [],
  742.    "source": [
  743.     "pd.DataFrame(policy1, columns=ACTIONS, index=pd.Index(STATES)).head(NUM_STATES)"
  744.    ]
  745.   },
  746.   {
  747.    "cell_type": "markdown",
  748.    "metadata": {},
  749.    "source": [
  750.     "##### New $\\pi_1$"
  751.    ]
  752.   },
  753.   {
  754.    "cell_type": "code",
  755.    "execution_count": null,
  756.    "metadata": {},
  757.    "outputs": [],
  758.    "source": [
  759.     "pd.DataFrame(policy2, columns=ACTIONS, index=pd.Index(STATES)).head(NUM_STATES)"
  760.    ]
  761.   },
  762.   {
  763.    "cell_type": "markdown",
  764.    "metadata": {},
  765.    "source": [
  766.     "### Iteration 3 - N"
  767.    ]
  768.   },
  769.   {
  770.    "cell_type": "markdown",
  771.    "metadata": {},
  772.    "source": [
  773.     "For iterations 3-N simply continue doing the same thing until the state value vector converges\n",
  774.     "(when $\\vec{v}^{\\pi_1} = \\vec{v}^{\\pi_{i+1}}$). For this simple MDP, 5 iterations should be sufficient but we can verify it by comparing the values found afer 50 iterations with that of 49 iterations."
  775.    ]
  776.   },
  777.   {
  778.    "cell_type": "code",
  779.    "execution_count": null,
  780.    "metadata": {},
  781.    "outputs": [],
  782.    "source": [
  783.     "N = 49-2 # we already computed two iterations manually\n",
  784.     "policy49, v49 = value_iteration(P, R, gamma, v2, N)\n",
  785.     "policy50, v50 = value_iteration(P, R, gamma, v49, 1)"
  786.    ]
  787.   },
  788.   {
  789.    "cell_type": "markdown",
  790.    "metadata": {},
  791.    "source": [
  792.     "#### Check for Convergence"
  793.    ]
  794.   },
  795.   {
  796.    "cell_type": "code",
  797.    "execution_count": null,
  798.    "metadata": {},
  799.    "outputs": [],
  800.    "source": [
  801.     "np.array_equal(v49, v50)"
  802.    ]
  803.   },
  804.   {
  805.    "cell_type": "code",
  806.    "execution_count": null,
  807.    "metadata": {},
  808.    "outputs": [],
  809.    "source": [
  810.     "np.array_equal(policy49, policy50)"
  811.    ]
  812.   },
  813.   {
  814.    "cell_type": "markdown",
  815.    "metadata": {},
  816.    "source": [
  817.     "##### $\\vec{v}^{\\pi}_{49}$"
  818.    ]
  819.   },
  820.   {
  821.    "cell_type": "code",
  822.    "execution_count": null,
  823.    "metadata": {},
  824.    "outputs": [],
  825.    "source": [
  826.     "pd.DataFrame(v49, columns=[\"State-Value\"], index=pd.Index(STATES)).head(NUM_STATES)"
  827.    ]
  828.   },
  829.   {
  830.    "cell_type": "markdown",
  831.    "metadata": {},
  832.    "source": [
  833.     "##### $\\vec{v}^{\\pi}_{50}$"
  834.    ]
  835.   },
  836.   {
  837.    "cell_type": "code",
  838.    "execution_count": null,
  839.    "metadata": {},
  840.    "outputs": [],
  841.    "source": [
  842.     "pd.DataFrame(v50, columns=[\"State-Value\"], index=pd.Index(STATES)).head(NUM_STATES)"
  843.    ]
  844.   },
  845.   {
  846.    "cell_type": "markdown",
  847.    "metadata": {},
  848.    "source": [
  849.     "##### $\\pi_{49}$"
  850.    ]
  851.   },
  852.   {
  853.    "cell_type": "code",
  854.    "execution_count": null,
  855.    "metadata": {},
  856.    "outputs": [],
  857.    "source": [
  858.     "pd.DataFrame(policy49, columns=ACTIONS, index=pd.Index(STATES)).head(NUM_STATES)"
  859.    ]
  860.   },
  861.   {
  862.    "cell_type": "markdown",
  863.    "metadata": {},
  864.    "source": [
  865.     "##### $\\pi_{50}$"
  866.    ]
  867.   },
  868.   {
  869.    "cell_type": "code",
  870.    "execution_count": null,
  871.    "metadata": {},
  872.    "outputs": [],
  873.    "source": [
  874.     "pd.DataFrame(policy50, columns=ACTIONS, index=pd.Index(STATES)).head(NUM_STATES)"
  875.    ]
  876.   },
  877.   {
  878.    "cell_type": "markdown",
  879.    "metadata": {},
  880.    "source": [
  881.     "### Optimal Policy\n",
  882.     "\n",
  883.     "Thus the learned optimal policy is (red actions chosen):\n",
  884.     "\n",
  885.     "![alt text](./mdp_policy.png \"Transition Graph\")"
  886.    ]
  887.   },
  888.   {
  889.    "cell_type": "code",
  890.    "execution_count": null,
  891.    "metadata": {},
  892.    "outputs": [],
  893.    "source": []
  894.   }
  895.  ],
  896.  "metadata": {
  897.   "kernelspec": {
  898.    "display_name": "Python 3",
  899.    "language": "python",
  900.    "name": "python3"
  901.   },
  902.   "language_info": {
  903.    "codemirror_mode": {
  904.     "name": "ipython",
  905.     "version": 3
  906.    },
  907.    "file_extension": ".py",
  908.    "mimetype": "text/x-python",
  909.    "name": "python",
  910.    "nbconvert_exporter": "python",
  911.    "pygments_lexer": "ipython3",
  912.    "version": "3.6.10"
  913.   }
  914.  },
  915.  "nbformat": 4,
  916.  "nbformat_minor": 2
  917. }
RAW Paste Data