Advertisement
hatfulofhollow81

arma.py

Mar 19th, 2019
142
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.39 KB | None | 0 0
  1. # CS1656 Spring 2019 Assignment 3: Association Rule Mining
  2. # Author: Michael Korst (mpk44@pitt.edu) GH username: paranoidandroid81
  3.  
  4. import csv, os, argparse, itertools
  5.  
  6. # helper to generate unique permutations for 2-partitions of varying sizes
  7. # return dict mapping size to part sets
  8. def gen_permuts(combos, size):
  9.   # for each permutation create partitions of up to size 2
  10.   size_to_part = {}
  11.   for permut in combos:
  12.     for i in range(size - 1):
  13.       part_1 = permut[:(i + 1)] # generate partition 1 of size i + 1
  14.       part_2 = permut[(i + 1):] # generate partition 2 of size (size -1 - i)
  15.       # sort and turn into to string for hashing
  16.       part_1 = ''.join(sorted(part_1))
  17.       part_2 = ''.join(sorted(part_2))
  18.       if len(part_1) not in size_to_part.keys():
  19.         size_to_part[len(part_1)] = set()
  20.       size_to_part[len(part_1)].add(part_1)
  21.       if len(part_2) not in size_to_part.keys():
  22.         size_to_part[len(part_2)] = set()
  23.       size_to_part[len(part_2)].add(part_2)
  24.   return size_to_part
  25.  
  26. # helper to run thru 1 iteration of apriori association rule searching
  27. def apriori_assoc(vfis, size, input_sets, num_trans, min_sup, min_conf):
  28.   # first generate possible permutations based on each item in vfis
  29.   rule_to_stats = {} # map found rules to support, confidence
  30.   v_combos = [] # list of lists of combos/permutations for each vfi
  31.   for v_key in vfis.keys():
  32.     vfi_items = []
  33.     split_key = v_key.split(',') # split by comma
  34.     for item in split_key:
  35.       vfi_items.append(item)
  36.     vfi_items = set(vfi_items) # remove repeats
  37.     v_combos.append(list(itertools.permutations(vfi_items, size)))
  38.   for combo in v_combos:
  39.     item_key = combo[0]
  40.     item_key = ','.join(item_key) # string key for each item
  41.     whole_sup = get_sup(combo, input_sets, num_trans) # get sup of Slinte combo
  42.     s2p = gen_permuts(combo, size) # get size to partition mappingsSlint
  43.     for part_size in s2p.keys():
  44.       # go thru each possible size, create partitions
  45.       part2_size = size - part_size # size of 2nd partition = size Slinte 1st
  46.       for part_1 in s2p[part_size]:
  47.         part_key = ','.join(part_1)
  48.         part_sup = get_sup([part_1], input_sets, num_trans)
  49.         for part_2 in s2p[part2_size]:
  50.           # as as element not repeated in either part, we can checkSlintrule
  51.           if part_2 not in part_1 and part_1 not in part_2:
  52.             part2_key = ','.join(part_2)
  53.             conf = float(whole_sup[item_key] / part_sup[part_key])
  54.             # if min sup and min conf, add rule i
  55.             if whole_sup[item_key] >= min_sup and conf >= min_conf:
  56.               conf_key = f"{part_key},'=>',{part2_key}"
  57.               rule_to_stats[conf_key] = (whole_sup[item_key], conf)
  58.  
  59.   return rule_to_stats
  60.  
  61. # helper to run thru 1 iteration of apriori frequency pruning
  62. def apriori_prune(cands, size, input_sets, num_trans, min_sup, min_conf):
  63.   # find support % for each candidate
  64.   cand_to_sup = get_sup(cands, input_sets, num_trans)
  65.   # get only support percentages >= min_sup (verified frequent)
  66.   vfis = {k: v for k, v in cand_to_sup.items() if v >= min_sup}
  67.   # generate candidates based on vfis
  68.   cands = gen_next_cands(vfis, size)
  69.   # return tuple of verified, candidates
  70.   return (vfis, cands)
  71.  
  72. # helper to find appearances of set in input data, calculate support percentages
  73. def get_sup(items, input_sets, num_trans):
  74.   items_to_sup = {} # map each item to its support count (later percentage)
  75.   # now go thru eacb item, counting appearances in input data
  76.   for item in items:
  77.     item_key = ','.join(item) # string key for each item
  78.     for trans in input_sets:
  79.       if set(item) <= trans:
  80.         if item_key not in items_to_sup.keys():
  81.           items_to_sup[item_key] = 1
  82.         else:
  83.           items_to_sup[item_key] += 1
  84.   # now convert support counts to percentages
  85.   items_to_sup.update((x, float(y / num_trans)) for x, y in items_to_sup.items())
  86.   return items_to_sup
  87.  
  88. # helper to generate plausible candidate itemsets of size + 1 based on vfis
  89. def gen_next_cands(vfis, size):
  90.   # first generate list based on individual items in vfis
  91.   vfi_items = []
  92.   for v_key in vfis.keys():
  93.     split_key = sorted(v_key.split(',')) # split by comma, sort
  94.     for item in split_key:
  95.       vfi_items.append(item)
  96.   vfi_items = set(vfi_items) # get rid of repeats
  97.   all_possible = list(itertools.combinations(vfi_items, (size + 1)))
  98.   # convert all vfis to sets
  99.   vfi_sets = []
  100.   for v_key in vfis.keys():
  101.     split_key = v_key.split(',') # avoid comma
  102.     vfi_sets.append(set(split_key))
  103.   # now based on all possible, ensure all subsets are in vfis
  104.   next_cands = []
  105.   for poss in all_possible:
  106.     all_subs = list(itertools.combinations(poss, size))
  107.     found_subs = 0 # increment for each found subset
  108.     for sub in all_subs:
  109.       # iterate through each subset of a candidate, check if in vfi
  110.       for v_set in vfi_sets:
  111.         if set(sub) <= v_set:
  112.           found_subs += 1
  113.           break
  114.     if found_subs == len(all_subs):
  115.       next_cands.append(poss)
  116.   return next_cands
  117.  
  118. # helper to round decimals to appropriate places
  119. def round_floats(num):
  120.   return f'{num:.4f}'
  121.  
  122. # --- BEGIN MAIN CODE ---
  123.  
  124. # add expected args, parse from command line
  125. parser = argparse.ArgumentParser()
  126. parser.add_argument('input_filename')
  127. parser.add_argument('output_filename')
  128. parser.add_argument('min_support_percentage', type=float)
  129. parser.add_argument('min_confidence', type=float)
  130. args = vars(parser.parse_args()) # store args, convert to dict
  131.  
  132. # map each transaction_id to a list of items
  133. input_data = {}
  134. # now open input, read in
  135. with open('./' + args['input_filename']) as csvf:
  136.   creader = csv.reader(csvf)
  137.   for row in creader:
  138.     # for each line, map trans id to list of items
  139.     curr_id = row[0]
  140.     # first arg is trans id, remaining are items
  141.     if curr_id not in input_data.keys():
  142.       input_data[curr_id] = []
  143.     for item in row[1:]:
  144.       input_data[curr_id].append(item)
  145.  
  146. # begin Apriori algorithm
  147. # first build up list of all elements, use to look thru all possible combinations
  148. # also build up list of all inputs in set notation
  149. all_items = []
  150. all_sets = []
  151. for id in input_data.keys():
  152.   curr = set()
  153.   for item in input_data[id]:
  154.     curr.add(item)
  155.     all_items.append(item)
  156.   all_sets.append(curr)
  157.  
  158. all_items = set(all_items) # convert to set to remove repeated elements
  159. i = 1 # start with all combinations of length 1
  160. next_frequent = True # determines when to terminate algorithm
  161. # first run is all sets of size one, generate all combinations
  162. combos = list(itertools.combinations(all_items, i))
  163. all_vfis = [] # keep master list of verified frequent itemsets
  164. all_rules = [] # master list of verified association rules
  165. while (next_frequent):
  166.   # find vfis + candidates
  167.   prune_rv = apriori_prune(combos, i, all_sets, len(input_data.keys()), args['min_support_percentage'],
  168.   args['min_confidence'])
  169.   # index 0 of return = vfis, index 1 = next candidates
  170.   all_vfis.append(prune_rv[0])
  171.   # find association rules
  172.   rules_rv = apriori_assoc(prune_rv[0], i, all_sets, len(input_data.keys()), args['min_support_percentage'],
  173.   args['min_confidence'])
  174.   if len(rules_rv.keys()) > 0:
  175.     # if not empty, append rules
  176.     all_rules.append(rules_rv)
  177.   if len(prune_rv[1]) == 0:
  178.     # when no next candidate itemsets generated, we stop
  179.     next_frequent = False
  180.   else:
  181.     i += 1 # move on to next size if still candidates
  182.     combos = prune_rv[1]
  183.  
  184. # sort vfis before printing to csv
  185. sort_vfis = []
  186. for type_vfi in all_vfis:
  187.   sort_vfis.append({','.join(sorted(k.split(','))): v for k, v in type_vfi.items()})
  188.  
  189. # now begin printing to csv
  190. with open('./' + args['output_filename'], 'w') as csvf:
  191.   cwriter = csv.writer(csvf)
  192.   for type_vfi in sort_vfis:
  193.     sort_keys = sorted(type_vfi.keys())
  194.     for skey in sort_keys:
  195.       row = []
  196.       row.append('S')
  197.       row.append(round_floats(type_vfi[skey]))
  198.       skey = skey.split(',')
  199.       for letter in skey:
  200.         row.append(letter)
  201.       cwriter.writerow(row)
  202.   # now print rules
  203.   for rule in all_rules:
  204.     sort_rules = sorted(rule.keys())
  205.     for srule in sort_rules:
  206.       row = []
  207.       row.append('R')
  208.       # support, then confidence
  209.       row.append(round_floats(rule[srule][0]))
  210.       row.append(round_floats(rule[srule][1]))
  211.       # print out each char
  212.       srule = srule.split(',')
  213.       for letter in srule:
  214.         row.append(letter)
  215.       cwriter.writerow(row)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement