From 01e210854df8a88f7ed81ce0e1b5a3dec35d911d Mon Sep 17 00:00:00 2001 From: NicoYuan1986 <109071306+NicoYuan1986@users.noreply.github.com> Date: Thu, 9 Feb 2023 19:18:32 +0800 Subject: [PATCH] Modify search test cases (#22054) Signed-off-by: nico --- tests/python_client/testcases/test_search.py | 22 ++++++++------------ 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 0e6d6634ad..a319e091a9 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -3991,7 +3991,7 @@ class TestsearchPagination(TestcaseBase): ids = hits.ids assert set(ids).issubset(filter_ids_set) res_distance = res[0].distances[offset:] - assert sorted(search_res[0].distances) == sorted(res_distance) + assert sorted(search_res[0].distances, key=float) == sorted(res_distance, key=float) assert set(search_res[0].ids) == set(res[0].ids[offset:]) @pytest.mark.tags(CaseLabel.L2) @@ -4033,7 +4033,7 @@ class TestsearchPagination(TestcaseBase): res.done() res = res.result() res_distance = res[0].distances[offset:] - assert sorted(search_res[0].distances) == sorted(res_distance) + assert sorted(search_res[0].distances, key=float) == sorted(res_distance, key=float) assert set(search_res[0].ids) == set(res[0].ids[offset:]) @pytest.mark.tags(CaseLabel.L2) @@ -4204,7 +4204,7 @@ class TestsearchPagination(TestcaseBase): res.done() res = res.result() res_distance = res[0].distances[offset:] - assert sorted(search_res[0].distances) == sorted(res_distance) + assert sorted(search_res[0].distances, key=float) == sorted(res_distance, key=float) assert set(search_res[0].ids) == set(res[0].ids[offset:]) @@ -4301,10 +4301,8 @@ class TestsearchDiskann(TestcaseBase): default_index = {"index_type": "DISKANN", "metric_type":"L2", "params": {}} collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() - - - default_search_params ={"metric_type": "L2", "params": {"search_list": 30}} + default_search_params = {"metric_type": "L2", "params": {"search_list": 30}} vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, @@ -4338,7 +4336,7 @@ class TestsearchDiskann(TestcaseBase): default_index = {"index_type": "DISKANN", "metric_type":"L2", "params": {}} collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}} + default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, @@ -4369,7 +4367,7 @@ class TestsearchDiskann(TestcaseBase): default_index = {"index_type": "DISKANN", "metric_type":"L2", "params": {}} collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}} + default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, @@ -4412,6 +4410,7 @@ class TestsearchDiskann(TestcaseBase): check_items={"err_code": 1, "err_msg": "fail to search on all shard leaders"} ) + @pytest.mark.tags(CaseLabel.L2) def test_search_with_diskann_with_string_pk(self, dim): """ @@ -4443,7 +4442,6 @@ class TestsearchDiskann(TestcaseBase): "limit": default_limit} ) - @pytest.mark.tags(CaseLabel.L2) def test_search_with_delete_data(self, dim, auto_id, _async): """ @@ -4484,8 +4482,7 @@ class TestsearchDiskann(TestcaseBase): "limit": default_limit, "_async": _async} ) - - + @pytest.mark.tags(CaseLabel.L2) def test_search_with_diskann_and_more_index(self, dim, auto_id, _async): """ @@ -4503,7 +4500,7 @@ class TestsearchDiskann(TestcaseBase): collection_w.create_index(ct.default_float_vec_field_name, default_index, index_name=index_name1) index_params_one = {} collection_w.create_index("float", index_params_one, index_name="a") - index_param_two ={} + index_param_two = {} collection_w.create_index("varchar", index_param_two, index_name="b") collection_w.load() @@ -4558,7 +4555,6 @@ class TestsearchDiskann(TestcaseBase): limit = 4 - default_search_params ={"metric_type": "L2", "params": {"nprobe": 64}} vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]