Guest User

Untitled

a guest
Dec 14th, 2017
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 5.06 KB | None | 0 0
  1. from random import seed
  2. from random import randrange
  3. from csv import reader
  4.  
  5.  
  6. def accuracy_metric(actual, predicted):
  7.     correct = 0
  8.     for i in range(len(actual)):
  9.         if actual[i] == predicted[i]:
  10.             correct += 1
  11.     return correct / float(len(actual)) * 100.0
  12.  
  13.  
  14. def str_column_to_float(dataset, column):
  15.     for row in dataset:
  16.         row[column] = float(row[column].strip())
  17.  
  18.  
  19. def load_csv(filename):
  20.     file = open(filename, "rb")
  21.     lines = reader(file)
  22.     dataset = list(lines)
  23.     return dataset
  24.  
  25.  
  26. def evaluate_algorithm(dataset, algorithm, n_folds, *args):
  27.     folds = cross_validation_split(dataset, n_folds)
  28.     scores = list()
  29.     for fold in folds:
  30.         train_set = list(folds)
  31.         train_set.remove(fold)
  32.         train_set = sum(train_set, [])
  33.         test_set = list()
  34.         for row in fold:
  35.             row_copy = list(row)
  36.             test_set.append(row_copy)
  37.             row_copy[-1] = None
  38.         predicted = algorithm(train_set, test_set, *args)
  39.         actual = [row[-1] for row in fold]
  40.         accuracy = accuracy_metric(actual, predicted)
  41.         scores.append(accuracy)
  42.     return scores
  43.  
  44.  
  45. def cross_validation_split(dataset, n_folds):
  46.     dataset_split = list()
  47.     dataset_copy = list(dataset)
  48.     fold_size = int(len(dataset) / n_folds)
  49.     for i in range(n_folds):
  50.         fold = list()
  51.         while len(fold) < fold_size:
  52.             index = randrange(len(dataset_copy))
  53.             fold.append(dataset_copy.pop(index))
  54.         dataset_split.append(fold)
  55.     return dataset_split
  56.  
  57.  
  58. def gini_index(groups, classes):
  59.     # count all samples at split point
  60.     n_instances = float(sum([len(group) for group in groups]))
  61.     # sum weighted Gini index for each group
  62.     gini = 0.0
  63.     for group in groups:
  64.         size = float(len(group))
  65.         # avoid divide by zero
  66.         if size == 0:
  67.             continue
  68.         score = 0.0
  69.         # score the group based on the score for each class
  70.         for class_val in classes:
  71.             p = [row[-1] for row in group].count(class_val) / size
  72.             score += p * p
  73.         # weight the group score by its relative size
  74.         gini += (1.0 - score) * (size / n_instances)
  75.     return gini
  76.  
  77.  
  78. def test_split(index, value, dataset):
  79.     left, right = list(), list()
  80.     for row in dataset:
  81.         if row[index] < value:
  82.             left.append(row)
  83.         else:
  84.             right.append(row)
  85.     return left, right
  86.  
  87.  
  88. def get_split(dataset):
  89.     class_values = list(set(row[-1] for row in dataset))
  90.     b_index, b_value, b_score, b_groups = 999, 999, 999, None
  91.     for index in range(len(dataset[0])-1):
  92.         for row in dataset:
  93.             groups = test_split(index, row[index], dataset)
  94.             gini = gini_index(groups, class_values)
  95.             if gini < b_score:
  96.                 b_index, b_value, b_score, b_groups = index, row[index], gini, groups
  97.     return {'index':b_index, 'value':b_value, 'groups':b_groups}
  98.  
  99.  
  100. def to_terminal(group):
  101.     outcomes = [row[-1] for row in group]
  102.     return max(set(outcomes), key=outcomes.count)
  103.  
  104.  
  105. def split(node, max_depth, min_size, depth):
  106.     left, right = node['groups']
  107.     del(node['groups'])
  108.     # check for a no split
  109.     if not left or not right:
  110.         node['left'] = node['right'] = to_terminal(left + right)
  111.         return
  112.     # check for max depth
  113.     if depth >= max_depth:
  114.         node['left'], node['right'] = to_terminal(left), to_terminal(right)
  115.         return
  116.     # process left child
  117.     if len(left) <= min_size:
  118.         node['left'] = to_terminal(left)
  119.     else:
  120.         node['left'] = get_split(left)
  121.         split(node['left'], max_depth, min_size, depth+1)
  122.     # process right child
  123.     if len(right) <= min_size:
  124.         node['right'] = to_terminal(right)
  125.     else:
  126.         node['right'] = get_split(right)
  127.         split(node['right'], max_depth, min_size, depth+1)
  128.  
  129.  
  130. def build_tree(train, max_depth, min_size):
  131.     root = get_split(train)
  132.     split(root, max_depth, min_size, 1)
  133.     return root
  134.  
  135.  
  136. def decision_tree(train, test, max_depth, min_size):
  137.     tree = build_tree(train, max_depth, min_size)
  138.     predictions = list()
  139.     for row in test:
  140.         prediction = predict(tree, row)
  141.         predictions.append(prediction)
  142.     return(predictions)
  143.  
  144.  
  145. def predict(node, row):
  146.     if row[node['index']] < node['value']:
  147.         if isinstance(node['left'], dict):
  148.             return predict(node['left'], row)
  149.         else:
  150.             return node['left']
  151.     else:
  152.         if isinstance(node['right'], dict):
  153.             return predict(node['right'], row)
  154.         else:
  155.             return node['right']
  156.  
  157.  
  158. if __name__ == '__main__':
  159.     seed(1)
  160.  
  161.     filename = 'dataset.csv'
  162.     dataset = load_csv(filename)
  163.  
  164.     for i in range(len(dataset[0])):
  165.         str_column_to_float(dataset, i)
  166.  
  167.     n_folds = 5
  168.     max_depth = 5
  169.     min_size = 10
  170.     scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
  171.     print('Scores: %s' % scores)
  172.     print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))
Advertisement
Add Comment
Please, Sign In to add comment