mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
[test] Update recall test (#22358)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
5534d38449
commit
fc328870db
@ -1,6 +1,9 @@
|
||||
import threading
|
||||
import h5py
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import pymilvus
|
||||
@ -12,6 +15,47 @@ from pymilvus import (
|
||||
|
||||
pymilvus_version = pymilvus.__version__
|
||||
|
||||
|
||||
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
|
||||
default_index_params = [{"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
|
||||
{"M": 48, "efConstruction": 500}, {"n_trees": 50}]
|
||||
index_params_map = dict(zip(all_index_types, default_index_params))
|
||||
|
||||
|
||||
def gen_index_params(index_type, metric_type="L2"):
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": metric_type}
|
||||
index = copy.deepcopy(default_index)
|
||||
index["index_type"] = index_type
|
||||
index["params"] = index_params_map[index_type]
|
||||
if index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
|
||||
index["metric_type"] = "HAMMING"
|
||||
return index
|
||||
|
||||
|
||||
def gen_search_param(index_type, metric_type="L2"):
|
||||
search_params = []
|
||||
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
|
||||
for nprobe in [10]:
|
||||
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
|
||||
search_params.append(ivf_search_params)
|
||||
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
|
||||
for nprobe in [10]:
|
||||
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
|
||||
search_params.append(bin_search_params)
|
||||
elif index_type in ["HNSW"]:
|
||||
for ef in [200]:
|
||||
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
|
||||
search_params.append(hnsw_search_param)
|
||||
elif index_type == "ANNOY":
|
||||
for search_k in [1000]:
|
||||
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
|
||||
search_params.append(annoy_search_param)
|
||||
else:
|
||||
logger.info("Invalid index_type.")
|
||||
raise Exception("Invalid index_type.")
|
||||
return search_params[0]
|
||||
|
||||
|
||||
def read_benchmark_hdf5(file_path):
|
||||
|
||||
f = h5py.File(file_path, 'r')
|
||||
@ -26,7 +70,8 @@ dim = 128
|
||||
TIMEOUT = 200
|
||||
|
||||
|
||||
def milvus_recall_test(host='127.0.0.1'):
|
||||
def milvus_recall_test(host='127.0.0.1', index_type="HNSW"):
|
||||
logger.info(f"recall test for index type {index_type}")
|
||||
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
|
||||
train, test, neighbors = read_benchmark_hdf5(file_path)
|
||||
connections.connect(host=host, port="19530")
|
||||
@ -39,7 +84,7 @@ def milvus_recall_test(host='127.0.0.1'):
|
||||
default_schema = CollectionSchema(
|
||||
fields=default_fields, description="test collection")
|
||||
|
||||
name = f"sift_128_euclidean"
|
||||
name = f"sift_128_euclidean_{index_type}"
|
||||
logger.info(f"Create collection {name}")
|
||||
collection = Collection(name=name, schema=default_schema)
|
||||
nb = len(train)
|
||||
@ -73,8 +118,7 @@ def milvus_recall_test(host='127.0.0.1'):
|
||||
logger.info(f"Get collection entities cost {t1 - t0:.4f} seconds")
|
||||
|
||||
# create index
|
||||
default_index = {"index_type": "IVF_SQ8",
|
||||
"metric_type": "L2", "params": {"nlist": 64}}
|
||||
default_index = gen_index_params(index_type)
|
||||
logger.info(f"Create index...")
|
||||
t0 = time.time()
|
||||
collection.create_index(field_name="float_vector",
|
||||
@ -103,14 +147,14 @@ def milvus_recall_test(host='127.0.0.1'):
|
||||
# search
|
||||
topK = 100
|
||||
nq = 10000
|
||||
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
||||
current_search_params = gen_search_param(index_type)
|
||||
|
||||
# define output_fields of search result
|
||||
for i in range(3):
|
||||
t0 = time.time()
|
||||
logger.info(f"Search...")
|
||||
res = collection.search(
|
||||
test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT
|
||||
test[:nq], "float_vector", current_search_params, topK, output_fields=["int64"], timeout=TIMEOUT
|
||||
)
|
||||
t1 = time.time()
|
||||
logger.info(f"search cost {t1 - t0:.4f} seconds")
|
||||
@ -132,7 +176,10 @@ def milvus_recall_test(host='127.0.0.1'):
|
||||
sum_radio = sum_radio + len(tmp) / len(item)
|
||||
recall = round(sum_radio / len(result_ids), 3)
|
||||
logger.info(f"recall={recall}")
|
||||
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
|
||||
if index_type in ["IVF_PQ", "ANNOY"]:
|
||||
assert recall >= 0.6, f"recall={recall} < 0.6"
|
||||
else:
|
||||
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
|
||||
# query
|
||||
expr = "int64 in [2,4,6,8]"
|
||||
output_fields = ["int64", "float"]
|
||||
@ -149,4 +196,7 @@ if __name__ == "__main__":
|
||||
default="127.0.0.1", help='milvus server ip')
|
||||
args = parser.parse_args()
|
||||
host = args.host
|
||||
milvus_recall_test(host)
|
||||
tasks = []
|
||||
for index_type in all_index_types:
|
||||
milvus_recall_test(host, index_type)
|
||||
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import h5py
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from pymilvus import connections, Collection
|
||||
|
||||
|
||||
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
|
||||
|
||||
|
||||
def read_benchmark_hdf5(file_path):
|
||||
|
||||
f = h5py.File(file_path, 'r')
|
||||
@ -16,18 +21,43 @@ def read_benchmark_hdf5(file_path):
|
||||
return train, test, neighbors
|
||||
|
||||
|
||||
def gen_search_param(index_type, metric_type="L2"):
|
||||
search_params = []
|
||||
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
|
||||
for nprobe in [10]:
|
||||
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
|
||||
search_params.append(ivf_search_params)
|
||||
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
|
||||
for nprobe in [10]:
|
||||
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
|
||||
search_params.append(bin_search_params)
|
||||
elif index_type in ["HNSW"]:
|
||||
for ef in [200]:
|
||||
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
|
||||
search_params.append(hnsw_search_param)
|
||||
elif index_type == "ANNOY":
|
||||
for search_k in [1000]:
|
||||
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
|
||||
search_params.append(annoy_search_param)
|
||||
else:
|
||||
logger.info("Invalid index_type.")
|
||||
raise Exception("Invalid index_type.")
|
||||
return search_params[0]
|
||||
|
||||
|
||||
dim = 128
|
||||
TIMEOUT = 200
|
||||
|
||||
|
||||
def search_test(host="127.0.0.1"):
|
||||
def search_test(host="127.0.0.1", index_type="HNSW"):
|
||||
logger.info(f"recall test for index type {index_type}")
|
||||
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
|
||||
train, test, neighbors = read_benchmark_hdf5(file_path)
|
||||
connections.connect(host=host, port="19530")
|
||||
collection = Collection(name="sift_128_euclidean")
|
||||
collection = Collection(name=f"sift_128_euclidean_{index_type}")
|
||||
nq = 10000
|
||||
topK = 100
|
||||
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
||||
search_params = gen_search_param(index_type)
|
||||
for i in range(3):
|
||||
t0 = time.time()
|
||||
logger.info(f"\nSearch...")
|
||||
@ -54,14 +84,20 @@ def search_test(host="127.0.0.1"):
|
||||
sum_radio = sum_radio + len(tmp) / len(item)
|
||||
recall = round(sum_radio / len(result_ids), 3)
|
||||
logger.info(f"recall={recall}")
|
||||
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
|
||||
if index_type in ["IVF_PQ", "ANNOY"]:
|
||||
assert recall >= 0.6, f"recall={recall} < 0.6"
|
||||
else:
|
||||
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import threading
|
||||
parser = argparse.ArgumentParser(description='config for recall test')
|
||||
parser.add_argument('--host', type=str, default="127.0.0.1", help='milvus server ip')
|
||||
args = parser.parse_args()
|
||||
host = args.host
|
||||
search_test(host)
|
||||
tasks = []
|
||||
for index_type in all_index_types:
|
||||
search_test(host, index_type)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user