Advertisement
Guest User

Untitled

a guest
Jun 1st, 2025
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.58 KB | None | 0 0
  1. import pygame
  2. import sys
  3. import random
  4. import numpy as np
  5. import math
  6.  
  7. pygame.init()
  8.  
  9. # Game Environment Interface for RL
  10. class SnakeEnv:
  11.     def __init__(self, screen_width=400, screen_height=400, block_size=20, starting_size=3):
  12.         self.SCREEN_WIDTH = screen_width
  13.         self.SCREEN_HEIGHT = screen_height
  14.         self.BLOCK_SIZE = block_size
  15.         self.STARTING_SIZE = starting_size
  16.  
  17.         self.grid_width = self.SCREEN_WIDTH // self.BLOCK_SIZE
  18.         self.grid_height = self.SCREEN_HEIGHT // self.BLOCK_SIZE
  19.  
  20.         self.window = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
  21.         self.font = pygame.font.SysFont('Arial', 15, bold=False)
  22.         self.clock = pygame.time.Clock()
  23.         self.reset()
  24.  
  25.     def reset(self):
  26.         self.snake_alive = True
  27.         self.snake_size = self.STARTING_SIZE
  28.         self.initial_x = 100
  29.         self.initial_y = 100
  30.         self.snake_body = [(self.initial_x, self.initial_y), (self.initial_x-self.BLOCK_SIZE, self.initial_y), (self.initial_x-(self.BLOCK_SIZE * 2), self.initial_y)]
  31.         self.snake_dir = (1, 0)
  32.         self.next_dir = self.snake_dir
  33.         self.randomize_food()
  34.        
  35.         return self.get_observation()
  36.  
  37.     # height * width of uint8 where 0 = empty, 1 = snake_body, 2 = snake_head, 3 = food, then cast to float32 and normalize to 0-1 range by dividing by 2
  38.     def get_observation(self):
  39.         grid_h, grid_w = self.grid_height, self.grid_width
  40.  
  41.         # Create 4 channels: body, head, food, direction
  42.         body_channel = np.zeros((grid_h, grid_w), dtype=np.float32)
  43.         head_channel = np.zeros((grid_h, grid_w), dtype=np.float32)
  44.         food_channel = np.zeros((grid_h, grid_w), dtype=np.float32)
  45.         direction_channel = np.zeros((grid_h, grid_w), dtype=np.float32)
  46.  
  47.         # Body (excluding head)
  48.         for (x, y) in self.snake_body[1:]:
  49.             gx = x // self.BLOCK_SIZE
  50.             gy = y // self.BLOCK_SIZE
  51.             body_channel[gy, gx] = 1.0
  52.  
  53.         # Head
  54.         head_x, head_y = self.snake_body[0]
  55.         gx = head_x // self.BLOCK_SIZE
  56.         gy = head_y // self.BLOCK_SIZE
  57.         head_channel[gy, gx] = 1.0
  58.  
  59.         # Food
  60.         fx, fy = self.food
  61.         fx = fx // self.BLOCK_SIZE
  62.         fy = fy // self.BLOCK_SIZE
  63.         food_channel[fy, fx] = 1.0
  64.  
  65.         # Direction (encoded only at head position)
  66.         if self.snake_dir == "UP":
  67.             direction_channel[gy, gx] = 0.25
  68.         elif self.snake_dir == "DOWN":
  69.             direction_channel[gy, gx] = 0.5
  70.         elif self.snake_dir == "LEFT":
  71.             direction_channel[gy, gx] = 0.75
  72.         elif self.snake_dir == "RIGHT":
  73.             direction_channel[gy, gx] = 1.0
  74.  
  75.         # Stack channels into a single observation tensor
  76.         obs = np.stack([body_channel, head_channel, food_channel, direction_channel], axis=0)  # shape: (4, H, W)
  77.         return obs
  78.  
  79.  
  80.     def step(self, action):
  81.  
  82.         reward = -0.01 # small penalty per step to incentivize faster food seeking
  83.  
  84.         directions = [(1, 0), (-1, 0), (0, -1), (0, 1)]
  85.         proposed_dir = directions[action]
  86.  
  87.         # prevent going opposite direction in one move
  88.         if (proposed_dir[0] * -1, proposed_dir[1] * -1) != self.snake_dir:
  89.             self.snake_dir = proposed_dir
  90.  
  91.         (dx, dy) = self.snake_dir
  92.         old_x, old_y = self.snake_body[0]
  93.         new_head = (old_x + dx * self.BLOCK_SIZE, old_y + dy * self.BLOCK_SIZE)
  94.  
  95.         done = False
  96.         x, y = new_head
  97.  
  98.         # incentivize moving snake closer to food
  99.         # food_x, food_y = self.food
  100.         # old_distance_to_food = abs(old_x - food_x) + abs(old_y - food_y)
  101.         # new_distance_to_food = abs(x - food_x) + abs(y - food_y)
  102.         # if old_distance_to_food > new_distance_to_food:
  103.         #     reward += 0.1
  104.         # else:
  105.         #     reward -= 0.1
  106.  
  107.         x, y = new_head
  108.         if x < 0 or x >= self.SCREEN_WIDTH or y < 0 or y >= self.SCREEN_HEIGHT:
  109.             self.snake_alive = False
  110.             reward -= 1
  111.             done = True
  112.  
  113.         # check if new head collisions with body before insertion
  114.         elif new_head in self.snake_body:
  115.             self.snake_alive = False
  116.             reward -= 1
  117.             done = True
  118.         else:
  119.             self.snake_body.insert(0, new_head)
  120.  
  121.             if new_head == self.food:
  122.                 self.snake_size += 1
  123.                 reward += 1
  124.                 self.randomize_food()
  125.             else:
  126.                 if len(self.snake_body) > self.snake_size:
  127.                     self.snake_body.pop()
  128.  
  129.         obs = self.get_observation()
  130.  
  131.         return obs, reward, done
  132.    
  133.     def randomize_food(self):
  134.         self.food = (random.randrange(0, self.SCREEN_WIDTH, self.BLOCK_SIZE), random.randrange(0, self.SCREEN_HEIGHT, self.BLOCK_SIZE))
  135.  
  136.         while self.food in self.snake_body:
  137.             self.food = (random.randrange(0, self.SCREEN_WIDTH, self.BLOCK_SIZE), random.randrange(0, self.SCREEN_HEIGHT, self.BLOCK_SIZE))
  138.  
  139.     def draw_snake(self):
  140.         for i, segment in enumerate(self.snake_body):
  141.  
  142.             max_green = 255
  143.             min_green = 90 # not fully dark
  144.             green_fade = int(max_green - (i / len(self.snake_body)) * (max_green - min_green))
  145.             color = (0, green_fade, 0)
  146.  
  147.             (x, y) = segment
  148.             rect = pygame.Rect(x, y, self.BLOCK_SIZE, self.BLOCK_SIZE)
  149.  
  150.             pygame.draw.rect(self.window, color, rect)
  151.  
  152.     def draw_food(self):    
  153.         x, y = self.food
  154.         rect = pygame.Rect(x, y, self.BLOCK_SIZE, self.BLOCK_SIZE)
  155.  
  156.         pygame.draw.rect(self.window, (255, 0, 0), rect)
  157.  
  158.     def move_snake(self):
  159.  
  160.         (dx, dy) = self.snake_dir
  161.         (x, y) = self.snake_body[0] # head coords
  162.         new_head = (x + (dx * self.BLOCK_SIZE), y + (dy * self.BLOCK_SIZE))
  163.  
  164.         # wall collision detection
  165.         x, y = new_head
  166.         if x < 0 or x >= self.SCREEN_WIDTH or y < 0 or y >= self.SCREEN_HEIGHT:
  167.             self.snake_alive = False
  168.  
  169.         # check if new head collisions with body before insertion
  170.         if new_head in self.snake_body:
  171.             self.snake_alive = False
  172.  
  173.         self.snake_body.insert(0, new_head)
  174.  
  175.         # eats fruit
  176.         if new_head == self.food:
  177.             self.snake_size += 1
  178.             self.randomize_food()
  179.  
  180.         if len(self.snake_body) > self.snake_size:
  181.             self.snake_body.pop()
  182.  
  183.     def display_score(self):
  184.         self.window.blit(self.font.render("Score: " + str(self.snake_size - self.STARTING_SIZE), True, 'white'), (0, 0))
  185.  
  186.  
  187.  
  188.  
  189. # ------- MANUAL Snake -------
  190.  
  191. # only execute if game_env file was executed, not if the file is imported
  192. if __name__ == "__main__":
  193.        
  194.     env = SnakeEnv()
  195.  
  196.     # allows fps to be high while limiting game tickspeed
  197.     GAME_TICK = pygame.USEREVENT
  198.     pygame.time.set_timer(GAME_TICK, 150) # 150ms default
  199.  
  200.     # Main Game Loop
  201.     while True:
  202.         for event in pygame.event.get():
  203.             if event.type == pygame.QUIT:
  204.                 pygame.quit()
  205.                 sys.exit()
  206.  
  207.             if env.snake_alive:
  208.  
  209.                 # update dir not move snake on keypress since movement should be gated by game tickspeed while keypress handle should be handled constantly (at a 60fps rate)
  210.                 # updates next_dir and not snake_dir since user can queue multiple movement commands during 1 game tick which shouldn't be allowed
  211.                 if event.type == pygame.KEYDOWN:
  212.                     if event.key == pygame.K_RIGHT and env.snake_dir != (-1, 0):
  213.                         env.next_dir = (1, 0)
  214.                     elif event.key == pygame.K_LEFT and env.snake_dir != (1, 0):
  215.                         env.next_dir = (-1, 0)
  216.                     elif event.key == pygame.K_UP and env.snake_dir != (0, 1):
  217.                         env.next_dir = (0, -1)
  218.                     elif event.key == pygame.K_DOWN and env.snake_dir != (0, -1):
  219.                         env.next_dir = (0, 1)
  220.  
  221.                 if event.type == GAME_TICK:
  222.                     env.snake_dir = env.next_dir
  223.                     env.window.fill('black')
  224.                     env.draw_snake()
  225.                     env.draw_food()
  226.                     env.move_snake()
  227.                     env.display_score()
  228.             else:
  229.                 env.window.blit(env.font.render("Dead", True, 'red'), (env.SCREEN_WIDTH // 2 - 30, env.SCREEN_HEIGHT // 2))
  230.  
  231.                 if event.type == pygame.KEYDOWN:
  232.                     if event.key == pygame.K_r:
  233.                         env.reset()
  234.            
  235.         pygame.display.flip()
  236.         env.clock.tick(60)
  237.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement