Advertisement
Yesver08

lvq.py

Nov 5th, 2022
797
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.09 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3.  
  4.  
  5. def lvq_fit(train, target, lrate, b, max_epoch):
  6.     label, train_idx = np.unique(target, return_index=True)
  7.     weight = train[train_idx].astype(np.float64)
  8.     train = np.array([e for i, e in enumerate(
  9.         zip(train, target)) if i not in train_idx], dtype=object)
  10.     train, target = train[:, 0], train[:, 1]
  11.     epoch = 0
  12.  
  13.     while epoch < max_epoch:
  14.         for i, x in enumerate(train):
  15.             distance = [sum((w - x) ** 2) for w in weight]
  16.             min = np.argmin(distance)
  17.             sign = 1 if target[i] == label[min] else -1
  18.             weight[min] += sign * lrate * (x - weight[min])
  19.  
  20.         lrate *= b
  21.         epoch += 1
  22.  
  23.     return weight, label
  24.  
  25.  
  26. def lvq_predict(X, model):
  27.     center, label = model
  28.     Y = []
  29.  
  30.     for x in X:
  31.         d = [sum((c - x) ** 2) for c in center]
  32.         Y.append(label[np.argmin(d)])
  33.  
  34.     return Y
  35.  
  36.  
  37. def calc_accuracy(a, b):
  38.     s = [1 if a[i] == b[i] else 0 for i in range(len(a))]
  39.  
  40.     return sum(s) / len(a)
  41.  
  42.  
  43. def class_to_int(category):
  44.     return '+0*#'.index(category)
  45.  
  46.  
  47. def class_to_plot(category):
  48.     return '+o*^'[category]
  49.  
  50.  
  51. def main():
  52.     data = ('#####0000'
  53.             '#####0000'
  54.             '+++++++00'
  55.             '+++++++00'
  56.             '+++++++**'
  57.             '+++++++**'
  58.             '+++++++**'
  59.             '+++++++**'
  60.             '+++++++**')
  61.  
  62.     max_epoch = 40
  63.  
  64.     X_train = np.array([((9 - i//9)/9, (i % 9)/9) for i in range(len(data))])
  65.     y_train = np.array([class_to_int(i) for i in data])
  66.  
  67.     model = lvq_fit(X_train, y_train, lrate=.01, b=1, max_epoch=max_epoch)
  68.     output = lvq_predict(X_train, model)
  69.     accuracy = calc_accuracy(output, y_train)
  70.     colors = 'rgby'
  71.  
  72.     for x, label in zip(X_train, output):
  73.         plt.plot(x[0], x[1], colors[label] + class_to_plot(label))
  74.  
  75.     for center, label in zip(model[0], model[1]):
  76.         plt.plot(center[0], center[1], colors[label] + '.')
  77.  
  78.     plt.title(f'Epoch = {max_epoch} (Accuracy: {accuracy})')
  79.     plt.show()
  80.  
  81.  
  82. if __name__ == '__main__':
  83.     main()
  84.  
  85.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement