mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
Update high level api test cases (#25118)
Signed-off-by: Binbin Lv <binbin.lv@zilliz.com>
This commit is contained in:
parent
fe24228909
commit
31122a6858
@ -10,6 +10,7 @@ from base.partition_wrapper import ApiPartitionWrapper
|
||||
from base.index_wrapper import ApiIndexWrapper
|
||||
from base.utility_wrapper import ApiUtilityWrapper
|
||||
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
|
||||
from base.high_level_api_wrapper import HighLevelApiWrapper
|
||||
from utils.util_log import test_log as log
|
||||
from common import common_func as cf
|
||||
from common import common_type as ct
|
||||
@ -28,6 +29,7 @@ class Base:
|
||||
field_schema_wrap = None
|
||||
collection_object_list = []
|
||||
resource_group_list = []
|
||||
high_level_api_wrap = None
|
||||
|
||||
def setup_class(self):
|
||||
log.info("[setup_class] Start setup class...")
|
||||
@ -45,6 +47,7 @@ class Base:
|
||||
self.index_wrap = ApiIndexWrapper()
|
||||
self.collection_schema_wrap = ApiCollectionSchemaWrapper()
|
||||
self.field_schema_wrap = ApiFieldSchemaWrapper()
|
||||
self.high_level_api_wrap = HighLevelApiWrapper()
|
||||
|
||||
def teardown_method(self, method):
|
||||
log.info(("*" * 35) + " teardown " + ("*" * 35))
|
||||
@ -118,18 +121,28 @@ class TestcaseBase(Base):
|
||||
Public methods that can be used for test cases.
|
||||
"""
|
||||
|
||||
def _connect(self):
|
||||
def _connect(self, enable_high_level_api=False):
|
||||
""" Add a connection and create the connect """
|
||||
if cf.param_info.param_user and cf.param_info.param_password:
|
||||
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
|
||||
host=cf.param_info.param_host,
|
||||
port=cf.param_info.param_port, user=cf.param_info.param_user,
|
||||
password=cf.param_info.param_password,
|
||||
secure=cf.param_info.param_secure)
|
||||
if enable_high_level_api:
|
||||
if cf.param_info.param_uri:
|
||||
uri = cf.param_info.param_uri
|
||||
else:
|
||||
uri = "http://" + cf.param_info.param_host + ":" + str(cf.param_info.param_port)
|
||||
res, is_succ = self.connection_wrap.MilvusClient(uri=uri,
|
||||
token=cf.param_info.param_token)
|
||||
else:
|
||||
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
|
||||
host=cf.param_info.param_host,
|
||||
port=cf.param_info.param_port)
|
||||
if cf.param_info.param_user and cf.param_info.param_password:
|
||||
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
|
||||
host=cf.param_info.param_host,
|
||||
port=cf.param_info.param_port,
|
||||
user=cf.param_info.param_user,
|
||||
password=cf.param_info.param_password,
|
||||
secure=cf.param_info.param_secure)
|
||||
else:
|
||||
res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,
|
||||
host=cf.param_info.param_host,
|
||||
port=cf.param_info.param_port)
|
||||
|
||||
return res
|
||||
|
||||
def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
|
||||
|
||||
@ -330,6 +330,7 @@ class ApiCollectionWrapper:
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def get_compaction_state(self, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
@ -337,6 +338,7 @@ class ApiCollectionWrapper:
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def get_compaction_plans(self, timeout=None, check_task=None, check_items={}, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
@ -350,6 +352,7 @@ class ApiCollectionWrapper:
|
||||
# log.debug(res)
|
||||
return res
|
||||
|
||||
@trace()
|
||||
def get_replicas(self, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
@ -357,9 +360,12 @@ class ApiCollectionWrapper:
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def describe(self, timeout=None, check_task=None, check_items=None):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([self.collection.describe, timeout])
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check).run()
|
||||
return res, check_result
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from pymilvus import Connections
|
||||
from pymilvus import DefaultConfig
|
||||
from pymilvus import MilvusClient
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
@ -58,3 +59,10 @@ class ApiConnectionsWrapper:
|
||||
response, is_succ = api_request([self.connection.get_connection_addr, alias])
|
||||
check_result = ResponseChecker(response, func_name, check_task, check_items, is_succ, alias=alias).run()
|
||||
return response, check_result
|
||||
|
||||
# high level api
|
||||
def MilvusClient(self, check_task=None, check_items=None, **kwargs):
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
response, succ = api_request([MilvusClient], **kwargs)
|
||||
check_result = ResponseChecker(response, func_name, check_task, check_items, succ, **kwargs).run()
|
||||
return response, check_result
|
||||
|
||||
163
tests/python_client/base/high_level_api_wrapper.py
Normal file
163
tests/python_client/base/high_level_api_wrapper.py
Normal file
@ -0,0 +1,163 @@
|
||||
import sys
|
||||
import time
|
||||
import timeout_decorator
|
||||
from numpy import NaN
|
||||
|
||||
from pymilvus import Collection
|
||||
|
||||
sys.path.append("..")
|
||||
from check.func_check import ResponseChecker
|
||||
from utils.api_request import api_request
|
||||
from utils.wrapper import trace
|
||||
from utils.util_log import test_log as log
|
||||
from pymilvus.orm.types import CONSISTENCY_STRONG
|
||||
from common.common_func import param_info
|
||||
|
||||
TIMEOUT = 120
|
||||
INDEX_NAME = ""
|
||||
|
||||
|
||||
# keep small timeout for stability tests
|
||||
# TIMEOUT = 5
|
||||
|
||||
|
||||
class HighLevelApiWrapper:
|
||||
|
||||
def __init__(self, active_trace=False):
|
||||
self.active_trace = active_trace
|
||||
|
||||
@trace()
|
||||
def create_collection(self, client, collection_name, dimension, timeout=None, check_task=None,
|
||||
check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.create_collection, collection_name, dimension], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, dimension=dimension,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def insert(self, client, collection_name, data, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.insert, collection_name, data], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, data=data,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None,
|
||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.search, collection_name, data, filter, limit,
|
||||
output_fields, search_params], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, data=data, limit=limit, filter=filter,
|
||||
output_fields=output_fields, search_params=search_params,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def query(self, client, collection_name, filter=None, output_fields=None,
|
||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.query, collection_name, filter, output_fields], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, filter=filter,
|
||||
output_fields=output_fields,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def get(self, client, collection_name, ids, output_fields=None,
|
||||
timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.get, collection_name, ids, output_fields], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, ids=ids,
|
||||
output_fields=output_fields,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def num_entities(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.num_entities, collection_name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def delete(self, client, collection_name, pks, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.delete, collection_name, pks], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name, pks=pks,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def flush(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.flush, collection_name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def describe_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.describe_collection, collection_name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def list_collections(self, client, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.list_collections], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def drop_collection(self, client, collection_name, check_task=None, check_items=None, **kwargs):
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.drop_collection, collection_name], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
collection_name=collection_name,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@ -83,9 +83,14 @@ class ResponseChecker:
|
||||
elif self.check_task == CheckTasks.check_permission_deny:
|
||||
# Collection interface response check
|
||||
result = self.check_permission_deny(self.response, self.succ)
|
||||
|
||||
elif self.check_task == CheckTasks.check_rg_property:
|
||||
# describe resource group interface response check
|
||||
result = self.check_rg_property(self.response, self.func_name, self.check_items)
|
||||
|
||||
elif self.check_task == CheckTasks.check_describe_collection_property:
|
||||
# describe collection interface(high level api) response check
|
||||
result = self.check_describe_collection_property(self.response, self.func_name, self.check_items)
|
||||
|
||||
# Add check_items here if something new need verify
|
||||
|
||||
@ -178,6 +183,48 @@ class ResponseChecker:
|
||||
assert collection.primary_field.name == check_items.get("primary")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_describe_collection_property(res, func_name, check_items):
|
||||
"""
|
||||
According to the check_items to check collection properties of res, which return from func_name
|
||||
:param res: actual response of init collection
|
||||
:type res: Collection
|
||||
|
||||
:param func_name: init collection API
|
||||
:type func_name: str
|
||||
|
||||
:param check_items: which items expected to be checked, including name, schema, num_entities, primary
|
||||
:type check_items: dict, {check_key: expected_value}
|
||||
"""
|
||||
exp_func_name = "describe_collection"
|
||||
if func_name != exp_func_name:
|
||||
log.warning("The function name is {} rather than {}".format(func_name, exp_func_name))
|
||||
if len(check_items) == 0:
|
||||
raise Exception("No expect values found in the check task")
|
||||
if check_items.get("collection_name", None) is not None:
|
||||
assert res["collection_name"] == check_items.get("collection_name")
|
||||
if check_items.get("auto_id", False):
|
||||
assert res["auto_id"] == check_items.get("auto_id")
|
||||
if check_items.get("num_shards", 1):
|
||||
assert res["num_shards"] == check_items.get("num_shards", 1)
|
||||
if check_items.get("consistency_level", 2):
|
||||
assert res["consistency_level"] == check_items.get("consistency_level", 2)
|
||||
if check_items.get("enable_dynamic_field", True):
|
||||
assert res["enable_dynamic_field"] == check_items.get("enable_dynamic_field", True)
|
||||
if check_items.get("num_partitions", 1):
|
||||
assert res["num_partitions"] == check_items.get("num_partitions", 1)
|
||||
if check_items.get("id_name", "id"):
|
||||
assert res["fields"][0]["name"] == check_items.get("id_name", "id")
|
||||
if check_items.get("vector_name", "vector"):
|
||||
assert res["fields"][1]["name"] == check_items.get("vector_name", "vector")
|
||||
if check_items.get("dim", None) is not None:
|
||||
assert res["fields"][1]["params"]["dim"] == check_items.get("dim")
|
||||
assert res["fields"][0]["is_primary"] is True
|
||||
assert res["fields"][0]["field_id"] == 100 and res["fields"][0]["type"] == 5
|
||||
assert res["fields"][1]["field_id"] == 101 and res["fields"][1]["type"] == 101
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_partition_property(partition, func_name, check_items):
|
||||
exp_func_name = "init_partition"
|
||||
@ -248,18 +295,26 @@ class ResponseChecker:
|
||||
assert len(search_res) == check_items["nq"]
|
||||
else:
|
||||
log.info("search_results_check: Numbers of query searched is correct")
|
||||
enable_high_level_api = check_items.get("enable_high_level_api", False)
|
||||
log.debug(search_res)
|
||||
for hits in search_res:
|
||||
searched_original_vectors = []
|
||||
ids = []
|
||||
if enable_high_level_api:
|
||||
for hit in hits:
|
||||
ids.append(hit['id'])
|
||||
else:
|
||||
ids = list(hits.ids)
|
||||
if (len(hits) != check_items["limit"]) \
|
||||
or (len(hits.ids) != check_items["limit"]):
|
||||
or (len(ids) != check_items["limit"]):
|
||||
log.error("search_results_check: limit(topK) searched (%d) "
|
||||
"is not equal with expected (%d)"
|
||||
% (len(hits), check_items["limit"]))
|
||||
assert len(hits) == check_items["limit"]
|
||||
assert len(hits.ids) == check_items["limit"]
|
||||
assert len(ids) == check_items["limit"]
|
||||
else:
|
||||
if check_items.get("ids", None) is not None:
|
||||
ids_match = pc.list_contain_check(list(hits.ids),
|
||||
ids_match = pc.list_contain_check(ids,
|
||||
list(check_items["ids"]))
|
||||
if not ids_match:
|
||||
log.error("search_results_check: ids searched not match")
|
||||
|
||||
@ -38,8 +38,10 @@ class ParamInfo:
|
||||
self.param_password = ""
|
||||
self.param_secure = False
|
||||
self.param_replica_num = ct.default_replica_num
|
||||
self.param_uri = ""
|
||||
self.param_token = ""
|
||||
|
||||
def prepare_param_info(self, host, port, handler, replica_num, user, password, secure):
|
||||
def prepare_param_info(self, host, port, handler, replica_num, user, password, secure, uri, token):
|
||||
self.param_host = host
|
||||
self.param_port = port
|
||||
self.param_handler = handler
|
||||
@ -47,6 +49,8 @@ class ParamInfo:
|
||||
self.param_password = password
|
||||
self.param_secure = secure
|
||||
self.param_replica_num = replica_num
|
||||
self.param_uri = uri
|
||||
self.param_token = token
|
||||
|
||||
|
||||
param_info = ParamInfo()
|
||||
|
||||
@ -253,6 +253,7 @@ class CheckTasks:
|
||||
check_permission_deny = "check_permission_deny"
|
||||
check_value_equal = "check_value_equal"
|
||||
check_rg_property = "check_resource_group_property"
|
||||
check_describe_collection_property = "check_describe_collection_property"
|
||||
|
||||
|
||||
class BulkLoadStates:
|
||||
|
||||
@ -45,6 +45,8 @@ def pytest_addoption(parser):
|
||||
parser.addoption('--field_name', action='store', default="field_name", help="field_name of index")
|
||||
parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number")
|
||||
parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip")
|
||||
parser.addoption('--uri', action='store', default="", help="uri for high level api")
|
||||
parser.addoption('--token', action='store', default="", help="token for high level api")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -174,6 +176,16 @@ def minio_host(request):
|
||||
return request.config.getoption("--minio_host")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def uri(request):
|
||||
return request.config.getoption("--uri")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token(request):
|
||||
return request.config.getoption("--token")
|
||||
|
||||
|
||||
""" fixture func """
|
||||
|
||||
|
||||
@ -188,6 +200,8 @@ def initialize_env(request):
|
||||
secure = request.config.getoption("--secure")
|
||||
clean_log = request.config.getoption("--clean_log")
|
||||
replica_num = request.config.getoption("--replica_num")
|
||||
uri = request.config.getoption("--uri")
|
||||
token = request.config.getoption("--token")
|
||||
|
||||
""" params check """
|
||||
assert ip_check(host) and number_check(port)
|
||||
@ -200,7 +214,7 @@ def initialize_env(request):
|
||||
|
||||
log.info("#" * 80)
|
||||
log.info("[initialize_milvus] Log cleaned up, start testing...")
|
||||
param_info.prepare_param_info(host, port, handler, replica_num, user, password, secure)
|
||||
param_info.prepare_param_info(host, port, handler, replica_num, user, password, secure, uri, token)
|
||||
|
||||
|
||||
@pytest.fixture(params=ct.get_invalid_strs)
|
||||
|
||||
311
tests/python_client/testcases/test_high_level_api.py
Normal file
311
tests/python_client/testcases/test_high_level_api.py
Normal file
@ -0,0 +1,311 @@
|
||||
import multiprocessing
|
||||
import numbers
|
||||
import random
|
||||
import numpy
|
||||
import threading
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import decimal
|
||||
from decimal import Decimal, getcontext
|
||||
from time import sleep
|
||||
import heapq
|
||||
|
||||
from base.client_base import TestcaseBase
|
||||
from utils.util_log import test_log as log
|
||||
from common import common_func as cf
|
||||
from common import common_type as ct
|
||||
from common.common_type import CaseLabel, CheckTasks
|
||||
from utils.util_pymilvus import *
|
||||
from common.constants import *
|
||||
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
|
||||
from base.high_level_api_wrapper import HighLevelApiWrapper
|
||||
client_w = HighLevelApiWrapper()
|
||||
|
||||
prefix = "high_level_api"
|
||||
epsilon = ct.epsilon
|
||||
default_nb = ct.default_nb
|
||||
default_nb_medium = ct.default_nb_medium
|
||||
default_nq = ct.default_nq
|
||||
default_dim = ct.default_dim
|
||||
default_limit = ct.default_limit
|
||||
default_search_exp = "id >= 0"
|
||||
exp_res = "exp_res"
|
||||
default_search_string_exp = "varchar >= \"0\""
|
||||
default_search_mix_exp = "int64 >= 0 && varchar >= \"0\""
|
||||
default_invaild_string_exp = "varchar >= 0"
|
||||
default_json_search_exp = "json_field[\"number\"] >= 0"
|
||||
perfix_expr = 'varchar like "0%"'
|
||||
default_search_field = ct.default_float_vec_field_name
|
||||
default_search_params = ct.default_search_params
|
||||
default_primary_key_field_name = "id"
|
||||
default_vector_field_name = "vector"
|
||||
default_float_field_name = ct.default_float_field_name
|
||||
default_bool_field_name = ct.default_bool_field_name
|
||||
default_string_field_name = ct.default_string_field_name
|
||||
|
||||
|
||||
class TestHighLevelApi(TestcaseBase):
|
||||
""" Test case of search interface """
|
||||
|
||||
@pytest.fixture(scope="function", params=[False, True])
|
||||
def auto_id(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=["COSINE", "L2"])
|
||||
def metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
# The following are invalid base cases
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.xfail(reason="pymilvus issue 1554")
|
||||
def test_high_level_collection_invalid_primary_field(self):
|
||||
"""
|
||||
target: test high level api: client.create_collection
|
||||
method: create collection with invalid primary field
|
||||
expected: Raise exception
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
error = {ct.err_code: 1, ct.err_msg: f"Param id_type must be int or string"}
|
||||
client_w.create_collection(client, collection_name, default_dim, id_type="invalid",
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_high_level_collection_string_auto_id(self):
|
||||
"""
|
||||
target: test high level api: client.create_collection
|
||||
method: create collection with auto id on string primary key
|
||||
expected: Raise exception
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
error = {ct.err_code: 1, ct.err_msg: f"The auto_id can only be specified on field with DataType.INT64"}
|
||||
client_w.create_collection(client, collection_name, default_dim, id_type="string", auto_id=True,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_high_level_create_same_collection_different_params(self):
|
||||
"""
|
||||
target: test high level api: client.create_collection
|
||||
method: create
|
||||
expected: 1. Successfully to create collection with same params
|
||||
2. Report errors for creating collection with same name and different params
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
client_w.create_collection(client, collection_name, default_dim)
|
||||
# 2. create collection with same params
|
||||
client_w.create_collection(client, collection_name, default_dim)
|
||||
# 3. create collection with same name and different params
|
||||
error = {ct.err_code: 1, ct.err_msg: f"create duplicate collection with different parameters, "
|
||||
f"collection: {collection_name}"}
|
||||
client_w.create_collection(client, collection_name, default_dim+1,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
client_w.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_high_level_collection_invalid_metric_type(self):
|
||||
"""
|
||||
target: test high level api: client.create_collection
|
||||
method: create collection with auto id on string primary key
|
||||
expected: Raise exception
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
error = {ct.err_code: 1, ct.err_msg: f"metric type not found or not supported, supported: [L2 IP COSINE]"}
|
||||
client_w.create_collection(client, collection_name, default_dim, metric_type="invalid",
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_high_level_search_not_consistent_metric_type(self, metric_type):
|
||||
"""
|
||||
target: test search with inconsistent metric type (default is IP) with that of index
|
||||
method: create connection, collection, insert and search with not consistent metric type
|
||||
expected: Raise exception
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
client_w.create_collection(client, collection_name, default_dim)
|
||||
# 2. search
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
vectors_to_search = rng.random((1, 8))
|
||||
search_params = {"metric_type": metric_type}
|
||||
error = {ct.err_code: 1, ct.err_msg: f"metric type not match: expected=IP, actual={metric_type}"}
|
||||
client_w.search(client, collection_name, vectors_to_search, limit=default_limit,
|
||||
search_params=search_params,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
client_w.drop_collection(client, collection_name)
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
# The following are valid base cases
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_high_level_search_query_default(self):
|
||||
"""
|
||||
target: test search (high level api) normal case
|
||||
method: create connection, collection, insert and search
|
||||
expected: search/query successfully
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
client_w.create_collection(client, collection_name, default_dim)
|
||||
collections = client_w.list_collections(client)[0]
|
||||
assert collection_name in collections
|
||||
client_w.describe_collection(client, collection_name,
|
||||
check_task=CheckTasks.check_describe_collection_property,
|
||||
check_items={"collection_name": collection_name,
|
||||
"dim": default_dim})
|
||||
# 2. insert
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
|
||||
client_w.insert(client, collection_name, rows)
|
||||
client_w.flush(client, collection_name)
|
||||
assert client_w.num_entities(client, collection_name)[0] == default_nb
|
||||
# 3. search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
insert_ids = [i for i in range(default_nb)]
|
||||
client_w.search(client, collection_name, vectors_to_search,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_high_level_api": True,
|
||||
"nq": len(vectors_to_search),
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit})
|
||||
# 4. query
|
||||
client_w.query(client, collection_name, filter=default_search_exp,
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: rows,
|
||||
"with_vec": True,
|
||||
"primary_field": default_primary_key_field_name})
|
||||
client_w.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="issue 25110")
|
||||
def test_high_level_search_query_string(self):
|
||||
"""
|
||||
target: test search (high level api) for string primary key
|
||||
method: create connection, collection, insert and search
|
||||
expected: search/query successfully
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
client_w.create_collection(client, collection_name, default_dim, id_type="string", max_length=ct.default_length)
|
||||
client_w.describe_collection(client, collection_name,
|
||||
check_task=CheckTasks.check_describe_collection_property,
|
||||
check_items={"collection_name": collection_name,
|
||||
"dim": default_dim,
|
||||
"auto_id": auto_id})
|
||||
# 2. insert
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [{default_primary_key_field_name: str(i), default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
|
||||
client_w.insert(client, collection_name, rows)
|
||||
client_w.flush(client, collection_name)
|
||||
assert client_w.num_entities(client, collection_name)[0] == default_nb
|
||||
# 3. search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
client_w.search(client, collection_name, vectors_to_search,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_high_level_api": True,
|
||||
"nq": len(vectors_to_search),
|
||||
"limit": default_limit})
|
||||
# 4. query
|
||||
client_w.query(client, collection_name, filter=default_search_exp,
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: rows,
|
||||
"with_vec": True,
|
||||
"primary_field": default_primary_key_field_name})
|
||||
client_w.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_high_level_search_different_metric_types(self, metric_type, auto_id):
|
||||
"""
|
||||
target: test search (high level api) normal case
|
||||
method: create connection, collection, insert and search
|
||||
expected: search successfully with limit(topK)
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
client_w.create_collection(client, collection_name, default_dim, metric_type=metric_type, auto_id=auto_id)
|
||||
# 2. insert
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
|
||||
if auto_id:
|
||||
for row in rows:
|
||||
row.pop(default_primary_key_field_name)
|
||||
client_w.insert(client, collection_name, rows)
|
||||
client_w.flush(client, collection_name)
|
||||
assert client_w.num_entities(client, collection_name)[0] == default_nb
|
||||
# 3. search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
search_params = {"metric_type": metric_type}
|
||||
client_w.search(client, collection_name, vectors_to_search, limit=default_limit,
|
||||
search_params=search_params,
|
||||
output_fields=[default_primary_key_field_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_high_level_api": True,
|
||||
"nq": len(vectors_to_search),
|
||||
"limit": default_limit})
|
||||
client_w.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_high_level_delete(self):
|
||||
"""
|
||||
target: test delete (high level api)
|
||||
method: create connection, collection, insert delete, and search
|
||||
expected: search/query successfully without deleted data
|
||||
"""
|
||||
client = self._connect(enable_high_level_api=True)
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
client_w.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
# 2. insert
|
||||
default_nb = 1000
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
|
||||
pks = client_w.insert(client, collection_name, rows)[0]
|
||||
client_w.flush(client, collection_name)
|
||||
assert client_w.num_entities(client, collection_name)[0] == default_nb
|
||||
# 3. get first primary key
|
||||
first_pk_data = client_w.get(client, collection_name, pks[0:1])
|
||||
# 4. delete
|
||||
delete_num = 3
|
||||
client_w.delete(client, collection_name, pks[0:delete_num])
|
||||
# 5. search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
insert_ids = [i for i in range(default_nb)]
|
||||
for insert_id in pks[0:delete_num]:
|
||||
if insert_id in insert_ids:
|
||||
insert_ids.remove(insert_id)
|
||||
limit = default_nb - delete_num
|
||||
client_w.search(client, collection_name, vectors_to_search, limit=default_nb,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_high_level_api": True,
|
||||
"nq": len(vectors_to_search),
|
||||
"ids": insert_ids,
|
||||
"limit": limit})
|
||||
# 6. query
|
||||
client_w.query(client, collection_name, filter=default_search_exp,
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: rows[delete_num:],
|
||||
"with_vec": True,
|
||||
"primary_field": default_primary_key_field_name})
|
||||
client_w.drop_collection(client, collection_name)
|
||||
@ -3771,8 +3771,7 @@ class TestCollectionSearch(TestcaseBase):
|
||||
collection_w.search(vectors[:nq], default_search_field,
|
||||
default_search_params, limit,
|
||||
default_search_exp, _async=_async,
|
||||
**kwargs
|
||||
)
|
||||
**kwargs)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_search_with_consistency_session(self, nq, dim, auto_id, _async, enable_dynamic_field):
|
||||
@ -5624,7 +5623,7 @@ class TestSearchDiskann(TestcaseBase):
|
||||
collection_w.create_index(ct.default_float_vec_field_name, default_index)
|
||||
collection_w.load()
|
||||
search_list = 20
|
||||
default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}}
|
||||
default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}}
|
||||
vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
|
||||
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user