Advertisement
zholnin

Hungarian algorithm for perfect allocation - Python

Apr 10th, 2014
110
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.26 KB | None | 0 0
  1. import sys
  2.  
  3. class PerfectAllocation:
  4.     def __init__(self, table):
  5.         self.table = table
  6.         self.zerorows = {}
  7.         self.zerocolumns = {}
  8.         self.matchingrows = {}
  9.         self.matchingcolumns = {}
  10.         self.N = len(table)
  11.         self.rows = []
  12.         self.columns = []
  13.         for x in range(self.N):
  14.             self.zerorows[x] = []
  15.             self.zerocolumns[x] = []
  16.             self.columns.append([])
  17.         for x in range(self.N):
  18.             self.rows.append(table[x][:])
  19.             for y in range(self.N):
  20.                 self.columns[y].append(table[x][y])
  21.                 if table[x][y] == 0:
  22.                     self.zerorows[x].append(y)
  23.                     self.zerocolumns[y].append(x)
  24.                    
  25.    
  26.     def adjustCell(self, x,y,v):
  27.         t = self.columns[x][y] + v
  28.         self.columns[x][y] = t
  29.         self.rows[y][x] = t
  30.         if t == 0:
  31.             self.zerorows[y].append(x)
  32.             self.zerocolumns[x].append(y)
  33.         elif t == v:
  34.             self.zerorows[y].remove(x)
  35.             self.zerocolumns[x].remove(y)
  36.            
  37.     def match(self, x, y):
  38.         self.matchingrows[x] = y
  39.         self.matchingcolumns[y] = x
  40.  
  41.     def unmatch(self, x, y):
  42.         if self.matchingrows[x] == y:
  43.             del self.mathingrows[x]
  44.         if self.matchingcolumns[y] == x:
  45.             del self.matchingcolumns[y]
  46.          
  47.     def step0(self): #reduce
  48.         for x in range(self.N):
  49.             a = min(self.rows[x])
  50.             if a > 0:
  51.                 for i in range(self.N):
  52.                     self.adjustCell(i, x, -a)
  53.  
  54.         for x in range(self.N):
  55.             a = min(self.columns[x])
  56.             if a > 0:
  57.                 for i in range(self.N):
  58.                     self.adjustCell(x, i, -a)
  59.    
  60.     def step01(self): #initial_match
  61.         for x in range(self.N):
  62.             if x not in self.matchingrows:
  63.                 for y in self.zerorows[x]:
  64.                     if y not in self.matchingcolumns:
  65.                         self.match(x,y)
  66.                         break
  67.    
  68.     def step1(self): #check if finished
  69.         if len(self.matchingrows) == self.N:
  70.             S = 0
  71.             for x in self.matchingrows:
  72.                 S = S + self.table[x][self.matchingrows[x]]
  73.             return True
  74.         return False
  75.    
  76.     def step2(self): #iteratively improve graph
  77.         def NextOption(x):
  78.             R = self.zerorows[self.matchingcolumns[x]][:]
  79.             R.remove(x)
  80.             return R
  81.                        
  82.         path = []
  83.         Done = sys.maxint
  84.         for x in self.zerorows:
  85.             if x not in self.matchingrows:
  86.                 for y in self.zerorows[x]:
  87.                     path.append([x, y])
  88.                     if y not in self.matchingcolumns:
  89.                         Done = 2
  90.         thislevelnodes = []
  91.         thislevel = 2
  92.         while len(path) > 0 and Done > len(path[0]):
  93.             x = path.pop(0)
  94.             if len(x) > thislevel:
  95.                 thislevel += 1
  96.                 thislevelnodes = []
  97.             nextx = NextOption(x[-1])
  98.             t = 0
  99.             while t <  len(nextx):
  100.                 if nextx[t] in thislevelnodes or nextx[t] in x[1:]:
  101.                     nextx.pop(t)
  102.                     continue
  103.                 else:
  104.                     path.append(x[:] + [nextx[t]])
  105.                     thislevelnodes.append(nextx[t])
  106.                     if nextx[t] not in self.matchingcolumns:
  107.                         Done = len(path[-1])
  108.                 t += 1
  109.        
  110.         duplR = []
  111.         duplL = []    
  112.        
  113.         if len(path) == 0:
  114.             return False
  115.        
  116.         for x in path:
  117.             if x[-1] in self.matchingcolumns:
  118.                 continue
  119.             if x[0] in duplL:
  120.                 continue
  121.            
  122.             duplL.append(x[0])
  123.             for t in range(1, len(x)):
  124.                 if x[t] not in duplR:
  125.                     duplR.append(x[t])
  126.                 else:
  127.                     break
  128.             else:
  129.                 if len(x) > 2:
  130.                     for t in range(len(x) - 1, 1, -1):
  131.                         self.match(self.matchingcolumns[x[t-1]], x[t])
  132.                         self.unmatch(self.matchingcolumns[x[t-1]], x[t-1])
  133.                 self.match(x[0], x[1])
  134.        
  135.         return True
  136.                    
  137.     def step345(self):
  138.         L = []
  139.         OL = range(0, self.N)
  140.         R = []
  141.         UL = []
  142.         UR = []
  143.        
  144.         for x in self.zerorows:
  145.             if x not in self.matchingrows:
  146.                 L.append(x)
  147.                 OL.remove(x)
  148.                 UL.append(x)
  149.  
  150.         while len(UL) > 0:
  151.             for x in UL:
  152.                 for y in self.zerorows[x]:
  153.                     if y not in R:
  154.                         R.append(y)
  155.                         UR.append(y)
  156.             UL = []
  157.             for x in UR:
  158.                 if self.matchingcolumns[x] not in L:
  159.                     L.append(self.matchingcolumns[x])
  160.                     OL.remove(self.matchingcolumns[x])
  161.                     UL.append(self.matchingcolumns[x])
  162.             UR = []
  163.        
  164.         NL = L
  165.         L = OL
  166.  
  167.         NR = [x for x in range(self.N) if x not in R]
  168.         m = sys.maxint
  169.         for x in NL:
  170.             for y in NR:
  171.                 if self.rows[x][y] < m:
  172.                     m = self.rows[x][y]
  173. #                m = min(m, self.rows[x][y])
  174. #                if m == None or self.rows[x][y] < m:
  175. #                    m = self.rows[x][y]
  176.         for x in L:
  177.             for y in R:
  178.                 self.adjustCell(y, x, m)
  179.                
  180.         for x in NL:
  181.             for y in NR:
  182.                 self.adjustCell(y, x, -m)    
  183.                
  184.     def step6(self):
  185.         s = 0
  186.         for x in self.matchingrows:
  187.             y = self.matchingrows[x]
  188.             s += self.table[x][y]
  189.         return s
  190.    
  191.     def solve(self):
  192.         self.step0()
  193.         self.step01()
  194.         self.step1()
  195.         while 1:
  196.             while self.step2():
  197.                 if len(self.matchingrows) == self.N:
  198.                     return self.step6()
  199.             if len(self.matchingrows) == self.N:
  200.                 return self.step6()
  201.             self.step345()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement