diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 156051e612..103a8841b4 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -10131,8 +10131,8 @@ class TestSearchIterator(TestcaseBase): class TestSearchGroupBy(TestcaseBase): """ Test case of search group by """ - @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.parametrize("index_type, metric", zip(["FLAT", "IVF_FLAT", "HNSW"], ct.float_metrics)) + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("index_type, metric", zip(["DISKANN", "IVF_FLAT", "HNSW"], ct.float_metrics)) @pytest.mark.parametrize("vector_data_type", ["FLOAT16_VECTOR", "FLOAT_VECTOR", "BFLOAT16_VECTOR"]) def test_search_group_by_default(self, index_type, metric, vector_data_type): """ @@ -10147,9 +10147,8 @@ class TestSearchGroupBy(TestcaseBase): collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False, vector_data_type=vector_data_type, is_all_data_type=True, with_json=False)[0] - _index_params = {"index_type": index_type, "metric_type": metric, "params": {"M": 16, "efConstruction": 128}} - if index_type in ["IVF_FLAT", "FLAT"]: - _index_params = {"index_type": index_type, "metric_type": metric, "params": {"nlist": 128}} + _index_params = {"index_type": index_type, "metric_type": metric, + "params": cf.get_index_params_params(index_type)} collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params) # insert with the same values for scalar fields for _ in range(50): @@ -10160,7 +10159,7 @@ class TestSearchGroupBy(TestcaseBase): collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params) collection_w.load() - search_params = {"metric_type": metric, "params": {"ef": 128}} + search_params = {"params": cf.get_search_params_params(index_type)} nq = 2 limit = 15 search_vectors = cf.gen_vectors(nq, dim=ct.default_dim) @@ -10206,9 +10205,11 @@ class TestSearchGroupBy(TestcaseBase): top1_expr_pk = res_tmp[0][0].id if top1_grpby_pk != top1_expr_pk: dismatch += 1 - log.info(f"{grpby_field} on {metric} dismatch_item, top1_grpby_dis: {top1.distance}, top1_expr_dis: {res_tmp[0][0].distance}") - log.info(f"{grpby_field} on {metric} top1_dismatch_num: {dismatch}, results_num: {results_num}, dismatch_rate: {dismatch / results_num}") - baseline = 1 if grpby_field == ct.default_bool_field_name else 0.2 # skip baseline check for boolean + log.info( + f"{grpby_field} on {metric} dismatch_item, top1_grpby_dis: {top1.distance}, top1_expr_dis: {res_tmp[0][0].distance}") + log.info( + f"{grpby_field} on {metric} top1_dismatch_num: {dismatch}, results_num: {results_num}, dismatch_rate: {dismatch / results_num}") + baseline = 1 if grpby_field == ct.default_bool_field_name else 0.4 # skip baseline check for boolean assert dismatch / results_num <= baseline # verify no dup values of the group_by_field in results assert len(grpby_values) == len(set(grpby_values))