Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import sys, os, time
- from contextlib import closing
- from random import randint
- import numpy as np
- from six import StringIO, b
- from gym import utils, make
- from gym.envs.registration import register
- from gym.envs.toy_text import discrete
- LEFT = 0
- DOWN = 1
- RIGHT = 2
- UP = 3
- TREAT = "+"
- WALL = "0"
- SNAKE = "S"
- MAPS = {
- "basic": [
- ["0000000000000000"],
- ["0--------------0"],
- ["0--------------0"],
- ["0--------------0"],
- ["0--------------0"],
- ["0--------------0"],
- ["0--------------0"],
- ["0000000000000000"]
- ]
- }
- class SnakeEnv(discrete.DiscreteEnv):
- """
- 0000000000000000
- 0 SSS 0
- 0 SSSSS 0
- 0 0
- 0 0
- 0 + 0
- 0 0
- 0000000000000000
- 0 : wall, bad
- + : treat, good
- S : snake-boi
- The episode ends when you reach the goal or fall in a hole.
- You receive a reward of 1 if you reach the goal, and zero otherwise.
- """
- metadata = {'render.modes': ['human', 'ansi']}
- def __init__(self, desc=None, map_name="basic",is_slippery=True):
- desc = MAPS[map_name]
- self.desc = desc = np.asarray(desc,dtype='c')
- self.nrow, self.ncol = nrow, ncol = desc.shape
- self.reward_range = (0, 1)
- nA = 4
- nS = nrow * ncol
- isd = np.array(desc == b'S').astype('float64').ravel()
- isd /= isd.sum()
- P = {s : {a : [] for a in range(nA)} for s in range(nS)}
- def to_s(row, col):
- return row*ncol + col
- def inc(row, col, a):
- if a == LEFT:
- col = max(col-1,0)
- elif a == DOWN:
- row = min(row+1,nrow-1)
- elif a == RIGHT:
- col = min(col+1,ncol-1)
- elif a == UP:
- row = max(row-1,0)
- return (row, col)
- for row in range(nrow):
- for col in range(ncol):
- s = to_s(row, col)
- for a in range(4):
- li = P[s][a]
- letter = desc[row, col]
- if letter in b'GH':
- li.append((1.0, s, 0, True))
- else:
- newrow, newcol = inc(row, col, a)
- newstate = to_s(newrow, newcol)
- newletter = desc[newrow, newcol]
- done = bytes(newletter) in b'GH'
- rew = float(newletter == b'G')
- li.append((1.0, newstate, rew, done))
- super(SnakeEnv, self).__init__(nS, nA, P, isd)
- def spawnTreat(self):
- r = randint(0, self.nrow)
- c = randint(0, self.ncol)
- letter = desc[row, col]
- if letter is not SNAKE and not WALL:
- desc[row, col] = TREAT
- else:
- self.spawnTreat()
- def render(self, mode='human'):
- outfile = StringIO() if mode == 'ansi' else sys.stdout
- row, col = self.s // self.ncol, self.s % self.ncol
- desc = self.desc.tolist()
- desc = [[c.decode('utf-8') for c in line] for line in desc]
- desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
- if self.lastaction is not None:
- outfile.write("Eelmine samm: ({})\n".format(["Left","Down","Right","Up"][self.lastaction]))
- else:
- outfile.write("\n")
- outfile.write("\n".join(''.join(line) for line in desc)+"\n")
- time.sleep(1)
- os.system("clear")
- if mode != 'human':
- with closing(outfile):
- return outfile.getvalue()
- register(
- id='snake-game-v100',
- entry_point=SnakeEnv,
- max_episode_steps=200,
- reward_threshold=25.0,
- )
- env = make('snake-game-v100')
- env.reset()
- for _ in range(1000):
- env.render()
- env.step(env.action_space.sample()) # take a random action
- env.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement