Advertisement
Guest User

Untitled

a guest
Nov 17th, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.24 KB | None | 0 0
  1. from time import perf_counter
  2. from sklearn.neighbors import KDTree
  3. # from sklearn.preprocessing import normalize
  4. # from scipy import spatial
  5.  
  6. def true_closest(X_train, X_test, k):
  7. result = []
  8. for x0 in X_test:
  9. bests = list(sorted([(i, np.linalg.norm(x - x0)) for i, x in enumerate(X_train)], key=lambda x: x[1]))
  10. bests = [i for i, d in bests]
  11. result.append(bests[:min(k, len(bests))])
  12. return result
  13.  
  14. # X, y = read_cancer_dataset('cancer.csv')
  15. X, y = read_spam_dataset('spam.csv')
  16. # X = normalize(X, axis=0, norm='l2')
  17. X_train, y_train, X_test, y_test = train_test_split(X, y, 0.9)
  18. # X_train = np.random.randn(100, 3)
  19. # X_test = np.random.randn(10, 3)
  20.  
  21. tree = KDTree(X_train, leaf_size=40)
  22.  
  23. time1 = perf_counter()
  24. _, predicted = tree.query(X_test, k=30)
  25. time1 = perf_counter() - time1
  26.  
  27. time2 = perf_counter()
  28. true = true_closest(X_train, X_test, k=30)
  29. time2 = perf_counter() - time2
  30. print(time1, time2)
  31. if np.sum(np.abs(np.array(np.array(predicted).shape) - np.array(np.array(true).shape))) != 0:
  32. print("Wrong shape")
  33. else:
  34. errors = sum([1 for row1, row2 in zip(predicted, true) for i1, i2 in zip(row1, row2) if i1 != i2])
  35. if errors > 0:
  36. print("Encounted", errors, "errors")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement