Advertisement
Guest User

Untitled

a guest
Nov 11th, 2019
130
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.98 KB | None | 0 0
  1. import sys, os, time
  2. from contextlib import closing
  3.  
  4. from random import randint
  5. import numpy as np
  6. from six import StringIO, b
  7.  
  8. from gym import utils, make
  9. from gym.envs.registration import register
  10. from gym.envs.toy_text import discrete
  11.  
  12. LEFT = 0
  13. DOWN = 1
  14. RIGHT = 2
  15. UP = 3
  16.  
  17. TREAT = "+"
  18. WALL = "0"
  19. SNAKE = "S"
  20.  
  21. MAPS = {
  22. "basic": [
  23. ["0000000000000000"],
  24. ["0--------------0"],
  25. ["0--------------0"],
  26. ["0--------------0"],
  27. ["0--------------0"],
  28. ["0--------------0"],
  29. ["0--------------0"],
  30. ["0000000000000000"]
  31. ]
  32. }
  33.  
  34. class SnakeEnv(discrete.DiscreteEnv):
  35. """
  36.  
  37. 0000000000000000
  38. 0 SSS 0
  39. 0 SSSSS 0
  40. 0 0
  41. 0 0
  42. 0 + 0
  43. 0 0
  44. 0000000000000000
  45.  
  46. 0 : wall, bad
  47. + : treat, good
  48. S : snake-boi
  49.  
  50. The episode ends when you reach the goal or fall in a hole.
  51. You receive a reward of 1 if you reach the goal, and zero otherwise.
  52.  
  53. """
  54.  
  55. metadata = {'render.modes': ['human', 'ansi']}
  56.  
  57. def __init__(self, desc=None, map_name="basic",is_slippery=True):
  58. desc = MAPS[map_name]
  59. self.desc = desc = np.asarray(desc,dtype='c')
  60. self.nrow, self.ncol = nrow, ncol = desc.shape
  61. self.reward_range = (0, 1)
  62.  
  63. nA = 4
  64. nS = nrow * ncol
  65.  
  66. isd = np.array(desc == b'S').astype('float64').ravel()
  67. isd /= isd.sum()
  68.  
  69. P = {s : {a : [] for a in range(nA)} for s in range(nS)}
  70.  
  71. def to_s(row, col):
  72. return row*ncol + col
  73.  
  74. def inc(row, col, a):
  75. if a == LEFT:
  76. col = max(col-1,0)
  77. elif a == DOWN:
  78. row = min(row+1,nrow-1)
  79. elif a == RIGHT:
  80. col = min(col+1,ncol-1)
  81. elif a == UP:
  82. row = max(row-1,0)
  83. return (row, col)
  84.  
  85. for row in range(nrow):
  86. for col in range(ncol):
  87. s = to_s(row, col)
  88. for a in range(4):
  89. li = P[s][a]
  90. letter = desc[row, col]
  91. if letter in b'GH':
  92. li.append((1.0, s, 0, True))
  93. else:
  94. newrow, newcol = inc(row, col, a)
  95. newstate = to_s(newrow, newcol)
  96. newletter = desc[newrow, newcol]
  97. done = bytes(newletter) in b'GH'
  98. rew = float(newletter == b'G')
  99. li.append((1.0, newstate, rew, done))
  100.  
  101. super(SnakeEnv, self).__init__(nS, nA, P, isd)
  102.  
  103. def spawnTreat(self):
  104. r = randint(0, self.nrow)
  105. c = randint(0, self.ncol)
  106. letter = desc[row, col]
  107. if letter is not SNAKE and not WALL:
  108. desc[row, col] = TREAT
  109. else:
  110. self.spawnTreat()
  111.  
  112. def render(self, mode='human'):
  113. outfile = StringIO() if mode == 'ansi' else sys.stdout
  114.  
  115. row, col = self.s // self.ncol, self.s % self.ncol
  116. desc = self.desc.tolist()
  117. desc = [[c.decode('utf-8') for c in line] for line in desc]
  118. desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
  119. if self.lastaction is not None:
  120. outfile.write("Eelmine samm: ({})\n".format(["Left","Down","Right","Up"][self.lastaction]))
  121. else:
  122. outfile.write("\n")
  123. outfile.write("\n".join(''.join(line) for line in desc)+"\n")
  124.  
  125. time.sleep(1)
  126. os.system("clear")
  127.  
  128. if mode != 'human':
  129. with closing(outfile):
  130. return outfile.getvalue()
  131.  
  132.  
  133. register(
  134. id='snake-game-v100',
  135. entry_point=SnakeEnv,
  136. max_episode_steps=200,
  137. reward_threshold=25.0,
  138. )
  139.  
  140. env = make('snake-game-v100')
  141. env.reset()
  142. for _ in range(1000):
  143. env.render()
  144. env.step(env.action_space.sample()) # take a random action
  145. env.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement