From 10ca86ef79c6711fabca74372ae0789cdf5a1286 Mon Sep 17 00:00:00 2001 From: wt Date: Thu, 16 Sep 2021 17:59:49 +0800 Subject: [PATCH] [skip ci] Add comment for the accuracy test (#8092) Signed-off-by: wangting0128 --- tests/benchmark/milvus_benchmark/runners/accuracy.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/benchmark/milvus_benchmark/runners/accuracy.py b/tests/benchmark/milvus_benchmark/runners/accuracy.py index 0ec2a1aaae..c2be5d2b9e 100644 --- a/tests/benchmark/milvus_benchmark/runners/accuracy.py +++ b/tests/benchmark/milvus_benchmark/runners/accuracy.py @@ -120,6 +120,7 @@ class AccAccuracyRunner(AccuracyRunner): def extract_cases(self, collection): collection_name = collection["collection_name"] if "collection_name" in collection else None (data_type, dimension, metric_type) = parser.parse_ann_collection_name(collection_name) + # hdf5_source_file: The path of the source data file saved on the NAS hdf5_source_file = collection["source_file"] index_types = collection["index_types"] index_params = collection["index_params"] @@ -136,11 +137,14 @@ class AccAccuracyRunner(AccuracyRunner): } filters = collection["filters"] if "filters" in collection else [] filter_query = [] + # Convert list data into a set of dictionary data search_params = utils.generate_combinations(search_params) index_params = utils.generate_combinations(index_params) cases = list() case_metrics = list() self.init_metric(self.name, collection_info, {}, search_info=None) + + # true_ids: The data set used to verify the results returned by query true_ids = np.array(dataset["neighbors"]) for index_type in index_types: for index_param in index_params: @@ -192,11 +196,14 @@ class AccAccuracyRunner(AccuracyRunner): "vector_query": vector_query, "true_ids": true_ids } + # Obtain the parameters of the use case to be tested cases.append(case) case_metrics.append(case_metric) return cases, case_metrics def prepare(self, **case_param): + """ According to the test case parameters, initialize the test """ + collection_name = case_param["collection_name"] metric_type = case_param["metric_type"] dimension = case_param["dimension"] @@ -211,6 +218,7 @@ class AccAccuracyRunner(AccuracyRunner): self.milvus.drop() dataset = case_param["dataset"] self.milvus.create_collection(dimension, data_type=vector_type) + # Get the data set train for inserting into the collection insert_vectors = utils.normalize(metric_type, np.array(dataset["train"])) if len(insert_vectors) != dataset["train"].shape[0]: raise Exception("Row count of insert vectors: %d is not equal to dataset size: %d" % ( @@ -224,6 +232,7 @@ class AccAccuracyRunner(AccuracyRunner): start = i * INSERT_INTERVAL end = min((i + 1) * INSERT_INTERVAL, len(insert_vectors)) if start < end: + # Insert up to INSERT_INTERVAL=50000 at a time tmp_vectors = insert_vectors[start:end] ids = [i for i in range(start, end)] if not isinstance(tmp_vectors, list): @@ -256,7 +265,9 @@ class AccAccuracyRunner(AccuracyRunner): top_k = case_metric.search["topk"] query_res = self.milvus.query(case_param["vector_query"], filter_query=case_param["filter_query"]) result_ids = self.milvus.get_ids(query_res) + # Calculate the accuracy of the result of query acc_value = utils.get_recall_value(true_ids[:nq, :top_k].tolist(), result_ids) tmp_result = {"acc": acc_value} + # Return accuracy results for reporting return tmp_result