Advertisement
hatfulofhollow81

Untitled

Apr 19th, 2019
572
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.21 KB | None | 0 0
  1. # CS1656 Spring 2019 Assignment 5: Decision Trees
  2. # Author: Michael Korst (mpk44@pitt.edu) GH username: paranoidandroid81
  3.  
  4. import argparse, csv, re
  5. import numpy as np
  6.  
  7. # Object to store decision tree, operations
  8. class EpicTree:
  9.  
  10.   # initialize tree with tree root node + empty references, track all options
  11.   # so as to build tree later
  12.   def __init__(self, root):
  13.     self.root = root
  14.     self.root_ref = {}
  15.     self.options = []
  16.  
  17.   # add a rule associated with a variable option
  18.   def add_rule(self, parents, option, rule):
  19.     # use parents list to dig into ref dict and add rule
  20.     self.options.append(option) # save option
  21.     curr = self.root_ref
  22.     for parent in parents:
  23.       if parent not in curr.keys():
  24.         curr[parent] = {}
  25.       curr = curr[parent]
  26.     curr[option] = rule
  27.     return
  28.    
  29.   # add a variable to the tree
  30.   def add_ref(self, parents, ref):
  31.     # use parents list to dig into ref dict and add rule
  32.     if len(parents) == 1:
  33.       # if only one parent, connected directly to root, considered option
  34.       self.options.append(ref)
  35.     curr = self.root_ref
  36.     for parent in parents:
  37.       if parent not in curr.keys():
  38.         curr[parent] = {}
  39.       curr = curr[parent]
  40.     curr[ref] = {}
  41.     return
  42.  
  43.   # find a rule in the table, returns label (good or bad) of existent, False otherwise
  44.   def get_rule(self, parents):
  45.     # use parents list to dig into ref dict and find rule, return False if not there
  46.     curr = self.root_ref
  47.     for parent in parents:
  48.       if type(curr) is not dict or parent not in curr.keys():
  49.         return False
  50.       curr = curr[parent]
  51.     # get label, removed from parentheses, removed whitespace
  52.     # dict indicates no rule found
  53.     if type(curr) is dict:
  54.       return False
  55.     else:
  56.       return normalize_in(curr.split('(')[0])
  57.  
  58.   # helper to build data structure for a tree based on rules and options
  59.   def build_tree(self, rule_dict):
  60.     all_rules = list(rule_dict.keys())
  61.     self.options = list(map(lambda x: normalize_in(x), self.options))
  62.     # heuristic: last word in each rule will be the label (i.e. good or bad)
  63.     tree_out = []
  64.     for rule in all_rules:
  65.       raw_rule = rule
  66.       rule = raw_rule.split(',')
  67.       rule = list(map(lambda x: normalize_in(x), rule))
  68.       raw_label = normalize_in(rule[-1])
  69.       label = re.sub("'", '', raw_label) # get rid of single quotes
  70.       for idx, word in enumerate(rule):
  71.         word = re.sub("'", '', word) # get rid of single quotes
  72.         # go thru and build list
  73.         # if next is label, must add colon
  74.         if word != label:
  75.           if rule[idx + 1] == raw_label:
  76.             tree_out.append(f"{word}:")
  77.           else:
  78.             tree_out.append(word)
  79.         else:
  80.           tree_out.append(word)
  81.         if word == label:
  82.           # if label, must append count
  83.           tree_out.append(f"({rule_dict[raw_rule]})")
  84.       tree_out.append("\n")
  85.     # now must, add branches
  86.     split_tree = tree_out
  87.     tree_out = ' '.join(tree_out)
  88.     pos_pats = [] # store possible variable/option patterns for adding branches
  89.     for line in tree_out.split("\n"):
  90.       line = normalize_in(line)
  91.       words = line.split(' ')
  92.       i = 0
  93.       cur_pat = words[i:(i+2)] # take first two words, add as possible pattern
  94.       # add patterns if not last before label, will be new branches
  95.       while(len(cur_pat) >= 2 and ':' not in cur_pat[1]):
  96.         pos_pats.append(cur_pat)
  97.         i += 2
  98.         cur_pat = words[i:i + 2]
  99.     # replace all multi-levels with branches
  100.     p_2_done = {} # map each pattern to boolean if already done
  101.     for pat in pos_pats:
  102.       # continue if already parsed
  103.       pat_key = ' '.join(pat)
  104.       if pat_key in p_2_done.keys():
  105.         if p_2_done[pat_key]:
  106.           continue
  107.       pat = ' '.join(pat)
  108.       p_2_done[pat_key] = True
  109.       split_tree = re.split(pat, ' '.join(split_tree))
  110.       first = True
  111.       # go thru replace with correct branches
  112.       for idx, spl in enumerate(split_tree):
  113.         # if first idx is space, indicates was split
  114.         if spl[0] == ' ':
  115.           split_tree[idx] = re.sub(r"^\s", '', spl) # get rid of space
  116.           if first:
  117.             # first one special case
  118.             split_tree.insert(idx, f"{pat}\n|   ")
  119.             first = False
  120.           else:
  121.             split_tree[idx] = f"|   {spl}"
  122.     # get rid of beginning spaces
  123.     return ' '.join(split_tree)
  124.  
  125.   # toString method
  126.   def __str__(self):
  127.     str_tree = []
  128.     for k, v in self.root_ref.items():
  129.       str_tree.append(f"{k}=>{v}")
  130.     return ''.join(str_tree)
  131.  
  132.    
  133.  
  134.  
  135. # function to normalize tree input, i.e. remove \n, spaces from beginning or end, tabs, quotes
  136. def normalize_in(raw_in):
  137.   out = raw_in
  138.   out = re.sub('\n', '', out)
  139.   out = re.sub(r"^\s+|\s+$", '', out)
  140.   out = re.sub('\t', '', out)
  141.   out = re.sub('"', '', out)
  142.   out = re.sub(r"\[|\]", '', out) # remove brackets
  143.   return out
  144.  
  145. # recursive helper function to map each branch of tree
  146. def traverse_tree(raw_tree, dt, pos, parents):
  147.   for rt in raw_tree:
  148.     if not rt:
  149.       # empty string
  150.       return
  151.     rt = normalize_in(rt)
  152.     if '|' in rt:
  153.       # indicates >=2 levels, split by '|'
  154.       curr = rt.split('|')
  155.       # first index is the option for the top level
  156.       curr_option = normalize_in(curr[0])
  157.       dt.add_ref(parents, curr_option)
  158.       curr.pop(0) # don't need option anymore
  159.       # use heuristic of 1st being variable name
  160.       curr_lev = normalize_in(curr[0])
  161.       all_words = curr_lev.split(' ')
  162.       curr_lev = normalize_in(all_words[0])
  163.       # map branches
  164.       curr_pos = parents + [curr_option]
  165.       dt.add_ref(curr_pos, curr_lev)
  166.       # recurse with updated postion
  167.       traverse_tree(curr, dt, curr_lev, curr_pos + [curr_lev])
  168.     else:
  169.       # 1 level, just map option to rule on tree
  170.       curr = rt.split(':')
  171.       # remove instances of higher level variable in option
  172.       curr_option = normalize_in(curr[0].replace(pos, ''))
  173.       dt.add_rule(parents, curr_option, normalize_in(curr[1]))
  174.  
  175. # generates permutations
  176. # based on Python library definition of permutations
  177. def permutations(iterable, r=None):
  178.     pool = tuple(iterable)
  179.     n = len(pool)
  180.     r = n if r is None else r
  181.     if r > n:
  182.         return
  183.     indices = list(range(n))
  184.     cycles = list(range(n, n-r, -1))
  185.     yield tuple(pool[i] for i in indices[:r])
  186.     while n:
  187.         for i in reversed(range(r)):
  188.             cycles[i] -= 1
  189.             if cycles[i] == 0:
  190.                 indices[i:] = indices[i+1:] + indices[i:i+1]
  191.                 cycles[i] = n - i
  192.             else:
  193.                 j = cycles[i]
  194.                 indices[i], indices[-j] = indices[-j], indices[i]
  195.                 yield tuple(pool[i] for i in indices[:r])
  196.                 break
  197.         else:
  198.             return
  199.  
  200. # helper to generate combinations of variables for tree
  201. def gen_permuts(types):
  202.   permuts = []
  203.   # go thru all length permutations from 1 to len
  204.   for i in range(1, len(types) + 1):
  205.     permuts.extend(list(permutations(types, i)))
  206.   return permuts
  207.  
  208. ### BEGIN MAIN LOGIC ###
  209. # first, validate args
  210. parser = argparse.ArgumentParser()
  211. parser.add_argument('tree_filename')
  212. parser.add_argument('data_filename')
  213. args = vars(parser.parse_args()) # store args, convert to dict
  214.  
  215. # first, read in whole decision tree
  216. with open('./' + args['tree_filename']) as f:
  217.   raw_tree = f.read()
  218.  
  219. # remove unnecessary characters and split by levels and rules
  220. raw_tree = normalize_in(raw_tree)
  221. all_levs = {}
  222.  
  223. # heuristic: 1st word in tree will be top-level variable name
  224. all_words = raw_tree.split(' ')
  225. lev_name = normalize_in(all_words[0])
  226. dt = EpicTree(lev_name) # instantiate custom tree data struct
  227. # map out decision tree recursively for each branch
  228. all_branches = re.split(lev_name, raw_tree)
  229. for branch in all_branches:
  230.   traverse_tree([branch], dt, lev_name, [lev_name])
  231.  
  232. # next, read in test data
  233. test_headers = [] # list of data names
  234. test_rows = [] # list of all rows
  235. with open('./' + args['data_filename']) as csvf:
  236.   creader = csv.reader(csvf)
  237.   for idx, row in enumerate(creader):
  238.     row = list(map(lambda x: normalize_in(x), row))
  239.     if idx == 0:
  240.       # first row = headers
  241.       test_headers = row
  242.     else:
  243.       test_rows.append(row)
  244. # map headers for each row to the data
  245. test_zip = []
  246. for row in test_rows:
  247.   test_zip.append(dict(zip(test_headers, row)))
  248.  
  249. # now, evaluate data based on decision tree, find statistics for rules
  250. # first generate all permutations for possible rules
  251. permuts = gen_permuts(test_headers)
  252. no_match = 0 # check number where not found rules
  253. rule_count = {} # map each rule to count of times
  254. # for each row, check each permut for matching rule track
  255. for tz in test_zip:
  256.   found = False # track if matched
  257.   for permut in permuts:
  258.     parents = []
  259.     for var in permut:
  260.       parents.append(var)
  261.       parents.append(tz[var])
  262.     ret_rule = dt.get_rule(parents)
  263.     if not ret_rule:
  264.       # if False, not found
  265.       continue
  266.     else:
  267.       found = True
  268.       parents.append(ret_rule)
  269.       curr = normalize_in(str(parents))
  270.       if curr not in rule_count.keys():
  271.         rule_count[curr] = 0
  272.       rule_count[curr] += 1
  273.   if not found:
  274.     no_match += 1
  275.  
  276. # now just print out results in specified tree format
  277. r_2_label = {} # map rule to label for sorting by rule
  278. for rule in rule_count.keys():
  279.   # heuristic: last element will be by label
  280.   # we should sort by everything except label
  281.   rule = rule.split(', ')
  282.   curr = rule[:-1]
  283.   curr = normalize_in(str(curr))
  284.   r_2_label[curr] = normalize_in(rule[-1])
  285. sort_keys = sorted(r_2_label.keys())
  286. # after sort, we can print out final tree
  287. # form keys for rule_count dict, map key to count of rule
  288. sort_keys = [f"{key}, {r_2_label[key]}" for key in sort_keys]
  289. sort_dict = {k: rule_count[k] for k in sort_keys}
  290. stat_tree = dt.build_tree(sort_dict)
  291. stat_tree = stat_tree.split("\n")
  292. # make into lines, remove beginning spaces
  293. stat_tree = list(map(lambda x: re.sub(r"^\s+", '', x), stat_tree))
  294. # remove empty member at end (possibly)
  295. if not stat_tree[-1]:
  296.   stat_tree.pop(-1)
  297. if no_match > 0:
  298.   # if rows not matched, contained UNMATCHED line
  299.   stat_tree.append(f"UNMATCHED: {no_match}")
  300. stat_tree = '\n'.join(stat_tree)
  301. print(stat_tree)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement