mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 17:48:29 +08:00
[test] Update recall test (#22358)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
5534d38449
commit
fc328870db
@ -1,6 +1,9 @@
|
|||||||
|
import threading
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
import sys
|
||||||
|
import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import pymilvus
|
import pymilvus
|
||||||
@ -12,6 +15,47 @@ from pymilvus import (
|
|||||||
|
|
||||||
pymilvus_version = pymilvus.__version__
|
pymilvus_version = pymilvus.__version__
|
||||||
|
|
||||||
|
|
||||||
|
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
|
||||||
|
default_index_params = [{"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
|
||||||
|
{"M": 48, "efConstruction": 500}, {"n_trees": 50}]
|
||||||
|
index_params_map = dict(zip(all_index_types, default_index_params))
|
||||||
|
|
||||||
|
|
||||||
|
def gen_index_params(index_type, metric_type="L2"):
|
||||||
|
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": metric_type}
|
||||||
|
index = copy.deepcopy(default_index)
|
||||||
|
index["index_type"] = index_type
|
||||||
|
index["params"] = index_params_map[index_type]
|
||||||
|
if index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
|
||||||
|
index["metric_type"] = "HAMMING"
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def gen_search_param(index_type, metric_type="L2"):
|
||||||
|
search_params = []
|
||||||
|
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
|
||||||
|
for nprobe in [10]:
|
||||||
|
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
|
||||||
|
search_params.append(ivf_search_params)
|
||||||
|
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
|
||||||
|
for nprobe in [10]:
|
||||||
|
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
|
||||||
|
search_params.append(bin_search_params)
|
||||||
|
elif index_type in ["HNSW"]:
|
||||||
|
for ef in [200]:
|
||||||
|
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
|
||||||
|
search_params.append(hnsw_search_param)
|
||||||
|
elif index_type == "ANNOY":
|
||||||
|
for search_k in [1000]:
|
||||||
|
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
|
||||||
|
search_params.append(annoy_search_param)
|
||||||
|
else:
|
||||||
|
logger.info("Invalid index_type.")
|
||||||
|
raise Exception("Invalid index_type.")
|
||||||
|
return search_params[0]
|
||||||
|
|
||||||
|
|
||||||
def read_benchmark_hdf5(file_path):
|
def read_benchmark_hdf5(file_path):
|
||||||
|
|
||||||
f = h5py.File(file_path, 'r')
|
f = h5py.File(file_path, 'r')
|
||||||
@ -26,7 +70,8 @@ dim = 128
|
|||||||
TIMEOUT = 200
|
TIMEOUT = 200
|
||||||
|
|
||||||
|
|
||||||
def milvus_recall_test(host='127.0.0.1'):
|
def milvus_recall_test(host='127.0.0.1', index_type="HNSW"):
|
||||||
|
logger.info(f"recall test for index type {index_type}")
|
||||||
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
|
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
|
||||||
train, test, neighbors = read_benchmark_hdf5(file_path)
|
train, test, neighbors = read_benchmark_hdf5(file_path)
|
||||||
connections.connect(host=host, port="19530")
|
connections.connect(host=host, port="19530")
|
||||||
@ -39,7 +84,7 @@ def milvus_recall_test(host='127.0.0.1'):
|
|||||||
default_schema = CollectionSchema(
|
default_schema = CollectionSchema(
|
||||||
fields=default_fields, description="test collection")
|
fields=default_fields, description="test collection")
|
||||||
|
|
||||||
name = f"sift_128_euclidean"
|
name = f"sift_128_euclidean_{index_type}"
|
||||||
logger.info(f"Create collection {name}")
|
logger.info(f"Create collection {name}")
|
||||||
collection = Collection(name=name, schema=default_schema)
|
collection = Collection(name=name, schema=default_schema)
|
||||||
nb = len(train)
|
nb = len(train)
|
||||||
@ -73,8 +118,7 @@ def milvus_recall_test(host='127.0.0.1'):
|
|||||||
logger.info(f"Get collection entities cost {t1 - t0:.4f} seconds")
|
logger.info(f"Get collection entities cost {t1 - t0:.4f} seconds")
|
||||||
|
|
||||||
# create index
|
# create index
|
||||||
default_index = {"index_type": "IVF_SQ8",
|
default_index = gen_index_params(index_type)
|
||||||
"metric_type": "L2", "params": {"nlist": 64}}
|
|
||||||
logger.info(f"Create index...")
|
logger.info(f"Create index...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
collection.create_index(field_name="float_vector",
|
collection.create_index(field_name="float_vector",
|
||||||
@ -103,14 +147,14 @@ def milvus_recall_test(host='127.0.0.1'):
|
|||||||
# search
|
# search
|
||||||
topK = 100
|
topK = 100
|
||||||
nq = 10000
|
nq = 10000
|
||||||
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
current_search_params = gen_search_param(index_type)
|
||||||
|
|
||||||
# define output_fields of search result
|
# define output_fields of search result
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
logger.info(f"Search...")
|
logger.info(f"Search...")
|
||||||
res = collection.search(
|
res = collection.search(
|
||||||
test[:nq], "float_vector", search_params, topK, output_fields=["int64"], timeout=TIMEOUT
|
test[:nq], "float_vector", current_search_params, topK, output_fields=["int64"], timeout=TIMEOUT
|
||||||
)
|
)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
logger.info(f"search cost {t1 - t0:.4f} seconds")
|
logger.info(f"search cost {t1 - t0:.4f} seconds")
|
||||||
@ -132,7 +176,10 @@ def milvus_recall_test(host='127.0.0.1'):
|
|||||||
sum_radio = sum_radio + len(tmp) / len(item)
|
sum_radio = sum_radio + len(tmp) / len(item)
|
||||||
recall = round(sum_radio / len(result_ids), 3)
|
recall = round(sum_radio / len(result_ids), 3)
|
||||||
logger.info(f"recall={recall}")
|
logger.info(f"recall={recall}")
|
||||||
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
|
if index_type in ["IVF_PQ", "ANNOY"]:
|
||||||
|
assert recall >= 0.6, f"recall={recall} < 0.6"
|
||||||
|
else:
|
||||||
|
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
|
||||||
# query
|
# query
|
||||||
expr = "int64 in [2,4,6,8]"
|
expr = "int64 in [2,4,6,8]"
|
||||||
output_fields = ["int64", "float"]
|
output_fields = ["int64", "float"]
|
||||||
@ -149,4 +196,7 @@ if __name__ == "__main__":
|
|||||||
default="127.0.0.1", help='milvus server ip')
|
default="127.0.0.1", help='milvus server ip')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
host = args.host
|
host = args.host
|
||||||
milvus_recall_test(host)
|
tasks = []
|
||||||
|
for index_type in all_index_types:
|
||||||
|
milvus_recall_test(host, index_type)
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pymilvus import connections, Collection
|
from pymilvus import connections, Collection
|
||||||
|
|
||||||
|
|
||||||
|
all_index_types = ["IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"]
|
||||||
|
|
||||||
|
|
||||||
def read_benchmark_hdf5(file_path):
|
def read_benchmark_hdf5(file_path):
|
||||||
|
|
||||||
f = h5py.File(file_path, 'r')
|
f = h5py.File(file_path, 'r')
|
||||||
@ -16,18 +21,43 @@ def read_benchmark_hdf5(file_path):
|
|||||||
return train, test, neighbors
|
return train, test, neighbors
|
||||||
|
|
||||||
|
|
||||||
|
def gen_search_param(index_type, metric_type="L2"):
|
||||||
|
search_params = []
|
||||||
|
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
|
||||||
|
for nprobe in [10]:
|
||||||
|
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
|
||||||
|
search_params.append(ivf_search_params)
|
||||||
|
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
|
||||||
|
for nprobe in [10]:
|
||||||
|
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
|
||||||
|
search_params.append(bin_search_params)
|
||||||
|
elif index_type in ["HNSW"]:
|
||||||
|
for ef in [200]:
|
||||||
|
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
|
||||||
|
search_params.append(hnsw_search_param)
|
||||||
|
elif index_type == "ANNOY":
|
||||||
|
for search_k in [1000]:
|
||||||
|
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
|
||||||
|
search_params.append(annoy_search_param)
|
||||||
|
else:
|
||||||
|
logger.info("Invalid index_type.")
|
||||||
|
raise Exception("Invalid index_type.")
|
||||||
|
return search_params[0]
|
||||||
|
|
||||||
|
|
||||||
dim = 128
|
dim = 128
|
||||||
TIMEOUT = 200
|
TIMEOUT = 200
|
||||||
|
|
||||||
|
|
||||||
def search_test(host="127.0.0.1"):
|
def search_test(host="127.0.0.1", index_type="HNSW"):
|
||||||
|
logger.info(f"recall test for index type {index_type}")
|
||||||
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
|
file_path = f"{str(Path(__file__).absolute().parent.parent.parent)}/assets/ann_hdf5/sift-128-euclidean.hdf5"
|
||||||
train, test, neighbors = read_benchmark_hdf5(file_path)
|
train, test, neighbors = read_benchmark_hdf5(file_path)
|
||||||
connections.connect(host=host, port="19530")
|
connections.connect(host=host, port="19530")
|
||||||
collection = Collection(name="sift_128_euclidean")
|
collection = Collection(name=f"sift_128_euclidean_{index_type}")
|
||||||
nq = 10000
|
nq = 10000
|
||||||
topK = 100
|
topK = 100
|
||||||
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
search_params = gen_search_param(index_type)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
logger.info(f"\nSearch...")
|
logger.info(f"\nSearch...")
|
||||||
@ -54,14 +84,20 @@ def search_test(host="127.0.0.1"):
|
|||||||
sum_radio = sum_radio + len(tmp) / len(item)
|
sum_radio = sum_radio + len(tmp) / len(item)
|
||||||
recall = round(sum_radio / len(result_ids), 3)
|
recall = round(sum_radio / len(result_ids), 3)
|
||||||
logger.info(f"recall={recall}")
|
logger.info(f"recall={recall}")
|
||||||
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95"
|
if index_type in ["IVF_PQ", "ANNOY"]:
|
||||||
|
assert recall >= 0.6, f"recall={recall} < 0.6"
|
||||||
|
else:
|
||||||
|
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
import threading
|
||||||
parser = argparse.ArgumentParser(description='config for recall test')
|
parser = argparse.ArgumentParser(description='config for recall test')
|
||||||
parser.add_argument('--host', type=str, default="127.0.0.1", help='milvus server ip')
|
parser.add_argument('--host', type=str, default="127.0.0.1", help='milvus server ip')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
host = args.host
|
host = args.host
|
||||||
search_test(host)
|
tasks = []
|
||||||
|
for index_type in all_index_types:
|
||||||
|
search_test(host, index_type)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user