gt22

Untitled

Oct 14th, 2018
265
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.08 KB | None | 0 0
  1. class SpamDetector:
  2.  
  3.     def __init__(self):
  4.         self.p_class = [0, 0]
  5.         self.p_word = [{}, {}]
  6.  
  7.     # typical machine learning model interface
  8.     def fit(self, x, y):
  9.         for c in [0, 1]:
  10.             data = x[y == c]
  11.             words = self._count_words(data)
  12.             word_sum = sum(words.values())
  13.             self.p_word[c] = {word: -log(count / word_sum) for word, count in words.items()}
  14.             self.p_class[c] = -log(np.mean(y == c))
  15.  
  16.     def predict(self, d):
  17.         return np.argmax(self.predict_proba(d), axis=1)
  18.  
  19.     def predict_proba(self, d):
  20.         ret = np.zeros((d.shape[0], 2))
  21.         ret += self.p_class
  22.         for i, msg in enumerate(d):
  23.             for c in [0, 1]:
  24.                 ret[i, c] = self.__p_message(msg, self.p_word[c])
  25.         return ret
  26.  
  27.     # helper functions
  28.     def _count_words(self, d):
  29.         ret = defaultdict(lambda: 0)
  30.         for line in d:
  31.             for word in line:
  32.                 ret[word] += 1
  33.         return ret
  34.  
  35.     def __p_message(self, msg, d):
  36.         return sum(d.get(s, 0) for s in msg)
Advertisement
Add Comment
Please, Sign In to add comment