enhance: Update groupby tests (#31297)

related issue: #29883 
skip running for now

Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
yanliang567 2024-03-15 15:21:03 +08:00 committed by GitHub
parent c408a32db6
commit 8563c4a5ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9898,12 +9898,12 @@ class TestSearchIterator(TestcaseBase):
"err_msg": "Not support multiple vector iterator at present"})
@pytest.mark.skip("not ready for running")
class TestSearchGroupBy(TestcaseBase):
""" Test case of search group by """
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.parametrize("index_type, metric", zip(["FLAT", "IVF_FLAT", "HNSW"], ct.float_metrics))
@pytest.mark.skip(reason="issue #29883")
def test_search_group_by_default(self, index_type, metric):
"""
target: test search group by
@ -9927,7 +9927,7 @@ class TestSearchGroupBy(TestcaseBase):
collection_w.flush()
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params)
time.sleep(30)
time.sleep(10)
collection_w.load()
search_params = {"metric_type": metric, "params": {"ef": 128}}
@ -9986,7 +9986,6 @@ class TestSearchGroupBy(TestcaseBase):
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("metric", ["JACCARD", "HAMMING"])
@pytest.mark.skip(reason="issue #29883")
def test_search_binary_vec_group_by(self, metric):
"""
target: test search group by
@ -10008,7 +10007,7 @@ class TestSearchGroupBy(TestcaseBase):
collection_w.flush()
collection_w.create_index(ct.default_binary_vec_field_name, index_params=_index)
time.sleep(30)
time.sleep(5)
collection_w.load()
search_params = {"metric_type": metric, "params": {"ef": 128}}
@ -10052,7 +10051,6 @@ class TestSearchGroupBy(TestcaseBase):
# verify no dup values of the group_by_field in results
assert len(grpby_values) == len(set(grpby_values))
@pytest.mark.skip(reason="issue #29883")
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.parametrize("grpby_field", [ct.default_string_field_name, ct.default_int8_field_name])
def test_search_group_by_with_field_indexed(self, grpby_field):
@ -10092,6 +10090,7 @@ class TestSearchGroupBy(TestcaseBase):
output_fields=[grpby_field])[0]
for i in range(nq):
grpby_values = []
dismatch = 0
results_num = 2 if grpby_field == ct.default_bool_field_name else limit
for l in range(results_num):
top1 = res1[i][l]
@ -10107,12 +10106,17 @@ class TestSearchGroupBy(TestcaseBase):
output_fields=[grpby_field])[0]
top1_expr_pk = res_tmp[0][0].id
log.info(f"nq={i}, limit={l}")
assert top1_grpby_pk == top1_expr_pk
# assert top1_grpby_pk == top1_expr_pk
if top1_grpby_pk != top1_expr_pk:
dismatch += 1
log.info(f"{grpby_field} on {metric} dismatch_item, top1_grpby_dis: {top1.distance}, top1_expr_dis: {res_tmp[0][0].distance}")
log.info(f"{grpby_field} on {metric} top1_dismatch_num: {dismatch}, results_num: {results_num}, dismatch_rate: {dismatch / results_num}")
baseline = 1 if grpby_field == ct.default_bool_field_name else 0.2 # skip baseline check for boolean
assert dismatch / results_num <= baseline
# verify no dup values of the group_by_field in results
assert len(grpby_values) == len(set(grpby_values))
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.skip(reason="issue #29967")
@pytest.mark.parametrize("grpby_unsupported_field", [ct.default_float_field_name, ct.default_json_field_name,
ct.default_double_field_name, ct.default_float_vec_field_name])
def test_search_group_by_unsupported_filed(self, grpby_unsupported_field):
@ -10137,7 +10141,7 @@ class TestSearchGroupBy(TestcaseBase):
# search with groupby
err_code = 999
err_msg = "unsupported"
err_msg = f"unsupported data type {grpby_unsupported_field} for group by operator"
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
param=search_params, limit=limit,
group_by_field=grpby_unsupported_field,
@ -10148,7 +10152,7 @@ class TestSearchGroupBy(TestcaseBase):
@pytest.mark.parametrize("index, params",
zip(ct.all_index_types[:7],
ct.default_index_params[:7]))
@pytest.mark.skip(reason="issue #29968")
# @pytest.mark.skip(reason="issue #29968")
def test_search_group_by_unsupported_index(self, index, params):
"""
target: test search group by with the unsupported vector index
@ -10183,9 +10187,39 @@ class TestSearchGroupBy(TestcaseBase):
check_task=CheckTasks.err_res,
check_items={"err_code": err_code, "err_msg": err_msg})
@pytest.mark.tags(CaseLabel.L2)
def test_search_group_by_multi_fields(self):
"""
target: test search group by with the multi fields
method: 1. create a collection with data
2. create index
3. search with group by the multi fields
verify: the error code and msg
"""
metric = "IP"
collection_w = self.init_collection_general(prefix, insert_data=False, is_index=False,
is_all_data_type=True, with_json=True, )[0]
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
collection_w.load()
search_params = {"metric_type": metric, "params": {"ef": 128}}
nq = 1
limit = 1
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
# search with groupby
err_code = 1700
err_msg = f"groupBy field not found in schema"
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
param=search_params, limit=limit,
group_by_field=[ct.default_string_field_name, ct.default_int32_field_name],
check_task=CheckTasks.err_res,
check_items={"err_code": err_code, "err_msg": err_msg})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("grpby_nonexist_field", ["nonexit_field", 100])
def test_search_group_by_nonexit_filed(self, grpby_nonexist_field):
def test_search_group_by_nonexit_fields(self, grpby_nonexist_field):
"""
target: test search group by with the nonexisting field
method: 1. create a collection with data
@ -10194,7 +10228,7 @@ class TestSearchGroupBy(TestcaseBase):
verify: the error code and msg
"""
metric = "IP"
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
collection_w = self.init_collection_general(prefix, insert_data=False, is_index=False,
is_all_data_type=True, with_json=True, )[0]
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
@ -10215,7 +10249,7 @@ class TestSearchGroupBy(TestcaseBase):
check_items={"err_code": err_code, "err_msg": err_msg})
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.skip(reason="issue #30033, #30828")
@pytest.mark.xfail(reason="issue #30828")
def test_search_pagination_group_by(self):
"""
target: test search group by
@ -10267,22 +10301,23 @@ class TestSearchGroupBy(TestcaseBase):
check_items={"nq": 1, "limit": limit * page_rounds}
)[0]
hit_num = len(set(total_res[0].ids).intersection(set(all_pages_ids)))
assert hit_num / (limit * page_rounds) > 0.90
hit_rate = round(hit_num / (limit * page_rounds), 3)
assert hit_rate > 0.90
log.info(f"search pagination with groupby hit_rate: {hit_rate}")
grpby_field_values = []
for i in range(limit * page_rounds):
grpby_field_values.append(total_res[0][i].fields.get(grpby_field))
assert len(grpby_field_values) == len(set(grpby_field_values))
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.skip(reason="not support iterator + group by")
def test_search_iterator_group_by(self):
def test_search_iterator_not_support_group_by(self):
"""
target: test search group by
method: 1. create a collection with data
2. create index HNSW
3. search iterator with group by
4. search with filtering every value of group_by_field
verify: verify successfully and iterators are correct
verify: error code and msg
"""
metric = "COSINE"
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
@ -10303,29 +10338,16 @@ class TestSearchGroupBy(TestcaseBase):
search_params = {"metric_type": metric}
batch_size = 10
# res = collection_w.search(search_vectors,ct.default_float_vec_field_name,
# search_params, group_by_field=grpby_field, limit=10)[0]
ite_res = collection_w.search_iterator(search_vectors, ct.default_float_vec_field_name,
search_params, batch_size, group_by_field=grpby_field,
output_fields=[grpby_field]
)[0]
# iterators = 0
# while True and iterators < value_num/batch_size:
# res = ite_res.next() # turn to the next page
# if len(res) == 0:
# ite_res.close() # close the iterator
# break
# iterators += 1
# grp_values = []
# for j in range(len(res)):
# grp_values.append(res.get__item(j).get(grpby_field))
# log.info(f"iterators: {iterators}, grp_values: {grp_values}")
# assert iterators == value_num/batch_size
err_code = 1100
err_msg = "Not allowed to do groupBy when doing iteration"
collection_w.search_iterator(search_vectors, ct.default_float_vec_field_name,
search_params, batch_size, group_by_field=grpby_field,
output_fields=[grpby_field],
check_task=CheckTasks.err_res,
check_items={"err_code": err_code, "err_msg": err_msg})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip(reason="not support range search + group by")
def test_range_search_group_by(self):
def test_range_search_not_support_group_by(self):
"""
target: test search group by
method: 1. create a collection with data
@ -10354,16 +10376,14 @@ class TestSearchGroupBy(TestcaseBase):
grpby_field = ct.default_int32_field_name
range_search_params = {"metric_type": "COSINE", "params": {"radius": 0.1,
"range_filter": 0.5}}
res = collection_w.search(search_vectors, ct.default_float_vec_field_name,
range_search_params, limit,
default_search_exp, group_by_field=grpby_field,
output_fields=[grpby_field],
check_task=CheckTasks.check_search_results,
check_items={"nq": nq, "limit": limit})[0]
# grpby_field_values = []
# for i in range(limit):
# grpby_field_values.append(res[0][i].fields.get(grpby_field))
# assert len(grpby_field_values) == len(set(grpby_field_values))
err_code = 1100
err_msg = f"Not allowed to do range-search"
collection_w.search(search_vectors, ct.default_float_vec_field_name,
range_search_params, limit,
default_search_exp, group_by_field=grpby_field,
output_fields=[grpby_field],
check_task=CheckTasks.err_res,
check_items={"err_code": err_code, "err_msg": err_msg})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip(reason="not completed")
@ -10371,10 +10391,11 @@ class TestSearchGroupBy(TestcaseBase):
"""
target: test search group by
method: 1. create a collection with multiple vector fields
2. create index hnsw and hnsw
2. create index hnsw and load
3. hybrid_search with group by
verify: the error code and msg
"""
# 1. initialize collection with data
pass
@pytest.mark.tags(CaseLabel.L1)