Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def fast_knn(elements, k, neighborhood_fraction=.01, metric='euclidean'):
- # Finds the indices of k nearest neighbors for each sample in a matrix,
- # using any of the standard scipy distance metrics.
- nearest_neighbors = np.zeros((elements.shape[0], k), dtype=int)
- complete = np.zeros(elements.shape[0], dtype=bool)
- neighborhood_size = max(
- k * 3, int(elements.shape[0] * neighborhood_fraction))
- anchor_loops = 0
- while np.sum(complete) < complete.shape[0]:
- anchor_loops += 1
- available = np.arange(complete.shape[0])[~complete]
- np.random.shuffle(available)
- anchors = available[:int(complete.shape[0] / neighborhood_size) * 3]
- for anchor in anchors:
- print(f"Complete:{np.sum(complete)}\r", end='')
- anchor_distances = cdist(elements[anchor].reshape(
- 1, -1), elements, metric=metric)[0]
- neighborhood = np.argpartition(anchor_distances, neighborhood_size)[
- :neighborhood_size]
- anchor_local = np.where(neighborhood == anchor)[0]
- local_distances = squareform(
- pdist(elements[neighborhood], metric=metric))
- anchor_to_worst = np.max(local_distances[anchor_local])
- for i, sample in enumerate(neighborhood):
- if not complete[sample]:
- # First select the indices in the neighborhood that are knn
- best_neighbors_local = np.argpartition(
- local_distances[i], k + 1)
- # Next find the worst neighbor among the knn observed
- best_worst_local = best_neighbors_local[np.argmax(
- local_distances[i][best_neighbors_local[:k + 1]])]
- # And store the worst distance among the local knn
- best_worst_distance = local_distances[i, best_worst_local]
- # Find the distance of the anchor to the central element
- anchor_distance = local_distances[anchor_local, i]
- # By the triangle inequality the closest any element outside the neighborhood
- # can be to element we are examining is the criterion distance:
- criterion_distance = anchor_to_worst - anchor_distance
- # if sample == 0:
- # print(f"ld:{local_distances[i][best_neighbors_local[:k]]}")
- # print(f"bwd:{best_worst_distance}")
- # print(f"cd:{criterion_distance}")
- # Therefore if the criterion distance is greater than the best worst distance, the local knn
- # is also the best global knn
- if best_worst_distance >= criterion_distance:
- continue
- else:
- # Before we conclude we must exclude the sample itself from its
- # k nearest neighbors
- best_neighbors_local = [
- bn for bn in best_neighbors_local[:k + 1] if bn != i]
- # Finally translate the local best knn to the global indices
- best_neighbors = neighborhood[best_neighbors_local]
- nearest_neighbors[sample] = best_neighbors
- complete[sample] = True
- print("\n")
- return nearest_neighbors
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement