Advertisement
AlessandroG

SpamDetector

Oct 19th, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.15 KB | None | 0 0
  1. from nltk import PorterStemmer
  2. from nltk.tokenize import word_tokenize
  3. from nltk.corpus import stopwords
  4. from prettytable import PrettyTable
  5. from texttable import Texttable
  6. import string
  7. from termcolor import colored
  8.  
  9. def get_data(file_name):
  10.     f = open(file_name, "r")
  11.     line_list = f.readlines()
  12.     f.close()
  13.  
  14.     data = []
  15.     target = []
  16.  
  17.     for i in range(len(line_list)):
  18.         # spam
  19.         if line_list[i][:4] == "spam":
  20.             data.append(line_list[i][5:])
  21.             target.append(1)
  22.         # ham
  23.         elif line_list[i][:3] == "ham":
  24.             data.append(line_list[i][4:])
  25.             target.append(0)
  26.     return data, target
  27.  
  28. class SpamDetector(object):
  29.  
  30.     def tokenize(self, text):
  31.         # split into words
  32.         tokens = word_tokenize(text, 'english')
  33.         # convert to lower case
  34.         tokens = [w.lower() for w in tokens]
  35.         # remove punctuation from each word
  36.         table = str.maketrans('', '', string.punctuation)
  37.         stripped = [w.translate(table) for w in tokens]
  38.         # remove remaining tokens that are not alphabetic
  39.         words = [word for word in stripped if word.isalpha() and len(word)>1]
  40.         # filter out stop words
  41.         stop_words = set(stopwords.words('english'))
  42.         words = [w for w in words if not w in stop_words]
  43.         #stammed
  44.         porter = PorterStemmer()
  45.         stemmed = [porter.stem(word) for word in words]
  46.  
  47.         #not stemmed
  48.         return stemmed
  49.  
  50.     def get_word_counts(self, words):
  51.         word_counts = {}
  52.         for word in words:
  53.             word_counts[word] = word_counts.get(word, 0.0) + 1.0
  54.         return word_counts
  55.  
  56.     def train(self, X, Y):
  57.         self.num_messages = {}
  58.         self.priors_prob = {}
  59.         self.word_counts = {}
  60.         self.all_words = {}
  61.         self.vocab = set()
  62.  
  63.         n = len(X)
  64.         self.num_messages['spam'] = sum(1 for label in Y if label == 1)  # numero messaggi spam
  65.         self.num_messages['ham'] = sum(1 for label in Y if label == 0)  # numero messagi ham
  66.         self.priors_prob['spam'] = self.num_messages['spam'] / n  # prob a priori spam
  67.         self.priors_prob['ham'] = self.num_messages['ham'] / n  # prob a priori messaggi ham
  68.         self.word_counts['spam'] = {}
  69.         self.word_counts['ham'] = {}
  70.         self.all_words['spam'] = 0
  71.         self.all_words['ham'] = 0
  72.  
  73.         for x, y in zip(X, Y):
  74.             c = 'spam' if y == 1 else 'ham'
  75.             counts = self.get_word_counts(self.tokenize(x))
  76.             for word, count in counts.items():
  77.                 if word not in self.vocab:
  78.                     self.vocab.add(word)  # vocabolario di tutte le parole
  79.                 if word not in self.word_counts[c]:
  80.                     self.word_counts[c][word] = 0.0  # vacabolari ham e spam con le relative occorrenze
  81.  
  82.                 self.word_counts[c][word] += count
  83.                 self.all_words[c] += count
  84.  
  85.     def predict(self, X):
  86.         result_multinomial = []
  87.         result_bernoulli = []
  88.  
  89.         for x in X:  # iterazione su tutti i messaggi presenti nel dataset
  90.             counts = self.get_word_counts(self.tokenize(x))  # per ogni messaggio del dataset viene spezzato e contato il numero di occorrenze
  91.             word_in_doc_spam = 0
  92.             word_in_doc_ham = 0
  93.             for word, _ in counts.items():  # iteriamo tutte le parole in counts
  94.                 if word not in self.vocab:
  95.                     continue  # se la parola non è presente nel vocabolario non viene valutata
  96.                 word_in_doc_spam += (self.word_counts['spam'].get(word, 0.0))
  97.                 word_in_doc_ham += (self.word_counts['ham'].get(word, 0.0))
  98.             spam_score = word_in_doc_spam + 1 / (self.all_words['spam'] + len(self.vocab)) * self.priors_prob['spam']
  99.             ham_score = word_in_doc_ham + 1 / (self.all_words['ham'] + len(self.vocab)) * self.priors_prob['ham']
  100.  
  101.             if spam_score > ham_score:  # per ogni messaggio teniamo traccia se è spam o ham
  102.                 result_multinomial.append(1)
  103.             else:
  104.                 result_multinomial.append(0)
  105.  
  106.         return result_multinomial
  107.  
  108.     def classify(self, true, predicted):
  109.  
  110.         true_positives = 0
  111.         true_negatives = 0
  112.         false_positives = 0
  113.         false_negatives = 0
  114.  
  115.         for x, y in zip(true, predicted):
  116.             if x == 1 and y == 1:
  117.                 true_positives += 1
  118.             elif x == 0 and y == 0:
  119.                 true_negatives += 1
  120.             elif x == 1 and y == 0:
  121.                 false_negatives += 1
  122.             elif x == 0 and y == 1:
  123.                 false_positives += 1
  124.  
  125.         error_rate = (false_negatives + false_positives) / (true_positives + true_negatives + false_positives + false_negatives)
  126.         accuracy = 1 - error_rate
  127.         if false_negatives == 0:
  128.             recall = 1
  129.         else:
  130.             recall = true_positives / (true_positives + false_negatives)
  131.         if false_positives == 0:
  132.             precision = 1
  133.         else:
  134.             precision = true_positives / (true_positives + false_positives)
  135.         f1_score = 2 * (precision * recall) / (precision + recall)
  136.  
  137.         t = PrettyTable(['','predicted', 'class'],)
  138.         t.add_row(['true class\nyes',"yes\n"
  139.                    + colored("TP: ","green") + colored(str(true_positives),"green"),"no\n"
  140.                    + colored("FN: ","red") + colored(str(false_negatives),"red")])
  141.  
  142.         t.add_row(['no',colored("FP: ","red")
  143.                    + colored(str(false_positives),"red")
  144.                       ,colored("TN: ","green") + colored(str(true_negatives),"green")])
  145.  
  146.         print("Confusion matrix")
  147.         print(t.get_string())
  148.  
  149.         t = Texttable()
  150.         t.add_rows([['parameter','value'],['Error rate:', error_rate],
  151.                     ['Accuracy:', accuracy], ['Recall:', recall],
  152.                     ['Precision:', precision], ['F1-score:', f1_score]])
  153.         print(t.draw())
  154.  
  155.  
  156.  
  157.  
  158.  
  159. if __name__ == '__main__':
  160.     data, target = get_data("SMSSpamCollection")
  161.     MNB = SpamDetector()
  162.     MNB.train(data, target)
  163.  
  164.     predicted = MNB.predict(data)
  165.     true = target
  166.  
  167.     MNB.classify(true, predicted)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement