Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class SpamDetector:
- def __init__(self):
- self.p_class = [0, 0]
- self.p_word = [{}, {}]
- # typical machine learning model interface
- def fit(self, x, y):
- for c in [0, 1]:
- data = x[y == c]
- words = self._count_words(data)
- word_sum = sum(words.values())
- self.p_word[c] = {word: -log(count / word_sum) for word, count in words.items()}
- self.p_class[c] = -log(np.mean(y == c))
- def predict(self, d):
- return np.argmax(self.predict_proba(d), axis=1)
- def predict_proba(self, d):
- ret = np.zeros((d.shape[0], 2))
- ret += self.p_class
- for i, msg in enumerate(d):
- for c in [0, 1]:
- ret[i, c] = self.__p_message(msg, self.p_word[c])
- return ret
- # helper functions
- def _count_words(self, d):
- ret = defaultdict(lambda: 0)
- for line in d:
- for word in line:
- ret[word] += 1
- return ret
- def __p_message(self, msg, d):
- return sum(d.get(s, 0) for s in msg)
Advertisement
Add Comment
Please, Sign In to add comment