Advertisement
Guest User

Untitled

a guest
Oct 20th, 2019
176
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.62 KB | None | 0 0
  1. # Kyle Andrus
  2. # CS 4375 PS3
  3. # 10.17.2019
  4.  
  5. from csv import reader
  6. import random
  7. import os
  8. import sys
  9.  
  10.  
  11. # Load a CSV file into list 'dataset'
  12. def load_csv(filename):
  13.     dataset = list()
  14.     with open(os.path.join(os.path.dirname(sys.argv[0]), 'heart_train.data')) as file:
  15.         csv_reader = reader(file)
  16.         for row in csv_reader:
  17.             if not row:
  18.                 continue
  19.             dataset.append(row)
  20.     return dataset
  21.  
  22. # Choose 3 attributes from data to split on randomly
  23. def choose_attributes(data):
  24.     num_of_attributes = (len(data[0])) - 1 # -1 for class label
  25.     chosen_attributes = [0,0,0]
  26.  
  27.     # Randomly pick 3 attributes to split on
  28.     for i in range(3):
  29.         chosen_attributes[i] = random.randrange(0, num_of_attributes)
  30.     return chosen_attributes
  31.  
  32. # Return the majority class label in a bucket
  33. def majority_vote(bucket):
  34.     zeros = 0
  35.     ones = 0
  36.     for i in (range(len(bucket))):
  37.         if (int)(bucket[i][0]) == 0:
  38.             zeros = zeros + 1
  39.         if (int)(bucket[i][0]) == 1:
  40.             ones = ones + 1
  41.     if(zeros >= ones):
  42.         return 0
  43.     elif(ones > zeros):
  44.         return 1
  45.  
  46. # Chose next split based on most uncertain branch
  47. def choose_next_split(left_bucket, right_bucket):
  48.     print("to do (^:")
  49.  
  50. def calculate_error(bucket):
  51.     # First majority vote the bucket to determine it's overall class
  52.     result = majority_vote(bucket)
  53.     # Now, compare the result with each class label in the bucket
  54.     error_sum = 0
  55.     for i in range(1, len(bucket)):
  56.         if (int)(bucket[i][0]) != result:
  57.             error_sum = error_sum + 1
  58.     true_error = (error_sum / len(bucket)) #normalized by # of elements in bucket
  59.     error_results = list()
  60.     error_results.append(result)
  61.     error_results.append(true_error)
  62.     return error_results
  63.  
  64. def average_errors(errors):
  65.     error_sum = 0 # sum of all error values to be avg'd
  66.     error_num = 0 # number of buckets that have any error
  67.     for i in range(len(errors)):
  68.         if float(errors[i][1]) > 0:
  69.             error_sum += float(errors[i][1])
  70.             error_num += 1
  71.     if error_num != 0:
  72.         error_epsilon = error_sum / error_num
  73.     else:
  74.         error_num = 1
  75.         error_epsilon = error_sum / error_num
  76.     return error_epsilon
  77.  
  78. def calculate_alpha(epsilon):
  79.     alpha = .5 * ((1 - epsilon)/epsilon)
  80.     return alpha
  81.  
  82. def init_weights(data):
  83.     for i in range(len(data)):
  84.         data[i].append(1)
  85.  
  86. def update_weights(alpha, data):
  87.     for i in range(len(data)):
  88.         data[i][23] = alpha
  89.  
  90. def build_tree_with_one_split(data, attribute):
  91.     left = list()
  92.     right = list()
  93.  
  94.     for i in range(1, len(data)-1):
  95.         if ((int)(data[i][attribute]) == 0):
  96.             left.append(data[i])
  97.         if ((int)(data[i][attribute]) == 1):
  98.             right.append(data[i])
  99.  
  100.     print("------Nodes------")
  101.     print("split on attribute {}".format(attribute))
  102.     print("top left")
  103.     print(len(left))
  104.     print("top right")
  105.     print(len(right))
  106.     print("------------------")
  107.  
  108.     errors = list()
  109.     left_error = calculate_error(left)
  110.     errors.append(left_error)
  111.     right_error = calculate_error(right)
  112.     errors.append(right_error)
  113.  
  114.     print("Split: ")
  115.     print(errors)
  116.     epsilon = average_errors(errors)
  117.     print("Epsilon = {}".format(epsilon))
  118.     alpha = calculate_alpha(epsilon)
  119.     print("Alpha = {}".format(alpha))
  120.  
  121.  
  122. # Construct a decision tree with three attribute splits, and
  123. # calculate the alpha weight and epsilon error for that tree
  124. def build_tree_with_three_splits(data, attributes):
  125.     # Save the original data for later (;
  126.     original_data = list(data)
  127.  
  128.     # First split
  129.     top_left = list()
  130.     top_right = list()
  131.  
  132.     for i in range(1, len(data)-1):
  133.         if ((int)(data[i][attributes[0]]) == 0):
  134.             top_left.append(data[i])
  135.         if ((int)(data[i][attributes[0]]) == 1):
  136.             top_right.append(data[i])
  137.  
  138.     # Figure out which branch to split on next based on uncertainty
  139.     # For now, just split the branch with the most data
  140.     if (len(top_left) >= len(top_right)):
  141.         data = list(top_left)
  142.         del top_left[:] # empty the list since it's contents will be split into 2 new buckets
  143.     elif (len(top_left) < len(top_right)):
  144.         data = list(top_right)
  145.         del top_right[:]
  146.  
  147.     # Second split
  148.     middle_left = list()
  149.     middle_right = list()
  150.  
  151.     for i in range(1, len(data)-1):
  152.         if ((int)(data[i][attributes[1]]) == 0):
  153.             middle_left.append(data[i])
  154.         if ((int)(data[i][attributes[1]]) == 1):
  155.             middle_right.append(data[i])
  156.  
  157.     # Figure out which branch to split on next based on uncertainty
  158.     # For now, just split the branch with the most data
  159.     if (len(middle_left) >= len(middle_right)):
  160.         data = list(middle_left)
  161.         del middle_left[:] # empty the list since it's contents will be split into 2 new buckets
  162.     elif (len(middle_left) < len(middle_right)):
  163.         data = list(middle_right)
  164.         del middle_right[:]
  165.  
  166.     # Third and final split
  167.     bottom_left = list()
  168.     bottom_right = list()
  169.  
  170.     for i in range(1, len(data)-1):
  171.         if ((int)(data[i][attributes[2]]) == 0):
  172.             bottom_left.append(data[i])
  173.         if ((int)(data[i][attributes[2]]) == 1):
  174.             bottom_right.append(data[i])
  175.  
  176.     print("------Nodes------")
  177.     print("split on attribute {}".format(attributes[0]))
  178.     print("top left")
  179.     print(len(top_left))
  180.     print("top right")
  181.     print(len(top_right))
  182.     print("split on attribute {}".format(attributes[1]))
  183.     print("middle left")
  184.     print(len(middle_left))
  185.     print("middle right")
  186.     print(len(middle_right))
  187.     print("split on attribute {}".format(attributes[2]))
  188.     print("bottom left")
  189.     print(len(bottom_left))
  190.     print("bottom right")
  191.     print(len(bottom_right))
  192.     print("------------------")
  193.  
  194.     # We now have the data seperated into six buckets,
  195.     # next we calculate the error that our selected attributes achieved
  196.     # calculate_error returns tuple of predicted label and actual error (% of bucket misclassified)
  197.     # errors is a list of these tuples for all buckets containing points
  198.     errors = list()
  199.  
  200.     if len(top_left) > 0:
  201.         tl_error = calculate_error(top_left)
  202.         errors.append(tl_error)
  203.     if len(top_right) > 0:
  204.         tr_error = calculate_error(top_right)
  205.         errors.append(tr_error)
  206.     if len(middle_left) > 0:
  207.         ml_error = calculate_error(middle_left)
  208.         errors.append(ml_error)
  209.     if len(middle_right) > 0:
  210.         mr_error = calculate_error(middle_right)
  211.         errors.append(mr_error)
  212.     if len(bottom_left) > 0:
  213.         bl_error = calculate_error(bottom_left)
  214.         errors.append(bl_error)
  215.     if len(bottom_right) > 0:
  216.         br_error = calculate_error(bottom_right)
  217.         errors.append(br_error)
  218.  
  219.     print("Resulting 4 data buckets: (format [class label, bucket error]")
  220.     print(errors)
  221.     epsilon = average_errors(errors)
  222.     print("Epsilon = {}".format(epsilon))
  223.     alpha = calculate_alpha(epsilon)
  224.     print("Alpha = {}".format(alpha))
  225.  
  226.     update_weights(alpha, data)
  227.  
  228. def ada_boost(data, iterations):
  229.     for i in range(0, iterations):
  230.         init_weights(data)
  231.         attributes = choose_attributes(data)
  232.         print("\n\nIteration {}\n".format(i+1))
  233.         build_tree_with_three_splits(data, attributes)
  234.  
  235. def coordinate_descent(data, iterations):
  236.     for i in range(1, iterations):
  237.         print("\n\nIteration {}\n".format(i+1))
  238.         build_tree_with_one_split(data, i)
  239.  
  240. data = load_csv("heart_train.data")
  241.  
  242. #adaboost(data, 5)
  243. coordinate_descent(data, 22)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement