Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Policy Evaluation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_state(state, action):\n",
- " \n",
- " action_grid = [(-1, 0), (1, 0), (0, -1), (0, 1)]\n",
- " \n",
- " state[0]+=action_grid[action][0]\n",
- " state[1]+=action_grid[action][1]\n",
- " \n",
- " if state[0] < 0 :\n",
- " state[0] = 0\n",
- " elif state[0] > 3 :\n",
- " state[0] = 3\n",
- " \n",
- " if state[1] < 0 :\n",
- " state[1] = 0\n",
- " elif state[1] > 3 :\n",
- " state[1] = 3\n",
- " \n",
- " return state[0], state[1]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "def policy_evaluation(grid_width, grid_height, action, policy, iter_num, reward=-1, dis=1):\n",
- " \n",
- " # table initialize\n",
- " post_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
- " \n",
- " # iteration\n",
- " if iter_num == 0:\n",
- " print('Iteration: {} \\n{}\\n'.format(iter_num, post_value_table))\n",
- " return post_value_table\n",
- " \n",
- " for iteration in range(iter_num):\n",
- " next_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
- " for i in range(grid_height):\n",
- " for j in range(grid_width):\n",
- " if i == j and ((i == 0) or (i == 3)):\n",
- " value_t = 0\n",
- " else :\n",
- " value_t_list= []\n",
- " for act in action:\n",
- " i_, j_ = get_state([i,j], act)\n",
- " value = (reward + dis*post_value_table[i_][j_])\n",
- " value_t_list.append(value)\n",
- " next_value_table[i][j] = max(value_t_list)\n",
- " iteration += 1\n",
- " \n",
- " # print result\n",
- " if (iteration % 10) != iter_num: \n",
- " # print result \n",
- " if iteration > 100 :\n",
- " if (iteration % 20) == 0: \n",
- " print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
- " else :\n",
- " if (iteration % 10) == 0:\n",
- " print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
- " else :\n",
- " print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table ))\n",
- " \n",
- " \n",
- " post_value_table = next_value_table\n",
- " \n",
- " \n",
- " return next_value_table"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "grid_width = 4\n",
- "grid_height = grid_width\n",
- "action = [0, 1, 2, 3] # up, down, left, right\n",
- "policy = np.empty([grid_height, grid_width, len(action)], dtype=float)\n",
- "for i in range(grid_height):\n",
- " for j in range(grid_width):\n",
- " for k in range(len(action)):\n",
- " if i==j and ((i==0) or (i==3)):\n",
- " policy[i][j]=0.00\n",
- " else :\n",
- " policy[i][j]=0.25\n",
- "policy[0][0] = [0] * grid_width\n",
- "policy[3][3] = [0] * grid_width"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Iteration: 1 \n",
- "[[ 0. -1. -1. -1.]\n",
- " [-1. -1. -1. -1.]\n",
- " [-1. -1. -1. -1.]\n",
- " [-1. -1. -1. 0.]]\n",
- "\n",
- "Iteration: 2 \n",
- "[[ 0. -1. -2. -2.]\n",
- " [-1. -2. -2. -2.]\n",
- " [-2. -2. -2. -1.]\n",
- " [-2. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 3 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 10 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n"
- ]
- }
- ],
- "source": [
- "value = policy_evaluation(grid_width, grid_height, action, policy, 1)\n",
- "value = policy_evaluation(grid_width, grid_height, action, policy, 2)\n",
- "value = policy_evaluation(grid_width, grid_height, action, policy, 3)\n",
- "value = policy_evaluation(grid_width, grid_height, action, policy, 10)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Iteration: 10 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 20 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 30 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 40 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 50 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 60 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 70 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 80 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 90 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n",
- "Iteration: 100 \n",
- "[[ 0. -1. -2. -3.]\n",
- " [-1. -2. -3. -2.]\n",
- " [-2. -3. -2. -1.]\n",
- " [-3. -2. -1. 0.]]\n",
- "\n"
- ]
- }
- ],
- "source": [
- "value = policy_evaluation(grid_width, grid_height, action, policy, 100)"
- ]
- }
- ],
- "metadata": {
- "anaconda-cloud": {},
- "kernelspec": {
- "display_name": "Python [default]",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment