From 6ae6b6a92231936cf65ee8f8a08c0dee4d7f32a4 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Wed, 17 Aug 2022 22:14:54 +0800 Subject: [PATCH] [test] Add case to test search and index about handoff (#18680) Signed-off-by: ThreadDao Signed-off-by: ThreadDao --- tests/python_client/testcases/test_utility.py | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index 3db37f6c46..2fad2f6486 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -10,6 +10,7 @@ from common import common_func as cf from common import common_type as ct from common.common_type import CaseLabel, CheckTasks from common.milvus_sys import MilvusSys +from pymilvus.grpc_gen.common_pb2 import SegmentState prefix = "utility" default_schema = cf.gen_default_collection_schema() @@ -907,14 +908,16 @@ class TestUtilityBase(TestcaseBase): assert res_part_partition == {'loading_progress': '50%', 'num_loaded_partitions': 1, 'not_loaded_partitions': [partition_w.name]} - res_part_partition, _ = self.utility_wrap.loading_progress(collection_w.name, partition_names=[partition_w.name]) + res_part_partition, _ = self.utility_wrap.loading_progress(collection_w.name, + partition_names=[partition_w.name]) assert res_part_partition == {'loading_progress': '0%', 'num_loaded_partitions': 0, 'not_loaded_partitions': [partition_w.name]} collection_w.release() collection_w.load(replica_number=2) res_all_partitions, _ = self.utility_wrap.loading_progress(collection_w.name) - assert res_all_partitions == {'loading_progress': '100%', 'num_loaded_partitions': 2, 'not_loaded_partitions': []} + assert res_all_partitions == {'loading_progress': '100%', 'num_loaded_partitions': 2, + 'not_loaded_partitions': []} @pytest.mark.tags(CaseLabel.L1) def test_wait_loading_collection_empty(self): @@ -1702,7 +1705,7 @@ class TestUtilityAdvanced(TestcaseBase): if len(g.group_nodes) >= 2: group_nodes = list(g.group_nodes) break - src_node_id = group_nodes[0] + src_node_id = group_nodes[0] dst_node_ids = list(set(all_querynodes) - set(group_nodes)) res, _ = self.utility_wrap.get_query_segment_info(c_name) segment_distribution = cf.get_segment_distribution(res) @@ -1710,4 +1713,47 @@ class TestUtilityAdvanced(TestcaseBase): # load balance self.utility_wrap.load_balance(collection_w.name, src_node_id, dst_node_ids, sealed_segment_ids, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "must be in the same replica group"}) \ No newline at end of file + check_items={ct.err_code: 1, ct.err_msg: "must be in the same replica group"}) + + @pytest.mark.tags(CaseLabel.L1) + def test_handoff_query_search(self): + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), shards_num=1) + collection_w.create_index(default_field_name, default_index_params) + collection_w.load() + + # handoff: insert and flush one segment + df = cf.gen_default_dataframe_data() + insert_res, _ = collection_w.insert(df) + term_expr = f'{ct.default_int64_field_name} in {insert_res.primary_keys[:10]}' + res = df.iloc[:10, :1].to_dict('records') + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, + check_items={'exp_res': res}) + search_res_before, _ = collection_w.search(df[ct.default_float_vec_field_name][:1].to_list(), + ct.default_float_vec_field_name, + ct.default_search_params, ct.default_limit) + log.debug(collection_w.num_entities) + + start = time.time() + while True: + time.sleep(0.5) + segment_infos, _ = self.utility_wrap.get_query_segment_info(collection_w.name) + # handoff done + if len(segment_infos) == 1 and segment_infos[0].state == SegmentState.Sealed: + break + if time.time() - start > 20: + raise MilvusException(1, f"Get query segment info after handoff cost more than 20s") + + # query and search from handoff segments + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, + check_items={'exp_res': res}) + search_res_after, _ = collection_w.search(df[ct.default_float_vec_field_name][:1].to_list(), + ct.default_float_vec_field_name, + ct.default_search_params, ct.default_limit) + # the ids between twice search is different because of index building + log.debug(search_res_before[0].ids) + log.debug(search_res_after[0].ids) + # assert search_res_before[0].ids != search_res_after[0].ids + + # assert search result includes the nq-vector before or after handoff + assert search_res_after[0].ids[0] == 0 + assert search_res_before[0].ids[0] == search_res_after[0].ids[0]