Advertisement
LunaeStellsr

Apriori

Oct 24th, 2018
243
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.68 KB | None | 0 0
  1. import os
  2. import collections
  3. import itertools
  4. import time
  5.  
  6. # Present Working Dictionary
  7. PWD      = os.path.dirname(os.path.realpath(__file__))
  8.  
  9. # Source Data Path
  10. Filename = PWD + '\\mushroom.dat'
  11.  
  12. # Minimum Support/Confidence thereshold
  13. MinSup   = 0.1
  14. MinConf  = 0.8
  15.  
  16. # Max elements of relationship
  17. TargetLevel = 5
  18.  
  19. #===================================================================
  20. # ** Apriori class
  21. #===================================================================
  22. class Aprio:
  23.   #-----------------------------------------------------------------
  24.   # * Object Initialization
  25.   #-----------------------------------------------------------------
  26.   def __init__(self, min_sup, min_conf, data_filename):
  27.     self.min_sup    = min_sup
  28.     self.min_conf   = min_conf
  29.     self.load_data(data_filename)
  30.   #-----------------------------------------------------------------
  31.   # * Load data and start
  32.   #-----------------------------------------------------------------
  33.   def load_data(self, data_filename):
  34.  
  35.     # Open source data file
  36.     # level = elements count in relationship set
  37.     with open(Filename, 'r') as file:
  38.       # Raw data lines count
  39.       self.data_size    = 0
  40.  
  41.       # Array of dict: Appearance count of elements itemset
  42.       self.appearance   = [{} for i in range(TargetLevel + 1)]  
  43.  
  44.       # Array of dict: supprt of elements in each level
  45.       self.support      = [{} for i in range(TargetLevel + 1)]  
  46.  
  47.       # Array of array: result array
  48.       self.result       = [[] for i in range(TargetLevel + 1)]
  49.  
  50.       # Iterate the data by lines
  51.       for line in file:
  52.         # debug: print line processed
  53.         print(self.data_size)
  54.  
  55.         # load line and convert to list
  56.         ar = [int(x) for x in line.split()]
  57.  
  58.         # Process each item in line
  59.         for i in range(1, TargetLevel + 1):
  60.           # Generate possible itemset combined by current line elements and add its counter
  61.           for element in itertools.combinations(ar, i):
  62.             element = tuple(sorted(element))
  63.             if element in self.appearance[i]:
  64.               self.appearance[i][element] += 1
  65.             else:
  66.               self.appearance[i][element] = 1
  67.           #end of items combination
  68.         #end item size
  69.         self.data_size += 1
  70.       #end each line
  71.  
  72.       self.fliter_elements()
  73.     #end open file
  74.    
  75.     # start
  76.     self.find_all_association()
  77.   #end load_data
  78.   #-----------------------------------------------------------------
  79.   # * Fliter unecessary elements
  80.   #-----------------------------------------------------------------
  81.   def fliter_elements(self):
  82.     for i in range(1, TargetLevel + 1):
  83.       # Record supports, only the one higher than thereshold will be recorded
  84.       self.support[i] = {k: v/self.data_size for k, v in self.appearance[i].items() if v/self.data_size >= self.min_sup}
  85.  
  86.   #-----------------------------------------------------------------
  87.   # * Find all associations
  88.   #-----------------------------------------------------------------
  89.   def find_all_association(self):
  90.     # i: item number for a (a => n)
  91.     for i in range(1, TargetLevel):
  92.       # j: item number for b
  93.       for j in range(1, TargetLevel + 1 - i):
  94.         self.find_association(i, j)
  95.  
  96.   #-----------------------------------------------------------------
  97.   # * Find association of given level
  98.   #-----------------------------------------------------------------
  99.   def find_association(self, a_level, b_level):
  100.     # item number of a + b
  101.     level = a_level + b_level
  102.    
  103.     # all possbible set that contains <a_level> items
  104.     for set_a in self.support[a_level]:
  105.       # all possbible set that contains <b_level> items
  106.       for set_b in self.support[b_level]:
  107.  
  108.         # convert to the set that contains both uniq items
  109.         union = tuple(sorted(set().union(set_a, set_b)))
  110.  
  111.         # continue if any item is duplicated or support not reached thereshold
  112.         if len(union) != level or (union not in self.support[level]):
  113.           continue
  114.  
  115.         # calculate the confidence
  116.         conf = self.support[level][union] / self.support[a_level][set_a]
  117.  
  118.         # push to result if thereshold reached
  119.         if conf >= self.min_conf:
  120.           text = "{} => {}".format(str(set_a), str(set_b))
  121.           print(text, conf)
  122.           self.result[level].append(text)
  123.        
  124.   #end
  125. #end
  126.  
  127. cur = time.time()
  128. aprio = Aprio(MinSup, MinConf, Filename)
  129. t = time.time() - cur
  130.  
  131. OutF = PWD + "\\out.txt"
  132. with open(OutF, 'w') as file:
  133.   print("Calculation time taken: {}".format(t))
  134.   for i in range(2, TargetLevel + 1):
  135.     for s in aprio.result[i]:
  136.       file.write(s + '\n')
  137.     file.write(str(len(aprio.result[i])) + '\n')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement