Advertisement
Guest User

Untitled

a guest
Jul 19th, 2019
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.63 KB | None | 0 0
  1. from random import randrange as rand
  2. import numpy as np
  3. import pygame, sys
  4.  
  5. from keras.models import Sequential
  6. from keras.layers import Dense, Activation, Flatten, InputLayer
  7. from keras.optimizers import Adam
  8.  
  9. from rl.agents.dqn import DQNAgent
  10. from rl.policy import EpsGreedyQPolicy
  11. from rl.memory import SequentialMemory
  12.  
  13. # The configuration
  14. config = {
  15. 'cell_size': 20,
  16. 'cols': 10,
  17. 'rows': 20,
  18. 'delay': 150,
  19. 'maxfps': 30
  20. }
  21.  
  22. colors = [
  23. (0, 0, 0),
  24. (255, 0, 0),
  25. (0, 150, 0),
  26. (0, 0, 255),
  27. (255, 120, 0),
  28. (255, 255, 0),
  29. (180, 0, 255),
  30. (0, 220, 220)
  31. ]
  32.  
  33. # Define the shapes of the single parts
  34. tetris_shapes = [
  35. [[1, 1, 1],
  36. [0, 1, 0]],
  37.  
  38. [[0, 2, 2],
  39. [2, 2, 0]],
  40.  
  41. [[3, 3, 0],
  42. [0, 3, 3]],
  43.  
  44. [[4, 0, 0],
  45. [4, 4, 4]],
  46.  
  47. [[0, 0, 5],
  48. [5, 5, 5]],
  49.  
  50. [[6, 6, 6, 6]],
  51.  
  52. [[7, 7],
  53. [7, 7]]
  54. ]
  55.  
  56. reward_scores = [
  57. 100,
  58. 300,
  59. 500,
  60. 800
  61. ]
  62.  
  63. def rotate_clockwise(shape):
  64. return np.rot90(shape,k=3)
  65.  
  66.  
  67. def check_collision(board, shape, offset):
  68. off_x, off_y = offset
  69. for cy, row in enumerate(shape):
  70. for cx, cell in enumerate(row):
  71. try:
  72. if cell and board[cy + off_y][cx + off_x]:
  73. return True
  74. except IndexError:
  75. return True
  76. return False
  77.  
  78.  
  79. def remove_row(board, row):
  80. del board[row]
  81. return [[0 for i in range(config['cols'])]] + board
  82.  
  83.  
  84. def join_matrixes(mat1, mat2, mat2_off):
  85. off_x, off_y = mat2_off
  86. for cy, row in enumerate(mat2):
  87. for cx, val in enumerate(row):
  88. mat1[cy + off_y - 1][cx + off_x] += val
  89. return mat1
  90.  
  91.  
  92. def new_board():
  93. board = [[0 for x in range(config['cols'])]
  94. for y in range(config['rows'])]
  95. board += [[1 for x in range(config['cols'])]]
  96. return board
  97.  
  98.  
  99. class TetrisApp(object):
  100. def __init__(self):
  101. pygame.init()
  102. pygame.key.set_repeat(250, 25)
  103. self.width = config['cell_size'] * config['cols'] + 200
  104. self.height = config['cell_size'] * config['rows']
  105.  
  106. self.stonebag = [i for i in range(len(tetris_shapes))] * 2
  107. self.current_stone = tetris_shapes[self.stonebag.pop(rand(len(self. stonebag)))]
  108. self.next_stone = tetris_shapes[self.stonebag.pop(rand(len(self. stonebag)))]
  109.  
  110. self.score = 0
  111. self.current_reward = 0
  112.  
  113. self.gameover = False
  114. self.paused = False
  115.  
  116. # self.screen = pygame.display.set_mode((self.width, self.height))
  117. # pygame.event.set_blocked(pygame.MOUSEMOTION) # We do not need
  118. # mouse movement
  119. # events, so we
  120. # block them.
  121. self.init_game()
  122.  
  123. def new_stone(self):
  124. if not self.stonebag:
  125. self.stonebag = [i for i in range(len(tetris_shapes))] * 2
  126. self.current_stone = self.next_stone
  127. self.next_stone = tetris_shapes[self.stonebag.pop(rand(len(self.stonebag)))]
  128.  
  129. self.stone_x = int(config['cols'] / 2 - len(self.current_stone[0]) / 2)
  130. self.stone_y = 0
  131.  
  132. if check_collision(self.board,
  133. self.current_stone,
  134. (self.stone_x, self.stone_y)):
  135. self.gameover = True
  136.  
  137. def init_game(self):
  138. self.board = new_board()
  139. self.new_stone()
  140. self.step_count = 0
  141.  
  142. def center_msg(self, msg):
  143. for i, line in enumerate(msg.splitlines()):
  144. msg_image = pygame.font.Font(
  145. pygame.font.get_default_font(), 12).render(
  146. line, False, (255, 255, 255), (0, 0, 0))
  147.  
  148. msgim_center_x, msgim_center_y = msg_image.get_size()
  149. msgim_center_x //= 2
  150. msgim_center_y //= 2
  151.  
  152. self.screen.blit(msg_image, (
  153. self.width // 2 - msgim_center_x,
  154. self.height // 2 - msgim_center_y + i * 22))
  155.  
  156. def draw_matrix(self, matrix, offset):
  157. off_x, off_y = offset
  158. for y, row in enumerate(matrix):
  159. for x, val in enumerate(row):
  160. if val:
  161. pygame.draw.rect(
  162. self.screen,
  163. colors[val],
  164. pygame.Rect(
  165. (off_x + x) *
  166. (config['cell_size']),
  167. (off_y + y) *
  168. (config['cell_size']),
  169. (config['cell_size']),
  170. (config['cell_size'])), 0)
  171.  
  172. def move(self, delta_x):
  173. if not self.gameover and not self.paused:
  174. new_x = self.stone_x + delta_x
  175. if new_x < 0:
  176. new_x = 0
  177. if new_x > config['cols'] - len(self.current_stone[0]):
  178. new_x = config['cols'] - len(self.current_stone[0])
  179. if not check_collision(self.board,
  180. self.current_stone,
  181. (new_x, self.stone_y)):
  182. self.stone_x = new_x
  183.  
  184. def quit(self):
  185. self.center_msg("Exiting...")
  186. pygame.display.update()
  187. sys.exit()
  188.  
  189. def drop(self):
  190. if not self.gameover and not self.paused:
  191. self.stone_y += 1
  192. if check_collision(self.board,
  193. self.current_stone,
  194. (self.stone_x, self.stone_y)):
  195. self.board = join_matrixes(
  196. self.board,
  197. self.current_stone,
  198. (self.stone_x, self.stone_y))
  199. self.new_stone()
  200. combo = 0
  201. while True:
  202. for i, row in enumerate(self.board[:-1]):
  203. if 0 not in row:
  204. self.board = remove_row(
  205. self.board, i)
  206. combo += 1
  207. break
  208. else:
  209. break
  210. # give reward
  211. if combo != 0:
  212. self.score += reward_scores[combo - 1]
  213. self.current_reward = reward_scores[combo - 1]
  214.  
  215. def rotate_stone(self):
  216. if not self.gameover and not self.paused:
  217. new_stone = rotate_clockwise(self.current_stone)
  218. # if self.stone_x > config["cols"] - 3:
  219. for i in range(1, 4):
  220. if check_collision(self.board,
  221. new_stone,
  222. (self.stone_x, self.stone_y)) and not check_collision(self.board, new_stone, (
  223. self.stone_x - i, self.stone_y)):
  224. if self.stone_x - i >= 0:
  225. self.stone_x -= i
  226. self.current_stone = new_stone
  227. elif not check_collision(self.board,
  228. new_stone,
  229. (self.stone_x, self.stone_y)):
  230. self.current_stone = new_stone
  231.  
  232. def toggle_pause(self):
  233. self.paused = not self.paused
  234.  
  235. def reset(self):
  236. if self.gameover:
  237. self.init_game()
  238. self.score = 0
  239. self.gameover = False
  240. return self.get_state()
  241.  
  242. def get_state(self):
  243. state = []
  244. for row in self.board:
  245. for value in row:
  246. if value == 0:
  247. state += [0]
  248. else:
  249. state += [1]
  250.  
  251. for i, row in enumerate(self.current_stone):
  252. for j, column in enumerate(row):
  253. if column != 0:
  254. state[(self.stone_y + i) * config['cols'] + self.stone_x + j] = 0.5
  255. """
  256. print("## BOARD ##")
  257. for i, value in enumerate(state):
  258. if (i + 1) % config['cols'] == 0:
  259. print(value)
  260. else:
  261. print(value, end="")
  262. """
  263. return state
  264.  
  265. def step(self, a):
  266. self.step_count += 1
  267.  
  268. key_actions = {
  269. 'ESCAPE': self.quit,
  270. 'LEFT': lambda: self.move(-1),
  271. 'RIGHT': lambda: self.move(+1),
  272. 'DOWN': self.drop,
  273. 'UP': self.rotate_stone,
  274. 'p': self.toggle_pause,
  275. 'SPACE': self.reset
  276. }
  277. actions = ['LEFT', 'RIGHT', 'DOWN', 'UP']
  278.  
  279. key_actions[actions[a]]()
  280.  
  281. if self.step_count % 3:
  282. self.drop()
  283. self.current_reward += 5
  284.  
  285. new_s = self.get_state()
  286.  
  287. r = self.current_reward
  288.  
  289. done = False
  290. if self.gameover:
  291. print(self.get_state())
  292. done = True
  293. r = -100
  294.  
  295. self.current_reward = 0
  296.  
  297. return new_s, r, done, {}
  298.  
  299.  
  300. #print(s,a,r,sep="\n")
  301. nb_actions = 4
  302.  
  303. np.random.seed(123)
  304.  
  305. model = Sequential()
  306.  
  307. model.add(Flatten(input_shape=(1,210)))
  308. model.add(Dense(128))
  309. model.add(Activation('relu'))
  310. model.add(Dense(64))
  311. model.add(Activation('relu'))
  312. model.add(Dense(nb_actions))
  313. model.add(Activation('linear'))
  314. print(model.summary())
  315.  
  316. policy = EpsGreedyQPolicy()
  317. memory = SequentialMemory(limit=50000, window_length=1)
  318. dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=40,
  319. target_model_update=1e-2, policy=policy)
  320. dqn.compile(Adam(lr=1e-3), metrics=['mae'])
  321.  
  322. env = TetrisApp()
  323.  
  324. print(env)
  325.  
  326. # Okay, now it's time to learn something! We visualize the training here for show, but this slows down training quite a lot.
  327. dqn.fit(env, nb_steps=500000, visualize=False, verbose=2)
  328.  
  329. dqn.test(env, nb_episodes=5, visualize=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement