mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
test: Update hybrid search tests to milvus client style (#45772)
related issue: #45326 --------- Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
parent
5efb0cedc8
commit
1da75c0ee2
@ -173,7 +173,6 @@ class TestcaseBase(Base):
|
|||||||
log.info(f"server version: {server_version}")
|
log.info(f"server version: {server_version}")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_by_analyzer(self, text, analyzer_params):
|
def get_tokens_by_analyzer(self, text, analyzer_params):
|
||||||
if cf.param_info.param_uri:
|
if cf.param_info.param_uri:
|
||||||
uri = cf.param_info.param_uri
|
uri = cf.param_info.param_uri
|
||||||
|
|||||||
@ -175,19 +175,6 @@ class TestMilvusClientV2Base(Base):
|
|||||||
**kwargs).run()
|
**kwargs).run()
|
||||||
return res, check_result
|
return res, check_result
|
||||||
|
|
||||||
@trace()
|
|
||||||
def hybrid_search(self, client, collection_name, reqs, ranker, limit=10, output_fields=None, timeout=None,
|
|
||||||
check_task=None, check_items=None, **kwargs):
|
|
||||||
timeout = TIMEOUT if timeout is None else timeout
|
|
||||||
kwargs.update({"timeout": timeout})
|
|
||||||
|
|
||||||
func_name = sys._getframe().f_code.co_name
|
|
||||||
res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
|
|
||||||
output_fields], **kwargs)
|
|
||||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
|
||||||
collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
|
|
||||||
output_fields=output_fields, **kwargs).run()
|
|
||||||
return res, check_result
|
|
||||||
|
|
||||||
@trace()
|
@trace()
|
||||||
def search_iterator(self, client, collection_name, data, batch_size, limit=-1, filter=None, output_fields=None,
|
def search_iterator(self, client, collection_name, data, batch_size, limit=-1, filter=None, output_fields=None,
|
||||||
@ -210,16 +197,16 @@ class TestMilvusClientV2Base(Base):
|
|||||||
return res, check_result
|
return res, check_result
|
||||||
|
|
||||||
@trace()
|
@trace()
|
||||||
def hybrid_search(self, client, collection_name, reqs, rerank, limit=10,
|
def hybrid_search(self, client, collection_name, reqs, ranker, limit=10,
|
||||||
output_fields=None, timeout=None, partition_names=None,
|
output_fields=None, timeout=None, partition_names=None,
|
||||||
check_task=None, check_items=None, **kwargs):
|
check_task=None, check_items=None, **kwargs):
|
||||||
timeout = TIMEOUT if timeout is None else timeout
|
timeout = TIMEOUT if timeout is None else timeout
|
||||||
# kwargs.update({"timeout": timeout})
|
# kwargs.update({"timeout": timeout})
|
||||||
func_name = sys._getframe().f_code.co_name
|
func_name = sys._getframe().f_code.co_name
|
||||||
res, check = api_request([client.hybrid_search, collection_name, reqs, rerank, limit,
|
res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
|
||||||
output_fields, timeout, partition_names], **kwargs)
|
output_fields, timeout, partition_names], **kwargs)
|
||||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||||
collection_name=collection_name, reqs=reqs, rerank=rerank, limit=limit,
|
collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
|
||||||
output_fields=output_fields, timeout=timeout, partition_names=partition_names, **kwargs).run()
|
output_fields=output_fields, timeout=timeout, partition_names=partition_names, **kwargs).run()
|
||||||
return res, check_result
|
return res, check_result
|
||||||
|
|
||||||
@ -332,15 +319,6 @@ class TestMilvusClientV2Base(Base):
|
|||||||
self.tear_down_collection_names.remove(collection_name)
|
self.tear_down_collection_names.remove(collection_name)
|
||||||
return res, check_result
|
return res, check_result
|
||||||
|
|
||||||
@trace()
|
|
||||||
def list_partitions(self, client, collection_name, check_task=None, check_items=None, **kwargs):
|
|
||||||
func_name = sys._getframe().f_code.co_name
|
|
||||||
res, check = api_request([client.list_partitions, collection_name], **kwargs)
|
|
||||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
|
||||||
collection_name=collection_name,
|
|
||||||
**kwargs).run()
|
|
||||||
return res, check_result
|
|
||||||
|
|
||||||
@trace()
|
@trace()
|
||||||
def list_indexes(self, client, collection_name, field_name=None, check_task=None, check_items=None, **kwargs):
|
def list_indexes(self, client, collection_name, field_name=None, check_task=None, check_items=None, **kwargs):
|
||||||
func_name = sys._getframe().f_code.co_name
|
func_name = sys._getframe().f_code.co_name
|
||||||
@ -359,16 +337,6 @@ class TestMilvusClientV2Base(Base):
|
|||||||
**kwargs).run()
|
**kwargs).run()
|
||||||
return res, check_result
|
return res, check_result
|
||||||
|
|
||||||
@trace()
|
|
||||||
def prepare_index_params(self, client, timeout=None, check_task=None, check_items=None, **kwargs):
|
|
||||||
timeout = TIMEOUT if timeout is None else timeout
|
|
||||||
kwargs.update({"timeout": timeout})
|
|
||||||
|
|
||||||
func_name = sys._getframe().f_code.co_name
|
|
||||||
res, check = api_request([client.prepare_index_params], **kwargs)
|
|
||||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
|
||||||
**kwargs).run()
|
|
||||||
return res, check_result
|
|
||||||
|
|
||||||
@trace()
|
@trace()
|
||||||
def load_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
def load_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||||
@ -803,19 +771,6 @@ class TestMilvusClientV2Base(Base):
|
|||||||
object_name=object_name, db_name=db_name, **kwargs).run()
|
object_name=object_name, db_name=db_name, **kwargs).run()
|
||||||
return res, check_result
|
return res, check_result
|
||||||
|
|
||||||
@trace()
|
|
||||||
def grant_privilege_v2(self, client, role_name, privilege, collection_name, db_name=None,
|
|
||||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
|
||||||
timeout = TIMEOUT if timeout is None else timeout
|
|
||||||
kwargs.update({"timeout": timeout})
|
|
||||||
func_name = sys._getframe().f_code.co_name
|
|
||||||
res, check = api_request([client.grant_privilege_v2, role_name, privilege, collection_name,
|
|
||||||
db_name], **kwargs)
|
|
||||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
|
||||||
role_name=role_name, privilege=privilege,
|
|
||||||
collection_name=collection_name, db_name=db_name, **kwargs).run()
|
|
||||||
return res, check_result
|
|
||||||
|
|
||||||
@trace()
|
@trace()
|
||||||
def revoke_privilege(self, client, role_name, object_type, privilege, object_name, db_name="",
|
def revoke_privilege(self, client, role_name, object_type, privilege, object_name, db_name="",
|
||||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
timeout=None, check_task=None, check_items=None, **kwargs):
|
||||||
@ -1082,7 +1037,7 @@ class TestMilvusClientV2Base(Base):
|
|||||||
kwargs.update({"timeout": timeout})
|
kwargs.update({"timeout": timeout})
|
||||||
|
|
||||||
func_name = sys._getframe().f_code.co_name
|
func_name = sys._getframe().f_code.co_name
|
||||||
res, check = api_request([client.update_resource_groups, name], **kwargs)
|
res, check = api_request([client.drop_resource_group, name], **kwargs)
|
||||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||||
name=name, **kwargs).run()
|
name=name, **kwargs).run()
|
||||||
return res, check_result
|
return res, check_result
|
||||||
|
|||||||
@ -454,8 +454,8 @@ class ResponseChecker:
|
|||||||
else:
|
else:
|
||||||
ids = list(hits.ids)
|
ids = list(hits.ids)
|
||||||
distances = list(hits.distances)
|
distances = list(hits.distances)
|
||||||
if (len(hits) != check_items["limit"]) \
|
if check_items.get("limit", None) is not None \
|
||||||
or (len(ids) != check_items["limit"]):
|
and ((len(hits) != check_items["limit"]) or (len(ids) != check_items["limit"])):
|
||||||
log.error("search_results_check: limit(topK) searched (%d) "
|
log.error("search_results_check: limit(topK) searched (%d) "
|
||||||
"is not equal with expected (%d)"
|
"is not equal with expected (%d)"
|
||||||
% (len(hits), check_items["limit"]))
|
% (len(hits), check_items["limit"]))
|
||||||
|
|||||||
@ -431,10 +431,13 @@ def output_field_value_check(search_res, original, pk_name):
|
|||||||
:return: True or False
|
:return: True or False
|
||||||
"""
|
"""
|
||||||
pk_name = ct.default_primary_field_name if pk_name is None else pk_name
|
pk_name = ct.default_primary_field_name if pk_name is None else pk_name
|
||||||
|
nq = len(search_res)
|
||||||
limit = len(search_res[0])
|
limit = len(search_res[0])
|
||||||
|
check_nqs = min(2, nq) # the output field values are wrong only at nq>=2 #45338
|
||||||
|
for n in range(check_nqs):
|
||||||
for i in range(limit):
|
for i in range(limit):
|
||||||
entity = search_res[0][i].fields
|
entity = search_res[n][i].fields
|
||||||
_id = search_res[0][i].id
|
_id = search_res[n][i].id
|
||||||
for field in entity.keys():
|
for field in entity.keys():
|
||||||
if isinstance(entity[field], list):
|
if isinstance(entity[field], list):
|
||||||
for order in range(0, len(entity[field]), 4):
|
for order in range(0, len(entity[field]), 4):
|
||||||
@ -445,6 +448,6 @@ def output_field_value_check(search_res, original, pk_name):
|
|||||||
assert entity[field].keys() == original[-1][_id].keys()
|
assert entity[field].keys() == original[-1][_id].keys()
|
||||||
else:
|
else:
|
||||||
num = original[original[pk_name] == _id].index.to_list()[0]
|
num = original[original[pk_name] == _id].index.to_list()[0]
|
||||||
assert original[field][num] == entity[field]
|
assert original[field][num] == entity[field], f"the output field values are wrong at nq={n}"
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,19 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from common.constants import *
|
|
||||||
from utils.util_pymilvus import *
|
from utils.util_pymilvus import *
|
||||||
from common.common_type import CaseLabel, CheckTasks
|
from common.common_type import CaseLabel, CheckTasks
|
||||||
from common import common_type as ct
|
from common import common_type as ct
|
||||||
from common import common_func as cf
|
from common import common_func as cf
|
||||||
from utils.util_log import test_log as log
|
from utils.util_log import test_log as log
|
||||||
from base.client_v2_base import TestMilvusClientV2Base
|
from base.client_v2_base import TestMilvusClientV2Base
|
||||||
from base.client_base import TestcaseBase
|
|
||||||
import random
|
import random
|
||||||
import pytest
|
import pytest
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
import inspect
|
|
||||||
|
|
||||||
Faker.seed(19530)
|
Faker.seed(19530)
|
||||||
fake_en = Faker("en_US")
|
fake_en = Faker("en_US")
|
||||||
|
|||||||
@ -218,7 +218,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
|
|||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L0)
|
@pytest.mark.tags(CaseLabel.L0)
|
||||||
async def test_async_client_with_schema(self, schema):
|
async def test_async_client_with_schema(self):
|
||||||
# init async client
|
# init async client
|
||||||
pk_field_name = "id"
|
pk_field_name = "id"
|
||||||
self.init_async_milvus_client()
|
self.init_async_milvus_client()
|
||||||
@ -390,7 +390,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
|
|||||||
await self.async_milvus_client_wrap.create_database(db_name)
|
await self.async_milvus_client_wrap.create_database(db_name)
|
||||||
await self.async_milvus_client_wrap.close()
|
await self.async_milvus_client_wrap.close()
|
||||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||||
self.async_milvus_client_wrap.init_async_client(uri, db_name=db_name)
|
self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token, db_name=db_name)
|
||||||
|
|
||||||
# create collection
|
# create collection
|
||||||
c_name = cf.gen_unique_str(prefix)
|
c_name = cf.gen_unique_str(prefix)
|
||||||
@ -450,7 +450,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
|
|||||||
async def test_async_client_close(self):
|
async def test_async_client_close(self):
|
||||||
# init async client
|
# init async client
|
||||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||||
self.async_milvus_client_wrap.init_async_client(uri)
|
self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token)
|
||||||
|
|
||||||
# create collection
|
# create collection
|
||||||
c_name = cf.gen_unique_str(prefix)
|
c_name = cf.gen_unique_str(prefix)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user