Advertisement
Guest User

Untitled

a guest
Jun 5th, 2013
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.43 KB | None | 0 0
  1. from collections import deque
  2. from heapq import heappush, heappop, heappushpop
  3. import time
  4. #SIZE = 218#12
  5. inf = float('inf')
  6.  
  7. class Map:
  8.     def __init__(self, size):
  9.         self.size = size
  10.  
  11.     def cost(self, pos):
  12.         x, y = pos
  13.         return x*x + y*y + abs(x)
  14.  
  15.     def valid(self, pos):
  16.         x, y = pos
  17.         return abs(x) + abs(y) < self.size
  18.  
  19. # Goal: get most treasure after 10 steps
  20. class NormalAgent:
  21.     def __init__(self):
  22.         self.actions = [
  23.                 (1, 0),
  24.                 (0, 1),
  25.                 (-1, 0),
  26.                 (0, -1)]
  27.  
  28. class Option:
  29.     FIRST_HALF = 1
  30.  
  31. class OptionAgent(NormalAgent):
  32.     def __init__(self):
  33.         NormalAgent.__init__(self)
  34.         self.actions.append(Option.FIRST_HALF)
  35.  
  36. class Path:
  37.     def __init__(self, nodes, accumulated):
  38.         self.nodes = nodes
  39.         self.accumulated = accumulated
  40.     def __repr__(self):
  41.         return 'Path:{'+str(self.nodes)+"}"
  42.  
  43. #def search_alg_generator(select_next):
  44. #    def run(agent, themap, startloc):
  45. #        visited = set()
  46. #        best_solution, best_cost = None, inf
  47. #        q = deque()
  48. #        path = Path([startloc], themap.cost(startloc))
  49. #        q.append(path)
  50. #        i = 0
  51. #        thresh = 1
  52. #        thresh_num = 1
  53. #        while q:
  54. #            i += 1
  55. #            path = select_next(q)
  56. #            node = path.nodes[-1]
  57. #            visited.add(node)
  58. #            l = len(path.nodes)
  59. #            if l > thresh:
  60. #                t = i - thresh_num
  61. #                print 'new thresh', l, t
  62. #                thresh = l
  63. #                thresh_num = i
  64. #            if len(path.nodes) >= SIZE:
  65. #                if path.accumulated < best_cost:
  66. #                    best_solution, best_cost = path, path.accumulated
  67. #                if best_solution.nodes[-1] == (SIZE-1,0):
  68. #                    break # vs continue changes things here...
  69. #                continue # don't build paths longer than SIZE
  70. #            for action in agent.actions:
  71. #                if action == Option.FIRST_HALF:
  72. #                    if node != (0,0):
  73. #                        continue
  74. #                    new_path = Path(path.nodes + [(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)], 70)
  75. #                    nodes = [(x,0) for x in range(SIZE/2)]
  76. #                    tcost = sum(themap.cost(l) for l in nodes)
  77. #                    q.append(Path(nodes, tcost))
  78. #                else:
  79. #                    new_node = tuple(map(sum, zip(action, node)))
  80. #                    if themap.valid(new_node) and new_node not in visited:
  81. #                        new_path = Path(path.nodes + [new_node],
  82. #                                path.accumulated + themap.cost(new_node))
  83. #                        q.append(new_path)
  84. #        return best_solution, i
  85. #    return run
  86.  
  87. #dfs = search_alg_generator(lambda q: q.pop())
  88. #bfs = search_alg_generator(lambda q: q.popleft())
  89. def estimator(themap, nodes):
  90.     steps_left = themap.size - len(nodes)
  91.     return steps_left # relaxation of problem. all locs are positive, so we'll have at least this many..
  92.  
  93. def heappush_node(q, themap, node, prev=None):
  94.     """prev is a Path"""
  95.     if prev is None:
  96.         prev = Path([], 0)
  97.     gval = themap.cost(node)
  98.     nodes = prev.nodes + [node]
  99.     path = Path(nodes, prev.accumulated + gval)
  100.     hval = estimator(themap, nodes)
  101.     fval = hval + gval
  102.     heappush(q, (fval, Path(nodes, gval)))
  103.  
  104. def astar(agent, themap, startloc):
  105.     solution = None
  106.     q = [] #priority queue on f
  107.     visited = set()
  108.     heappush_node(q, themap, startloc)
  109.     i = 0
  110.     while q:
  111.         i += 1
  112.         heapval, path = heappop(q)
  113.         node = path.nodes[-1]
  114.         if len(path.nodes) >= themap.size:
  115.             solution = path
  116.             break
  117.         if node not in visited:
  118.             visited.add(node)
  119.         for action in agent.actions:
  120.             if action == Option.FIRST_HALF:
  121.                 if node != (0,0):
  122.                     continue
  123.                 nodes = [(0,0), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (1,-1), (2,-1)]
  124.                 gval = sum(themap.cost(z) for z in nodes)
  125.                 hval = estimator(themap, nodes)
  126.                 fval = gval + hval
  127.                 heappush(q, (fval, Path(nodes, gval)))
  128.                 # NOTE stops needless branching. kind of a hack, but works for our purposes.
  129.                 agent.actions.pop()
  130.             else:
  131.                 new_node = tuple(map(sum, zip(action, node)))
  132.                 if new_node in path.nodes: # can't revisit the same spot
  133.                     continue
  134.                 heappush_node(q, themap, new_node, path)
  135.     return solution, i
  136.  
  137. if __name__ == '__main__':
  138. #    print '\tnoopt\topt'
  139.     for SIZE in [5, 10,15, 20,30,40,50,100,150,200,250,300,330]:
  140.         #print '%i\t' %SIZE,
  141.         themap =  Map(SIZE)
  142.         for cls, tag in [(NormalAgent, 'nooptions'),
  143.                 (OptionAgent, 'options')]:
  144.             agent = cls()
  145.             t = time.time()
  146.             path, i = astar(agent, themap, (0,0))
  147.             #path, i  = bfs(agent, themap, (0,0))
  148.             dt = time.time() - t
  149.             #print time.strftime('%a %b  %d %H:%M:%S %Z %Y')
  150.             #print '%i\t' %i,
  151.             print SIZE,tag, 'cost:', path.accumulated,'loops:', i
  152.             #print path.nodes
  153.             #print path.accumulated
  154.             #print 'runtime %.3fs' % dt
  155.             #print '%d loop iterations' % i
  156.             #print
  157.         #print
  158.     #print
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement