test: Update hybrid search tests to milvus client style (#45772)

related issue: #45326

---------

Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
yanliang567 2025-11-24 17:55:07 +08:00 committed by GitHub
parent 5efb0cedc8
commit 1da75c0ee2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 918 additions and 1004 deletions

View File

@ -173,7 +173,6 @@ class TestcaseBase(Base):
log.info(f"server version: {server_version}")
return res
def get_tokens_by_analyzer(self, text, analyzer_params):
if cf.param_info.param_uri:
uri = cf.param_info.param_uri

View File

@ -175,19 +175,6 @@ class TestMilvusClientV2Base(Base):
**kwargs).run()
return res, check_result
@trace()
def hybrid_search(self, client, collection_name, reqs, ranker, limit=10, output_fields=None, timeout=None,
check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
output_fields], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
output_fields=output_fields, **kwargs).run()
return res, check_result
@trace()
def search_iterator(self, client, collection_name, data, batch_size, limit=-1, filter=None, output_fields=None,
@ -210,16 +197,16 @@ class TestMilvusClientV2Base(Base):
return res, check_result
@trace()
def hybrid_search(self, client, collection_name, reqs, rerank, limit=10,
def hybrid_search(self, client, collection_name, reqs, ranker, limit=10,
output_fields=None, timeout=None, partition_names=None,
check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
# kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.hybrid_search, collection_name, reqs, rerank, limit,
res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
output_fields, timeout, partition_names], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, reqs=reqs, rerank=rerank, limit=limit,
collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
output_fields=output_fields, timeout=timeout, partition_names=partition_names, **kwargs).run()
return res, check_result
@ -332,15 +319,6 @@ class TestMilvusClientV2Base(Base):
self.tear_down_collection_names.remove(collection_name)
return res, check_result
@trace()
def list_partitions(self, client, collection_name, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.list_partitions, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name,
**kwargs).run()
return res, check_result
@trace()
def list_indexes(self, client, collection_name, field_name=None, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
@ -359,16 +337,6 @@ class TestMilvusClientV2Base(Base):
**kwargs).run()
return res, check_result
@trace()
def prepare_index_params(self, client, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.prepare_index_params], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
**kwargs).run()
return res, check_result
@trace()
def load_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
@ -803,19 +771,6 @@ class TestMilvusClientV2Base(Base):
object_name=object_name, db_name=db_name, **kwargs).run()
return res, check_result
@trace()
def grant_privilege_v2(self, client, role_name, privilege, collection_name, db_name=None,
timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.grant_privilege_v2, role_name, privilege, collection_name,
db_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
role_name=role_name, privilege=privilege,
collection_name=collection_name, db_name=db_name, **kwargs).run()
return res, check_result
@trace()
def revoke_privilege(self, client, role_name, object_type, privilege, object_name, db_name="",
timeout=None, check_task=None, check_items=None, **kwargs):
@ -1082,7 +1037,7 @@ class TestMilvusClientV2Base(Base):
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.update_resource_groups, name], **kwargs)
res, check = api_request([client.drop_resource_group, name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
name=name, **kwargs).run()
return res, check_result

View File

@ -454,8 +454,8 @@ class ResponseChecker:
else:
ids = list(hits.ids)
distances = list(hits.distances)
if (len(hits) != check_items["limit"]) \
or (len(ids) != check_items["limit"]):
if check_items.get("limit", None) is not None \
and ((len(hits) != check_items["limit"]) or (len(ids) != check_items["limit"])):
log.error("search_results_check: limit(topK) searched (%d) "
"is not equal with expected (%d)"
% (len(hits), check_items["limit"]))

View File

@ -431,20 +431,23 @@ def output_field_value_check(search_res, original, pk_name):
:return: True or False
"""
pk_name = ct.default_primary_field_name if pk_name is None else pk_name
nq = len(search_res)
limit = len(search_res[0])
for i in range(limit):
entity = search_res[0][i].fields
_id = search_res[0][i].id
for field in entity.keys():
if isinstance(entity[field], list):
for order in range(0, len(entity[field]), 4):
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
elif isinstance(entity[field], dict) and field != ct.default_json_field_name:
# sparse checking, sparse vector must be the last, this is a bit hacky,
# but sparse only supports list data type insertion for now
assert entity[field].keys() == original[-1][_id].keys()
else:
num = original[original[pk_name] == _id].index.to_list()[0]
assert original[field][num] == entity[field]
check_nqs = min(2, nq) # the output field values are wrong only at nq>=2 #45338
for n in range(check_nqs):
for i in range(limit):
entity = search_res[n][i].fields
_id = search_res[n][i].id
for field in entity.keys():
if isinstance(entity[field], list):
for order in range(0, len(entity[field]), 4):
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
elif isinstance(entity[field], dict) and field != ct.default_json_field_name:
# sparse checking, sparse vector must be the last, this is a bit hacky,
# but sparse only supports list data type insertion for now
assert entity[field].keys() == original[-1][_id].keys()
else:
num = original[original[pk_name] == _id].index.to_list()[0]
assert original[field][num] == entity[field], f"the output field values are wrong at nq={n}"
return True

View File

@ -1,19 +1,16 @@
import logging
import numpy as np
from common.constants import *
from utils.util_pymilvus import *
from common.common_type import CaseLabel, CheckTasks
from common import common_type as ct
from common import common_func as cf
from utils.util_log import test_log as log
from base.client_v2_base import TestMilvusClientV2Base
from base.client_base import TestcaseBase
import random
import pytest
import pandas as pd
from faker import Faker
import inspect
Faker.seed(19530)
fake_en = Faker("en_US")

View File

@ -218,7 +218,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
await asyncio.gather(*tasks)
@pytest.mark.tags(CaseLabel.L0)
async def test_async_client_with_schema(self, schema):
async def test_async_client_with_schema(self):
# init async client
pk_field_name = "id"
self.init_async_milvus_client()
@ -390,7 +390,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
await self.async_milvus_client_wrap.create_database(db_name)
await self.async_milvus_client_wrap.close()
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
self.async_milvus_client_wrap.init_async_client(uri, db_name=db_name)
self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token, db_name=db_name)
# create collection
c_name = cf.gen_unique_str(prefix)
@ -450,7 +450,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
async def test_async_client_close(self):
# init async client
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
self.async_milvus_client_wrap.init_async_client(uri)
self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token)
# create collection
c_name = cf.gen_unique_str(prefix)