Guest User

Untitled

a guest
Jan 6th, 2018
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.01 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import numpy as np"
  10. ]
  11. },
  12. {
  13. "cell_type": "markdown",
  14. "metadata": {},
  15. "source": [
  16. "# Policy Evaluation"
  17. ]
  18. },
  19. {
  20. "cell_type": "code",
  21. "execution_count": 2,
  22. "metadata": {},
  23. "outputs": [],
  24. "source": [
  25. "def get_state(state, action):\n",
  26. " \n",
  27. " action_grid = [(-1, 0), (1, 0), (0, -1), (0, 1)]\n",
  28. " \n",
  29. " state[0]+=action_grid[action][0]\n",
  30. " state[1]+=action_grid[action][1]\n",
  31. " \n",
  32. " if state[0] < 0 :\n",
  33. " state[0] = 0\n",
  34. " elif state[0] > 3 :\n",
  35. " state[0] = 3\n",
  36. " \n",
  37. " if state[1] < 0 :\n",
  38. " state[1] = 0\n",
  39. " elif state[1] > 3 :\n",
  40. " state[1] = 3\n",
  41. " \n",
  42. " return state[0], state[1]"
  43. ]
  44. },
  45. {
  46. "cell_type": "code",
  47. "execution_count": 3,
  48. "metadata": {},
  49. "outputs": [],
  50. "source": [
  51. "def policy_evaluation(grid_width, grid_height, action, policy, iter_num, reward=-1, dis=1):\n",
  52. " \n",
  53. " # table initialize\n",
  54. " post_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
  55. " \n",
  56. " # iteration\n",
  57. " if iter_num == 0:\n",
  58. " print('Iteration: {} \\n{}\\n'.format(iter_num, post_value_table))\n",
  59. " return post_value_table\n",
  60. " \n",
  61. " for iteration in range(iter_num):\n",
  62. " next_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
  63. " for i in range(grid_height):\n",
  64. " for j in range(grid_width):\n",
  65. " if i == j and ((i == 0) or (i == 3)):\n",
  66. " value_t = 0\n",
  67. " else :\n",
  68. " value_t_list= []\n",
  69. " for act in action:\n",
  70. " i_, j_ = get_state([i,j], act)\n",
  71. " value = (reward + dis*post_value_table[i_][j_])\n",
  72. " value_t_list.append(value)\n",
  73. " next_value_table[i][j] = max(value_t_list)\n",
  74. " iteration += 1\n",
  75. " \n",
  76. " # print result\n",
  77. " if (iteration % 10) != iter_num: \n",
  78. " # print result \n",
  79. " if iteration > 100 :\n",
  80. " if (iteration % 20) == 0: \n",
  81. " print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
  82. " else :\n",
  83. " if (iteration % 10) == 0:\n",
  84. " print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
  85. " else :\n",
  86. " print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table ))\n",
  87. " \n",
  88. " \n",
  89. " post_value_table = next_value_table\n",
  90. " \n",
  91. " \n",
  92. " return next_value_table"
  93. ]
  94. },
  95. {
  96. "cell_type": "code",
  97. "execution_count": 4,
  98. "metadata": {},
  99. "outputs": [],
  100. "source": [
  101. "grid_width = 4\n",
  102. "grid_height = grid_width\n",
  103. "action = [0, 1, 2, 3] # up, down, left, right\n",
  104. "policy = np.empty([grid_height, grid_width, len(action)], dtype=float)\n",
  105. "for i in range(grid_height):\n",
  106. " for j in range(grid_width):\n",
  107. " for k in range(len(action)):\n",
  108. " if i==j and ((i==0) or (i==3)):\n",
  109. " policy[i][j]=0.00\n",
  110. " else :\n",
  111. " policy[i][j]=0.25\n",
  112. "policy[0][0] = [0] * grid_width\n",
  113. "policy[3][3] = [0] * grid_width"
  114. ]
  115. },
  116. {
  117. "cell_type": "code",
  118. "execution_count": 5,
  119. "metadata": {},
  120. "outputs": [
  121. {
  122. "name": "stdout",
  123. "output_type": "stream",
  124. "text": [
  125. "Iteration: 1 \n",
  126. "[[ 0. -1. -1. -1.]\n",
  127. " [-1. -1. -1. -1.]\n",
  128. " [-1. -1. -1. -1.]\n",
  129. " [-1. -1. -1. 0.]]\n",
  130. "\n",
  131. "Iteration: 2 \n",
  132. "[[ 0. -1. -2. -2.]\n",
  133. " [-1. -2. -2. -2.]\n",
  134. " [-2. -2. -2. -1.]\n",
  135. " [-2. -2. -1. 0.]]\n",
  136. "\n",
  137. "Iteration: 3 \n",
  138. "[[ 0. -1. -2. -3.]\n",
  139. " [-1. -2. -3. -2.]\n",
  140. " [-2. -3. -2. -1.]\n",
  141. " [-3. -2. -1. 0.]]\n",
  142. "\n",
  143. "Iteration: 10 \n",
  144. "[[ 0. -1. -2. -3.]\n",
  145. " [-1. -2. -3. -2.]\n",
  146. " [-2. -3. -2. -1.]\n",
  147. " [-3. -2. -1. 0.]]\n",
  148. "\n"
  149. ]
  150. }
  151. ],
  152. "source": [
  153. "value = policy_evaluation(grid_width, grid_height, action, policy, 1)\n",
  154. "value = policy_evaluation(grid_width, grid_height, action, policy, 2)\n",
  155. "value = policy_evaluation(grid_width, grid_height, action, policy, 3)\n",
  156. "value = policy_evaluation(grid_width, grid_height, action, policy, 10)"
  157. ]
  158. },
  159. {
  160. "cell_type": "code",
  161. "execution_count": 6,
  162. "metadata": {},
  163. "outputs": [
  164. {
  165. "name": "stdout",
  166. "output_type": "stream",
  167. "text": [
  168. "Iteration: 10 \n",
  169. "[[ 0. -1. -2. -3.]\n",
  170. " [-1. -2. -3. -2.]\n",
  171. " [-2. -3. -2. -1.]\n",
  172. " [-3. -2. -1. 0.]]\n",
  173. "\n",
  174. "Iteration: 20 \n",
  175. "[[ 0. -1. -2. -3.]\n",
  176. " [-1. -2. -3. -2.]\n",
  177. " [-2. -3. -2. -1.]\n",
  178. " [-3. -2. -1. 0.]]\n",
  179. "\n",
  180. "Iteration: 30 \n",
  181. "[[ 0. -1. -2. -3.]\n",
  182. " [-1. -2. -3. -2.]\n",
  183. " [-2. -3. -2. -1.]\n",
  184. " [-3. -2. -1. 0.]]\n",
  185. "\n",
  186. "Iteration: 40 \n",
  187. "[[ 0. -1. -2. -3.]\n",
  188. " [-1. -2. -3. -2.]\n",
  189. " [-2. -3. -2. -1.]\n",
  190. " [-3. -2. -1. 0.]]\n",
  191. "\n",
  192. "Iteration: 50 \n",
  193. "[[ 0. -1. -2. -3.]\n",
  194. " [-1. -2. -3. -2.]\n",
  195. " [-2. -3. -2. -1.]\n",
  196. " [-3. -2. -1. 0.]]\n",
  197. "\n",
  198. "Iteration: 60 \n",
  199. "[[ 0. -1. -2. -3.]\n",
  200. " [-1. -2. -3. -2.]\n",
  201. " [-2. -3. -2. -1.]\n",
  202. " [-3. -2. -1. 0.]]\n",
  203. "\n",
  204. "Iteration: 70 \n",
  205. "[[ 0. -1. -2. -3.]\n",
  206. " [-1. -2. -3. -2.]\n",
  207. " [-2. -3. -2. -1.]\n",
  208. " [-3. -2. -1. 0.]]\n",
  209. "\n",
  210. "Iteration: 80 \n",
  211. "[[ 0. -1. -2. -3.]\n",
  212. " [-1. -2. -3. -2.]\n",
  213. " [-2. -3. -2. -1.]\n",
  214. " [-3. -2. -1. 0.]]\n",
  215. "\n",
  216. "Iteration: 90 \n",
  217. "[[ 0. -1. -2. -3.]\n",
  218. " [-1. -2. -3. -2.]\n",
  219. " [-2. -3. -2. -1.]\n",
  220. " [-3. -2. -1. 0.]]\n",
  221. "\n",
  222. "Iteration: 100 \n",
  223. "[[ 0. -1. -2. -3.]\n",
  224. " [-1. -2. -3. -2.]\n",
  225. " [-2. -3. -2. -1.]\n",
  226. " [-3. -2. -1. 0.]]\n",
  227. "\n"
  228. ]
  229. }
  230. ],
  231. "source": [
  232. "value = policy_evaluation(grid_width, grid_height, action, policy, 100)"
  233. ]
  234. }
  235. ],
  236. "metadata": {
  237. "anaconda-cloud": {},
  238. "kernelspec": {
  239. "display_name": "Python [default]",
  240. "language": "python",
  241. "name": "python3"
  242. },
  243. "language_info": {
  244. "codemirror_mode": {
  245. "name": "ipython",
  246. "version": 3
  247. },
  248. "file_extension": ".py",
  249. "mimetype": "text/x-python",
  250. "name": "python",
  251. "nbconvert_exporter": "python",
  252. "pygments_lexer": "ipython3",
  253. "version": "3.5.2"
  254. }
  255. },
  256. "nbformat": 4,
  257. "nbformat_minor": 2
  258. }
Add Comment
Please, Sign In to add comment