mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 09:08:43 +08:00
test: Update hybrid search tests to milvus client style (#45772)
related issue: #45326 --------- Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
parent
5efb0cedc8
commit
1da75c0ee2
@ -173,7 +173,6 @@ class TestcaseBase(Base):
|
||||
log.info(f"server version: {server_version}")
|
||||
return res
|
||||
|
||||
|
||||
def get_tokens_by_analyzer(self, text, analyzer_params):
|
||||
if cf.param_info.param_uri:
|
||||
uri = cf.param_info.param_uri
|
||||
|
||||
@ -175,19 +175,6 @@ class TestMilvusClientV2Base(Base):
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def hybrid_search(self, client, collection_name, reqs, ranker, limit=10, output_fields=None, timeout=None,
|
||||
check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
|
||||
output_fields], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
|
||||
output_fields=output_fields, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def search_iterator(self, client, collection_name, data, batch_size, limit=-1, filter=None, output_fields=None,
|
||||
@ -210,16 +197,16 @@ class TestMilvusClientV2Base(Base):
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def hybrid_search(self, client, collection_name, reqs, rerank, limit=10,
|
||||
def hybrid_search(self, client, collection_name, reqs, ranker, limit=10,
|
||||
output_fields=None, timeout=None, partition_names=None,
|
||||
check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
# kwargs.update({"timeout": timeout})
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.hybrid_search, collection_name, reqs, rerank, limit,
|
||||
res, check = api_request([client.hybrid_search, collection_name, reqs, ranker, limit,
|
||||
output_fields, timeout, partition_names], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, reqs=reqs, rerank=rerank, limit=limit,
|
||||
collection_name=collection_name, reqs=reqs, ranker=ranker, limit=limit,
|
||||
output_fields=output_fields, timeout=timeout, partition_names=partition_names, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@ -332,15 +319,6 @@ class TestMilvusClientV2Base(Base):
|
||||
self.tear_down_collection_names.remove(collection_name)
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def list_partitions(self, client, collection_name, check_task=None, check_items=None, **kwargs):
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.list_partitions, collection_name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def list_indexes(self, client, collection_name, field_name=None, check_task=None, check_items=None, **kwargs):
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
@ -359,16 +337,6 @@ class TestMilvusClientV2Base(Base):
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def prepare_index_params(self, client, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.prepare_index_params], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def load_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
@ -803,19 +771,6 @@ class TestMilvusClientV2Base(Base):
|
||||
object_name=object_name, db_name=db_name, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def grant_privilege_v2(self, client, role_name, privilege, collection_name, db_name=None,
|
||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.grant_privilege_v2, role_name, privilege, collection_name,
|
||||
db_name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
role_name=role_name, privilege=privilege,
|
||||
collection_name=collection_name, db_name=db_name, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def revoke_privilege(self, client, role_name, object_type, privilege, object_name, db_name="",
|
||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
@ -1082,7 +1037,7 @@ class TestMilvusClientV2Base(Base):
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.update_resource_groups, name], **kwargs)
|
||||
res, check = api_request([client.drop_resource_group, name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
name=name, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@ -454,8 +454,8 @@ class ResponseChecker:
|
||||
else:
|
||||
ids = list(hits.ids)
|
||||
distances = list(hits.distances)
|
||||
if (len(hits) != check_items["limit"]) \
|
||||
or (len(ids) != check_items["limit"]):
|
||||
if check_items.get("limit", None) is not None \
|
||||
and ((len(hits) != check_items["limit"]) or (len(ids) != check_items["limit"])):
|
||||
log.error("search_results_check: limit(topK) searched (%d) "
|
||||
"is not equal with expected (%d)"
|
||||
% (len(hits), check_items["limit"]))
|
||||
|
||||
@ -431,20 +431,23 @@ def output_field_value_check(search_res, original, pk_name):
|
||||
:return: True or False
|
||||
"""
|
||||
pk_name = ct.default_primary_field_name if pk_name is None else pk_name
|
||||
nq = len(search_res)
|
||||
limit = len(search_res[0])
|
||||
for i in range(limit):
|
||||
entity = search_res[0][i].fields
|
||||
_id = search_res[0][i].id
|
||||
for field in entity.keys():
|
||||
if isinstance(entity[field], list):
|
||||
for order in range(0, len(entity[field]), 4):
|
||||
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
|
||||
elif isinstance(entity[field], dict) and field != ct.default_json_field_name:
|
||||
# sparse checking, sparse vector must be the last, this is a bit hacky,
|
||||
# but sparse only supports list data type insertion for now
|
||||
assert entity[field].keys() == original[-1][_id].keys()
|
||||
else:
|
||||
num = original[original[pk_name] == _id].index.to_list()[0]
|
||||
assert original[field][num] == entity[field]
|
||||
check_nqs = min(2, nq) # the output field values are wrong only at nq>=2 #45338
|
||||
for n in range(check_nqs):
|
||||
for i in range(limit):
|
||||
entity = search_res[n][i].fields
|
||||
_id = search_res[n][i].id
|
||||
for field in entity.keys():
|
||||
if isinstance(entity[field], list):
|
||||
for order in range(0, len(entity[field]), 4):
|
||||
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
|
||||
elif isinstance(entity[field], dict) and field != ct.default_json_field_name:
|
||||
# sparse checking, sparse vector must be the last, this is a bit hacky,
|
||||
# but sparse only supports list data type insertion for now
|
||||
assert entity[field].keys() == original[-1][_id].keys()
|
||||
else:
|
||||
num = original[original[pk_name] == _id].index.to_list()[0]
|
||||
assert original[field][num] == entity[field], f"the output field values are wrong at nq={n}"
|
||||
|
||||
return True
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,19 +1,16 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from common.constants import *
|
||||
|
||||
from utils.util_pymilvus import *
|
||||
from common.common_type import CaseLabel, CheckTasks
|
||||
from common import common_type as ct
|
||||
from common import common_func as cf
|
||||
from utils.util_log import test_log as log
|
||||
from base.client_v2_base import TestMilvusClientV2Base
|
||||
from base.client_base import TestcaseBase
|
||||
import random
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from faker import Faker
|
||||
import inspect
|
||||
|
||||
Faker.seed(19530)
|
||||
fake_en = Faker("en_US")
|
||||
|
||||
@ -218,7 +218,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
async def test_async_client_with_schema(self, schema):
|
||||
async def test_async_client_with_schema(self):
|
||||
# init async client
|
||||
pk_field_name = "id"
|
||||
self.init_async_milvus_client()
|
||||
@ -390,7 +390,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
|
||||
await self.async_milvus_client_wrap.create_database(db_name)
|
||||
await self.async_milvus_client_wrap.close()
|
||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||
self.async_milvus_client_wrap.init_async_client(uri, db_name=db_name)
|
||||
self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token, db_name=db_name)
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
@ -450,7 +450,7 @@ class TestAsyncMilvusClient(TestMilvusClientV2Base):
|
||||
async def test_async_client_close(self):
|
||||
# init async client
|
||||
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
|
||||
self.async_milvus_client_wrap.init_async_client(uri)
|
||||
self.async_milvus_client_wrap.init_async_client(uri, token=cf.param_info.param_token)
|
||||
|
||||
# create collection
|
||||
c_name = cf.gen_unique_str(prefix)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user