Emania

Untitled

Jan 13th, 2018
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.88 KB | None | 0 0
  1.  
  2.  
  3. import random
  4. import pickle
  5. import sys
  6. from pprint import pprint
  7. import time
  8. import numpy as np
  9. GAMMA = 0.9
  10. GRID = None
  11.  
  12. class State:
  13.     def __init__(self, evader, pursuer1, pursuer2, grid=None):
  14.         self.evader = evader
  15.         self.pursuer1 = pursuer1
  16.         self.pursuer2 = pursuer2
  17.         if grid is not None:
  18.             self.V = 0
  19.             # key = (evader, pursuer1, pursuer2)
  20.             # self.my_actions = actionsE(grid, key)
  21.             # self.op_actions = actionsP(grid, key)
  22.  
  23.     def get_my_actions(self):
  24.         key = (self.evader, self.pursuer1, self.pursuer2)
  25.         return actionsE(GRID, key)
  26.  
  27.     def get_op_actions(self):
  28.         key = (self.evader, self.pursuer1, self.pursuer2)
  29.         return actionsP(GRID, key)
  30.  
  31.     def getV0(self):
  32.         if self.isEnd():
  33.             return -1
  34.         else:
  35.             return 1
  36.  
  37.     def isEnd(self):
  38.         self.evader == self.pursuer1 or self.evader == self.pursuer2
  39.  
  40.     def getV(self):
  41.         return self.V
  42.  
  43.  
  44.     def present(self):
  45.         print("[" + str(self.evader) + "|" + str(self.pursuer1) + str(self.pursuer2) +
  46.               str(self.V)) #+ str(self.best_action_evader) + "|" + str(self.best_action_pursuer))
  47.  
  48.     def computeQ(self, states, myA, opA):
  49.         p1s, p2s = opA
  50.         es = myA
  51.         action = sort_key((es, p1s, p2s))
  52.         s = states[action]      # transferred state
  53.         return self.getV0() + GAMMA * s.getV()
  54.  
  55.  
  56.     def updateV(self, states):
  57.         my_actions = self.get_my_actions()
  58.         op_actions = self.get_op_actions()
  59.  
  60.         if len(my_actions) == 0:
  61.             self.best_action_evader = (-1, -1)
  62.             return
  63.         if self.isEnd():
  64.             self.V = self.getV0()
  65.             return
  66.  
  67.         my_max = -10000
  68.         for m in my_actions:
  69.             op_min = 10000
  70.             for o in op_actions:
  71.                 q = self.computeQ(states, m, o)
  72.                 if q < op_min:
  73.                     op_min = q
  74.             if op_min > my_max:
  75.                 my_max = op_min
  76.         self.V = my_max
  77.  
  78.  
  79.     def set_best_action_evader(self, states):
  80.         best_value = -10000
  81.         best_action = (-1, -1)
  82.         for a in self.get_my_actions():
  83.             e = a
  84.             p1, p2 = self.pursuer1, self.pursuer2
  85.             key = sort_key((e, p1, p2))
  86.             s = states[key]
  87.             if s.V > best_value:
  88.                 best_value = s.V
  89.                 best_action = a
  90.         self.best_action_evader = best_action
  91.  
  92.  
  93.     def set_best_action_pursuer(self, states):
  94.         best_value = 10000
  95.         best_action = ((-1, -1), (-1, -1))
  96.         for a in self.op_actions:
  97.             p1, p2 = a
  98.             e = self.evader
  99.             key = sort_key((e, p1, p2))
  100.             s = states[key]
  101.             if s.V < best_value:
  102.                 best_value = s.V
  103.                 best_action = a
  104.         self.best_action_pursuer = best_action
  105.  
  106.  
  107. class ValueIteration:
  108.     def __init__(self, grid, n):
  109.         global GRID
  110.         print("ValueIteration")
  111.         GRID = grid
  112.         PIK = "pacman.policy"
  113.  
  114.         try:
  115.             with open(PIK, "rb") as f:
  116.                 print("opening a pickle " + PIK + "...")
  117.                 arr = pickle.load(f)
  118.                 self.states = self.decompress(arr)
  119.                 print("pickle loaded")
  120.         except:
  121.             self.learn(grid, PIK)
  122.         print("DONE")
  123.  
  124.     def create_states(self, grid):
  125.         self.states = dict()
  126.         for e in grid.all_nodes():
  127.             sys.stdout.write("\r creating state " + str(e))
  128.             sys.stdout.flush()
  129.             for p1 in grid.all_nodes():
  130.                 for p2 in grid.all_nodes():
  131.                     p1s, p2s = sort(p1, p2)
  132.                     if (e, p1s, p2s) in self.states:
  133.                         continue
  134.                     e = (np.int8(e[0]), np.int8(e[1]))
  135.                     p1s = (np.int8(p1s[0]), np.int8(p1s[1]))
  136.                     p2s = (np.int8(p2s[0]), np.int8(p2s[1]))
  137.                     self.states[(e, p1s, p2s)] = State(e, p1s, p2s, grid)
  138.         print()
  139.  
  140.     def learn(self, grid, PIK):
  141.         print("computing a policy")
  142.         self.grid = grid
  143.         print("creating states...")
  144.         self.create_states(grid)
  145.         print("created %d states" % len(self.states))
  146.         start_time = time.time()
  147.         total_time = 60 * 10 * 10000
  148.         i = 0
  149.         while time.time() - start_time < total_time and i < 20:
  150.             sys.stdout.write(
  151.                 "\r updating value, run %4d,  %5ds / %5ds" % (i, int(time.time() - start_time), total_time))
  152.             sys.stdout.flush()
  153.             self.updateAllV()
  154.             i += 1
  155.         print()
  156.         for s in self.states.values():
  157.             s.set_best_action_evader(self.states)
  158.             s.set_best_action_pursuer(self.states)
  159.         print("compressing...")
  160.         data = self.compress()
  161.         with open(PIK, "wb") as f:
  162.             pickle.dump(data, f)
  163.         print("policy computed and saved")
  164.  
  165.     def updateAllV(self):
  166.         keys = self.states.keys()
  167.         random.shuffle(keys)
  168.         i = 0
  169.         for key in keys:
  170.             s = self.states[key]
  171.             s.updateV(self.states)
  172.             i += 1
  173.         del keys
  174.  
  175.     def get_best_action_evader(self, e, p1, p2):
  176.         key = sort_key((e, p1, p2))
  177.         return self.states[key].best_action_evader
  178.  
  179.     def get_best_action_pursuer(self, e, my_p, p2):
  180.         key = sort_key((e, my_p, p2))
  181.         a1, a2 = self.states[key].best_action_pursuer
  182.         if my_p == key[1]:
  183.             return a1
  184.         else:
  185.             return a2
  186.  
  187.     def compress(self):
  188.         new_arr = np.zeros((len(self.states), 12), dtype=np.int8)
  189.         i = 0
  190.  
  191.         for s in self.states.values():
  192.             del s.evader
  193.             del s.pursuer1
  194.             del s.pursuer2
  195.             del s.my_actions
  196.             del s.op_actions
  197.             del s.V
  198.  
  199.         for k in self.states.keys():
  200.             s = self.states[k]
  201.             a = [k[0][0], k[0][1],
  202.                  k[1][0], k[1][1],
  203.                  k[2][0], k[2][1],
  204.                  s.best_action_evader[0], s.best_action_evader[1],
  205.                  s.best_action_pursuer[0][0], s.best_action_pursuer[0][1],
  206.                  s.best_action_pursuer[1][0], s.best_action_pursuer[1][1]]
  207.             new_arr[i, :] = a
  208.             i += 1
  209.         return new_arr
  210.  
  211.     def decompress(self, arr):
  212.         print("decompressing...")
  213.         print(arr)
  214.         states = dict()
  215.         print(type(arr))
  216.         (n, _) = arr.shape
  217.         arr = arr.astype(np.int)
  218.         print("n = " + str(n))
  219.         for i in range(n):
  220.             s = arr[i, :]
  221.             ex, ey = s[0], s[1]
  222.             p1x, p1y, p2x, p2y = s[2], s[3], s[4], s[5]
  223.             bex, bey = s[6], s[7]
  224.             bp1x, bp1y, bp2x, bp2y = s[8], s[9], s[10], s[11]
  225.             state = State((ex, ey), (p1x, p1y), (p2x, p2y))
  226.             state.best_action_evader = (bex, bey)
  227.             state.best_action_pursuer = ((bp1x, bp1y), (bp2x, bp2y))
  228.             states[((ex, ey), (p1x, p1y), (p2x, p2y))] = state
  229.         return states
  230.         print("decompressed")
  231.  
  232.  
  233.  
  234. def sort(id1, id2):
  235.     x1, y1 = id1
  236.     x2, y2 = id2
  237.     if x1 < x2 or (x1 == x2 and y1 < y2):
  238.         return id1, id2
  239.     else:
  240.         return id2, id1
  241.  
  242.  
  243. def sort_key(key):
  244.     e, p1, p2 = key
  245.     p1s, p2s = sort(p1, p2)
  246.     return e, p1s, p2s
  247.  
  248.  
  249. def actions(grid, s0):
  250.     a = grid.neighbors4(s0)
  251.     return filter(lambda s: grid.passable(s), a)
  252.  
  253. def actionsE(grid, key):  # returns actions for eaver, that are passable and don't go to persuers position
  254.     e, p1, p2 = key
  255.     return filter(lambda s: s != p1 and s != p2, actions(grid, e))
  256.  
  257.  
  258. def actionsP(grid, key):  # returns all possible tuples of actions
  259.     _, p1, p2 = key
  260.     a = []
  261.     for a1 in actions(grid, p1):
  262.         for a2 in actions(grid, p2):
  263.             a.append((a1, a2))
  264.     return a
Advertisement
Add Comment
Please, Sign In to add comment