Advertisement
Guest User

Untitled

a guest
Dec 15th, 2019
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.58 KB | None | 0 0
  1. from ple.games.snake import Snake
  2. from ple import PLE
  3. import numpy as np
  4. from agent import Agent
  5. import pygame
  6. import sys
  7.  
  8.  
  9. def get_dist(head_x, head_y, obs_x, obs_y):
  10. return ((head_x - obs_x) ** 2 + (head_y - obs_y) ** 2) ** 0.5
  11.  
  12.  
  13. def get_state(state):
  14. head_x, head_y = state[0], state[1]
  15. min_dist_walls = min(get_dist(head_x, head_y, head_x, 0), get_dist(head_x, head_y, 0, head_y),
  16. get_dist(head_x, head_y, 600, head_y), get_dist(head_x, head_y, head_x, 600))
  17. return [state[0], state[1], state[2], state[3], min(min(state[4][4:]), min_dist_walls)]
  18.  
  19.  
  20. def vision(state):
  21. my_vision = [[0,0] for _ in range(4)]
  22. head_x, head_y = state[0], state[1]
  23. food_x, food_y = state[2], state[3]
  24.  
  25. # food
  26. dist_x, dist_y = head_x - food_x, head_y - food_y
  27. if abs(dist_y) < 500:
  28. if dist_x < 0:
  29. my_vision[3][0] = 10
  30. else:
  31. my_vision[2][0] = 10
  32. if abs(dist_x) < 500:
  33. if dist_y < 0:
  34. my_vision[1][0] = 10
  35. else:
  36. my_vision[0][0] = 10
  37.  
  38. # wall
  39. if head_x <= 50:
  40. my_vision[2][1] = -100
  41. elif 600 - head_x <= 50:
  42. my_vision[3][1] = -100
  43. if head_y <= 50:
  44. my_vision[0][1] = -100
  45. elif 600 - head_y <= 50:
  46. my_vision[1][1] = -100
  47.  
  48. # body
  49. for body_x, body_y in state[5][3:]:
  50. # print(body_x,body_y)
  51. dist_x = head_x - body_x
  52. dist_y = head_y - body_y
  53. if abs(dist_x) <= 50:
  54. if dist_x > 0:
  55. my_vision[2][1] = -100
  56. else:
  57. my_vision[3][1] = -100
  58. if abs(dist_y) <= 50:
  59. if dist_y < 0:
  60. my_vision[1][1] = -100
  61. else:
  62. my_vision[0][1] = -100
  63. output = []
  64. [output.extend(item) for item in my_vision]
  65. output.extend([head_x, head_y, food_x, food_y])
  66. return output
  67.  
  68.  
  69. def prepare_corect_directions(direction):
  70. direction = str(direction)
  71. if direction == "Left":
  72. return {119: "Up", 115: "Down", 97: "Left"}
  73. if direction == "Right":
  74. return {115: "Down", 119: "Up", 100: "Right"}
  75. if direction == "Up":
  76. return {100: "Right", 97: "Left", 119: "Up"}
  77. if direction == "Down":
  78. return {97: "Left", 100: "Right", 115: "Down"}
  79.  
  80.  
  81. def process_state(state):
  82. return np.array([state.values()])
  83.  
  84.  
  85. def run():
  86. game = Snake(600, 600)
  87. p = PLE(game, state_preprocessor=process_state,force_fps=True, display_screen=True,frame_skip=5)
  88. print(sys.argv[1])
  89. agent = Agent(alpha=float(sys.argv[1]), gamma=float(sys.argv[2]), n_actions=3, epsilon=0.99, batch_size=100,
  90. input_shape=12, epsilon_dec=0.99999,
  91. epsilon_end=0.1,
  92. memory_size=50000000, file_name=sys.argv[3], activations=[str(sys.argv[4]), str(sys.argv[5])])
  93. p.init()
  94. # agent.load_game()
  95.  
  96. scores = []
  97.  
  98. for _ in range(100000):
  99. if p.game_over():
  100. p.reset_game()
  101. score = 0
  102. # state = p.getGameState()
  103. initial_direction = "Right"
  104. game_state = np.array(vision(list(p.getGameState()[0])))
  105. # print(game_state)
  106. prec_dist = get_dist(game_state[0], game_state[1], game_state[2], game_state[3])
  107.  
  108. while not p.game_over():
  109. old_state = np.array(vision(list(p.getGameState()[0])))
  110.  
  111. action = agent.choose_action(old_state)
  112.  
  113. possible_directions = prepare_corect_directions(initial_direction)
  114. possible_directions_tuples = list(zip(possible_directions.keys(), possible_directions.values()))
  115. direction = possible_directions_tuples[action]
  116. initial_direction = direction[1]
  117.  
  118. reward = p.act(direction[0])
  119. if reward == -0.1:
  120. game_state = np.array(vision(list(p.getGameState()[0])))
  121. curr_dist = get_dist(game_state[0], game_state[1], game_state[2], game_state[3])
  122. if prec_dist > curr_dist: reward = 1.5
  123. prec_dist=curr_dist
  124.  
  125. print(reward)
  126.  
  127. new_state = np.array(vision(list(p.getGameState()[0])))
  128. agent.add_experience(old_state, action, reward, new_state)
  129. agent.learn()
  130. score = p.score()
  131. scores.append(score)
  132. print(f"Score for model iteration number _ {str(sys.argv[3])} with learning_rate {sys.argv[1]}, gama {sys.argv[2]}, activations: {sys.argv[4], sys.argv[5]} is score {score}. Epsilon is {agent.epsilon}")
  133. agent.save_game()
  134.  
  135.  
  136. #
  137.  
  138. if __name__ == '__main__':
  139. run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement