Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # coding: utf-8
- from collections import Counter
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn.metrics import accuracy_score
- class KNearestNeightbors(object):
- def __init__(self, k = 1):
- self._train_data = None
- self._target_data = None
- self._k = k
- def fit(self, train_data, target_data):
- self._train_data = train_data
- self._target_data = target_data
- def predict(self, x):
- distances = np.array([self._distance(p, x) for p in self._train_data])
- nearest_indexes = distances.argsort()[:self._k]
- nearest_labels = self._target_data[nearest_indexes]
- c = Counter(nearest_labels)
- return c.most_common(1)[0][0]
- def _distance(self, p0, p1):
- return np.sum((p0 - p1) ** 2)
- def find_x2(self, cls0, cls1, x1):
- diff = []
- x2 = np.linspace(0.5, 3.0, 100)
- for y in x2:
- x = [x1, y]
- distances_0 = np.array([self._distance(p, x) for p in cls0])
- distances_1 = np.array([self._distance(p, x) for p in cls1])
- diff.append(np.absolute(np.min(np.array(distances_0)) - np.min(np.array(distances_1))))
- x2_indexes = np.argmin(diff)
- return x2[x2_indexes]
- def main():
- iris_dataset = np.loadtxt("dataset/iris_traning.csv", delimiter=",")
- features = iris_dataset[:, 1:3]
- targets = iris_dataset[:, 0]
- iris_dataset_test = np.loadtxt("dataset/iris_test.csv", delimiter=",")
- features_test = iris_dataset_test[:, 1:3]
- targets_test = iris_dataset_test[:, 0]
- for k in [1, 5, 10]:
- model = KNearestNeightbors(k)
- model.fit(features, targets)
- predicted_labels = []
- for test in features_test:
- predicted_label = model.predict(test)
- predicted_labels.append(predicted_label)
- score = accuracy_score(targets_test, predicted_labels)
- print("k = {}".format(k) )
- print("acuracy : {}".format(score))
- cls0 = iris_dataset[targets == 0, 1:3]
- cls1 = iris_dataset[targets > 0, 1:3]
- x1 = np.linspace(3, 7, 100)
- x2 = [model.find_x2(cls0, cls1, x) for x in x1]
- plt.scatter(features[:,0], features[:,1], linewidths=0, alpha=1,
- c=targets
- )
- plt.plot(x1, x2)
- plt.show()
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement