Advertisement
Guest User

Untitled

a guest
Sep 19th, 2019
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.59 KB | None | 0 0
  1. import numpy as np
  2. from data import Data
  3.  
  4. DATA_DIR = 'data/'
  5.  
  6.  
  7. data = np.loadtxt(DATA_DIR + 'train.csv', delimiter=',', dtype=str)
  8. data_obj = Data(data=data)
  9. print(np.unique(data_obj.get_column('label'), return_counts=True))
  10.  
  11. def entropy(column):
  12. """
  13. Calculate the entropy of the labels
  14. column specifies the target column
  15. """
  16. elements, counts = np.unique(column, return_counts = True)
  17. entropy = np.sum([(-counts[i]/np.sum(counts))*np.log2(counts[i]/np.sum(counts)) for i in range(len(elements))])
  18. return entropy
  19.  
  20.  
  21. def infogain(data, attribute, target_name='label'):
  22. """
  23. Calculate information gain of the dataset:
  24. :param data: The dataset for which the information gain is calculated
  25. :param attribute: the name of the feature for which the information gain should be calculated
  26. :param target_name: the name of the target feature; default is "bruises"
  27. :return: the information gain for this split
  28. """
  29.  
  30. #calculate total entropy of the data
  31. total_entropy = entropy(data_obj.get_column(target_name))
  32.  
  33. #count the labels for the subsets of the split
  34. vals, counts = np.unique(data_obj.get_column(attribute), return_counts=True)
  35.  
  36.  
  37. def weighted_entropy(attribute):
  38. vals, counts = np.unique(data_obj.get_column(attribute), return_counts=True)
  39.  
  40. def subs(attribute):
  41. values = data_obj.attributes[attribute].possible_vals # gets possible values for given attribute
  42. subs = [] # creates list of subtables (partitions of data_obj) each containing every row instance of a value
  43. for i in values:
  44. subs.append(data_obj.get_row_subset(attribute, i).raw_data)
  45. return subs
  46.  
  47. w_ent = np.sum([(counts[i] / np.sum(counts)) * entropy(subs(attribute)[i][:, 0]) for i in range(len(vals))])
  48. return w_ent
  49.  
  50.  
  51. # Calculate the information gain
  52. Information_Gain = total_entropy - weighted_entropy
  53. return Information_Gain
  54.  
  55. def id3(data, originaldata=data_obj, features, target_attribute_name="label", parent_node_class=None):
  56. """
  57.  
  58. :param data: the data for which the ID3 algorithm should be run (in the first run, the whole dataset)
  59. :param originaldata: the original dataset (for finding the most frequent label of the original dataset)
  60. :param features: the feature space of the dataset (needed for the recursive call)
  61. :param target_attribute_name: the name of the target attribute
  62. :param parent_node_class: most frequently appearing label for the direct parent node
  63. :return: a classification decision tree
  64. """
  65. # define the stopping criteria; if satisfied, return leaf node
  66. # if all target_values have the same value, return this value
  67. if len(np.unique(data_obj.get_column(target_attribute_name))) <= 1:
  68. return np.unique(data_obj.get_column(target_attribute_name))[0]
  69.  
  70. # if the dataset is empty, return the most frequently appearing label of original dataset
  71. elif len(data) == 0:
  72. return np.unique(originaldata.get_column('label'))[np.argmax(np.unique(originaldata[target_attribute_name], return_counts=True)[1])]
  73.  
  74. # if the feature space is empty, return the most frequently occurring label of the direct parent node
  75. elif len(features) == 0:
  76. return parent_node_class
  77.  
  78. # If none of the above holds true, grow the tree!
  79. else:
  80. # set the default value for this node (the most commonly occurring label of the current node !!!!!problem line below
  81. parent_node_class = np.unique(data_obj.get_column(target_attribute_name))[np.argmax(np.unique(data_obj.get_column(target_attribute_name)],return_counts=True)[1])]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement