Advertisement
lackofcheese

creamsteak.py

Sep 9th, 2011
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.43 KB | None | 0 0
  1. import numpy as np
  2. import Queue
  3.  
  4. def init_marks(m):
  5.     marks = np.zeros(m.shape, dtype=int)
  6.     rowMaxPosns = m.argmax(1)
  7.     colMaxPosns = m.argmax(0)
  8.     collis = []
  9.     for i, p in enumerate(rowMaxPosns):
  10.         add_mark(marks, (i, p), collis)
  11.     for i, p in enumerate(colMaxPosns):
  12.         add_mark(marks, (p, i), collis)
  13.     return marks, collis
  14.  
  15. def add_mark(m, pos, collis):
  16.     m[pos] += 1
  17.     if m[pos] > 1:
  18.         collis.append(pos)
  19.  
  20. def move_mark(m, marks, markPos):
  21.     n = m.shape[0]
  22.     done = np.zeros(m.shape, dtype = bool)
  23.     done[markPos] = True
  24.     todo = Queue.Queue()
  25.     todo.put((True, markPos))
  26.     todo.put((False, markPos))
  27.     maxVal = -np.inf
  28.    
  29.     while not todo.empty():
  30.         doRow, pos = todo.get()
  31.         for i in xrange(n):
  32.             newPos = (pos[0], i) if doRow else (i, pos[1])
  33.             if newPos == pos or done[newPos]:
  34.                 continue
  35.             done[newPos] = True
  36.             if marks[newPos] > 0:
  37.                 todo.put((not doRow, newPos))
  38.             else:
  39.                 newVal = m[newPos]
  40.                 if newVal > maxVal:
  41.                     maxVal = newVal
  42.                     maxPos = newPos          
  43.     marks[markPos] -= 1
  44.     marks[maxPos] += 1
  45.  
  46. def creamsteak(m):
  47.     n = m.shape[0]
  48.     if n != m.shape[1]:
  49.         raise Exception('Not square!')
  50.     if n == 1:
  51.         raise Exception('1x1 case has no solution')
  52.     marks, collis = init_marks(m)
  53.     for pos in collis:
  54.         move_mark(m, marks, pos)
  55.     total = 0
  56.     for i in xrange(n):
  57.         for j in xrange(n):
  58.             if marks[i, j] == 0:
  59.                 pass
  60.             elif marks[i, j] == 1:
  61.                 total += m[i, j]
  62.             else:
  63.                 raise Exception('OH NOES!')
  64.     return total, marks
  65.  
  66. def pretty_test(m):
  67.     print
  68.     print "------------------BEGIN TEST----------------"
  69.     print "Input matrix:"
  70.     print repr(m)
  71.     print "--------------------------------------------"
  72.     try:
  73.         total, marks = creamsteak(m)
  74.         print "Choice of elements (1=chosen, 0=not chosen):"
  75.         print repr(marks)
  76.         print "--------------------------------------------"
  77.         print "Calculated total:", total
  78.     except Exception as e:
  79.         print "Exception: ", e
  80.     print "-------------------END TEST-----------------"
  81.     print
  82.  
  83. def trivial_tests():
  84.     pretty_test(np.array([[1]]))
  85.     pretty_test(np.array([[1, 2],
  86.                           [3, 4]]))
  87.  
  88. def other_tests():
  89.     pretty_test(np.array([[9, 8, 7, 6],
  90.                           [1, 9, 0, 0],
  91.                           [0, 0, 9, 0],
  92.                           [0, 0, 0, 9]]))
  93.     pretty_test(np.array([[0, 0, 0, 0],
  94.                           [0, 0, 0, 0],
  95.                           [0, 0, 0, 0],
  96.                           [0, 0, 0, 0]]))
  97.     pretty_test(np.array([[9, 0, 0, 0],
  98.                           [0, 8, 0, 7],
  99.                           [0, 0, 7, 6],
  100.                           [0, 7, 6, 5]]))
  101.     pretty_test(np.array([[3, 0, 0, 0, 1],
  102.                           [0, 9, 9, 9, 0],
  103.                           [0, 9, 0, 9, 0],
  104.                           [0, 9, 9, 9, 0],
  105.                           [4, 0, 0, 0, 3]]))
  106.     pretty_test(np.array([[1, 0, 0, 2],
  107.                           [0, 9, 9, 9],
  108.                           [0, 9, 9, 9],
  109.                           [2, 9, 9, 9]]))
  110.        
  111. if __name__ == '__main__':
  112.     trivial_tests()
  113.     other_tests()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement