[test] Update recall test (#22358)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2023-02-23 16:43:46 +08:00 committed by GitHub
parent 5534d38449
commit fc328870db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 13 deletions

View File

@ -1,6 +1,9 @@
import threading
import h5py
import numpy as np
import time
import sys
import copy
from pathlib import Path
from loguru import logger
import pymilvus
@ -12,6 +15,47 @@ from pymilvus import (
pymilvus_version = pymilvus.__version__
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
default_index_params = [{"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
{"M": 48, "efConstruction": 500}, {"n_trees": 50}]
index_params_map = dict(zip(all_index_types, default_index_params))
def gen_index_params(index_type, metric_type="L2"):
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": metric_type}
index = copy.deepcopy(default_index)
index["index_type"] = index_type
index["params"] = index_params_map[index_type]
if index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
index["metric_type"] = "HAMMING"
return index
def gen_search_param(index_type, metric_type="L2"):
search_params = []
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
for nprobe in [10]:
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
search_params.append(ivf_search_params)
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
for nprobe in [10]:
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [200]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
for search_k in [1000]:
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
search_params.append(annoy_search_param)
else:
logger.info("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params[0]
def read_benchmark_hdf5(file_path):
f = h5py.File(file_path, 'r')
@ -26,7 +70,8 @@ dim = 128
TIMEOUT = 200
def milvus_recall_test(host='127.0.0.1'):
def milvus_recall_test(host='127.0.0.1', index_type="HNSW"):
logger.info(f"recall test for index type {index_type}")
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
train, test, neighbors = read_benchmark_hdf5(file_path)
connections.connect(host=host, port="19530")
@ -39,7 +84,7 @@ def milvus_recall_test(host='127.0.0.1'):
default_schema = CollectionSchema(
fields=default_fields, description="test collection")
name = f"sift_128_euclidean"
name = f"sift_128_euclidean_{index_type}"
logger.info(f"Create collection {name}")
collection = Collection(name=name, schema=default_schema)
nb = len(train)
@ -73,8 +118,7 @@ def milvus_recall_test(host='127.0.0.1'):
logger.info(f"Get collection entities cost {t1 - t0:.4f} seconds")
# create index
default_index = {"index_type": "IVF_SQ8",
"metric_type": "L2", "params": {"nlist": 64}}
default_index = gen_index_params(index_type)
logger.info(f"Create index...")
t0 = time.time()
collection.create_index(field_name="float_vector",
@ -103,14 +147,14 @@ def milvus_recall_test(host='127.0.0.1'):
# search
topK = 100
nq = 10000
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
current_search_params = gen_search_param(index_type)
# define output_fields of search result
for i in range(3):
t0 = time.time()
logger.info(f"Search...")
res = collection.search(
test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT
test[:nq], "float_vector", current_search_params, topK, output_fields=["int64"], timeout=TIMEOUT
)
t1 = time.time()
logger.info(f"search cost {t1 - t0:.4f} seconds")
@ -132,7 +176,10 @@ def milvus_recall_test(host='127.0.0.1'):
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
logger.info(f"recall={recall}")
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
if index_type in ["IVF_PQ", "ANNOY"]:
assert recall >= 0.6, f"recall={recall} < 0.6"
else:
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
# query
expr = "int64 in [2,4,6,8]"
output_fields = ["int64", "float"]
@ -149,4 +196,7 @@ if __name__ == "__main__":
default="127.0.0.1", help='milvus server ip')
args = parser.parse_args()
host = args.host
milvus_recall_test(host)
tasks = []
for index_type in all_index_types:
milvus_recall_test(host, index_type)

View File

@ -1,11 +1,16 @@
import h5py
import numpy as np
import time
import sys
import threading
from pathlib import Path
from loguru import logger
from pymilvus import connections, Collection
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
def read_benchmark_hdf5(file_path):
f = h5py.File(file_path, 'r')
@ -16,18 +21,43 @@ def read_benchmark_hdf5(file_path):
return train, test, neighbors
def gen_search_param(index_type, metric_type="L2"):
search_params = []
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
for nprobe in [10]:
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
search_params.append(ivf_search_params)
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
for nprobe in [10]:
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [200]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
for search_k in [1000]:
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
search_params.append(annoy_search_param)
else:
logger.info("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params[0]
dim = 128
TIMEOUT = 200
def search_test(host="127.0.0.1"):
def search_test(host="127.0.0.1", index_type="HNSW"):
logger.info(f"recall test for index type {index_type}")
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
train, test, neighbors = read_benchmark_hdf5(file_path)
connections.connect(host=host, port="19530")
collection = Collection(name="sift_128_euclidean")
collection = Collection(name=f"sift_128_euclidean_{index_type}")
nq = 10000
topK = 100
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
search_params = gen_search_param(index_type)
for i in range(3):
t0 = time.time()
logger.info(f"\nSearch...")
@ -54,14 +84,20 @@ def search_test(host="127.0.0.1"):
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
logger.info(f"recall={recall}")
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
if index_type in ["IVF_PQ", "ANNOY"]:
assert recall >= 0.6, f"recall={recall} < 0.6"
else:
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
if __name__ == "__main__":
import argparse
import threading
parser = argparse.ArgumentParser(description='config for recall test')
parser.add_argument('--host', type=str, default="127.0.0.1", help='milvus server ip')
args = parser.parse_args()
host = args.host
search_test(host)
tasks = []
for index_type in all_index_types:
search_test(host, index_type)