test: Update hybrid search tests to milvus client style (#45772)

related issue: #45326

---------

Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
yanliang567 2025-11-24 17:55:07 +08:00 committed by GitHub
parent 5efb0cedc8
commit 1da75c0ee2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 918 additions and 1004 deletions

View File

@ -173,7 +173,6 @@ class TestcaseBase(Base):
log.info(f"server version: {server_version}") log.info(f"server version: {server_version}")
return res return res
def get_tokens_by_analyzer(self, text, analyzer_params): def get_tokens_by_analyzer(self, text, analyzer_params):
if cf.param_info.param_uri: if cf.param_info.param_uri:
uri = cf.param_info.param_uri uri = cf.param_info.param_uri

View File

@ -175,19 +175,6 @@ class TestMilvusClientV2Base(Base):
**kwargs).run() **kwargs).run()
return res, check_result return res, check_result
@trace()
def hybrid_search(self, client, collection_name, reqs, ranker, limit=10, output_fields=None, timeout=None,
check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
output_fields], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
output_fields=output_fields, **kwargs).run()
return res, check_result
@trace() @trace()
def search_iterator(self, client, collection_name, data, batch_size, limit=-1, filter=None, output_fields=None, def search_iterator(self, client, collection_name, data, batch_size, limit=-1, filter=None, output_fields=None,
@ -210,16 +197,16 @@ class TestMilvusClientV2Base(Base):
return res, check_result return res, check_result
@trace() @trace()
def hybrid_search(self, client, collection_name, reqs, rerank, limit=10, def hybrid_search(self, client, collection_name, reqs, ranker, limit=10,
output_fields=None, timeout=None, partition_names=None, output_fields=None, timeout=None, partition_names=None,
check_task=None, check_items=None, **kwargs): check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout timeout = TIMEOUT if timeout is None else timeout
# kwargs.update({"timeout": timeout}) # kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name func_name = sys._getframe().f_code.co_name
res, check = api_request([client.hybrid_search, collection_name, reqs, rerank, limit, res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
output_fields, timeout, partition_names], **kwargs) output_fields, timeout, partition_names], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check, check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, reqs=reqs, rerank=rerank, limit=limit, collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
output_fields=output_fields, timeout=timeout, partition_names=partition_names, **kwargs).run() output_fields=output_fields, timeout=timeout, partition_names=partition_names, **kwargs).run()
return res, check_result return res, check_result
@ -332,15 +319,6 @@ class TestMilvusClientV2Base(Base):
self.tear_down_collection_names.remove(collection_name) self.tear_down_collection_names.remove(collection_name)
return res, check_result return res, check_result
@trace()
def list_partitions(self, client, collection_name, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.list_partitions, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name,
**kwargs).run()
return res, check_result
@trace() @trace()
def list_indexes(self, client, collection_name, field_name=None, check_task=None, check_items=None, **kwargs): def list_indexes(self, client, collection_name, field_name=None, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name func_name = sys._getframe().f_code.co_name
@ -359,16 +337,6 @@ class TestMilvusClientV2Base(Base):
**kwargs).run() **kwargs).run()
return res, check_result return res, check_result
@trace()
def prepare_index_params(self, client, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.prepare_index_params], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
**kwargs).run()
return res, check_result
@trace() @trace()
def load_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs): def load_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
@ -803,19 +771,6 @@ class TestMilvusClientV2Base(Base):
object_name=object_name, db_name=db_name, **kwargs).run() object_name=object_name, db_name=db_name, **kwargs).run()
return res, check_result return res, check_result
@trace()
def grant_privilege_v2(self, client, role_name, privilege, collection_name, db_name=None,
timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.grant_privilege_v2, role_name, privilege, collection_name,
db_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
role_name=role_name, privilege=privilege,
collection_name=collection_name, db_name=db_name, **kwargs).run()
return res, check_result
@trace() @trace()
def revoke_privilege(self, client, role_name, object_type, privilege, object_name, db_name="", def revoke_privilege(self, client, role_name, object_type, privilege, object_name, db_name="",
timeout=None, check_task=None, check_items=None, **kwargs): timeout=None, check_task=None, check_items=None, **kwargs):
@ -1082,7 +1037,7 @@ class TestMilvusClientV2Base(Base):
kwargs.update({"timeout": timeout}) kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name func_name = sys._getframe().f_code.co_name
res, check = api_request([client.update_resource_groups, name], **kwargs) res, check = api_request([client.drop_resource_group, name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check, check_result = ResponseChecker(res, func_name, check_task, check_items, check,
name=name, **kwargs).run() name=name, **kwargs).run()
return res, check_result return res, check_result

View File

@ -454,8 +454,8 @@ class ResponseChecker:
else: else:
ids = list(hits.ids) ids = list(hits.ids)
distances = list(hits.distances) distances = list(hits.distances)
if (len(hits) != check_items["limit"]) \ if check_items.get("limit", None) is not None \
or (len(ids) != check_items["limit"]): and ((len(hits) != check_items["limit"]) or (len(ids) != check_items["limit"])):
log.error("search_results_check: limit(topK) searched (%d) " log.error("search_results_check: limit(topK) searched (%d) "
"is not equal with expected (%d)" "is not equal with expected (%d)"
% (len(hits), check_items["limit"])) % (len(hits), check_items["limit"]))

View File

@ -431,20 +431,23 @@ def output_field_value_check(search_res, original, pk_name):
:return: True or False :return: True or False
""" """
pk_name = ct.default_primary_field_name if pk_name is None else pk_name pk_name = ct.default_primary_field_name if pk_name is None else pk_name
nq = len(search_res)
limit = len(search_res[0]) limit = len(search_res[0])
for i in range(limit): check_nqs = min(2, nq) # the output field values are wrong only at nq>=2 #45338
entity = search_res[0][i].fields for n in range(check_nqs):
_id = search_res[0][i].id for i in range(limit):
for field in entity.keys(): entity = search_res[n][i].fields
if isinstance(entity[field], list): _id = search_res[n][i].id
for order in range(0, len(entity[field]), 4): for field in entity.keys():
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon if isinstance(entity[field], list):
elif isinstance(entity[field], dict) and field != ct.default_json_field_name: for order in range(0, len(entity[field]), 4):
# sparse checking, sparse vector must be the last, this is a bit hacky, assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
# but sparse only supports list data type insertion for now elif isinstance(entity[field], dict) and field != ct.default_json_field_name:
assert entity[field].keys() == original[-1][_id].keys() # sparse checking, sparse vector must be the last, this is a bit hacky,
else: # but sparse only supports list data type insertion for now
num = original[original[pk_name] == _id].index.to_list()[0] assert entity[field].keys() == original[-1][_id].keys()
assert original[field][num] == entity[field] else:
num = original[original[pk_name] == _id].index.to_list()[0]
assert original[field][num] == entity[field], f"the output field values are wrong at nq={n}"
return True return True

View File

@ -1,19 +1,16 @@
import logging import logging
import numpy as np
from common.constants import *
from utils.util_pymilvus import * from utils.util_pymilvus import *
from common.common_type import CaseLabel, CheckTasks from common.common_type import CaseLabel, CheckTasks
from common import common_type as ct from common import common_type as ct
from common import common_func as cf from common import common_func as cf
from utils.util_log import test_log as log from utils.util_log import test_log as log
from base.client_v2_base import TestMilvusClientV2Base from base.client_v2_base import TestMilvusClientV2Base
from base.client_base import TestcaseBase
import random import random
import pytest import pytest
import pandas as pd import pandas as pd
from faker import Faker from faker import Faker
import inspect
Faker.seed(19530) Faker.seed(19530)
fake_en = Faker("en_US") fake_en = Faker("en_US")

View File

@ -218,7 +218,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@pytest.mark.tags(CaseLabel.L0) @pytest.mark.tags(CaseLabel.L0)
async def test_async_client_with_schema(self, schema): async def test_async_client_with_schema(self):
# init async client # init async client
pk_field_name = "id" pk_field_name = "id"
self.init_async_milvus_client() self.init_async_milvus_client()
@ -390,7 +390,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
await self.async_milvus_client_wrap.create_database(db_name) await self.async_milvus_client_wrap.create_database(db_name)
await self.async_milvus_client_wrap.close() await self.async_milvus_client_wrap.close()
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}" uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
self.async_milvus_client_wrap.init_async_client(uri, db_name=db_name) self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token, db_name=db_name)
# create collection # create collection
c_name = cf.gen_unique_str(prefix) c_name = cf.gen_unique_str(prefix)
@ -450,7 +450,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
async def test_async_client_close(self): async def test_async_client_close(self):
# init async client # init async client
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}" uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
self.async_milvus_client_wrap.init_async_client(uri) self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token)
# create collection # create collection
c_name = cf.gen_unique_str(prefix) c_name = cf.gen_unique_str(prefix)