Modify search test cases (#22054)

Signed-off-by: nico <cheng.yuan@zilliz.com>
This commit is contained in:
NicoYuan1986 2023-02-09 19:18:32 +08:00 committed by GitHub
parent 915503de0b
commit 01e210854d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3991,7 +3991,7 @@ class TestsearchPagination(TestcaseBase):
ids = hits.ids
assert set(ids).issubset(filter_ids_set)
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert sorted(search_res[0].distances, key=float) == sorted(res_distance, key=float)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
@ -4033,7 +4033,7 @@ class TestsearchPagination(TestcaseBase):
res.done()
res = res.result()
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert sorted(search_res[0].distances, key=float) == sorted(res_distance, key=float)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@pytest.mark.tags(CaseLabel.L2)
@ -4204,7 +4204,7 @@ class TestsearchPagination(TestcaseBase):
res.done()
res = res.result()
res_distance = res[0].distances[offset:]
assert sorted(search_res[0].distances) == sorted(res_distance)
assert sorted(search_res[0].distances, key=float) == sorted(res_distance, key=float)
assert set(search_res[0].ids) == set(res[0].ids[offset:])
@ -4301,10 +4301,8 @@ class TestsearchDiskann(TestcaseBase):
default_index = {"index_type": "DISKANN", "metric_type":"L2", "params": {}}
collection_w.create_index(ct.default_float_vec_field_name, default_index)
collection_w.load()
default_search_params ={"metric_type": "L2", "params": {"search_list": 30}}
default_search_params = {"metric_type": "L2", "params": {"search_list": 30}}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
collection_w.search(vectors[:default_nq], default_search_field,
@ -4338,7 +4336,7 @@ class TestsearchDiskann(TestcaseBase):
default_index = {"index_type": "DISKANN", "metric_type":"L2", "params": {}}
collection_w.create_index(ct.default_float_vec_field_name, default_index)
collection_w.load()
default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}}
default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
collection_w.search(vectors[:default_nq], default_search_field,
@ -4369,7 +4367,7 @@ class TestsearchDiskann(TestcaseBase):
default_index = {"index_type": "DISKANN", "metric_type":"L2", "params": {}}
collection_w.create_index(ct.default_float_vec_field_name, default_index)
collection_w.load()
default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}}
default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
collection_w.search(vectors[:default_nq], default_search_field,
@ -4412,6 +4410,7 @@ class TestsearchDiskann(TestcaseBase):
check_items={"err_code": 1,
"err_msg": "fail to search on all shard leaders"}
)
@pytest.mark.tags(CaseLabel.L2)
def test_search_with_diskann_with_string_pk(self, dim):
"""
@ -4443,7 +4442,6 @@ class TestsearchDiskann(TestcaseBase):
"limit": default_limit}
)
@pytest.mark.tags(CaseLabel.L2)
def test_search_with_delete_data(self, dim, auto_id, _async):
"""
@ -4484,8 +4482,7 @@ class TestsearchDiskann(TestcaseBase):
"limit": default_limit,
"_async": _async}
)
@pytest.mark.tags(CaseLabel.L2)
def test_search_with_diskann_and_more_index(self, dim, auto_id, _async):
"""
@ -4503,7 +4500,7 @@ class TestsearchDiskann(TestcaseBase):
collection_w.create_index(ct.default_float_vec_field_name, default_index, index_name=index_name1)
index_params_one = {}
collection_w.create_index("float", index_params_one, index_name="a")
index_param_two ={}
index_param_two = {}
collection_w.create_index("varchar", index_param_two, index_name="b")
collection_w.load()
@ -4558,7 +4555,6 @@ class TestsearchDiskann(TestcaseBase):
limit = 4
default_search_params ={"metric_type": "L2", "params": {"nprobe": 64}}
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]