Advertisement
Viraax

Untitled

Nov 25th, 2022
710
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.83 KB | None | 0 0
  1. import time
  2. import numpy
  3. from annoy import AnnoyIndex
  4. from pymongo import MongoClient
  5.  
  6. DIMENSION = 128
  7.  
  8.  
  9. def vectorize(str):
  10.     return [int(bit) for bit in str]
  11.  
  12.  
  13. def int_to_bin(i):
  14.     return bin(i)[2:].rjust(32, '0')
  15.  
  16.  
  17. def build_tree(filename, n_trees=10, n_jobs=-1):
  18.     client = MongoClient()
  19.     nb_images = client.local['images'].count_documents({})
  20.     batch_size = 5000000
  21.     total_batch = nb_images // batch_size + 1
  22.     tree = AnnoyIndex(DIMENSION, 'hamming')
  23.     query_start = time.time()
  24.     images = list(client.local['images'].find({}).limit(batch_size))
  25.     print(f'Query executed in {time.time() - query_start}')
  26.  
  27.     i = 0
  28.     for image in images:
  29.         h = int_to_bin(image['h1']) + \
  30.             int_to_bin(image['h2']) + \
  31.             int_to_bin(image['h3']) + \
  32.             int_to_bin(image['h4'])
  33.         tree.add_item(i, vectorize(h))
  34.         i = i + 1
  35.  
  36.     # samples_start = time.time()
  37.     # numbers = numpy.random.rand(samples)
  38.     # print(f'Generated {samples} samples in {time.time() - samples_start}')
  39.  
  40.     build_start = time.time()
  41.     tree.build(n_trees, n_jobs=n_jobs)
  42.     print(f'Built tree in {time.time() - build_start}')
  43.  
  44.     save_start = time.time()
  45.     tree.save(filename)
  46.     print(f'Saved tree in {time.time() - save_start}')
  47.  
  48.  
  49. def search_tree(filename, n=100):
  50.     search_start = time.time()
  51.     tree = AnnoyIndex(DIMENSION, 'hamming')
  52.     tree.load(filename)
  53.     results = tree.get_nns_by_vector(vectorize(
  54.         "10110000000100000110110111111111000001111111110001011001000000000001011101001111001001110011001111010111110111100100100111010011"), n, include_distances=True)
  55.     print(results)
  56.     print(f'Found {len(results[0])} in {time.time() - search_start}')
  57.  
  58.  
  59. if __name__ == "__main__":
  60.     #build_tree('C:/tmp/test.ann')
  61.     search_tree('C:/tmp/test.ann')
  62.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement