Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class SpamDetector:
- def __init__(self):
- self.p_class = [0, 0]
- self.p_words = [{}, {}]
- # typical machine learning model interface
- def fit(self, x, y):
- self.p_class[1] = np.mean(y)
- self.p_class[0] = 1 - self.p_class[1]
- for c in [0, 1]:
- data = x[y == c]
- words = self._count_words(data)
- words_sum = sum(words.values())
- self.p_words[c] = {word: log(count / words_sum) for (word, count) in words.items()}
- def predict(self, d):
- proba = self.predict_proba(d)
- return proba[:, 1] < proba[:, 0]
- def predict_proba(self, d):
- ret = np.zeros((d.shape[0], 2))
- for i, msg in enumerate(d):
- for c in [0, 1]:
- ret[i, c] = self.p_class[c] * self.__p_message(msg, c)
- return ret
- # helper functions
- def _count_words(self, d):
- ret = defaultdict(lambda: 0)
- for line in d:
- for word in line:
- ret[word] += 1
- return ret
- def __p_message(self, msg, c):
- d = self.p_words[c]
- return sum(d[s] if s in d else 0 for s in msg)
Advertisement
Add Comment
Please, Sign In to add comment