Advertisement
Guest User

Untitled

a guest
Apr 27th, 2017
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.19 KB | None | 0 0
  1. # coding: utf-8
  2. from collections import Counter
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.metrics import accuracy_score
  6.  
  7. class KNearestNeightbors(object):
  8. def __init__(self, k = 1):
  9. self._train_data = None
  10. self._target_data = None
  11. self._k = k
  12.  
  13. def fit(self, train_data, target_data):
  14. self._train_data = train_data
  15. self._target_data = target_data
  16.  
  17. def predict(self, x):
  18. distances = np.array([self._distance(p, x) for p in self._train_data])
  19. nearest_indexes = distances.argsort()[:self._k]
  20. nearest_labels = self._target_data[nearest_indexes]
  21. c = Counter(nearest_labels)
  22.  
  23. return c.most_common(1)[0][0]
  24.  
  25. def _distance(self, p0, p1):
  26. return np.sum((p0 - p1) ** 2)
  27.  
  28. def find_x2(self, cls0, cls1, x1):
  29. diff = []
  30. x2 = np.linspace(0.5, 3.0, 100)
  31. for y in x2:
  32. x = [x1, y]
  33. distances_0 = np.array([self._distance(p, x) for p in cls0])
  34. distances_1 = np.array([self._distance(p, x) for p in cls1])
  35. diff.append(np.absolute(np.min(np.array(distances_0)) - np.min(np.array(distances_1))))
  36. x2_indexes = np.argmin(diff)
  37. return x2[x2_indexes]
  38.  
  39. def main():
  40. iris_dataset = np.loadtxt("dataset/iris_traning.csv", delimiter=",")
  41. features = iris_dataset[:, 1:3]
  42. targets = iris_dataset[:, 0]
  43. iris_dataset_test = np.loadtxt("dataset/iris_test.csv", delimiter=",")
  44. features_test = iris_dataset_test[:, 1:3]
  45. targets_test = iris_dataset_test[:, 0]
  46.  
  47. for k in [1, 5, 10]:
  48. model = KNearestNeightbors(k)
  49. model.fit(features, targets)
  50. predicted_labels = []
  51. for test in features_test:
  52. predicted_label = model.predict(test)
  53. predicted_labels.append(predicted_label)
  54. score = accuracy_score(targets_test, predicted_labels)
  55. print("k = {}".format(k) )
  56. print("acuracy : {}".format(score))
  57.  
  58. cls0 = iris_dataset[targets == 0, 1:3]
  59. cls1 = iris_dataset[targets > 0, 1:3]
  60. x1 = np.linspace(3, 7, 100)
  61. x2 = [model.find_x2(cls0, cls1, x) for x in x1]
  62.  
  63. plt.scatter(features[:,0], features[:,1], linewidths=0, alpha=1,
  64. c=targets
  65. )
  66. plt.plot(x1, x2)
  67. plt.show()
  68.  
  69. if __name__ == '__main__':
  70. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement