Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import collections
- import itertools
- import time
- # Present Working Dictionary
- PWD = os.path.dirname(os.path.realpath(__file__))
- # Source Data Path
- Filename = PWD + '\\mushroom.dat'
- # Minimum Support/Confidence thereshold
- MinSup = 0.1
- MinConf = 0.8
- # Max elements of relationship
- TargetLevel = 5
- #===================================================================
- # ** Apriori class
- #===================================================================
- class Aprio:
- #-----------------------------------------------------------------
- # * Object Initialization
- #-----------------------------------------------------------------
- def __init__(self, min_sup, min_conf, data_filename):
- self.min_sup = min_sup
- self.min_conf = min_conf
- self.load_data(data_filename)
- #-----------------------------------------------------------------
- # * Load data and start
- #-----------------------------------------------------------------
- def load_data(self, data_filename):
- # Open source data file
- # level = elements count in relationship set
- with open(Filename, 'r') as file:
- # Raw data lines count
- self.data_size = 0
- # Array of dict: Appearance count of elements itemset
- self.appearance = [{} for i in range(TargetLevel + 1)]
- # Array of dict: supprt of elements in each level
- self.support = [{} for i in range(TargetLevel + 1)]
- # Array of array: result array
- self.result = [[] for i in range(TargetLevel + 1)]
- # Iterate the data by lines
- for line in file:
- # debug: print line processed
- print(self.data_size)
- # load line and convert to list
- ar = [int(x) for x in line.split()]
- # Process each item in line
- for i in range(1, TargetLevel + 1):
- # Generate possible itemset combined by current line elements and add its counter
- for element in itertools.combinations(ar, i):
- element = tuple(sorted(element))
- if element in self.appearance[i]:
- self.appearance[i][element] += 1
- else:
- self.appearance[i][element] = 1
- #end of items combination
- #end item size
- self.data_size += 1
- #end each line
- self.fliter_elements()
- #end open file
- # start
- self.find_all_association()
- #end load_data
- #-----------------------------------------------------------------
- # * Fliter unecessary elements
- #-----------------------------------------------------------------
- def fliter_elements(self):
- for i in range(1, TargetLevel + 1):
- # Record supports, only the one higher than thereshold will be recorded
- self.support[i] = {k: v/self.data_size for k, v in self.appearance[i].items() if v/self.data_size >= self.min_sup}
- #-----------------------------------------------------------------
- # * Find all associations
- #-----------------------------------------------------------------
- def find_all_association(self):
- # i: item number for a (a => n)
- for i in range(1, TargetLevel):
- # j: item number for b
- for j in range(1, TargetLevel + 1 - i):
- self.find_association(i, j)
- #-----------------------------------------------------------------
- # * Find association of given level
- #-----------------------------------------------------------------
- def find_association(self, a_level, b_level):
- # item number of a + b
- level = a_level + b_level
- # all possbible set that contains <a_level> items
- for set_a in self.support[a_level]:
- # all possbible set that contains <b_level> items
- for set_b in self.support[b_level]:
- # convert to the set that contains both uniq items
- union = tuple(sorted(set().union(set_a, set_b)))
- # continue if any item is duplicated or support not reached thereshold
- if len(union) != level or (union not in self.support[level]):
- continue
- # calculate the confidence
- conf = self.support[level][union] / self.support[a_level][set_a]
- # push to result if thereshold reached
- if conf >= self.min_conf:
- text = "{} => {}".format(str(set_a), str(set_b))
- print(text, conf)
- self.result[level].append(text)
- #end
- #end
- cur = time.time()
- aprio = Aprio(MinSup, MinConf, Filename)
- t = time.time() - cur
- OutF = PWD + "\\out.txt"
- with open(OutF, 'w') as file:
- print("Calculation time taken: {}".format(t))
- for i in range(2, TargetLevel + 1):
- for s in aprio.result[i]:
- file.write(s + '\n')
- file.write(str(len(aprio.result[i])) + '\n')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement