fix: refine test case for search_group_by(#36401) (#36511)

related: #36401

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
Chun Han 2024-09-30 10:13:17 +08:00 committed by GitHub
parent ecb2b242e2
commit a54bffd6cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2244,7 +2244,6 @@ class TestGroupSearch(TestCaseClassBase):
assert len(grpby_field_values) == len(set(grpby_field_values))
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.skip(reason="issue #36401")
def test_search_pagination_group_size(self):
limit = 10
group_size = 5
@ -2256,34 +2255,44 @@ class TestGroupSearch(TestCaseClassBase):
search_vectors = cf.gen_vectors(1, dim=self.dims[1], vector_data_type=self.vector_fields[1])
all_pages_ids = []
all_pages_grpby_field_values = []
res_count = limit * group_size
for r in range(page_rounds):
page_res = self.collection_wrap.search(search_vectors, anns_field=default_search_field,
param=search_param, limit=limit, offset=limit * r,
expr=default_search_exp,
group_by_field=grpby_field, group_size=group_size,
group_strict_size=True,
output_fields=[grpby_field],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1, "limit": limit},
check_items={"nq": 1, "limit": res_count},
)[0]
for j in range(limit):
for j in range(res_count):
all_pages_grpby_field_values.append(page_res[0][j].get(grpby_field))
all_pages_ids += page_res[0].ids
hit_rate = round(len(set(all_pages_grpby_field_values)) / len(all_pages_grpby_field_values), 3)
assert hit_rate >= 0.8
hit_rate = round(len(set(all_pages_grpby_field_values)) / len(all_pages_grpby_field_values), 3)
expect_hit_rate = round(1 / group_size, 3) * 0.7
log.info(f"expect_hit_rate :{expect_hit_rate}, hit_rate:{hit_rate}, "
f"unique_group_by_value_count:{len(set(all_pages_grpby_field_values))},"
f"total_group_by_value_count:{len(all_pages_grpby_field_values)}")
assert hit_rate >= expect_hit_rate
total_count = limit * group_size * page_rounds
total_res = self.collection_wrap.search(search_vectors, anns_field=default_search_field,
param=search_param, limit=limit * page_rounds,
expr=default_search_exp,
group_by_field=grpby_field, group_size=group_size,
group_strict_size=True,
output_fields=[grpby_field],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1, "limit": limit * page_rounds}
check_items={"nq": 1, "limit": total_count}
)[0]
hit_num = len(set(total_res[0].ids).intersection(set(all_pages_ids)))
hit_rate = round(hit_num / (limit * page_rounds), 3)
assert hit_rate >= 0.8
log.info(f"search pagination with groupby hit_rate: {hit_rate}")
grpby_field_values = []
for i in range(limit * page_rounds):
for i in range(total_count):
grpby_field_values.append(total_res[0][i].fields.get(grpby_field))
assert len(grpby_field_values) == len(set(grpby_field_values))
assert len(grpby_field_values) == total_count
assert len(set(grpby_field_values)) == limit * page_rounds