gt22

Untitled

Oct 8th, 2018
521
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.17 KB | None | 0 0
  1. class SpamDetector:
  2.  
  3.     def __init__(self):
  4.         self.p_class = [0, 0]
  5.         self.p_words = [{}, {}]
  6.  
  7.     # typical machine learning model interface
  8.     def fit(self, x, y):
  9.         self.p_class[1] = np.mean(y)
  10.         self.p_class[0] = 1 - self.p_class[1]
  11.         for c in [0, 1]:
  12.             data = x[y == c]
  13.             words = self._count_words(data)
  14.             words_sum = sum(words.values())
  15.             self.p_words[c] = {word: log(count / words_sum) for (word, count) in words.items()}
  16.            
  17.     def predict(self, d):
  18.         proba = self.predict_proba(d)
  19.         return proba[:, 1] < proba[:, 0]
  20.  
  21.     def predict_proba(self, d):
  22.         ret = np.zeros((d.shape[0], 2))
  23.         for i, msg in enumerate(d):
  24.             for c in [0, 1]:
  25.                 ret[i, c] = self.p_class[c] * self.__p_message(msg, c)
  26.         return ret
  27.  
  28.     # helper functions
  29.     def _count_words(self, d):
  30.         ret = defaultdict(lambda: 0)
  31.         for line in d:
  32.             for word in line:
  33.                 ret[word] += 1
  34.         return ret
  35.  
  36.     def __p_message(self, msg, c):
  37.         d = self.p_words[c]
  38.         return sum(d[s] if s in d else 0 for s in msg)
Advertisement
Add Comment
Please, Sign In to add comment