mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
[skip ci] Add comment for the accuracy test (#8092)
Signed-off-by: wangting0128 <ting.wang@zilliz.com>
This commit is contained in:
parent
5a008c72e0
commit
10ca86ef79
@ -120,6 +120,7 @@ class AccAccuracyRunner(AccuracyRunner):
|
||||
def extract_cases(self, collection):
|
||||
collection_name = collection["collection_name"] if "collection_name" in collection else None
|
||||
(data_type, dimension, metric_type) = parser.parse_ann_collection_name(collection_name)
|
||||
# hdf5_source_file: The path of the source data file saved on the NAS
|
||||
hdf5_source_file = collection["source_file"]
|
||||
index_types = collection["index_types"]
|
||||
index_params = collection["index_params"]
|
||||
@ -136,11 +137,14 @@ class AccAccuracyRunner(AccuracyRunner):
|
||||
}
|
||||
filters = collection["filters"] if "filters" in collection else []
|
||||
filter_query = []
|
||||
# Convert list data into a set of dictionary data
|
||||
search_params = utils.generate_combinations(search_params)
|
||||
index_params = utils.generate_combinations(index_params)
|
||||
cases = list()
|
||||
case_metrics = list()
|
||||
self.init_metric(self.name, collection_info, {}, search_info=None)
|
||||
|
||||
# true_ids: The data set used to verify the results returned by query
|
||||
true_ids = np.array(dataset["neighbors"])
|
||||
for index_type in index_types:
|
||||
for index_param in index_params:
|
||||
@ -192,11 +196,14 @@ class AccAccuracyRunner(AccuracyRunner):
|
||||
"vector_query": vector_query,
|
||||
"true_ids": true_ids
|
||||
}
|
||||
# Obtain the parameters of the use case to be tested
|
||||
cases.append(case)
|
||||
case_metrics.append(case_metric)
|
||||
return cases, case_metrics
|
||||
|
||||
def prepare(self, **case_param):
|
||||
""" According to the test case parameters, initialize the test """
|
||||
|
||||
collection_name = case_param["collection_name"]
|
||||
metric_type = case_param["metric_type"]
|
||||
dimension = case_param["dimension"]
|
||||
@ -211,6 +218,7 @@ class AccAccuracyRunner(AccuracyRunner):
|
||||
self.milvus.drop()
|
||||
dataset = case_param["dataset"]
|
||||
self.milvus.create_collection(dimension, data_type=vector_type)
|
||||
# Get the data set train for inserting into the collection
|
||||
insert_vectors = utils.normalize(metric_type, np.array(dataset["train"]))
|
||||
if len(insert_vectors) != dataset["train"].shape[0]:
|
||||
raise Exception("Row count of insert vectors: %d is not equal to dataset size: %d" % (
|
||||
@ -224,6 +232,7 @@ class AccAccuracyRunner(AccuracyRunner):
|
||||
start = i * INSERT_INTERVAL
|
||||
end = min((i + 1) * INSERT_INTERVAL, len(insert_vectors))
|
||||
if start < end:
|
||||
# Insert up to INSERT_INTERVAL=50000 at a time
|
||||
tmp_vectors = insert_vectors[start:end]
|
||||
ids = [i for i in range(start, end)]
|
||||
if not isinstance(tmp_vectors, list):
|
||||
@ -256,7 +265,9 @@ class AccAccuracyRunner(AccuracyRunner):
|
||||
top_k = case_metric.search["topk"]
|
||||
query_res = self.milvus.query(case_param["vector_query"], filter_query=case_param["filter_query"])
|
||||
result_ids = self.milvus.get_ids(query_res)
|
||||
# Calculate the accuracy of the result of query
|
||||
acc_value = utils.get_recall_value(true_ids[:nq, :top_k].tolist(), result_ids)
|
||||
tmp_result = {"acc": acc_value}
|
||||
# Return accuracy results for reporting
|
||||
return tmp_result
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user