Advertisement
Guest User

Untitled

a guest
Dec 11th, 2019
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.65 KB | None | 0 0
  1. from ple.games.snake import Snake
  2. from ple import PLE
  3. import numpy as np
  4. from agent import Agent
  5. import pygame
  6. import sys
  7.  
  8. def get_dist(head_x, head_y, obs_x, obs_y):
  9. return ((head_x - obs_x) ** 2 + (head_y - obs_y) ** 2) ** 0.5
  10.  
  11.  
  12. def get_state(state):
  13. head_x, head_y = state[0], state[1]
  14. min_dist_walls = min(get_dist(head_x, head_y, head_x, 0), get_dist(head_x, head_y, 0, head_y),
  15. get_dist(head_x, head_y, 600, head_y), get_dist(head_x, head_y, head_x, 600))
  16. return [state[0], state[1], state[2], state[3], min(min(state[4][2:]), min_dist_walls)]
  17.  
  18.  
  19. def vision(state):
  20. my_vision = [[0, 0, 0] for _ in range(4)]
  21. head_x, head_y = state[0], state[1]
  22. food_x, food_y = state[2], state[3]
  23.  
  24. # food
  25. dist_x, dist_y = head_x - food_x, head_y - food_y
  26. if abs(dist_y) < 100:
  27. if dist_x < 0:
  28. my_vision[3][0] = 1
  29. else:
  30. my_vision[2][0] = 1
  31. if abs(dist_x) < 100:
  32. if dist_y < 0:
  33. my_vision[1][0] = 1
  34. else:
  35. my_vision[0][0] = 1
  36.  
  37. # wall
  38. if head_x <= 50:
  39. my_vision[2][1] = 1
  40. elif 600 - head_x <= 50:
  41. my_vision[3][1] = 1
  42. if head_y <= 50:
  43. my_vision[0][1] = 1
  44. elif 600 - head_y <= 50:
  45. my_vision[1][1] = 1
  46.  
  47. # body
  48. for body_x, body_y in state[5][3:]:
  49. # print(body_x,body_y)
  50. dist_x = head_x - body_x
  51. dist_y = head_y - body_y
  52. if abs(dist_x) <= 50:
  53. if dist_x > 0:
  54. my_vision[2][2] = 1
  55. else:
  56. my_vision[3][2] = 1
  57. if abs(dist_y) <= 50:
  58. if dist_y < 0:
  59. my_vision[1][2] = 1
  60. else:
  61. my_vision[0][2] = 1
  62. output = []
  63. [output.extend(item) for item in my_vision]
  64. return output
  65.  
  66.  
  67. def process_state(state):
  68. return np.array([state.values()])
  69.  
  70. def run():
  71. game = Snake(600, 600)
  72. p = PLE(game,reward_values={"positive": 100.0,
  73. "negative": -100.0,
  74. "tick": -0.5,
  75. "loss": -50.0,
  76. "win": 5.0}, display_screen=False, state_preprocessor=process_state)
  77. n_games = 10000
  78. print(sys.argv[1])
  79. agent = Agent(alpha=float(sys.argv[1]), gamma=float(sys.argv[2]), n_actions=4, epsilon=0.99, batch_size=64, input_shape=12, epsilon_dec=0.09,
  80. epsilon_end=0.01,
  81. memory_size=1000000,file_name=sys.argv[3],activations = [str(sys.argv[4]),str(sys.argv[5])])
  82. # agent.load_game()
  83. actions = [119,115,97,100]
  84. scores = []
  85. for _ in range(100000):
  86. if p.game_over():
  87. p.reset_game()
  88. score = 0
  89. # state = p.getGameState()
  90. while not p.game_over():
  91. old_state = np.array(vision(list(p.getGameState()[0])))
  92. # print(old_state)
  93.  
  94. action = agent.choose_action(old_state)
  95. reward = p.act(actions[action])
  96. new_state = np.array(vision(list(p.getGameState()[0])))
  97. agent.add_experience(old_state,action,reward,new_state)
  98. agent.learn()
  99. score = p.score()
  100. scores.append(score)
  101. print(f"Score for model iteration number _ {str(sys.argv[3])} with learning_rate {sys.argv[1]}, gama {sys.argv[2]}, activations: {sys.argv[4],sys.argv[5]} is score {score}")
  102. agent.save_game()
  103. with open('scoruri.txt',"a") as my_scores:
  104. my_scores.write(f'Scorurile pentru rularea cu activarile {sys.argv[4], sys.argv[5]}')
  105. my_scores.write(scores)
  106.  
  107.  
  108. if __name__ == '__main__':
  109. run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement