Advertisement
spkenny

branch and bound

Mar 7th, 2015
291
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.80 KB | None | 0 0
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3.  
  4. from collections import namedtuple
  5. Item = namedtuple("Item", ['index', 'value', 'weight'])
  6.  
  7.  
  8. def estimate(items, capacity):
  9.     value = 0
  10.     weight = 0
  11.  
  12.     ratios = [(item, float(item.value)/float(item.weight)) for item in items]
  13.     global total_value
  14.     total_value = sum([item.value for item in items])
  15.     ratios = [item[0] for item in reversed(sorted(ratios, key=lambda (item, ratio): ratio))]
  16.     taken = [0]*len(items)
  17.  
  18.     for item in ratios:
  19.         if weight + item.weight <= capacity:
  20.             taken[item.index] = 1
  21.             value += item.value
  22.             weight += item.weight
  23.  
  24.         else:
  25.             rest = capacity - weight
  26.             value += item.value * rest/float(item.weight)
  27.  
  28.     return value
  29.  
  30.  
  31. def neighbors(state, length, items):
  32.     global total_value
  33.     return_value = []
  34.     print state
  35.     old_state = state[0][0]
  36.     value = state[0][1]
  37.     rest_capacity = state[0][2]
  38.     curr_estimation = state[0][3]
  39.     right_state = old_state + [1]
  40.     left_state = old_state + [0]
  41.     current_pos = len(old_state)
  42.     estim = estimation(old_state, curr_estimation, items)
  43.     if current_pos < length:
  44.         left = [(left_state, value, rest_capacity, curr_estimation)]
  45.         if rest_capacity - items[current_pos].weight > 0:
  46.             right = [(right_state, value + items[current_pos].value, rest_capacity - items[current_pos].weight, estim)]
  47.             return_value = [left, right]
  48.         else:
  49.             return_value = [left]
  50.  
  51.     return return_value
  52.  
  53.  
  54. def estimation(state, total, items):
  55.     t = total
  56.     for i in range(len(state)):
  57.         if state[i] == 0:
  58.             t -= items[i].value
  59.     return min(t, total)
  60.  
  61.  
  62. def solve_it(input_data):
  63.     # Modify this code to run your optimization algorithm
  64.  
  65.     # parse the input
  66.     lines = input_data.split('\n')
  67.  
  68.     firstLine = lines[0].split()
  69.     item_count = int(firstLine[0])
  70.     capacity = int(firstLine[1])
  71.  
  72.     items = []
  73.  
  74.     for i in range(1, item_count+1):
  75.         line = lines[i]
  76.         parts = line.split()
  77.         items.append(Item(i-1, int(parts[0]), int(parts[1])))
  78.  
  79.  
  80.     # a trivial greedy algorithm for filling the knapsack
  81.     # it takes items in-order until the knapsack is full
  82.  
  83.     estimation = estimate(items, capacity)
  84.  
  85.     fringe = []
  86.     #fringe.append(0, capacity, estimation)
  87.  
  88.     state = [([], 0, capacity, estimation)]
  89.     fringe.append(state)
  90.  
  91.     while len(fringe) > 0:
  92.         state = fringe.pop()
  93.         nodes = neighbors(state, item_count, items)
  94.         for node in nodes:
  95.             if node:
  96.                 fringe.append(node)
  97.                 print node
  98.  
  99.  
  100.  
  101.  
  102.     value = 0
  103.     weight = 0
  104.  
  105.     ratios = [(item, float(item.value)/float(item.weight)) for item in items]
  106.     ratios = [item[0] for item in reversed(sorted(ratios, key=lambda (item, ratio): ratio))]
  107.     taken = [0]*len(items)
  108.  
  109.     for item in ratios:
  110.         if weight + item.weight <= capacity:
  111.             taken[item.index] = 1
  112.             value += item.value
  113.             weight += item.weight
  114.  
  115.         else:
  116.             rest = capacity - weight
  117.             value += item.value * rest/float(item.weight)
  118.    
  119.     # prepare the solution in the specified output format
  120.     output_data = str(value) + ' ' + str(0) + '\n'
  121.     output_data += ' '.join(map(str, taken))
  122.     return output_data
  123.  
  124.  
  125. import sys
  126.  
  127. if __name__ == '__main__':
  128.     #if len(sys.argv) > 1:
  129.     file_location = 'data/test.data'   #sys.argv[1].strip()
  130.     input_data_file = open(file_location, 'r')
  131.     input_data = ''.join(input_data_file.readlines())
  132.     input_data_file.close()
  133.     print solve_it(input_data)
  134.     #else:
  135.     #print 'This test requires an input file.  Please select one from the data directory. (i.e. python solver.py ./data/ks_4_0)'
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement