Advertisement
GroX24

neurodrone

Apr 26th, 2024 (edited)
600
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 16.84 KB | None | 0 0
  1. import pygame as pg
  2. from math import sin, cos, pi, ceil, floor
  3.  
  4. import torch as T
  5. from torch import nn
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. import numpy as np
  9. from numpy.random import random as nprand
  10. import matplotlib.pyplot as plt
  11. from collections import OrderedDict
  12. import time, os
  13. import csv
  14.  
  15. # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  16.  
  17. WIDTH, HEIGHT = 800, 600
  18. m = 1  # drone mass
  19. g = 4  # grav. acceleration
  20. dt = 2 / 60
  21. l = 1  # length of the base
  22. eng_l = 0.25  # length of the engine (there are two of them on the left and on the right)
  23. d = 0.25  # height of both the base and the engines
  24. drag = 0.1  # drag coefficient
  25. maxthr = 4  # max engine thrust
  26. thr_incr = maxthr * dt / 0.5  # increment by which the power is changed according to the key presses
  27. I = (m * (l + 2 * eng_l) ** 2 / 12)  # Moment of inertia for a thin rod
  28. fontsize = 18
  29.  
  30. pg.init()
  31. font = pg.font.SysFont("arial", fontsize)
  32.  
  33.  
  34. # image = pg.image.load("undrtale.png")
  35.  
  36.  
  37. class QNet(nn.Module):
  38.     def __init__(self, n_state, n_actions, n_layers, n_neurons, lr=0.001):
  39.         super().__init__()
  40.         self.layers = nn.ModuleList()
  41.         self.len = n_layers
  42.         self.n_state = n_state
  43.         self.n_actions = n_actions
  44.         if n_layers == 1:
  45.             self.layers.append(nn.Linear(n_state, n_actions))
  46.         else:
  47.             self.layers.append(nn.Linear(n_state, n_neurons))
  48.             for i in range(n_layers - 2):
  49.                 self.layers.append(nn.Linear(n_neurons, n_neurons))
  50.             self.layers.append(nn.Linear(n_neurons, n_actions))
  51.         self.optimizer = optim.Adam(self.parameters(), lr=lr)
  52.         self.loss = nn.MSELoss()
  53.         self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
  54.         self.to(self.device)
  55.         print(f"using {self.device}")
  56.  
  57.     def forward(self, x):
  58.         # start = time.time_ns() / 1e6
  59.         for i in range(self.len - 1):
  60.             x = F.relu(self.layers[i](x))
  61.         # end = time.time_ns() / 1e6
  62.         # print(f"QNet forward time: {end - start} ms")
  63.         return self.layers[-1](x)
  64.  
  65.  
  66. class Agent():
  67.     def __init__(self, gamma, eps, lr, n_state, n_actions, batch_size,
  68.                  max_mem=100000, eps_end=0.01, eps_dec=5e-4, n_layers=3, n_neurons=128):
  69.         self.gamma = gamma
  70.         self.eps = eps
  71.         self.eps_min = eps_end
  72.         self.eps_dec = eps_dec
  73.         self.action_space = [i for i in range(n_actions)]
  74.         self.lr = lr
  75.         self.batch_size = batch_size
  76.         self.mem_size = max_mem
  77.         self.mem_countr = 0
  78.  
  79.         self.eval = QNet(n_state, n_actions, n_layers, n_neurons, lr)
  80.  
  81.         self.smemory = np.zeros((self.mem_size, n_state), dtype=np.float32)
  82.         self.nsmemory = np.zeros((self.mem_size, n_state), dtype=np.float32)
  83.         self.amemory = np.zeros(self.mem_size, dtype=np.int32)
  84.         self.rmemory = np.zeros(self.mem_size, dtype=np.float32)
  85.         self.terminalmemory = np.zeros(self.mem_size, dtype=np.bool_)
  86.  
  87.     def store_transition(self, state, action, reward, newstate, done):
  88.         i = self.mem_countr % self.mem_size
  89.         self.smemory[i] = state
  90.         self.amemory[i] = action
  91.         self.rmemory[i] = reward
  92.         self.nsmemory[i] = newstate
  93.         self.terminalmemory[i] = done
  94.         self.mem_countr += 1
  95.  
  96.     def policy(self, state):
  97.         if np.random.random() < self.eps:
  98.             action = np.random.choice(self.action_space)
  99.         else:
  100.             state = T.tensor([state]).to(self.eval.device)
  101.             actions = self.eval.forward(state)
  102.             action = T.argmax(actions).item()
  103.         return action
  104.  
  105.     def learn(self):
  106.         # start = time.time_ns() / 1e6
  107.         if self.mem_countr < self.batch_size:
  108.             return
  109.         self.eval.optimizer.zero_grad()
  110.         mem = min(self.mem_size, self.mem_countr)
  111.         batch = np.random.choice(mem, self.batch_size, replace=False)
  112.         batch_i = np.arange(self.batch_size, dtype=np.int32)
  113.  
  114.         state_batch = T.tensor(self.smemory[batch]).to(self.eval.device)
  115.         new_state_batch = T.tensor(self.nsmemory[batch]).to(self.eval.device)
  116.         reward_batch = T.tensor(self.rmemory[batch]).to(self.eval.device)
  117.         terminal_batch = T.tensor(self.terminalmemory[batch]).to(self.eval.device)
  118.         action_batch = self.amemory[batch]  # not necessarily a tensor
  119.  
  120.         q_eval = self.eval.forward(state_batch)[batch_i, action_batch]
  121.         nq_eval = self.eval.forward(new_state_batch)
  122.         nq_eval[terminal_batch] = 0.0
  123.  
  124.         q_target = reward_batch + self.gamma * T.max(nq_eval, dim=1)[0]
  125.         loss = self.eval.loss(q_target, q_eval).to(self.eval.device)
  126.         loss.backward()
  127.         self.eval.optimizer.step()
  128.  
  129.         self.eps = max(self.eps_min, self.eps - self.eps_dec)
  130.         # end = time.time_ns() / 1e6
  131.         # print(f"Agent learn time: {end - start} ms")
  132.  
  133.     def save(self, file):
  134.         T.save(self.eval.state_dict(), file)
  135.  
  136.     def load(self, file):
  137.         self.eval.load_state_dict(T.load(file))
  138.  
  139.  
  140. def reward(x, y, h):
  141.     global l, eng_l, d
  142.     collision_punish = 100
  143.     R = 6
  144.     r = (x ** 2 + y ** 2) ** 0.5
  145.     if r > R:
  146.         r = R
  147.     done = h < d + l / 2 + eng_l or abs(x) > 20 or abs(y) > 20
  148.     return ((1 - r / R) * 10 + 1) * 0.01 - collision_punish * int(done), done
  149.  
  150.  
  151. def simstep(state, playable=True, action=None):
  152.     # start = time.time_ns()
  153.     global dt, m, g, l, eng_l, d, drag, maxthr, thr_incr, I
  154.     (x, y, xc, yc, angle,
  155.      vx, vy, vxc, vyc, vangle,
  156.      left_thrust, right_thrust, done) = state
  157.  
  158.     # cursor
  159.     prevx = xc
  160.     prevy = yc
  161.     # some code for moving
  162.  
  163.     vxc = (xc - prevx) / dt
  164.     vyc = (yc - prevy) / dt
  165.  
  166.     # forces
  167.     fx = -drag * vx - (left_thrust + right_thrust) * sin(angle)
  168.     fy = - m * g - drag * vy + (left_thrust + right_thrust) * cos(angle)
  169.     torque = (right_thrust - left_thrust) * (l + eng_l) / 2 - drag * vangle * 4
  170.  
  171.     # velocities
  172.     vx += (fx / m) * dt
  173.     vy += (fy / m) * dt
  174.     vangle += (torque / I) * dt
  175.  
  176.     # position and angle
  177.     x += vx * dt
  178.     y += vy * dt
  179.     angle += vangle * dt
  180.     if angle < -pi:
  181.         angle += 2 * pi
  182.     elif angle > pi:
  183.         angle -= 2 * pi
  184.  
  185.     # Engine control
  186.     if playable:
  187.         # Adjust engine thrusts based on key presses
  188.         if pg.key.get_pressed()[pg.K_LEFT]:
  189.             left_thrust += thr_incr
  190.         else:
  191.             left_thrust -= 2 * thr_incr
  192.         if pg.key.get_pressed()[pg.K_RIGHT]:
  193.             right_thrust += thr_incr
  194.         else:
  195.             right_thrust -= 2 * thr_incr
  196.     else:
  197.         '''
  198.        if action in (1, 5):
  199.            left_thrust -= thr_incr
  200.        if action in (2, 5):
  201.            right_thrust -= thr_incr
  202.        if action in (3, 6):
  203.            left_thrust += thr_incr
  204.        if action in (4, 6):
  205.            right_thrust += thr_incr
  206.        '''
  207.         if action == 0:
  208.             left_thrust -= thr_incr
  209.             right_thrust += thr_incr
  210.         elif action == 1:
  211.             left_thrust += thr_incr
  212.             right_thrust -= thr_incr
  213.         elif action == 2:
  214.             left_thrust += thr_incr
  215.             right_thrust += thr_incr
  216.         elif action == 3:
  217.             left_thrust -= thr_incr
  218.             right_thrust -= thr_incr
  219.     left_thrust = max(0, min(left_thrust, maxthr))
  220.     right_thrust = max(0, min(right_thrust, maxthr))
  221.  
  222.     rew, done = reward(xc - x, yc - y, y)
  223.     # end = time.time_ns()
  224.     # print(f"sim time: {end - start} ns")
  225.     return (rew,
  226.             [x, y, xc, yc, angle,
  227.              vx, vy, vxc, vyc, vangle,
  228.              left_thrust, right_thrust, done])
  229.  
  230.  
  231. def get_observation(state):
  232.     global dt
  233.     (x, y, xc, yc, angle,
  234.      vx, vy, vxc, vyc, vangle,
  235.      left_thr, right_thr, done) = state
  236.     return (xc - x, yc - y, y, angle,
  237.             vx, vy, vxc - vx, vyc - vy, vangle,
  238.             left_thr, right_thr)
  239.  
  240.  
  241. def get_observation2(state):
  242.     global dt
  243.     (x, y, xc, yc, angle,
  244.      vx, vy, vxc, vyc, vangle,
  245.      left_thr, right_thr, done) = state
  246.     return (xc - x, yc - y, y, sin(angle), cos(angle), vx, vy, vangle, left_thr, right_thr)
  247.  
  248.  
  249. def render_multi_line(screen, font, text, x, y, color, fsize):
  250.     lines = text.splitlines()
  251.     for i, l in enumerate(lines):
  252.         screen.blit(font.render(l, 1, color), (x, y + fsize * i))
  253.  
  254.  
  255. def drawgrid(cam, step, substeps, wl=1, dark=100, thin=0):
  256.     w, h, scale, x, y = cam
  257.     surf = pg.Surface((w, h), pg.SRCALPHA, 32)
  258.     x -= w / scale / 2
  259.     y -= h / scale / 2
  260.     xstart = floor(x / step) * step - x
  261.     ystart = y - ceil(y / step) * step
  262.     for i in range(ceil(h / step) * (substeps + 1)):
  263.         if ystart + i * step / substeps > h:
  264.             break
  265.         weaken = bool(i % (substeps + 1))
  266.         pg.draw.line(surf, (255 - weaken * dark, 255 - weaken * dark, 255 - weaken * dark),
  267.                      (0, (ystart + i * step / (substeps + 1)) * scale),
  268.                      (w, (ystart + i * step / (substeps + 1)) * scale), wl - weaken * thin)
  269.     for j in range(ceil(w / step) * (substeps + 1)):
  270.         if xstart + j * step / substeps > w:
  271.             break
  272.         weaken = bool(j % (substeps + 1))
  273.         pg.draw.line(surf, (255 - weaken * dark, 255 - weaken * dark, 255 - weaken * dark),
  274.                      ((xstart + j * step / (substeps + 1)) * scale, 0),
  275.                      ((xstart + j * step / (substeps + 1)) * scale, h), wl - weaken * thin)
  276.     return surf
  277.  
  278.  
  279. def cam_coords(cam, x, y):
  280.     w, h, scale, x0, y0 = cam
  281.     x = (x - x0) * scale + w / 2
  282.     y = (y0 - y) * scale + h / 2
  283.     return x, y
  284.  
  285.  
  286. def render(state, score, screen, cam, scale, w, h):
  287.     '''render the drone, its engines, and the ground.
  288.    The camera is centered at (0, 2); 1 unit corresponds to 100px.
  289.    The background is black, the drone is also black woth a white thin outline;
  290.    the engines are also outlined. When they are turned on, little triangles appear,
  291.    which represent air/propellant/whatever. The ground is grey.
  292.    '''
  293.  
  294.     # Clear the screen
  295.     screen.fill((0, 0, 0))
  296.  
  297.     global l, eng_l, d, maxthr
  298.     # Unpack the state
  299.     x, y, xc, yc, angle, vx, vy, vxc, vyc, vangle, left_thrust, right_thrust, done = state
  300.  
  301.     # Draw the ground
  302.     pg.draw.rect(screen, (100, 100, 100), (0, cam_coords(cam, 0, 0)[1], w, h + 1))
  303.  
  304.     # Draw the grid
  305.     grid = drawgrid(cam, 4, 3, 2, thin=1)
  306.     screen.blit(grid, (0, 0))
  307.  
  308.     # Calculate the coordinates relative to the camera
  309.     xc, yc = cam_coords(cam, xc, yc)
  310.  
  311.     # Draw the cursor
  312.     pg.draw.circle(screen, (150, 255, 150), (max(min(xc, w), 0), max(min(yc, h), 0)), 0.25 * scale)
  313.  
  314.     # Draw the drone
  315.     thr_scale = 0.5 * scale
  316.     l_ = l * scale
  317.     eng_l_ = eng_l * scale
  318.     d_ = d * scale
  319.     drone_surf = pg.Surface((l_ + 2 * eng_l_, d_ + 2 * thr_scale), pg.SRCALPHA, 32)
  320.  
  321.     pg.draw.rect(drone_surf, (255, 255, 255), (0, thr_scale, eng_l_, d_), 2)  # left engine
  322.     pg.draw.rect(drone_surf, (255, 255, 255), (eng_l_, thr_scale, l_, d_), 2)  # base
  323.     pg.draw.rect(drone_surf, (255, 255, 255), (l_ + eng_l_, thr_scale, eng_l_, d_), 2)  # right engine
  324.     pg.draw.polygon(drone_surf, (255, 255, 200),
  325.                     [(0, d_ + thr_scale),
  326.                      (eng_l_ // 2, d_ + (1 + left_thrust / maxthr) * thr_scale),
  327.                      (eng_l_, d_ + thr_scale)])  # left flame
  328.     pg.draw.polygon(drone_surf, (255, 255, 200),
  329.                     [(l_ + eng_l_, d_ + thr_scale),
  330.                      (l_ + eng_l_ + eng_l_ // 2, d_ + (1 + right_thrust / maxthr) * thr_scale),
  331.                      (l_ + 2 * eng_l_, d_ + thr_scale)])  # right flame
  332.  
  333.     drone_surf = pg.transform.rotate(drone_surf, angle / pi * 180)
  334.     drone_rect = drone_surf.get_rect()
  335.     drone_rect.center = cam_coords(cam, x, y)
  336.     screen.blit(drone_surf, drone_rect)
  337.  
  338.     # Print information & "HUD"
  339.     # global image
  340.     # screen.blit(image, (0, 500))
  341.  
  342.     winfo = 3
  343.     trnsprt = 180
  344.     hud = pg.Surface((fontsize * 18 + 2 * winfo, fontsize * 8 + 2 * winfo), pg.SRCALPHA, 32)
  345.     pg.draw.rect(hud, (180, 180, 180, trnsprt), (0, 0, fontsize * 18 + 2 * winfo, fontsize * 8 + 2 * winfo))
  346.     pg.draw.rect(hud, (0, 0, 0, trnsprt), (winfo, winfo, fontsize * 18, fontsize * 8))
  347.     render_multi_line(hud, font,
  348.                       f'Coords: ({round(x, 2):.2f}, {round(y, 2):.2f}); angle: {round(angle, 2):.2f}\n'
  349.                       f'Velocity: ({round(vx, 2):.2f}, {round(vy, 2):.2f}); angular: {round(vangle, 2):.2f}\n'
  350.                       f'Thrusters: left: {round(left_thrust, 2):.2f}; right: {round(right_thrust, 2):.2f}\n'
  351.                       f'Score: {score:.2f}',
  352.                       20, 20, (255, 255, 255), fontsize * 2)
  353.     screen.blit(hud, (0, 0))
  354.     return screen
  355.  
  356.  
  357. def plot_progress(x, scores, file):
  358.     plt.scatter(x, scores, s=1 / 4, c=((0.3, 0.6, 0.8),), linewidth=0)
  359.     plt.savefig(file, dpi=300)
  360.  
  361.  
  362. def writedata(file, *args):
  363.     with open(file, "a") as f:
  364.         writer = csv.writer(f)
  365.         writer.writerow(args)
  366.  
  367. def main():
  368.     print(T.cuda.is_available())
  369.     scale = 100
  370.     dronename = "smol"
  371.     if not os.path.exists(dronename):
  372.         os.mkdir(dronename)
  373.     writedata(f"{dronename}/data.csv", ["i", "score", "mean_score"])
  374.     screen = pg.display.set_mode((WIDTH, HEIGHT))
  375.     pg.display.set_caption('Drone thingy')
  376.     clock = pg.time.Clock()
  377.     do_render = False
  378.     playable = False
  379.     n_games = 500000
  380.     ## observation (not-exactly-state): [xc', yc', h, angle, vx, vy, vxc', vxy' vangle, left_thr, right_thr] - 11
  381.     # observation2: [xc', yc', h, sin, cos, vx, vy, vangle, left_thr, right_thr] - 9
  382.     ## actions = (0:nothing, 1:left-, 2:right-, 3:left+, 4:right+, 5:both-, 6:both+) - 7
  383.     # actions2 = (0:left_roll, 1:right_roll, 2:both+, 3:both-) - 4
  384.     drone = Agent(0.995, 1, 0.001, n_state=10, n_actions=4, batch_size=64, n_layers=2, n_neurons=64, eps_dec=1e-5)
  385.     drone.load(f"{dronename}/drone.pt")
  386.     scores, epss = np.array([], dtype=np.float32), np.array([], dtype=np.float32)
  387.     maxcount = 1000
  388.     states = []
  389.     n_states = 10
  390.     for i in range(n_games):
  391.         # [x, y, xc, yc, angle, vx, vy, vxc, vyc, vangle, left_thrust, right_thrust, done]
  392.         state = [(2 * nprand() - 1) * 10, 2 + nprand() * 8,  # x y
  393.                  (2 * nprand() - 1) * 10, -2 + nprand() * 12,  # xc yc
  394.                  pi * (2 * nprand() - 1) * 0.1,  # angle
  395.                  (2 * nprand() - 1) * 1, (1.5 * nprand() - 0.5) * 1,  # vx, vy
  396.                  0, 0,  # vxc', vyc' (have to be initialised even with no actual info
  397.                  pi * (2 * nprand() - 1) * 1,  # vangle
  398.                  maxthr * nprand() * 0, maxthr * nprand() * 0, False]  # thrust, done
  399.         # cam = (WIDTH, HEIGHT, scale, state[0], state[1])
  400.         score = 0
  401.         counter = 0
  402.         while not state[-1]:
  403.             # start = time.time_ns() / 1e6
  404.             for event in pg.event.get():
  405.                 if event.type == pg.QUIT:
  406.                     T.save(drone.eval.state_dict(), f"{dronename}/drone.pt")
  407.                     return
  408.                 if event.type == pg.KEYDOWN:
  409.                     if event.key == pg.K_r:
  410.                         do_render = True
  411.                     elif event.key == pg.K_SPACE:
  412.                         do_render = False
  413.  
  414.             observation = get_observation2(state)
  415.             action = drone.policy(observation)
  416.             reward, state = simstep(state, playable, action)
  417.             score += reward
  418.             next_observation = get_observation2(state)
  419.  
  420.             drone.store_transition(observation, action, reward, next_observation, state[-1])
  421.             drone.learn()
  422.  
  423.             if do_render:
  424.                 cam = (WIDTH, HEIGHT, scale, state[0], max(2, state[1]))
  425.                 screen = render(state, score, screen, cam, scale, WIDTH, HEIGHT)
  426.                 pg.display.flip()
  427.                 clock.tick(60)
  428.  
  429.             if counter > maxcount:
  430.                 state[-1] = True
  431.                 print("EXCEEDED")
  432.                 maxcount += 1
  433.             counter += 1
  434.             # end = time.time_ns() / 1e6
  435.             # print(f"total time: {end - start} ms")
  436.             # print("\n\n")
  437.         scores = np.append(scores, score)
  438.         epss = np.append(epss, drone.eps)
  439.         avg_score = np.mean(scores[max(0, i - 500):i + 1])
  440.         writedata(f"{dronename}/data.csv", [i, score, avg_score])
  441.         if not i % 50:
  442.             print(f'episode {i}:\nscore: {score}\naverage score: {avg_score}\neps: {drone.eps}\n')
  443.             if not i % 1000:
  444.                 x = np.arange(i + 1)
  445.                 plot_progress(x, scores, f"{dronename}/plot_{i // 1000}k.png")
  446.                 if not i % 5000:
  447.                     T.save(drone.eval.state_dict(), f"{dronename}/drone_{i // 1000}k.pt")
  448.  
  449.     drone.save(f"drone_{dronename}.pt")
  450.     return
  451.  
  452.  
  453. # Press the green button in the gutter to run the script.
  454. if __name__ == '__main__':
  455.     main()
  456.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement