Advertisement
Guest User

Untitled

a guest
Apr 20th, 2019
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.44 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3.  
  4. # # Sutton and Barto Racetrack: Sarsa
  5. # Exercise 5.8 from *Reinforcement Learning: An Introduction* by Sutton and Barto.
  6. #
  7. # This notebook applies the **Sarsa** algorithm from Chapter 6 to the Racetrack problem from Chapter 5.
  8. #
  9. # Python Notebook by Patrick Coady: [Learning Artificial Intelligence](https://learningai.io/)
  10.  
  11. # In[1]:
  12.  
  13.  
  14. import numpy as np
  15. import random
  16. import matplotlib.pyplot as plt
  17.  
  18.  
  19. # In[2]:
  20.  
  21.  
  22. class RaceTrack(object):
  23. """
  24. RaceTrack object maintains and updates the race track
  25. state. Interaction with the class is through
  26. the take_action() method. The take_action() method returns
  27. a successor state and reward (i.e. s' and r)
  28.  
  29. The class constructor is given a race course as a list of
  30. strings. The constructor loads the course and initializes
  31. the environment state.
  32. """
  33.  
  34. def __init__(self, course):
  35. """
  36. Load race course, set any min or max limits in the
  37. environment (e.g. max speed), and set initial state.
  38. Initial state is random position on start line with
  39. velocity = (0, 0).
  40.  
  41. Args:
  42. course: List of text strings used to construct
  43. race-track.
  44. '+': start line
  45. '-': finish line
  46. 'o': track
  47. 'X': wall
  48.  
  49. Returns:
  50. self
  51. """
  52. self.NOISE = 0.0
  53. self.EPS = 0.1 # epsilon-greedy coefficient
  54. self.MAX_VELOCITY = 4
  55. self.start_positions = []
  56. self.course = None
  57. self._load_course(course)
  58. self._random_start_position()
  59. self.velocity = np.array([0, 0], dtype=np.int16)
  60.  
  61. def take_action(self, action):
  62. """
  63. Take action, return state' and reward
  64.  
  65. Args:
  66. action: 2-tuple of requested change in velocity in x- and
  67. y-direction. valid action is -1, 0, +1 in each axis.
  68.  
  69. Returns:
  70. reward: integer
  71. """
  72.  
  73. self._update_velocity(action)
  74. self._update_position()
  75. if self.is_terminal_state():
  76. return 100.0
  77.  
  78. return -1.0
  79.  
  80. def get_state(self):
  81. """Return 2-tuple: (position, velocity). Each is a 2D numpy array."""
  82. return self.position.copy(), self.velocity.copy()
  83.  
  84. def _update_velocity(self, action):
  85. """
  86. Update x- and y-velocity. Clip at 0 and self.MAX_VELOCITY
  87.  
  88. Args:
  89. action: 2-tuple of requested change in velocity in x- and
  90. y-direction. valid action is -1, 0, +1 in each axis.
  91. """
  92. if np.random.rand() > self.NOISE:
  93. self.velocity += np.array(action, dtype=np.int16)
  94. self.velocity = np.minimum(self.velocity, self.MAX_VELOCITY)
  95. self.velocity = np.maximum(self.velocity, 0)
  96.  
  97. def reset(self):
  98. self._random_start_position()
  99. self.velocity = np.array([0, 0], dtype=np.int16)
  100.  
  101. def _update_position(self):
  102. """
  103. Update position based on present velocity. Check at fine time
  104. scale for wall or finish. If wall is hit, set position to random
  105. position at start line. If finish is reached, set position to
  106. first crossed point on finish line.
  107. """
  108. for tstep in range(0, self.MAX_VELOCITY + 1):
  109. t = tstep / self.MAX_VELOCITY
  110. pos = self.position + np.round(self.velocity * t).astype(np.int16)
  111. if self._is_wall(pos):
  112. self._random_start_position()
  113. self.velocity = np.array([0, 0], dtype=np.int16)
  114. return
  115. if self._is_finish(pos):
  116. self.position = pos
  117. self.velocity = np.array([0, 0], dtype=np.int16)
  118. return
  119. self.position = pos
  120.  
  121. def _random_start_position(self):
  122. """Set car to random position on start line"""
  123. self.position = np.array(random.choice(self.start_positions),
  124. dtype=np.int16)
  125.  
  126. def _load_course(self, course):
  127. """Load course. Internally represented as numpy array"""
  128. y_size, x_size = len(course), len(course[0])
  129. self.course = np.zeros((x_size, y_size), dtype=np.int16)
  130. for y in range(y_size):
  131. for x in range(x_size):
  132. point = course[y][x]
  133. if point == 'o':
  134. self.course[x, y] = 1
  135. elif point == '-':
  136. self.course[x, y] = 0
  137. elif point == '+':
  138. self.course[x, y] = 2
  139. elif point == 'W':
  140. self.course[x, y] = -1
  141. # flip left/right so (0,0) is in bottom-left corner
  142. self.course = np.fliplr(self.course)
  143. for y in range(y_size):
  144. for x in range(x_size):
  145. if self.course[x, y] == 0:
  146. self.start_positions.append((x, y))
  147.  
  148. def _is_wall(self, pos):
  149. """Return True is position is wall"""
  150. return self.course[pos[0], pos[1]] == -1
  151.  
  152. def _is_finish(self, pos):
  153. """Return True if position is finish line"""
  154. return self.course[pos[0], pos[1]] == 2
  155.  
  156. def is_terminal_state(self):
  157. """Return True at episode terminal state"""
  158. return (self.course[self.position[0],
  159. self.position[1]] == 2)
  160.  
  161. def action_to_tuple(self, a):
  162. """Convert integer action to 2-tuple: (ax, ay)"""
  163. ax = a // 3 - 1
  164. ay = a % 3 - 1
  165.  
  166. return ax, ay
  167.  
  168. def tuple_to_action(self, a):
  169. """Convert 2-tuple to integer action: {0-8}"""
  170. return int((a[0] + 1) * 3 + a[1] + 1)
  171.  
  172. def greedy_eps(self, Q):
  173. """Based on state and Q values, return epsilon-greedy action"""
  174. s = self.get_state()
  175. s_x, s_y = s[0][0], s[0][1]
  176. s_vx, s_vy = s[1][0], s[1][1]
  177. if np.random.rand() > self.EPS:
  178. if (np.max(Q[s_x, s_y, s_vx, s_vy, :, :]) ==
  179. np.min(Q[s_x, s_y, s_vx, s_vy, :, :])):
  180. a = (0, 0)
  181. else:
  182. a = np.argmax(Q[s_x, s_y, s_vx, s_vy, :, :])
  183. a = np.unravel_index(a, (3, 3)) - np.array([1, 1])
  184. a = (a[0], a[1])
  185. else:
  186. a = self.action_to_tuple(random.randrange(9))
  187.  
  188. return a
  189.  
  190. def state_action(self, s, a):
  191. """Build state-action tuple for indexing Q NumPy array"""
  192. s_x, s_y = s[0][0], s[0][1]
  193. s_vx, s_vy = s[1][0], s[1][1]
  194. a_x, a_y = a[0] + 1, a[1] + 1
  195. s_a = (s_x, s_y, s_vx, s_vy, a_x, a_y)
  196.  
  197. return s_a
  198.  
  199. # In[3]:
  200.  
  201.  
  202. # Race Track from Sutton and Barto Figure 5.6
  203.  
  204. big_course = ['WWWWWWWWWWWWWWWWWW',
  205. 'WWWWooooooooooooo+',
  206. 'WWWoooooooooooooo+',
  207. 'WWWoooooooooooooo+',
  208. 'WWooooooooooooooo+',
  209. 'Woooooooooooooooo+',
  210. 'Woooooooooooooooo+',
  211. 'WooooooooooWWWWWWW',
  212. 'WoooooooooWWWWWWWW',
  213. 'WoooooooooWWWWWWWW',
  214. 'WoooooooooWWWWWWWW',
  215. 'WoooooooooWWWWWWWW',
  216. 'WoooooooooWWWWWWWW',
  217. 'WoooooooooWWWWWWWW',
  218. 'WoooooooooWWWWWWWW',
  219. 'WWooooooooWWWWWWWW',
  220. 'WWooooooooWWWWWWWW',
  221. 'WWooooooooWWWWWWWW',
  222. 'WWooooooooWWWWWWWW',
  223. 'WWooooooooWWWWWWWW',
  224. 'WWooooooooWWWWWWWW',
  225. 'WWooooooooWWWWWWWW',
  226. 'WWooooooooWWWWWWWW',
  227. 'WWWoooooooWWWWWWWW',
  228. 'WWWoooooooWWWWWWWW',
  229. 'WWWoooooooWWWWWWWW',
  230. 'WWWoooooooWWWWWWWW',
  231. 'WWWoooooooWWWWWWWW',
  232. 'WWWoooooooWWWWWWWW',
  233. 'WWWoooooooWWWWWWWW',
  234. 'WWWWooooooWWWWWWWW',
  235. 'WWWWooooooWWWWWWWW',
  236. 'WWWW------WWWWWWWW']
  237.  
  238. # Tiny course for debug
  239.  
  240. tiny_course = ['WWWWWW',
  241. 'Woooo+',
  242. 'Woooo+',
  243. 'WooWWW',
  244. 'WooWWW',
  245. 'WooWWW',
  246. 'WooWWW',
  247. 'W--WWW', ]
  248.  
  249. # In[4]:
  250.  
  251.  
  252. # Problem Initialization
  253.  
  254. course = big_course
  255. x_size, y_size = len(course[0]), len(course)
  256. # Q[x_pos, y_pos, x_velocity, y-velocity, x-acceleration, y-acceleration]
  257. Q = np.zeros((x_size, y_size, 5, 5, 3, 3), dtype=np.float64)
  258. position_map = np.zeros((x_size, y_size), dtype=np.float64) # track explored positions
  259.  
  260. N = 2000 # num episodes
  261. gamma = 1.0
  262. alpha = 0.1
  263. track = RaceTrack(course)
  264.  
  265. # Sarsa
  266.  
  267. epochs = []
  268. counts = []
  269. count = 0
  270. for e in range(N):
  271. if (e + 1) % 200 == 0: print('Episode {}'.format(e + 1))
  272. track.reset()
  273. s = track.get_state()
  274. a = track.greedy_eps(Q)
  275.  
  276. while not track.is_terminal_state():
  277. position_map[s[0][0], s[0][1]] += 1
  278. count += 1
  279. r = track.take_action(a)
  280. s_prime = track.get_state()
  281. a_prime = track.greedy_eps(Q)
  282. s_a = track.state_action(s, a)
  283. s_a_prime = track.state_action(s_prime, a_prime)
  284. Q[s_a] = Q[s_a] + alpha * (r + gamma * Q[s_a_prime] - Q[s_a])
  285. s, a = s_prime, a_prime
  286. epochs.append(e)
  287. counts.append(count)
  288.  
  289. # In[5]:
  290.  
  291.  
  292. plt.plot(epochs, counts)
  293. plt.title('Simulation Steps vs. Episodes')
  294. plt.xlabel('Epochs')
  295. plt.ylabel('Total Simulation Steps')
  296. plt.show()
  297.  
  298. # In[6]:
  299.  
  300.  
  301. print('Heat map of position exploration:')
  302. plt.imshow(np.flipud(position_map.T), cmap='hot', interpolation='nearest')
  303. plt.show()
  304.  
  305. # In[7]:
  306.  
  307.  
  308. # Convert Q (action-values) to pi (policy)
  309. pi = np.zeros((x_size, y_size, 5, 5), dtype=np.int16)
  310. for idx in np.ndindex(x_size, y_size, 5, 5):
  311. a = np.argmax(Q[idx[0], idx[1], idx[2], idx[3], :, :])
  312. a = np.unravel_index(a, (3, 3))
  313. pi[idx] = track.tuple_to_action(a - np.array([1, 1]))
  314.  
  315. # In[8]:
  316.  
  317.  
  318. # Run learned policy on test case
  319.  
  320. pos_map = np.zeros((x_size, y_size))
  321. track.reset()
  322. for e in range(1000):
  323. s = track.get_state()
  324. s_x, s_y = s[0][0], s[0][1]
  325. s_vx, s_vy = s[1][0], s[1][1]
  326. pos_map[s_x, s_y] += 1 # exploration map
  327. act = track.action_to_tuple(pi[s_x, s_y, s_vx, s_vy])
  328. track.take_action(act)
  329. if track.is_terminal_state(): break
  330.  
  331. print('Sample trajectory on learned policy:')
  332. pos_map = (pos_map > 0).astype(np.float32)
  333. pos_map += track.course # overlay track course
  334. plt.imshow(np.flipud(pos_map.T), cmap='hot', interpolation='nearest')
  335. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement