[test] Update recall test (#22358)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2023-02-23 16:43:46 +08:00 committed by GitHub
parent 5534d38449
commit fc328870db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 13 deletions

View File

@ -1,6 +1,9 @@
import threading
import h5py import h5py
import numpy as np import numpy as np
import time import time
import sys
import copy
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
import pymilvus import pymilvus
@ -12,6 +15,47 @@ from pymilvus import (
pymilvus_version = pymilvus.__version__ pymilvus_version = pymilvus.__version__
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
default_index_params = [{"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
{"M": 48, "efConstruction": 500}, {"n_trees": 50}]
index_params_map = dict(zip(all_index_types, default_index_params))
def gen_index_params(index_type, metric_type="L2"):
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": metric_type}
index = copy.deepcopy(default_index)
index["index_type"] = index_type
index["params"] = index_params_map[index_type]
if index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
index["metric_type"] = "HAMMING"
return index
def gen_search_param(index_type, metric_type="L2"):
search_params = []
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
for nprobe in [10]:
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
search_params.append(ivf_search_params)
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
for nprobe in [10]:
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [200]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
for search_k in [1000]:
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
search_params.append(annoy_search_param)
else:
logger.info("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params[0]
def read_benchmark_hdf5(file_path): def read_benchmark_hdf5(file_path):
f = h5py.File(file_path, 'r') f = h5py.File(file_path, 'r')
@ -26,7 +70,8 @@ dim = 128
TIMEOUT = 200 TIMEOUT = 200
def milvus_recall_test(host='127.0.0.1'): def milvus_recall_test(host='127.0.0.1', index_type="HNSW"):
logger.info(f"recall test for index type {index_type}")
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5" file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
train, test, neighbors = read_benchmark_hdf5(file_path) train, test, neighbors = read_benchmark_hdf5(file_path)
connections.connect(host=host, port="19530") connections.connect(host=host, port="19530")
@ -39,7 +84,7 @@ def milvus_recall_test(host='127.0.0.1'):
default_schema = CollectionSchema( default_schema = CollectionSchema(
fields=default_fields, description="test collection") fields=default_fields, description="test collection")
name = f"sift_128_euclidean" name = f"sift_128_euclidean_{index_type}"
logger.info(f"Create collection {name}") logger.info(f"Create collection {name}")
collection = Collection(name=name, schema=default_schema) collection = Collection(name=name, schema=default_schema)
nb = len(train) nb = len(train)
@ -73,8 +118,7 @@ def milvus_recall_test(host='127.0.0.1'):
logger.info(f"Get collection entities cost {t1 - t0:.4f} seconds") logger.info(f"Get collection entities cost {t1 - t0:.4f} seconds")
# create index # create index
default_index = {"index_type": "IVF_SQ8", default_index = gen_index_params(index_type)
"metric_type": "L2", "params": {"nlist": 64}}
logger.info(f"Create index...") logger.info(f"Create index...")
t0 = time.time() t0 = time.time()
collection.create_index(field_name="float_vector", collection.create_index(field_name="float_vector",
@ -103,14 +147,14 @@ def milvus_recall_test(host='127.0.0.1'):
# search # search
topK = 100 topK = 100
nq = 10000 nq = 10000
search_params = {"metric_type": "L2", "params": {"nprobe": 10}} current_search_params = gen_search_param(index_type)
# define output_fields of search result # define output_fields of search result
for i in range(3): for i in range(3):
t0 = time.time() t0 = time.time()
logger.info(f"Search...") logger.info(f"Search...")
res = collection.search( res = collection.search(
test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT test[:nq], "float_vector", current_search_params, topK, output_fields=["int64"], timeout=TIMEOUT
) )
t1 = time.time() t1 = time.time()
logger.info(f"search cost {t1 - t0:.4f} seconds") logger.info(f"search cost {t1 - t0:.4f} seconds")
@ -132,7 +176,10 @@ def milvus_recall_test(host='127.0.0.1'):
sum_radio = sum_radio + len(tmp) / len(item) sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3) recall = round(sum_radio / len(result_ids), 3)
logger.info(f"recall={recall}") logger.info(f"recall={recall}")
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95" if index_type in ["IVF_PQ", "ANNOY"]:
assert recall >= 0.6, f"recall={recall} < 0.6"
else:
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
# query # query
expr = "int64 in [2,4,6,8]" expr = "int64 in [2,4,6,8]"
output_fields = ["int64", "float"] output_fields = ["int64", "float"]
@ -149,4 +196,7 @@ if __name__ == "__main__":
default="127.0.0.1", help='milvus server ip') default="127.0.0.1", help='milvus server ip')
args = parser.parse_args() args = parser.parse_args()
host = args.host host = args.host
milvus_recall_test(host) tasks = []
for index_type in all_index_types:
milvus_recall_test(host, index_type)

View File

@ -1,11 +1,16 @@
import h5py import h5py
import numpy as np import numpy as np
import time import time
import sys
import threading
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from pymilvus import connections, Collection from pymilvus import connections, Collection
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
def read_benchmark_hdf5(file_path): def read_benchmark_hdf5(file_path):
f = h5py.File(file_path, 'r') f = h5py.File(file_path, 'r')
@ -16,18 +21,43 @@ def read_benchmark_hdf5(file_path):
return train, test, neighbors return train, test, neighbors
def gen_search_param(index_type, metric_type="L2"):
search_params = []
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
for nprobe in [10]:
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
search_params.append(ivf_search_params)
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
for nprobe in [10]:
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [200]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
for search_k in [1000]:
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
search_params.append(annoy_search_param)
else:
logger.info("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params[0]
dim = 128 dim = 128
TIMEOUT = 200 TIMEOUT = 200
def search_test(host="127.0.0.1"): def search_test(host="127.0.0.1", index_type="HNSW"):
logger.info(f"recall test for index type {index_type}")
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5" file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
train, test, neighbors = read_benchmark_hdf5(file_path) train, test, neighbors = read_benchmark_hdf5(file_path)
connections.connect(host=host, port="19530") connections.connect(host=host, port="19530")
collection = Collection(name="sift_128_euclidean") collection = Collection(name=f"sift_128_euclidean_{index_type}")
nq = 10000 nq = 10000
topK = 100 topK = 100
search_params = {"metric_type": "L2", "params": {"nprobe": 10}} search_params = gen_search_param(index_type)
for i in range(3): for i in range(3):
t0 = time.time() t0 = time.time()
logger.info(f"\nSearch...") logger.info(f"\nSearch...")
@ -54,14 +84,20 @@ def search_test(host="127.0.0.1"):
sum_radio = sum_radio + len(tmp) / len(item) sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3) recall = round(sum_radio / len(result_ids), 3)
logger.info(f"recall={recall}") logger.info(f"recall={recall}")
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95" if index_type in ["IVF_PQ", "ANNOY"]:
assert recall >= 0.6, f"recall={recall} < 0.6"
else:
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
import threading
parser = argparse.ArgumentParser(description='config for recall test') parser = argparse.ArgumentParser(description='config for recall test')
parser.add_argument('--host', type=str, default="127.0.0.1", help='milvus server ip') parser.add_argument('--host', type=str, default="127.0.0.1", help='milvus server ip')
args = parser.parse_args() args = parser.parse_args()
host = args.host host = args.host
search_test(host) tasks = []
for index_type in all_index_types:
search_test(host, index_type)