Advertisement
GeorgePashev_88

KNN Classifier

Mar 14th, 2024
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.43 KB | None | 0 0
  1. class KNN_NLC_Classifer():
  2.     def __init__(self, k=1, distance_type='path'):
  3.         self.k = k
  4.         self.distance_type = distance_type
  5.  
  6.     # This function is used for training
  7.     def fit(self, x_train, y_train):
  8.         self.x_train = x_train
  9.         self.y_train = y_train
  10.  
  11.     # This function runs the K(1) nearest neighbour algorithm and
  12.     # returns the label with closest match.
  13.     def predict(self, x_test):
  14.         self.x_test = x_test
  15.         y_predict = []
  16.  
  17.         for i in range(len(x_test)):
  18.             max_sim = 0
  19.             max_index = 0
  20.             for j in range(self.x_train.shape[0]):
  21.                 temp = self.document_similarity(x_test[i], self.x_train[j])
  22.                 if temp > max_sim:
  23.                     max_sim = temp
  24.                     max_index = j
  25.             y_predict.append(self.y_train[max_index])
  26.         return y_predict
  27.  
  28.     def convert_tag(self, tag):
  29.         """Convert the tag given by nltk.pos_tag to the tag used by wordnet.synsets"""
  30.         tag_dict = {'N': 'n', 'J': 'a', 'R': 'r', 'V': 'v'}
  31.         try:
  32.             return tag_dict[tag[0]]
  33.         except KeyError:
  34.             return None
  35.  
  36.     def doc_to_synsets(self, doc):
  37.         """
  38.            Returns a list of synsets in document.
  39.            Tokenizes and tags the words in the document doc.
  40.            Then finds the first synset for each word/tag combination.
  41.        If a synset is not found for that combination it is skipped.
  42.  
  43.        Args:
  44.            doc: string to be converted
  45.  
  46.        Returns:
  47.            list of synsets
  48.        """
  49.         tokens = word_tokenize(doc + ' ')
  50.  
  51.         l = []
  52.         tags = nltk.pos_tag([tokens[0] + ' ']) if len(tokens) == 1 else nltk.pos_tag(tokens)
  53.  
  54.         for token, tag in zip(tokens, tags):
  55.             syntag = self.convert_tag(tag[1])
  56.             syns = wn.synsets(token, syntag)
  57.             if (len(syns) > 0):
  58.                 l.append(syns[0])
  59.         return l
  60.  
  61.     def similarity_score(self, s1, s2, distance_type='path'):
  62.         """
  63.        Calculate the normalized similarity score of s1 onto s2
  64.        For each synset in s1, finds the synset in s2 with the largest similarity value.
  65.        Sum of all of the largest similarity values and normalize this value by dividing it by the
  66.        number of largest similarity values found.
  67.  
  68.        Args:
  69.            s1, s2: list of synsets from doc_to_synsets
  70.  
  71.        Returns:
  72.            normalized similarity score of s1 onto s2
  73.        """
  74.         s1_largest_scores = []
  75.  
  76.         for i, s1_synset in enumerate(s1, 0):
  77.             max_score = 0
  78.             for s2_synset in s2:
  79.                 if distance_type == 'path':
  80.                     score = s1_synset.path_similarity(s2_synset, simulate_root=False)
  81.                 else:
  82.                     score = s1_synset.wup_similarity(s2_synset)
  83.                 if score != None:
  84.                     if score > max_score:
  85.                         max_score = score
  86.  
  87.             if max_score != 0:
  88.                 s1_largest_scores.append(max_score)
  89.  
  90.         mean_score = np.mean(s1_largest_scores)
  91.  
  92.         return mean_score
  93.  
  94.     def document_similarity(self, doc1, doc2):
  95.         """Finds the symmetrical similarity between doc1 and doc2"""
  96.  
  97.         synsets1 = self.doc_to_synsets(doc1)
  98.         synsets2 = self.doc_to_synsets(doc2)
  99.  
  100.         return (self.similarity_score(synsets1, synsets2) + self.similarity_score(synsets2, synsets1)) / 2
  101.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement