Advertisement
cat_baxter

Programming Challenge 54 - Fastest Railroad Route [Python]

Apr 20th, 2012
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.80 KB | None | 0 0
  1. import sys, time, copy, heapq, pprint
  2.  
  3. STATIONS = set()
  4. COSTS = {'+':1, 'S':3}
  5. INF = 1000000000
  6.  
  7. def dijkstra(G, start, end):
  8.     def flatten(L):
  9.         while len(L) > 0:
  10.             yield L[0]
  11.             L = L[1]
  12.     q = [(0, start, ())]
  13.     visited = set()
  14.     while True:
  15.         (cost, v1, path) = heapq.heappop(q)
  16.         if v1 not in visited:
  17.             visited.add(v1)
  18.             if v1 == end:
  19.                 return list(flatten(path))[::-1] + [v1]
  20.             path = (v1, path)
  21.             for (v2, cost2) in G[v1].iteritems():
  22.                 if v2 not in visited:
  23.                     heapq.heappush(q, (cost + cost2, v2, path))
  24.  
  25. def hamiltonian(dist):
  26.     n = len(dist)
  27.     dp = [[INF for _ in range(n)] for _ in range(1 << n)]
  28.     for i in range(n): dp[1 << i][i] = 0
  29.     for mask in range(1<<n):
  30.         for i in range(n):
  31.             if mask & (1 << i):
  32.                 for j in range(n):
  33.                     if mask & (1 << j):
  34.                         dp[mask][i] = min(dp[mask][i], dp[mask ^ 1 << i][j] + dist[j][i])
  35.     res = INF
  36.     v   = 0
  37.     for i in range(n):
  38.         value = dp[(1 << n) - 1][i]
  39.         if value < res:
  40.             res = value
  41.             v = i
  42.  
  43.     order = n*[0]
  44.     cur = (1 << n) - 1
  45.     cur ^= 1 << v
  46.     rest = res
  47.     order[n - 1] = v
  48.     for i in range(n - 2, -1, -1):
  49.         for j in range(n):
  50.             if cur & 1 << j:
  51.                if rest == dp[cur][j] + dist[j][v]:
  52.                    rest -= dist[j][v]
  53.                    v = j
  54.                    order[i] = v
  55.                    cur ^= 1 << v
  56.                    break
  57.     return res, order
  58.  
  59. def print_track(content):
  60.     s1 = s2 = ' '*3
  61.     last = '0'
  62.     for i in range(len(content[0])):
  63.         s = str(i).zfill(2)
  64.         if s[0] != last:
  65.             s1 += s[0]
  66.             last = s[0]
  67.         else:
  68.             s1 += ' '
  69.         s2 += s[1]
  70.     print s1
  71.     print s2
  72.     no = 0
  73.     for line in content:
  74.         print str(no).zfill(2), line
  75.         no += 1
  76.  
  77. def get_nexts(data, x, y, visited, stack, STATIONS):
  78.     res = []
  79.     moves = [(x,y-1), (x+1,y-1), (x+1,y), (x+1,y+1), (x,y+1), (x-1,y+1), (x-1,y), (x-1,y-1)]
  80.     for cx,cy in moves:
  81.         if (cx,cy) not in visited:
  82.             d = data[cy][cx]
  83.             if d in STATIONS or d == 'S':
  84.                 res.append((data[cy][cx],cx,cy))
  85.                 break
  86.     if not len(res):
  87.         for cx,cy in moves:
  88.             if (cx,cy) not in visited and (cx,cy) not in stack:
  89.                d = data[cy][cx]
  90.                if d != '.':
  91.                    res.append((data[cy][cx],cx,cy))
  92.     return res
  93.  
  94. def bfs(data, station, x, y, STATIONS):
  95.     g = {}
  96.     stack = set()
  97.     q = [((x,y), station, [])]
  98.     stack.add((x,y))
  99.     visited = set()
  100.     while len(q):
  101.         (x, y), vertex, l = q.pop(0)
  102.         if (x, y) not in visited:
  103.             visited.add((x, y))
  104.             for (v2, x2, y2) in get_nexts(data, x, y, visited, stack, STATIONS):
  105.                 new_len = l + [v2]
  106.                 if v2 != '+':
  107.                     if v2 == 'S':
  108.                         new_vertex = 'S%d_%d'%(x2,y2)
  109.                     else:
  110.                         new_vertex = v2
  111.                     g.setdefault(vertex, {})[new_vertex] = sum([COSTS[e] for e in new_len])
  112.                     vertex = new_vertex
  113.                     new_len = []
  114.                 if (x2, y2) not in visited:
  115.                     q.append(((x2, y2), vertex, new_len))
  116.                     stack.add((x2,y2))
  117.     return g
  118.  
  119. def find_coordinates(data, station):
  120.     x = y = 0
  121.     for row in data:
  122.         if station in row:
  123.            x = row.index(station)
  124.            break
  125.         else:
  126.            y += 1
  127.     return x, y
  128.  
  129. def main():
  130.     start = time.time()
  131.     content = open(sys.argv[1]).read().split()
  132.  
  133.     print_track(content)
  134.  
  135.     # find all stations in text file
  136.     STATIONS = set()
  137.     for row in content:
  138.         for e in row:
  139.             if e not in ['+','.','S']:
  140.                STATIONS.add(e)
  141.  
  142.     for s in STATIONS: COSTS[s]=5
  143.     data = [list(content[i]) for i in range(len(content))]
  144.  
  145.     # build the graph using bfs
  146.     G = {}
  147.     for station in sorted(STATIONS):
  148.         x, y = find_coordinates(data, station)
  149.         SG = bfs(data, station, x, y, STATIONS)
  150.         for k,v in SG.items():
  151.             for sv in v:
  152.                  value = SG[k][sv]
  153.                  G.setdefault(k,{})[sv] = value
  154.  
  155.     N = len(G.keys())
  156.     d = [[INF for _ in xrange(N)] for _ in xrange(N)]
  157.  
  158.     SYMBOLS = G.keys()
  159.     for k,v in G.items():
  160.         for sv in v:
  161.             d[SYMBOLS.index(k)][SYMBOLS.index(sv)] = G[k][sv]
  162.  
  163.     # Floyd Warshall
  164.     floyd = copy.deepcopy(d)
  165.     for k in xrange(N):
  166.         for i in xrange(N):
  167.             for j in xrange(N):
  168.                 if i != j:
  169.                     floyd[i][j] = min(floyd[i][j], floyd[i][k] + floyd[k][j])
  170.  
  171.     # calc paths between only all stations
  172.     N = len(STATIONS)
  173.     STATION_LIST = list(STATIONS)
  174.     m = [[INF for _ in xrange(N)] for _ in xrange(N)]
  175.     for src in STATIONS:
  176.         for dst in STATIONS:
  177.             if src != dst:
  178.                 m[STATION_LIST.index(src)][STATION_LIST.index(dst)] = floyd[SYMBOLS.index(src)][SYMBOLS.index(dst)]
  179.  
  180.     # calc shortest hamiltonian path
  181.     shortest_path_len, path = hamiltonian(m)
  182.     real_path = []
  183.  
  184.     for i in range(0,len(path)-1):
  185.         src = STATION_LIST[path[i]]
  186.         dst = STATION_LIST[path[i+1]]
  187.         p = dijkstra(G, src, dst)
  188.         calc_size = 0
  189.         for i in range(len(p)-1):
  190.             calc_size += G[p[i]][p[i+1]]
  191.         real_path.append((src, dst, calc_size, p))
  192.  
  193.     for src, dst, calc_size, p in real_path:
  194.         print src,' -> ',dst,' = ', calc_size, ' path = ', p
  195.  
  196.     print "Shortest path = ", shortest_path_len
  197.     print "Time (s):", time.time()-start
  198.  
  199. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement