mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
test: Update hybrid search tests with milvus client (#46003)
related issue: https://github.com/milvus-io/milvus/issues/45326 Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
parent
e70c01362d
commit
13a52016ac
@ -436,14 +436,13 @@ class ResponseChecker:
|
|||||||
original_entities = pandas.DataFrame(original_entities)
|
original_entities = pandas.DataFrame(original_entities)
|
||||||
pc.output_field_value_check(search_res, original_entities, pk_name=pk_name)
|
pc.output_field_value_check(search_res, original_entities, pk_name=pk_name)
|
||||||
if len(search_res) != check_items["nq"]:
|
if len(search_res) != check_items["nq"]:
|
||||||
log.error("search_results_check: Numbers of query searched (%d) "
|
log.error("search_results_check: Numbers of query searched(nq) (%d) "
|
||||||
"is not equal with expected (%d)"
|
"is not equal with expected (%d)"
|
||||||
% (len(search_res), check_items["nq"]))
|
% (len(search_res), check_items["nq"]))
|
||||||
assert len(search_res) == check_items["nq"]
|
assert len(search_res) == check_items["nq"]
|
||||||
else:
|
else:
|
||||||
log.info("search_results_check: Numbers of query searched is correct")
|
log.info("search_results_check: Numbers of query searched is correct")
|
||||||
# log.debug(search_res)
|
# log.debug(search_res)
|
||||||
nq_i = 0
|
|
||||||
for hits in search_res:
|
for hits in search_res:
|
||||||
ids = []
|
ids = []
|
||||||
distances = []
|
distances = []
|
||||||
@ -461,25 +460,24 @@ class ResponseChecker:
|
|||||||
% (len(hits), check_items["limit"]))
|
% (len(hits), check_items["limit"]))
|
||||||
assert len(hits) == check_items["limit"]
|
assert len(hits) == check_items["limit"]
|
||||||
assert len(ids) == check_items["limit"]
|
assert len(ids) == check_items["limit"]
|
||||||
else:
|
|
||||||
if check_items.get("ids", None) is not None:
|
if check_items.get("ids", None) is not None:
|
||||||
ids_match = pc.list_contain_check(ids, list(check_items["ids"]))
|
ids_match = pc.list_contain_check(ids, list(check_items["ids"]))
|
||||||
if not ids_match:
|
if not ids_match:
|
||||||
log.error("search_results_check: ids searched not match")
|
log.error("search_results_check: ids searched not match")
|
||||||
assert ids_match
|
assert ids_match
|
||||||
elif check_items.get("metric", None) is not None:
|
if check_items.get("metric", None) is not None:
|
||||||
# verify the distances are already sorted
|
# verify the distances are already sorted
|
||||||
|
num_to_check = min(100, len(distances)) # check 100 items if more than that
|
||||||
if check_items.get("metric").upper() in ["IP", "COSINE", "BM25"]:
|
if check_items.get("metric").upper() in ["IP", "COSINE", "BM25"]:
|
||||||
assert pc.compare_lists_with_epsilon_ignore_dict_order(distances, sorted(distances, reverse=True))
|
assert distances[:num_to_check] == sorted(distances[:num_to_check], reverse=True)
|
||||||
else:
|
else:
|
||||||
assert pc.compare_lists_with_epsilon_ignore_dict_order(distances, sorted(distances, reverse=False))
|
assert distances[:num_to_check] == sorted(distances[:num_to_check], reverse=False)
|
||||||
if check_items.get("vector_nq") is None or check_items.get("original_vectors") is None:
|
if check_items.get("vector_nq") is None or check_items.get("original_vectors") is None:
|
||||||
log.debug("skip distance check for knowhere does not return the precise distances")
|
log.debug("skip distance check for knowhere does not return the precise distances")
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
pass # just check nq and topk, not specific ids need check
|
pass # just check nq and topk, not specific ids need check
|
||||||
nq_i += 1
|
|
||||||
|
|
||||||
log.info("search_results_check: limit (topK) and "
|
log.info("search_results_check: limit (topK) and "
|
||||||
"ids searched for %d queries are correct" % len(search_res))
|
"ids searched for %d queries are correct" % len(search_res))
|
||||||
@ -586,12 +584,13 @@ class ResponseChecker:
|
|||||||
for single_query_result in query_res:
|
for single_query_result in query_res:
|
||||||
single_query_result[vector_field] = np.frombuffer(single_query_result[vector_field][0], dtype=np.int8).tolist()
|
single_query_result[vector_field] = np.frombuffer(single_query_result[vector_field][0], dtype=np.int8).tolist()
|
||||||
if isinstance(query_res, list):
|
if isinstance(query_res, list):
|
||||||
result = pc.compare_lists_with_epsilon_ignore_dict_order(a=query_res, b=exp_res)
|
debug_mode = check_items.get("debug_mode", False)
|
||||||
if result is False:
|
if debug_mode is True:
|
||||||
# Only for debug, compare the result with deepdiff
|
assert pc.compare_lists_with_epsilon_ignore_dict_order_deepdiff(a=query_res, b=exp_res)
|
||||||
pc.compare_lists_with_epsilon_ignore_dict_order_deepdiff(a=query_res, b=exp_res)
|
else:
|
||||||
assert result
|
assert pc.compare_lists_with_epsilon_ignore_dict_order(a=query_res, b=exp_res), \
|
||||||
return result
|
f"there exists different values between query_results and expected_results, " \
|
||||||
|
f"use debug_mode in check_items to print the difference entity by entity(but it is slow)"
|
||||||
else:
|
else:
|
||||||
log.error(f"Query result {query_res} is not list")
|
log.error(f"Query result {query_res} is not list")
|
||||||
return False
|
return False
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -344,8 +344,8 @@ class TestMilvusClientSearchBasicV2(TestMilvusClientV2Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
# @pytest.mark.parametrize("limit, nq", zip([1, 1000, ct.max_limit], [ct.max_nq, 10, 1]))
|
@pytest.mark.parametrize("limit, nq", zip([1, 1000, ct.max_limit], [ct.max_nq, 10, 1]))
|
||||||
@pytest.mark.parametrize("limit, nq", zip([ct.max_limit], [1]))
|
# @pytest.mark.parametrize("limit, nq", zip([ct.max_limit], [1]))
|
||||||
def test_search_with_different_nq_limits(self, limit, nq):
|
def test_search_with_different_nq_limits(self, limit, nq):
|
||||||
"""
|
"""
|
||||||
target: test search with different nq and limit values
|
target: test search with different nq and limit values
|
||||||
|
|||||||
@ -28,8 +28,8 @@ pytest-parallel
|
|||||||
pytest-random-order
|
pytest-random-order
|
||||||
|
|
||||||
# pymilvus
|
# pymilvus
|
||||||
pymilvus==2.7.0rc72
|
pymilvus==2.7.0rc75
|
||||||
pymilvus[bulk_writer]==2.7.0rc72
|
pymilvus[bulk_writer]==2.7.0rc75
|
||||||
# for protobuf
|
# for protobuf
|
||||||
protobuf>=5.29.5
|
protobuf>=5.29.5
|
||||||
|
|
||||||
|
|||||||
@ -254,10 +254,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
|
|||||||
assert r[0]['insert_count'] == step
|
assert r[0]['insert_count'] == step
|
||||||
|
|
||||||
# flush
|
# flush
|
||||||
# TODO: call async flush() as https://github.com/milvus-io/pymilvus/issues/3060 fixed
|
await self.async_milvus_client_wrap.flush(c_name)
|
||||||
# await self.async_milvus_client_wrap.flush(c_name)
|
|
||||||
milvus_client = self._client()
|
|
||||||
self.flush(milvus_client, c_name)
|
|
||||||
stats, _ = await self.async_milvus_client_wrap.get_collection_stats(c_name)
|
stats, _ = await self.async_milvus_client_wrap.get_collection_stats(c_name)
|
||||||
assert stats["row_count"] == async_default_nb
|
assert stats["row_count"] == async_default_nb
|
||||||
|
|
||||||
|
|||||||
@ -60,8 +60,7 @@ class TestAsyncMilvusClientIndexInvalid(TestMilvusClientV2Base):
|
|||||||
index_params = async_client.prepare_index_params()[0]
|
index_params = async_client.prepare_index_params()[0]
|
||||||
index_params.add_index(field_name="vector")
|
index_params.add_index(field_name="vector")
|
||||||
# 3. create index
|
# 3. create index
|
||||||
error = {ct.err_code: 1100, ct.err_msg: f"Invalid collection name: {name}. the first character of a collection "
|
error = {ct.err_code: 1100, ct.err_msg: f"collection not found[database=default][collection={name}]"}
|
||||||
f"name must be an underscore or letter: invalid parameter"}
|
|
||||||
await async_client.create_index(name, index_params,
|
await async_client.create_index(name, index_params,
|
||||||
check_task=CheckTasks.err_res,
|
check_task=CheckTasks.err_res,
|
||||||
check_items=error)
|
check_items=error)
|
||||||
@ -88,8 +87,7 @@ class TestAsyncMilvusClientIndexInvalid(TestMilvusClientV2Base):
|
|||||||
index_params = async_client.prepare_index_params()[0]
|
index_params = async_client.prepare_index_params()[0]
|
||||||
index_params.add_index(field_name="vector")
|
index_params.add_index(field_name="vector")
|
||||||
# 3. create index
|
# 3. create index
|
||||||
error = {ct.err_code: 1100, ct.err_msg: f"Invalid collection name: {name}. the length of a collection name "
|
error = {ct.err_code: 1100, ct.err_msg: f"collection not found[database=default][collection={name}]"}
|
||||||
f"must be less than 255 characters: invalid parameter"}
|
|
||||||
await async_client.create_index(name, index_params,
|
await async_client.create_index(name, index_params,
|
||||||
check_task=CheckTasks.err_res,
|
check_task=CheckTasks.err_res,
|
||||||
check_items=error)
|
check_items=error)
|
||||||
@ -117,7 +115,7 @@ class TestAsyncMilvusClientIndexInvalid(TestMilvusClientV2Base):
|
|||||||
index_params.add_index(field_name="vector")
|
index_params.add_index(field_name="vector")
|
||||||
# 3. create index
|
# 3. create index
|
||||||
error = {ct.err_code: 100,
|
error = {ct.err_code: 100,
|
||||||
ct.err_msg: f"can't find collection[database=default][collection={not_existed_collection_name}]"}
|
ct.err_msg: f"collection not found[database=default][collection={not_existed_collection_name}]"}
|
||||||
await async_client.create_index(not_existed_collection_name, index_params,
|
await async_client.create_index(not_existed_collection_name, index_params,
|
||||||
check_task=CheckTasks.err_res,
|
check_task=CheckTasks.err_res,
|
||||||
check_items=error)
|
check_items=error)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user