Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # CS1656 Spring 2019 Assignment 5: Decision Trees
- # Author: Michael Korst (mpk44@pitt.edu) GH username: paranoidandroid81
- import argparse, csv, re
- import numpy as np
- # Object to store decision tree, operations
- class EpicTree:
- # initialize tree with tree root node + empty references, track all options
- # so as to build tree later
- def __init__(self, root):
- self.root = root
- self.root_ref = {}
- self.options = []
- # add a rule associated with a variable option
- def add_rule(self, parents, option, rule):
- # use parents list to dig into ref dict and add rule
- self.options.append(option) # save option
- curr = self.root_ref
- for parent in parents:
- if parent not in curr.keys():
- curr[parent] = {}
- curr = curr[parent]
- curr[option] = rule
- return
- # add a variable to the tree
- def add_ref(self, parents, ref):
- # use parents list to dig into ref dict and add rule
- if len(parents) == 1:
- # if only one parent, connected directly to root, considered option
- self.options.append(ref)
- curr = self.root_ref
- for parent in parents:
- if parent not in curr.keys():
- curr[parent] = {}
- curr = curr[parent]
- curr[ref] = {}
- return
- # find a rule in the table, returns label (good or bad) of existent, False otherwise
- def get_rule(self, parents):
- # use parents list to dig into ref dict and find rule, return False if not there
- curr = self.root_ref
- for parent in parents:
- if type(curr) is not dict or parent not in curr.keys():
- return False
- curr = curr[parent]
- # get label, removed from parentheses, removed whitespace
- # dict indicates no rule found
- if type(curr) is dict:
- return False
- else:
- return normalize_in(curr.split('(')[0])
- # helper to build data structure for a tree based on rules and options
- def build_tree(self, rule_dict):
- all_rules = list(rule_dict.keys())
- self.options = list(map(lambda x: normalize_in(x), self.options))
- # heuristic: last word in each rule will be the label (i.e. good or bad)
- tree_out = []
- for rule in all_rules:
- raw_rule = rule
- rule = raw_rule.split(',')
- rule = list(map(lambda x: normalize_in(x), rule))
- raw_label = normalize_in(rule[-1])
- label = re.sub("'", '', raw_label) # get rid of single quotes
- for idx, word in enumerate(rule):
- word = re.sub("'", '', word) # get rid of single quotes
- # go thru and build list
- # if next is label, must add colon
- if word != label:
- if rule[idx + 1] == raw_label:
- tree_out.append(f"{word}:")
- else:
- tree_out.append(word)
- else:
- tree_out.append(word)
- if word == label:
- # if label, must append count
- tree_out.append(f"({rule_dict[raw_rule]})")
- tree_out.append("\n")
- # now must, add branches
- split_tree = tree_out
- tree_out = ' '.join(tree_out)
- pos_pats = [] # store possible variable/option patterns for adding branches
- for line in tree_out.split("\n"):
- line = normalize_in(line)
- words = line.split(' ')
- i = 0
- cur_pat = words[i:(i+2)] # take first two words, add as possible pattern
- # add patterns if not last before label, will be new branches
- while(len(cur_pat) >= 2 and ':' not in cur_pat[1]):
- pos_pats.append(cur_pat)
- i += 2
- cur_pat = words[i:i + 2]
- # replace all multi-levels with branches
- p_2_done = {} # map each pattern to boolean if already done
- for pat in pos_pats:
- # continue if already parsed
- pat_key = ' '.join(pat)
- if pat_key in p_2_done.keys():
- if p_2_done[pat_key]:
- continue
- pat = ' '.join(pat)
- p_2_done[pat_key] = True
- split_tree = re.split(pat, ' '.join(split_tree))
- first = True
- # go thru replace with correct branches
- for idx, spl in enumerate(split_tree):
- # if first idx is space, indicates was split
- if spl[0] == ' ':
- split_tree[idx] = re.sub(r"^\s", '', spl) # get rid of space
- if first:
- # first one special case
- split_tree.insert(idx, f"{pat}\n| ")
- first = False
- else:
- split_tree[idx] = f"| {spl}"
- # get rid of beginning spaces
- return ' '.join(split_tree)
- # toString method
- def __str__(self):
- str_tree = []
- for k, v in self.root_ref.items():
- str_tree.append(f"{k}=>{v}")
- return ''.join(str_tree)
- # function to normalize tree input, i.e. remove \n, spaces from beginning or end, tabs, quotes
- def normalize_in(raw_in):
- out = raw_in
- out = re.sub('\n', '', out)
- out = re.sub(r"^\s+|\s+$", '', out)
- out = re.sub('\t', '', out)
- out = re.sub('"', '', out)
- out = re.sub(r"\[|\]", '', out) # remove brackets
- return out
- # recursive helper function to map each branch of tree
- def traverse_tree(raw_tree, dt, pos, parents):
- for rt in raw_tree:
- if not rt:
- # empty string
- return
- rt = normalize_in(rt)
- if '|' in rt:
- # indicates >=2 levels, split by '|'
- curr = rt.split('|')
- # first index is the option for the top level
- curr_option = normalize_in(curr[0])
- dt.add_ref(parents, curr_option)
- curr.pop(0) # don't need option anymore
- # use heuristic of 1st being variable name
- curr_lev = normalize_in(curr[0])
- all_words = curr_lev.split(' ')
- curr_lev = normalize_in(all_words[0])
- # map branches
- curr_pos = parents + [curr_option]
- dt.add_ref(curr_pos, curr_lev)
- # recurse with updated postion
- traverse_tree(curr, dt, curr_lev, curr_pos + [curr_lev])
- else:
- # 1 level, just map option to rule on tree
- curr = rt.split(':')
- # remove instances of higher level variable in option
- curr_option = normalize_in(curr[0].replace(pos, ''))
- dt.add_rule(parents, curr_option, normalize_in(curr[1]))
- # generates permutations
- # based on Python library definition of permutations
- def permutations(iterable, r=None):
- pool = tuple(iterable)
- n = len(pool)
- r = n if r is None else r
- if r > n:
- return
- indices = list(range(n))
- cycles = list(range(n, n-r, -1))
- yield tuple(pool[i] for i in indices[:r])
- while n:
- for i in reversed(range(r)):
- cycles[i] -= 1
- if cycles[i] == 0:
- indices[i:] = indices[i+1:] + indices[i:i+1]
- cycles[i] = n - i
- else:
- j = cycles[i]
- indices[i], indices[-j] = indices[-j], indices[i]
- yield tuple(pool[i] for i in indices[:r])
- break
- else:
- return
- # helper to generate combinations of variables for tree
- def gen_permuts(types):
- permuts = []
- # go thru all length permutations from 1 to len
- for i in range(1, len(types) + 1):
- permuts.extend(list(permutations(types, i)))
- return permuts
- ### BEGIN MAIN LOGIC ###
- # first, validate args
- parser = argparse.ArgumentParser()
- parser.add_argument('tree_filename')
- parser.add_argument('data_filename')
- args = vars(parser.parse_args()) # store args, convert to dict
- # first, read in whole decision tree
- with open('./' + args['tree_filename']) as f:
- raw_tree = f.read()
- # remove unnecessary characters and split by levels and rules
- raw_tree = normalize_in(raw_tree)
- all_levs = {}
- # heuristic: 1st word in tree will be top-level variable name
- all_words = raw_tree.split(' ')
- lev_name = normalize_in(all_words[0])
- dt = EpicTree(lev_name) # instantiate custom tree data struct
- # map out decision tree recursively for each branch
- all_branches = re.split(lev_name, raw_tree)
- for branch in all_branches:
- traverse_tree([branch], dt, lev_name, [lev_name])
- # next, read in test data
- test_headers = [] # list of data names
- test_rows = [] # list of all rows
- with open('./' + args['data_filename']) as csvf:
- creader = csv.reader(csvf)
- for idx, row in enumerate(creader):
- row = list(map(lambda x: normalize_in(x), row))
- if idx == 0:
- # first row = headers
- test_headers = row
- else:
- test_rows.append(row)
- # map headers for each row to the data
- test_zip = []
- for row in test_rows:
- test_zip.append(dict(zip(test_headers, row)))
- # now, evaluate data based on decision tree, find statistics for rules
- # first generate all permutations for possible rules
- permuts = gen_permuts(test_headers)
- no_match = 0 # check number where not found rules
- rule_count = {} # map each rule to count of times
- # for each row, check each permut for matching rule track
- for tz in test_zip:
- found = False # track if matched
- for permut in permuts:
- parents = []
- for var in permut:
- parents.append(var)
- parents.append(tz[var])
- ret_rule = dt.get_rule(parents)
- if not ret_rule:
- # if False, not found
- continue
- else:
- found = True
- parents.append(ret_rule)
- curr = normalize_in(str(parents))
- if curr not in rule_count.keys():
- rule_count[curr] = 0
- rule_count[curr] += 1
- if not found:
- no_match += 1
- # now just print out results in specified tree format
- r_2_label = {} # map rule to label for sorting by rule
- for rule in rule_count.keys():
- # heuristic: last element will be by label
- # we should sort by everything except label
- rule = rule.split(', ')
- curr = rule[:-1]
- curr = normalize_in(str(curr))
- r_2_label[curr] = normalize_in(rule[-1])
- sort_keys = sorted(r_2_label.keys())
- # after sort, we can print out final tree
- # form keys for rule_count dict, map key to count of rule
- sort_keys = [f"{key}, {r_2_label[key]}" for key in sort_keys]
- sort_dict = {k: rule_count[k] for k in sort_keys}
- stat_tree = dt.build_tree(sort_dict)
- stat_tree = stat_tree.split("\n")
- # make into lines, remove beginning spaces
- stat_tree = list(map(lambda x: re.sub(r"^\s+", '', x), stat_tree))
- # remove empty member at end (possibly)
- if not stat_tree[-1]:
- stat_tree.pop(-1)
- if no_match > 0:
- # if rows not matched, contained UNMATCHED line
- stat_tree.append(f"UNMATCHED: {no_match}")
- stat_tree = '\n'.join(stat_tree)
- print(stat_tree)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement