Guest User

Untitled

a guest
Jun 23rd, 2018
108
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.56 KB | None | 0 0
  1. # encoding:utf-8
  2.  
  3. """
  4. faiss库 IndexScalarQuantizer 索引性能测试
  5. (Scalar quantizer (SQ) in flat mode)
  6. 4 bit per component is also implemented, but the impact on accuracy may be inacceptable
  7.  
  8. author : h-j-13
  9. time : 2018-6-22
  10. """
  11.  
  12. # QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform
  13.  
  14. import time
  15. import numpy
  16.  
  17. import faiss
  18. from recall_data import recall_data
  19.  
  20. # 基本参数
  21. d = 300 # 向量维数
  22. data_size = 10000 # 数据库大小
  23. k = 50
  24. qname = "QT_4bit"
  25.  
  26. # 生成测试数据
  27. numpy.random.seed(13)
  28. data = numpy.random.random(size=(data_size, d)).astype('float32')
  29. test_data = recall_data
  30.  
  31. # 创建索引模型并添加向量
  32. qtype = getattr(faiss.ScalarQuantizer, qname)
  33. index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2)
  34.  
  35. #  训练数据
  36. start_time = time.time()
  37. assert not index.is_trained
  38. index.train(data)
  39. assert index.is_trained
  40. print "Train Index Used %.2f sec." % (time.time() - start_time)
  41.  
  42. # 添加数据
  43. start_time = time.time()
  44. index.add(data) # 添加索引可能会有一点慢
  45. print "Add vector Used %.2f sec." % (time.time() - start_time)
  46.  
  47. start_time = time.time()
  48. D, I = index.search(data[:50], k) # 搜索每一个数据的的k临近向量
  49.  
  50. # 输出结果
  51. print "Used %.2f ms" % ((time.time() - start_time) * 1000)
  52. recall_1_count = 0
  53. recall_50_count = 0
  54. for (search_vec, test_vec) in zip(I, test_data):
  55. if test_vec[0] in search_vec:
  56. recall_1_count += 1
  57. recall_50_count += len(set(search_vec.tolist()) & set(test_vec))
  58. print "recall1@50 = " + str(recall_1_count / (50.0))
  59. print "recall50@50 = " + str(recall_50_count / (50.0 * 50.0))
Add Comment
Please, Sign In to add comment