Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import Queue
- def init_marks(m):
- marks = np.zeros(m.shape, dtype=int)
- rowMaxPosns = m.argmax(1)
- colMaxPosns = m.argmax(0)
- collis = []
- for i, p in enumerate(rowMaxPosns):
- add_mark(marks, (i, p), collis)
- for i, p in enumerate(colMaxPosns):
- add_mark(marks, (p, i), collis)
- return marks, collis
- def add_mark(m, pos, collis):
- m[pos] += 1
- if m[pos] > 1:
- collis.append(pos)
- def move_mark(m, marks, markPos):
- n = m.shape[0]
- done = np.zeros(m.shape, dtype = bool)
- done[markPos] = True
- todo = Queue.Queue()
- todo.put((True, markPos))
- todo.put((False, markPos))
- maxVal = -np.inf
- while not todo.empty():
- doRow, pos = todo.get()
- for i in xrange(n):
- newPos = (pos[0], i) if doRow else (i, pos[1])
- if newPos == pos or done[newPos]:
- continue
- done[newPos] = True
- if marks[newPos] > 0:
- todo.put((not doRow, newPos))
- else:
- newVal = m[newPos]
- if newVal > maxVal:
- maxVal = newVal
- maxPos = newPos
- marks[markPos] -= 1
- marks[maxPos] += 1
- def creamsteak(m):
- n = m.shape[0]
- if n != m.shape[1]:
- raise Exception('Not square!')
- if n == 1:
- raise Exception('1x1 case has no solution')
- marks, collis = init_marks(m)
- for pos in collis:
- move_mark(m, marks, pos)
- total = 0
- for i in xrange(n):
- for j in xrange(n):
- if marks[i, j] == 0:
- pass
- elif marks[i, j] == 1:
- total += m[i, j]
- else:
- raise Exception('OH NOES!')
- return total, marks
- def pretty_test(m):
- print
- print "------------------BEGIN TEST----------------"
- print "Input matrix:"
- print repr(m)
- print "--------------------------------------------"
- try:
- total, marks = creamsteak(m)
- print "Choice of elements (1=chosen, 0=not chosen):"
- print repr(marks)
- print "--------------------------------------------"
- print "Calculated total:", total
- except Exception as e:
- print "Exception: ", e
- print "-------------------END TEST-----------------"
- print
- def trivial_tests():
- pretty_test(np.array([[1]]))
- pretty_test(np.array([[1, 2],
- [3, 4]]))
- def other_tests():
- pretty_test(np.array([[9, 8, 7, 6],
- [1, 9, 0, 0],
- [0, 0, 9, 0],
- [0, 0, 0, 9]]))
- pretty_test(np.array([[0, 0, 0, 0],
- [0, 0, 0, 0],
- [0, 0, 0, 0],
- [0, 0, 0, 0]]))
- pretty_test(np.array([[9, 0, 0, 0],
- [0, 8, 0, 7],
- [0, 0, 7, 6],
- [0, 7, 6, 5]]))
- pretty_test(np.array([[3, 0, 0, 0, 1],
- [0, 9, 9, 9, 0],
- [0, 9, 0, 9, 0],
- [0, 9, 9, 9, 0],
- [4, 0, 0, 0, 3]]))
- pretty_test(np.array([[1, 0, 0, 2],
- [0, 9, 9, 9],
- [0, 9, 9, 9],
- [2, 9, 9, 9]]))
- if __name__ == '__main__':
- trivial_tests()
- other_tests()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement