Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from nltk import PorterStemmer
- from nltk.tokenize import word_tokenize
- from nltk.corpus import stopwords
- from prettytable import PrettyTable
- from texttable import Texttable
- import string
- from termcolor import colored
- def get_data(file_name):
- f = open(file_name, "r")
- line_list = f.readlines()
- f.close()
- data = []
- target = []
- for i in range(len(line_list)):
- # spam
- if line_list[i][:4] == "spam":
- data.append(line_list[i][5:])
- target.append(1)
- # ham
- elif line_list[i][:3] == "ham":
- data.append(line_list[i][4:])
- target.append(0)
- return data, target
- class SpamDetector(object):
- def tokenize(self, text):
- # split into words
- tokens = word_tokenize(text, 'english')
- # convert to lower case
- tokens = [w.lower() for w in tokens]
- # remove punctuation from each word
- table = str.maketrans('', '', string.punctuation)
- stripped = [w.translate(table) for w in tokens]
- # remove remaining tokens that are not alphabetic
- words = [word for word in stripped if word.isalpha() and len(word)>1]
- # filter out stop words
- stop_words = set(stopwords.words('english'))
- words = [w for w in words if not w in stop_words]
- #stammed
- porter = PorterStemmer()
- stemmed = [porter.stem(word) for word in words]
- #not stemmed
- return stemmed
- def get_word_counts(self, words):
- word_counts = {}
- for word in words:
- word_counts[word] = word_counts.get(word, 0.0) + 1.0
- return word_counts
- def train(self, X, Y):
- self.num_messages = {}
- self.priors_prob = {}
- self.word_counts = {}
- self.all_words = {}
- self.vocab = set()
- n = len(X)
- self.num_messages['spam'] = sum(1 for label in Y if label == 1) # numero messaggi spam
- self.num_messages['ham'] = sum(1 for label in Y if label == 0) # numero messagi ham
- self.priors_prob['spam'] = self.num_messages['spam'] / n # prob a priori spam
- self.priors_prob['ham'] = self.num_messages['ham'] / n # prob a priori messaggi ham
- self.word_counts['spam'] = {}
- self.word_counts['ham'] = {}
- self.all_words['spam'] = 0
- self.all_words['ham'] = 0
- for x, y in zip(X, Y):
- c = 'spam' if y == 1 else 'ham'
- counts = self.get_word_counts(self.tokenize(x))
- for word, count in counts.items():
- if word not in self.vocab:
- self.vocab.add(word) # vocabolario di tutte le parole
- if word not in self.word_counts[c]:
- self.word_counts[c][word] = 0.0 # vacabolari ham e spam con le relative occorrenze
- self.word_counts[c][word] += count
- self.all_words[c] += count
- def predict(self, X):
- result_multinomial = []
- result_bernoulli = []
- for x in X: # iterazione su tutti i messaggi presenti nel dataset
- counts = self.get_word_counts(self.tokenize(x)) # per ogni messaggio del dataset viene spezzato e contato il numero di occorrenze
- word_in_doc_spam = 0
- word_in_doc_ham = 0
- for word, _ in counts.items(): # iteriamo tutte le parole in counts
- if word not in self.vocab:
- continue # se la parola non è presente nel vocabolario non viene valutata
- word_in_doc_spam += (self.word_counts['spam'].get(word, 0.0))
- word_in_doc_ham += (self.word_counts['ham'].get(word, 0.0))
- spam_score = word_in_doc_spam + 1 / (self.all_words['spam'] + len(self.vocab)) * self.priors_prob['spam']
- ham_score = word_in_doc_ham + 1 / (self.all_words['ham'] + len(self.vocab)) * self.priors_prob['ham']
- if spam_score > ham_score: # per ogni messaggio teniamo traccia se è spam o ham
- result_multinomial.append(1)
- else:
- result_multinomial.append(0)
- return result_multinomial
- def classify(self, true, predicted):
- true_positives = 0
- true_negatives = 0
- false_positives = 0
- false_negatives = 0
- for x, y in zip(true, predicted):
- if x == 1 and y == 1:
- true_positives += 1
- elif x == 0 and y == 0:
- true_negatives += 1
- elif x == 1 and y == 0:
- false_negatives += 1
- elif x == 0 and y == 1:
- false_positives += 1
- error_rate = (false_negatives + false_positives) / (true_positives + true_negatives + false_positives + false_negatives)
- accuracy = 1 - error_rate
- if false_negatives == 0:
- recall = 1
- else:
- recall = true_positives / (true_positives + false_negatives)
- if false_positives == 0:
- precision = 1
- else:
- precision = true_positives / (true_positives + false_positives)
- f1_score = 2 * (precision * recall) / (precision + recall)
- t = PrettyTable(['','predicted', 'class'],)
- t.add_row(['true class\nyes',"yes\n"
- + colored("TP: ","green") + colored(str(true_positives),"green"),"no\n"
- + colored("FN: ","red") + colored(str(false_negatives),"red")])
- t.add_row(['no',colored("FP: ","red")
- + colored(str(false_positives),"red")
- ,colored("TN: ","green") + colored(str(true_negatives),"green")])
- print("Confusion matrix")
- print(t.get_string())
- t = Texttable()
- t.add_rows([['parameter','value'],['Error rate:', error_rate],
- ['Accuracy:', accuracy], ['Recall:', recall],
- ['Precision:', precision], ['F1-score:', f1_score]])
- print(t.draw())
- if __name__ == '__main__':
- data, target = get_data("SMSSpamCollection")
- MNB = SpamDetector()
- MNB.train(data, target)
- predicted = MNB.predict(data)
- true = target
- MNB.classify(true, predicted)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement