mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: Update groupby tests (#31297)
related issue: #29883 skip running for now Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
parent
c408a32db6
commit
8563c4a5ee
@ -9898,12 +9898,12 @@ class TestSearchIterator(TestcaseBase):
|
||||
"err_msg": "Not support multiple vector iterator at present"})
|
||||
|
||||
|
||||
@pytest.mark.skip("not ready for running")
|
||||
class TestSearchGroupBy(TestcaseBase):
|
||||
""" Test case of search group by """
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("index_type, metric", zip(["FLAT", "IVF_FLAT", "HNSW"], ct.float_metrics))
|
||||
@pytest.mark.skip(reason="issue #29883")
|
||||
def test_search_group_by_default(self, index_type, metric):
|
||||
"""
|
||||
target: test search group by
|
||||
@ -9927,7 +9927,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params)
|
||||
time.sleep(30)
|
||||
time.sleep(10)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
@ -9986,7 +9986,6 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("metric", ["JACCARD", "HAMMING"])
|
||||
@pytest.mark.skip(reason="issue #29883")
|
||||
def test_search_binary_vec_group_by(self, metric):
|
||||
"""
|
||||
target: test search group by
|
||||
@ -10008,7 +10007,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
|
||||
collection_w.flush()
|
||||
collection_w.create_index(ct.default_binary_vec_field_name, index_params=_index)
|
||||
time.sleep(30)
|
||||
time.sleep(5)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
@ -10052,7 +10051,6 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
@pytest.mark.skip(reason="issue #29883")
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("grpby_field", [ct.default_string_field_name, ct.default_int8_field_name])
|
||||
def test_search_group_by_with_field_indexed(self, grpby_field):
|
||||
@ -10092,6 +10090,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
output_fields=[grpby_field])[0]
|
||||
for i in range(nq):
|
||||
grpby_values = []
|
||||
dismatch = 0
|
||||
results_num = 2 if grpby_field == ct.default_bool_field_name else limit
|
||||
for l in range(results_num):
|
||||
top1 = res1[i][l]
|
||||
@ -10107,12 +10106,17 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
output_fields=[grpby_field])[0]
|
||||
top1_expr_pk = res_tmp[0][0].id
|
||||
log.info(f"nq={i}, limit={l}")
|
||||
assert top1_grpby_pk == top1_expr_pk
|
||||
# assert top1_grpby_pk == top1_expr_pk
|
||||
if top1_grpby_pk != top1_expr_pk:
|
||||
dismatch += 1
|
||||
log.info(f"{grpby_field} on {metric} dismatch_item, top1_grpby_dis: {top1.distance}, top1_expr_dis: {res_tmp[0][0].distance}")
|
||||
log.info(f"{grpby_field} on {metric} top1_dismatch_num: {dismatch}, results_num: {results_num}, dismatch_rate: {dismatch / results_num}")
|
||||
baseline = 1 if grpby_field == ct.default_bool_field_name else 0.2 # skip baseline check for boolean
|
||||
assert dismatch / results_num <= baseline
|
||||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="issue #29967")
|
||||
@pytest.mark.parametrize("grpby_unsupported_field", [ct.default_float_field_name, ct.default_json_field_name,
|
||||
ct.default_double_field_name, ct.default_float_vec_field_name])
|
||||
def test_search_group_by_unsupported_filed(self, grpby_unsupported_field):
|
||||
@ -10137,7 +10141,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
|
||||
# search with groupby
|
||||
err_code = 999
|
||||
err_msg = "unsupported"
|
||||
err_msg = f"unsupported data type {grpby_unsupported_field} for group by operator"
|
||||
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=grpby_unsupported_field,
|
||||
@ -10148,7 +10152,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
@pytest.mark.parametrize("index, params",
|
||||
zip(ct.all_index_types[:7],
|
||||
ct.default_index_params[:7]))
|
||||
@pytest.mark.skip(reason="issue #29968")
|
||||
# @pytest.mark.skip(reason="issue #29968")
|
||||
def test_search_group_by_unsupported_index(self, index, params):
|
||||
"""
|
||||
target: test search group by with the unsupported vector index
|
||||
@ -10183,9 +10187,39 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_search_group_by_multi_fields(self):
|
||||
"""
|
||||
target: test search group by with the multi fields
|
||||
method: 1. create a collection with data
|
||||
2. create index
|
||||
3. search with group by the multi fields
|
||||
verify: the error code and msg
|
||||
"""
|
||||
metric = "IP"
|
||||
collection_w = self.init_collection_general(prefix, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=True, )[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": metric, "params": {"ef": 128}}
|
||||
nq = 1
|
||||
limit = 1
|
||||
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
|
||||
|
||||
# search with groupby
|
||||
err_code = 1700
|
||||
err_msg = f"groupBy field not found in schema"
|
||||
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=[ct.default_string_field_name, ct.default_int32_field_name],
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("grpby_nonexist_field", ["nonexit_field", 100])
|
||||
def test_search_group_by_nonexit_filed(self, grpby_nonexist_field):
|
||||
def test_search_group_by_nonexit_fields(self, grpby_nonexist_field):
|
||||
"""
|
||||
target: test search group by with the nonexisting field
|
||||
method: 1. create a collection with data
|
||||
@ -10194,7 +10228,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
verify: the error code and msg
|
||||
"""
|
||||
metric = "IP"
|
||||
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
|
||||
collection_w = self.init_collection_general(prefix, insert_data=False, is_index=False,
|
||||
is_all_data_type=True, with_json=True, )[0]
|
||||
_index = {"index_type": "HNSW", "metric_type": metric, "params": {"M": 16, "efConstruction": 128}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=_index)
|
||||
@ -10215,7 +10249,7 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="issue #30033, #30828")
|
||||
@pytest.mark.xfail(reason="issue #30828")
|
||||
def test_search_pagination_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
@ -10267,22 +10301,23 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
check_items={"nq": 1, "limit": limit * page_rounds}
|
||||
)[0]
|
||||
hit_num = len(set(total_res[0].ids).intersection(set(all_pages_ids)))
|
||||
assert hit_num / (limit * page_rounds) > 0.90
|
||||
hit_rate = round(hit_num / (limit * page_rounds), 3)
|
||||
assert hit_rate > 0.90
|
||||
log.info(f"search pagination with groupby hit_rate: {hit_rate}")
|
||||
grpby_field_values = []
|
||||
for i in range(limit * page_rounds):
|
||||
grpby_field_values.append(total_res[0][i].fields.get(grpby_field))
|
||||
assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.skip(reason="not support iterator + group by")
|
||||
def test_search_iterator_group_by(self):
|
||||
def test_search_iterator_not_support_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
2. create index HNSW
|
||||
3. search iterator with group by
|
||||
4. search with filtering every value of group_by_field
|
||||
verify: verify successfully and iterators are correct
|
||||
verify: error code and msg
|
||||
"""
|
||||
metric = "COSINE"
|
||||
collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False,
|
||||
@ -10303,29 +10338,16 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
search_params = {"metric_type": metric}
|
||||
batch_size = 10
|
||||
|
||||
# res = collection_w.search(search_vectors,ct.default_float_vec_field_name,
|
||||
# search_params, group_by_field=grpby_field, limit=10)[0]
|
||||
|
||||
ite_res = collection_w.search_iterator(search_vectors, ct.default_float_vec_field_name,
|
||||
search_params, batch_size, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field]
|
||||
)[0]
|
||||
# iterators = 0
|
||||
# while True and iterators < value_num/batch_size:
|
||||
# res = ite_res.next() # turn to the next page
|
||||
# if len(res) == 0:
|
||||
# ite_res.close() # close the iterator
|
||||
# break
|
||||
# iterators += 1
|
||||
# grp_values = []
|
||||
# for j in range(len(res)):
|
||||
# grp_values.append(res.get__item(j).get(grpby_field))
|
||||
# log.info(f"iterators: {iterators}, grp_values: {grp_values}")
|
||||
# assert iterators == value_num/batch_size
|
||||
err_code = 1100
|
||||
err_msg = "Not allowed to do groupBy when doing iteration"
|
||||
collection_w.search_iterator(search_vectors, ct.default_float_vec_field_name,
|
||||
search_params, batch_size, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="not support range search + group by")
|
||||
def test_range_search_group_by(self):
|
||||
def test_range_search_not_support_group_by(self):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with data
|
||||
@ -10354,16 +10376,14 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
grpby_field = ct.default_int32_field_name
|
||||
range_search_params = {"metric_type": "COSINE", "params": {"radius": 0.1,
|
||||
"range_filter": 0.5}}
|
||||
res = collection_w.search(search_vectors, ct.default_float_vec_field_name,
|
||||
range_search_params, limit,
|
||||
default_search_exp, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": nq, "limit": limit})[0]
|
||||
# grpby_field_values = []
|
||||
# for i in range(limit):
|
||||
# grpby_field_values.append(res[0][i].fields.get(grpby_field))
|
||||
# assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
err_code = 1100
|
||||
err_msg = f"Not allowed to do range-search"
|
||||
collection_w.search(search_vectors, ct.default_float_vec_field_name,
|
||||
range_search_params, limit,
|
||||
default_search_exp, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": err_code, "err_msg": err_msg})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="not completed")
|
||||
@ -10371,10 +10391,11 @@ class TestSearchGroupBy(TestcaseBase):
|
||||
"""
|
||||
target: test search group by
|
||||
method: 1. create a collection with multiple vector fields
|
||||
2. create index hnsw and hnsw
|
||||
2. create index hnsw and load
|
||||
3. hybrid_search with group by
|
||||
verify: the error code and msg
|
||||
"""
|
||||
# 1. initialize collection with data
|
||||
pass
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user