[test]Refine recall test (#21789)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2023-01-19 09:49:44 +08:00 committed by GitHub
parent c82b5d15b6
commit 81f2840682
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 63 deletions

View File

@ -2,11 +2,12 @@ import h5py
import numpy as np import numpy as np
import time import time
from pathlib import Path from pathlib import Path
from loguru import logger
import pymilvus import pymilvus
from pymilvus import ( from pymilvus import (
connections, connections,
FieldSchema, CollectionSchema, DataType, FieldSchema, CollectionSchema, DataType,
Collection Collection, utility
) )
pymilvus_version = pymilvus.__version__ pymilvus_version = pymilvus.__version__
@ -37,13 +38,16 @@ def milvus_recall_test(host='127.0.0.1'):
] ]
default_schema = CollectionSchema( default_schema = CollectionSchema(
fields=default_fields, description="test collection") fields=default_fields, description="test collection")
collection = Collection(name="sift_128_euclidean", schema=default_schema)
name = f"sift_128_euclidean"
logger.info(f"Create collection {name}")
collection = Collection(name=name, schema=default_schema)
nb = len(train) nb = len(train)
batch_size = 50000 batch_size = 50000
epoch = int(nb / batch_size) epoch = int(nb / batch_size)
t0 = time.time() t0 = time.time()
for i in range(epoch): for i in range(epoch):
print("epoch:", i) logger.info(f"epoch: {i}")
start = i * batch_size start = i * batch_size
end = (i + 1) * batch_size end = (i + 1) * batch_size
if end > nb: if end > nb:
@ -56,74 +60,86 @@ def milvus_recall_test(host='127.0.0.1'):
] ]
collection.insert(data) collection.insert(data)
t1 = time.time() t1 = time.time()
print(f"\nInsert {nb} vectors cost {t1 - t0:.4f} seconds") logger.info(f"Insert {nb} vectors cost {t1 - t0:.4f} seconds")
t0 = time.time() t0 = time.time()
print(f"\nGet collection entities...") logger.info(f"Get collection entities...")
if pymilvus_version >= "2.2.0": if pymilvus_version >= "2.2.0":
collection.flush() collection.flush()
else: else:
collection.num_entities collection.num_entities
print(collection.num_entities) logger.info(collection.num_entities)
t1 = time.time() t1 = time.time()
print(f"\nGet collection entities cost {t1 - t0:.4f} seconds") logger.info(f"Get collection entities cost {t1 - t0:.4f} seconds")
# create index # create index
default_index = {"index_type": "IVF_SQ8", default_index = {"index_type": "IVF_SQ8",
"metric_type": "L2", "params": {"nlist": 64}} "metric_type": "L2", "params": {"nlist": 64}}
print(f"\nCreate index...") logger.info(f"Create index...")
t0 = time.time() t0 = time.time()
collection.create_index(field_name="float_vector", collection.create_index(field_name="float_vector",
index_params=default_index) index_params=default_index)
t1 = time.time() t1 = time.time()
print(f"\nCreate index cost {t1 - t0:.4f} seconds") logger.info(f"Create index cost {t1 - t0:.4f} seconds")
# load collection # load collection
replica_number = 1 replica_number = 1
print(f"\nload collection...") logger.info(f"load collection...")
t0 = time.time() t0 = time.time()
collection.load(replica_number=replica_number) collection.load(replica_number=replica_number)
t1 = time.time() t1 = time.time()
print(f"\nload collection cost {t1 - t0:.4f} seconds") logger.info(f"load collection cost {t1 - t0:.4f} seconds")
res = utility.get_query_segment_info(name)
cnt = 0
logger.info(f"segments info: {res}")
for segment in res:
cnt += segment.num_rows
assert cnt == collection.num_entities
logger.info(f"wait for loading complete...")
time.sleep(30)
res = utility.get_query_segment_info(name)
logger.info(f"segments info: {res}")
# search # search
topK = 100 topK = 100
nq = 10000 nq = 10000
search_params = {"metric_type": "L2", "params": {"nprobe": 10}} search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
t0 = time.time()
print(f"\nSearch...")
# define output_fields of search result # define output_fields of search result
res = collection.search( for i in range(3):
test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT t0 = time.time()
) logger.info(f"Search...")
t1 = time.time() res = collection.search(
print(f"search cost {t1 - t0:.4f} seconds") test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT
result_ids = [] )
for hits in res: t1 = time.time()
result_id = [] logger.info(f"search cost {t1 - t0:.4f} seconds")
for hit in hits: result_ids = []
result_id.append(hit.entity.get("int64")) for hits in res:
result_ids.append(result_id) result_id = []
for hit in hits:
# calculate recall result_id.append(hit.entity.get("int64"))
true_ids = neighbors[:nq, :topK] result_ids.append(result_id)
sum_radio = 0.0
for index, item in enumerate(result_ids):
# tmp = set(item).intersection(set(flat_id_list[index]))
assert len(item) == len(true_ids[index])
tmp = set(true_ids[index]).intersection(set(item))
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
assert recall >= 0.95
print(f"recall={recall}")
# calculate recall
true_ids = neighbors[:nq, :topK]
sum_radio = 0.0
logger.info(f"Calculate recall...")
for index, item in enumerate(result_ids):
# tmp = set(item).intersection(set(flat_id_list[index]))
assert len(item) == len(true_ids[index])
tmp = set(true_ids[index]).intersection(set(item))
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
logger.info(f"recall={recall}")
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
# query # query
expr = "int64 in [2,4,6,8]" expr = "int64 in [2,4,6,8]"
output_fields = ["int64", "float"] output_fields = ["int64", "float"]
res = collection.query(expr, output_fields, timeout=TIMEOUT) res = collection.query(expr, output_fields, timeout=TIMEOUT)
sorted_res = sorted(res, key=lambda k: k['int64']) sorted_res = sorted(res, key=lambda k: k['int64'])
for r in sorted_res: for r in sorted_res:
print(r) logger.info(r)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,7 @@ import h5py
import numpy as np import numpy as np
import time import time
from pathlib import Path from pathlib import Path
from loguru import logger
from pymilvus import connections, Collection from pymilvus import connections, Collection
@ -27,32 +28,34 @@ def search_test(host="127.0.0.1"):
nq = 10000 nq = 10000
topK = 100 topK = 100
search_params = {"metric_type": "L2", "params": {"nprobe": 10}} search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
t0 = time.time() for i in range(3):
print(f"\nSearch...") t0 = time.time()
# define output_fields of search result logger.info(f"\nSearch...")
res = collection.search( # define output_fields of search result
test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT res = collection.search(
) test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT
t1 = time.time() )
print(f"search cost {t1 - t0:.4f} seconds") t1 = time.time()
result_ids = [] logger.info(f"search cost {t1 - t0:.4f} seconds")
for hits in res: result_ids = []
result_id = [] for hits in res:
for hit in hits: result_id = []
result_id.append(hit.entity.get("int64")) for hit in hits:
result_ids.append(result_id) result_id.append(hit.entity.get("int64"))
result_ids.append(result_id)
# calculate recall
true_ids = neighbors[:nq, :topK]
sum_radio = 0.0
for index, item in enumerate(result_ids):
# tmp = set(item).intersection(set(flat_id_list[index]))
assert len(item) == len(true_ids[index]), f"get {len(item)} but expect {len(true_ids[index])}"
tmp = set(true_ids[index]).intersection(set(item))
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
logger.info(f"recall={recall}")
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
# calculate recall
true_ids = neighbors[:nq,:topK]
sum_radio = 0.0
for index, item in enumerate(result_ids):
# tmp = set(item).intersection(set(flat_id_list[index]))
assert len(item) == len(true_ids[index]), f"get {len(item)} but expect {len(true_ids[index])}"
tmp = set(true_ids[index]).intersection(set(item))
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
assert recall >= 0.95, f"recall is {recall}, less than 0.95"
print(f"recall={recall}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -38,4 +38,4 @@ minio==7.1.5
h5py==3.1.0 h5py==3.1.0
# for log # for log
loguru==0.5.3 loguru==0.6.0