Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import matplotlib.pyplot as plt
- import numpy as np
- def lvq_fit(train, target, lrate, b, max_epoch):
- label, train_idx = np.unique(target, return_index=True)
- weight = train[train_idx].astype(np.float64)
- train = np.array([e for i, e in enumerate(
- zip(train, target)) if i not in train_idx], dtype=object)
- train, target = train[:, 0], train[:, 1]
- epoch = 0
- while epoch < max_epoch:
- for i, x in enumerate(train):
- distance = [sum((w - x) ** 2) for w in weight]
- min = np.argmin(distance)
- sign = 1 if target[i] == label[min] else -1
- weight[min] += sign * lrate * (x - weight[min])
- lrate *= b
- epoch += 1
- return weight, label
- def lvq_predict(X, model):
- center, label = model
- Y = []
- for x in X:
- d = [sum((c - x) ** 2) for c in center]
- Y.append(label[np.argmin(d)])
- return Y
- def calc_accuracy(a, b):
- s = [1 if a[i] == b[i] else 0 for i in range(len(a))]
- return sum(s) / len(a)
- def class_to_int(category):
- return '+0*#'.index(category)
- def class_to_plot(category):
- return '+o*^'[category]
- def main():
- data = ('#####0000'
- '#####0000'
- '+++++++00'
- '+++++++00'
- '+++++++**'
- '+++++++**'
- '+++++++**'
- '+++++++**'
- '+++++++**')
- max_epoch = 40
- X_train = np.array([((9 - i//9)/9, (i % 9)/9) for i in range(len(data))])
- y_train = np.array([class_to_int(i) for i in data])
- model = lvq_fit(X_train, y_train, lrate=.01, b=1, max_epoch=max_epoch)
- output = lvq_predict(X_train, model)
- accuracy = calc_accuracy(output, y_train)
- colors = 'rgby'
- for x, label in zip(X_train, output):
- plt.plot(x[0], x[1], colors[label] + class_to_plot(label))
- for center, label in zip(model[0], model[1]):
- plt.plot(center[0], center[1], colors[label] + '.')
- plt.title(f'Epoch = {max_epoch} (Accuracy: {accuracy})')
- plt.show()
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement