Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from collections import deque
- from heapq import heappush, heappop, heappushpop
- import time
- #SIZE = 218#12
- inf = float('inf')
- class Map:
- def __init__(self, size):
- self.size = size
- def cost(self, pos):
- x, y = pos
- return x*x + y*y + abs(x)
- def valid(self, pos):
- x, y = pos
- return abs(x) + abs(y) < self.size
- # Goal: get most treasure after 10 steps
- class NormalAgent:
- def __init__(self):
- self.actions = [
- (1, 0),
- (0, 1),
- (-1, 0),
- (0, -1)]
- class Option:
- FIRST_HALF = 1
- class OptionAgent(NormalAgent):
- def __init__(self):
- NormalAgent.__init__(self)
- self.actions.append(Option.FIRST_HALF)
- class Path:
- def __init__(self, nodes, accumulated):
- self.nodes = nodes
- self.accumulated = accumulated
- def __repr__(self):
- return 'Path:{'+str(self.nodes)+"}"
- #def search_alg_generator(select_next):
- # def run(agent, themap, startloc):
- # visited = set()
- # best_solution, best_cost = None, inf
- # q = deque()
- # path = Path([startloc], themap.cost(startloc))
- # q.append(path)
- # i = 0
- # thresh = 1
- # thresh_num = 1
- # while q:
- # i += 1
- # path = select_next(q)
- # node = path.nodes[-1]
- # visited.add(node)
- # l = len(path.nodes)
- # if l > thresh:
- # t = i - thresh_num
- # print 'new thresh', l, t
- # thresh = l
- # thresh_num = i
- # if len(path.nodes) >= SIZE:
- # if path.accumulated < best_cost:
- # best_solution, best_cost = path, path.accumulated
- # if best_solution.nodes[-1] == (SIZE-1,0):
- # break # vs continue changes things here...
- # continue # don't build paths longer than SIZE
- # for action in agent.actions:
- # if action == Option.FIRST_HALF:
- # if node != (0,0):
- # continue
- # new_path = Path(path.nodes + [(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)], 70)
- # nodes = [(x,0) for x in range(SIZE/2)]
- # tcost = sum(themap.cost(l) for l in nodes)
- # q.append(Path(nodes, tcost))
- # else:
- # new_node = tuple(map(sum, zip(action, node)))
- # if themap.valid(new_node) and new_node not in visited:
- # new_path = Path(path.nodes + [new_node],
- # path.accumulated + themap.cost(new_node))
- # q.append(new_path)
- # return best_solution, i
- # return run
- #dfs = search_alg_generator(lambda q: q.pop())
- #bfs = search_alg_generator(lambda q: q.popleft())
- def estimator(themap, nodes):
- steps_left = themap.size - len(nodes)
- return steps_left # relaxation of problem. all locs are positive, so we'll have at least this many..
- def heappush_node(q, themap, node, prev=None):
- """prev is a Path"""
- if prev is None:
- prev = Path([], 0)
- gval = themap.cost(node)
- nodes = prev.nodes + [node]
- path = Path(nodes, prev.accumulated + gval)
- hval = estimator(themap, nodes)
- fval = hval + gval
- heappush(q, (fval, Path(nodes, gval)))
- def astar(agent, themap, startloc):
- solution = None
- q = [] #priority queue on f
- visited = set()
- heappush_node(q, themap, startloc)
- i = 0
- while q:
- i += 1
- heapval, path = heappop(q)
- node = path.nodes[-1]
- if len(path.nodes) >= themap.size:
- solution = path
- break
- if node not in visited:
- visited.add(node)
- for action in agent.actions:
- if action == Option.FIRST_HALF:
- if node != (0,0):
- continue
- nodes = [(0,0), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (1,-1), (2,-1)]
- gval = sum(themap.cost(z) for z in nodes)
- hval = estimator(themap, nodes)
- fval = gval + hval
- heappush(q, (fval, Path(nodes, gval)))
- # NOTE stops needless branching. kind of a hack, but works for our purposes.
- agent.actions.pop()
- else:
- new_node = tuple(map(sum, zip(action, node)))
- if new_node in path.nodes: # can't revisit the same spot
- continue
- heappush_node(q, themap, new_node, path)
- return solution, i
- if __name__ == '__main__':
- # print '\tnoopt\topt'
- for SIZE in [5, 10,15, 20,30,40,50,100,150,200,250,300,330]:
- #print '%i\t' %SIZE,
- themap = Map(SIZE)
- for cls, tag in [(NormalAgent, 'nooptions'),
- (OptionAgent, 'options')]:
- agent = cls()
- t = time.time()
- path, i = astar(agent, themap, (0,0))
- #path, i = bfs(agent, themap, (0,0))
- dt = time.time() - t
- #print time.strftime('%a %b %d %H:%M:%S %Z %Y')
- #print '%i\t' %i,
- print SIZE,tag, 'cost:', path.accumulated,'loops:', i
- #print path.nodes
- #print path.accumulated
- #print 'runtime %.3fs' % dt
- #print '%d loop iterations' % i
- #print
- #print
- #print
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement