diff --git a/tests/python_client/base/client_base.py b/tests/python_client/base/client_base.py index 113c65945f..7636393a27 100644 --- a/tests/python_client/base/client_base.py +++ b/tests/python_client/base/client_base.py @@ -10,6 +10,7 @@ from base.partition_wrapper import ApiPartitionWrapper from base.index_wrapper import ApiIndexWrapper from base.utility_wrapper import ApiUtilityWrapper from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper +from base.high_level_api_wrapper import HighLevelApiWrapper from utils.util_log import test_log as log from common import common_func as cf from common import common_type as ct @@ -28,6 +29,7 @@ class Base: field_schema_wrap = None collection_object_list = [] resource_group_list = [] + high_level_api_wrap = None def setup_class(self): log.info("[setup_class] Start setup class...") @@ -45,6 +47,7 @@ class Base: self.index_wrap = ApiIndexWrapper() self.collection_schema_wrap = ApiCollectionSchemaWrapper() self.field_schema_wrap = ApiFieldSchemaWrapper() + self.high_level_api_wrap = HighLevelApiWrapper() def teardown_method(self, method): log.info(("*" * 35) + " teardown " + ("*" * 35)) @@ -118,18 +121,28 @@ class TestcaseBase(Base): Public methods that can be used for test cases. """ - def _connect(self): + def _connect(self, enable_high_level_api=False): """ Add a connection and create the connect """ - if cf.param_info.param_user and cf.param_info.param_password: - res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, - host=cf.param_info.param_host, - port=cf.param_info.param_port, user=cf.param_info.param_user, - password=cf.param_info.param_password, - secure=cf.param_info.param_secure) + if enable_high_level_api: + if cf.param_info.param_uri: + uri = cf.param_info.param_uri + else: + uri = "http://" + cf.param_info.param_host + ":" + str(cf.param_info.param_port) + res, is_succ = self.connection_wrap.MilvusClient(uri=uri, + token=cf.param_info.param_token) else: - res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, - host=cf.param_info.param_host, - port=cf.param_info.param_port) + if cf.param_info.param_user and cf.param_info.param_password: + res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, + host=cf.param_info.param_host, + port=cf.param_info.param_port, + user=cf.param_info.param_user, + password=cf.param_info.param_password, + secure=cf.param_info.param_secure) + else: + res, is_succ = self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING, + host=cf.param_info.param_host, + port=cf.param_info.param_port) + return res def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None, diff --git a/tests/python_client/base/collection_wrapper.py b/tests/python_client/base/collection_wrapper.py index cd34d1305a..45d6015374 100644 --- a/tests/python_client/base/collection_wrapper.py +++ b/tests/python_client/base/collection_wrapper.py @@ -330,6 +330,7 @@ class ApiCollectionWrapper: check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run() return res, check_result + @trace() def get_compaction_state(self, timeout=None, check_task=None, check_items=None, **kwargs): timeout = TIMEOUT if timeout is None else timeout func_name = sys._getframe().f_code.co_name @@ -337,6 +338,7 @@ class ApiCollectionWrapper: check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run() return res, check_result + @trace() def get_compaction_plans(self, timeout=None, check_task=None, check_items={}, **kwargs): timeout = TIMEOUT if timeout is None else timeout func_name = sys._getframe().f_code.co_name @@ -350,6 +352,7 @@ class ApiCollectionWrapper: # log.debug(res) return res + @trace() def get_replicas(self, timeout=None, check_task=None, check_items=None, **kwargs): timeout = TIMEOUT if timeout is None else timeout func_name = sys._getframe().f_code.co_name @@ -357,9 +360,12 @@ class ApiCollectionWrapper: check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run() return res, check_result + @trace() def describe(self, timeout=None, check_task=None, check_items=None): timeout = TIMEOUT if timeout is None else timeout func_name = sys._getframe().f_code.co_name res, check = api_request([self.collection.describe, timeout]) check_result = ResponseChecker(res, func_name, check_task, check_items, check).run() return res, check_result + + diff --git a/tests/python_client/base/connections_wrapper.py b/tests/python_client/base/connections_wrapper.py index 80d4f01cac..4d89a0bde7 100644 --- a/tests/python_client/base/connections_wrapper.py +++ b/tests/python_client/base/connections_wrapper.py @@ -1,5 +1,6 @@ from pymilvus import Connections from pymilvus import DefaultConfig +from pymilvus import MilvusClient import sys sys.path.append("..") @@ -58,3 +59,10 @@ class ApiConnectionsWrapper: response, is_succ = api_request([self.connection.get_connection_addr, alias]) check_result = ResponseChecker(response, func_name, check_task, check_items, is_succ, alias=alias).run() return response, check_result + + # high level api + def MilvusClient(self, check_task=None, check_items=None, **kwargs): + func_name = sys._getframe().f_code.co_name + response, succ = api_request([MilvusClient], **kwargs) + check_result = ResponseChecker(response, func_name, check_task, check_items, succ, **kwargs).run() + return response, check_result diff --git a/tests/python_client/base/high_level_api_wrapper.py b/tests/python_client/base/high_level_api_wrapper.py new file mode 100644 index 0000000000..671d999de1 --- /dev/null +++ b/tests/python_client/base/high_level_api_wrapper.py @@ -0,0 +1,163 @@ +import sys +import time +import timeout_decorator +from numpy import NaN + +from pymilvus import Collection + +sys.path.append("..") +from check.func_check import ResponseChecker +from utils.api_request import api_request +from utils.wrapper import trace +from utils.util_log import test_log as log +from pymilvus.orm.types import CONSISTENCY_STRONG +from common.common_func import param_info + +TIMEOUT = 120 +INDEX_NAME = "" + + +# keep small timeout for stability tests +# TIMEOUT = 5 + + +class HighLevelApiWrapper: + + def __init__(self, active_trace=False): + self.active_trace = active_trace + + @trace() + def create_collection(self, client, collection_name, dimension, timeout=None, check_task=None, + check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.create_collection, collection_name, dimension], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, dimension=dimension, + **kwargs).run() + return res, check_result + + @trace() + def insert(self, client, collection_name, data, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.insert, collection_name, data], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, data=data, + **kwargs).run() + return res, check_result + + @trace() + def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None, + timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.search, collection_name, data, filter, limit, + output_fields, search_params], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, data=data, limit=limit, filter=filter, + output_fields=output_fields, search_params=search_params, + **kwargs).run() + return res, check_result + + @trace() + def query(self, client, collection_name, filter=None, output_fields=None, + timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.query, collection_name, filter, output_fields], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, filter=filter, + output_fields=output_fields, + **kwargs).run() + return res, check_result + + @trace() + def get(self, client, collection_name, ids, output_fields=None, + timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.get, collection_name, ids, output_fields], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, ids=ids, + output_fields=output_fields, + **kwargs).run() + return res, check_result + + @trace() + def num_entities(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.num_entities, collection_name], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, + **kwargs).run() + return res, check_result + + @trace() + def delete(self, client, collection_name, pks, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.delete, collection_name, pks], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, pks=pks, + **kwargs).run() + return res, check_result + + @trace() + def flush(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.flush, collection_name], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, + **kwargs).run() + return res, check_result + + @trace() + def describe_collection(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.describe_collection, collection_name], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, + **kwargs).run() + return res, check_result + + @trace() + def list_collections(self, client, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.list_collections], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + **kwargs).run() + return res, check_result + + @trace() + def drop_collection(self, client, collection_name, check_task=None, check_items=None, **kwargs): + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.drop_collection, collection_name], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, + **kwargs).run() + return res, check_result + diff --git a/tests/python_client/check/func_check.py b/tests/python_client/check/func_check.py index b86aac74a1..3549d72157 100644 --- a/tests/python_client/check/func_check.py +++ b/tests/python_client/check/func_check.py @@ -83,9 +83,14 @@ class ResponseChecker: elif self.check_task == CheckTasks.check_permission_deny: # Collection interface response check result = self.check_permission_deny(self.response, self.succ) + elif self.check_task == CheckTasks.check_rg_property: # describe resource group interface response check result = self.check_rg_property(self.response, self.func_name, self.check_items) + + elif self.check_task == CheckTasks.check_describe_collection_property: + # describe collection interface(high level api) response check + result = self.check_describe_collection_property(self.response, self.func_name, self.check_items) # Add check_items here if something new need verify @@ -178,6 +183,48 @@ class ResponseChecker: assert collection.primary_field.name == check_items.get("primary") return True + @staticmethod + def check_describe_collection_property(res, func_name, check_items): + """ + According to the check_items to check collection properties of res, which return from func_name + :param res: actual response of init collection + :type res: Collection + + :param func_name: init collection API + :type func_name: str + + :param check_items: which items expected to be checked, including name, schema, num_entities, primary + :type check_items: dict, {check_key: expected_value} + """ + exp_func_name = "describe_collection" + if func_name != exp_func_name: + log.warning("The function name is {} rather than {}".format(func_name, exp_func_name)) + if len(check_items) == 0: + raise Exception("No expect values found in the check task") + if check_items.get("collection_name", None) is not None: + assert res["collection_name"] == check_items.get("collection_name") + if check_items.get("auto_id", False): + assert res["auto_id"] == check_items.get("auto_id") + if check_items.get("num_shards", 1): + assert res["num_shards"] == check_items.get("num_shards", 1) + if check_items.get("consistency_level", 2): + assert res["consistency_level"] == check_items.get("consistency_level", 2) + if check_items.get("enable_dynamic_field", True): + assert res["enable_dynamic_field"] == check_items.get("enable_dynamic_field", True) + if check_items.get("num_partitions", 1): + assert res["num_partitions"] == check_items.get("num_partitions", 1) + if check_items.get("id_name", "id"): + assert res["fields"][0]["name"] == check_items.get("id_name", "id") + if check_items.get("vector_name", "vector"): + assert res["fields"][1]["name"] == check_items.get("vector_name", "vector") + if check_items.get("dim", None) is not None: + assert res["fields"][1]["params"]["dim"] == check_items.get("dim") + assert res["fields"][0]["is_primary"] is True + assert res["fields"][0]["field_id"] == 100 and res["fields"][0]["type"] == 5 + assert res["fields"][1]["field_id"] == 101 and res["fields"][1]["type"] == 101 + + return True + @staticmethod def check_partition_property(partition, func_name, check_items): exp_func_name = "init_partition" @@ -248,18 +295,26 @@ class ResponseChecker: assert len(search_res) == check_items["nq"] else: log.info("search_results_check: Numbers of query searched is correct") + enable_high_level_api = check_items.get("enable_high_level_api", False) + log.debug(search_res) for hits in search_res: searched_original_vectors = [] + ids = [] + if enable_high_level_api: + for hit in hits: + ids.append(hit['id']) + else: + ids = list(hits.ids) if (len(hits) != check_items["limit"]) \ - or (len(hits.ids) != check_items["limit"]): + or (len(ids) != check_items["limit"]): log.error("search_results_check: limit(topK) searched (%d) " "is not equal with expected (%d)" % (len(hits), check_items["limit"])) assert len(hits) == check_items["limit"] - assert len(hits.ids) == check_items["limit"] + assert len(ids) == check_items["limit"] else: if check_items.get("ids", None) is not None: - ids_match = pc.list_contain_check(list(hits.ids), + ids_match = pc.list_contain_check(ids, list(check_items["ids"])) if not ids_match: log.error("search_results_check: ids searched not match") diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index c53df134cd..cc055f12ad 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -38,8 +38,10 @@ class ParamInfo: self.param_password = "" self.param_secure = False self.param_replica_num = ct.default_replica_num + self.param_uri = "" + self.param_token = "" - def prepare_param_info(self, host, port, handler, replica_num, user, password, secure): + def prepare_param_info(self, host, port, handler, replica_num, user, password, secure, uri, token): self.param_host = host self.param_port = port self.param_handler = handler @@ -47,6 +49,8 @@ class ParamInfo: self.param_password = password self.param_secure = secure self.param_replica_num = replica_num + self.param_uri = uri + self.param_token = token param_info = ParamInfo() diff --git a/tests/python_client/common/common_type.py b/tests/python_client/common/common_type.py index 36ad5e3e66..3f18e0605b 100644 --- a/tests/python_client/common/common_type.py +++ b/tests/python_client/common/common_type.py @@ -253,6 +253,7 @@ class CheckTasks: check_permission_deny = "check_permission_deny" check_value_equal = "check_value_equal" check_rg_property = "check_resource_group_property" + check_describe_collection_property = "check_describe_collection_property" class BulkLoadStates: diff --git a/tests/python_client/conftest.py b/tests/python_client/conftest.py index e717c8eec2..8511acbe63 100644 --- a/tests/python_client/conftest.py +++ b/tests/python_client/conftest.py @@ -45,6 +45,8 @@ def pytest_addoption(parser): parser.addoption('--field_name', action='store', default="field_name", help="field_name of index") parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number") parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip") + parser.addoption('--uri', action='store', default="", help="uri for high level api") + parser.addoption('--token', action='store', default="", help="token for high level api") @pytest.fixture @@ -174,6 +176,16 @@ def minio_host(request): return request.config.getoption("--minio_host") +@pytest.fixture +def uri(request): + return request.config.getoption("--uri") + + +@pytest.fixture +def token(request): + return request.config.getoption("--token") + + """ fixture func """ @@ -188,6 +200,8 @@ def initialize_env(request): secure = request.config.getoption("--secure") clean_log = request.config.getoption("--clean_log") replica_num = request.config.getoption("--replica_num") + uri = request.config.getoption("--uri") + token = request.config.getoption("--token") """ params check """ assert ip_check(host) and number_check(port) @@ -200,7 +214,7 @@ def initialize_env(request): log.info("#" * 80) log.info("[initialize_milvus] Log cleaned up, start testing...") - param_info.prepare_param_info(host, port, handler, replica_num, user, password, secure) + param_info.prepare_param_info(host, port, handler, replica_num, user, password, secure, uri, token) @pytest.fixture(params=ct.get_invalid_strs) diff --git a/tests/python_client/testcases/test_high_level_api.py b/tests/python_client/testcases/test_high_level_api.py new file mode 100644 index 0000000000..d4519d51a5 --- /dev/null +++ b/tests/python_client/testcases/test_high_level_api.py @@ -0,0 +1,311 @@ +import multiprocessing +import numbers +import random +import numpy +import threading +import pytest +import pandas as pd +import decimal +from decimal import Decimal, getcontext +from time import sleep +import heapq + +from base.client_base import TestcaseBase +from utils.util_log import test_log as log +from common import common_func as cf +from common import common_type as ct +from common.common_type import CaseLabel, CheckTasks +from utils.util_pymilvus import * +from common.constants import * +from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY +from base.high_level_api_wrapper import HighLevelApiWrapper +client_w = HighLevelApiWrapper() + +prefix = "high_level_api" +epsilon = ct.epsilon +default_nb = ct.default_nb +default_nb_medium = ct.default_nb_medium +default_nq = ct.default_nq +default_dim = ct.default_dim +default_limit = ct.default_limit +default_search_exp = "id >= 0" +exp_res = "exp_res" +default_search_string_exp = "varchar >= \"0\"" +default_search_mix_exp = "int64 >= 0 && varchar >= \"0\"" +default_invaild_string_exp = "varchar >= 0" +default_json_search_exp = "json_field[\"number\"] >= 0" +perfix_expr = 'varchar like "0%"' +default_search_field = ct.default_float_vec_field_name +default_search_params = ct.default_search_params +default_primary_key_field_name = "id" +default_vector_field_name = "vector" +default_float_field_name = ct.default_float_field_name +default_bool_field_name = ct.default_bool_field_name +default_string_field_name = ct.default_string_field_name + + +class TestHighLevelApi(TestcaseBase): + """ Test case of search interface """ + + @pytest.fixture(scope="function", params=[False, True]) + def auto_id(self, request): + yield request.param + + @pytest.fixture(scope="function", params=["COSINE", "L2"]) + def metric_type(self, request): + yield request.param + + """ + ****************************************************************** + # The following are invalid base cases + ****************************************************************** + """ + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.xfail(reason="pymilvus issue 1554") + def test_high_level_collection_invalid_primary_field(self): + """ + target: test high level api: client.create_collection + method: create collection with invalid primary field + expected: Raise exception + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + error = {ct.err_code: 1, ct.err_msg: f"Param id_type must be int or string"} + client_w.create_collection(client, collection_name, default_dim, id_type="invalid", + check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + def test_high_level_collection_string_auto_id(self): + """ + target: test high level api: client.create_collection + method: create collection with auto id on string primary key + expected: Raise exception + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + error = {ct.err_code: 1, ct.err_msg: f"The auto_id can only be specified on field with DataType.INT64"} + client_w.create_collection(client, collection_name, default_dim, id_type="string", auto_id=True, + check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_high_level_create_same_collection_different_params(self): + """ + target: test high level api: client.create_collection + method: create + expected: 1. Successfully to create collection with same params + 2. Report errors for creating collection with same name and different params + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + client_w.create_collection(client, collection_name, default_dim) + # 2. create collection with same params + client_w.create_collection(client, collection_name, default_dim) + # 3. create collection with same name and different params + error = {ct.err_code: 1, ct.err_msg: f"create duplicate collection with different parameters, " + f"collection: {collection_name}"} + client_w.create_collection(client, collection_name, default_dim+1, + check_task=CheckTasks.err_res, check_items=error) + client_w.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L2) + def test_high_level_collection_invalid_metric_type(self): + """ + target: test high level api: client.create_collection + method: create collection with auto id on string primary key + expected: Raise exception + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + error = {ct.err_code: 1, ct.err_msg: f"metric type not found or not supported, supported: [L2 IP COSINE]"} + client_w.create_collection(client, collection_name, default_dim, metric_type="invalid", + check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + def test_high_level_search_not_consistent_metric_type(self, metric_type): + """ + target: test search with inconsistent metric type (default is IP) with that of index + method: create connection, collection, insert and search with not consistent metric type + expected: Raise exception + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + client_w.create_collection(client, collection_name, default_dim) + # 2. search + rng = np.random.default_rng(seed=19530) + vectors_to_search = rng.random((1, 8)) + search_params = {"metric_type": metric_type} + error = {ct.err_code: 1, ct.err_msg: f"metric type not match: expected=IP, actual={metric_type}"} + client_w.search(client, collection_name, vectors_to_search, limit=default_limit, + search_params=search_params, + check_task=CheckTasks.err_res, check_items=error) + client_w.drop_collection(client, collection_name) + + """ + ****************************************************************** + # The following are valid base cases + ****************************************************************** + """ + + @pytest.mark.tags(CaseLabel.L1) + def test_high_level_search_query_default(self): + """ + target: test search (high level api) normal case + method: create connection, collection, insert and search + expected: search/query successfully + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + client_w.create_collection(client, collection_name, default_dim) + collections = client_w.list_collections(client)[0] + assert collection_name in collections + client_w.describe_collection(client, collection_name, + check_task=CheckTasks.check_describe_collection_property, + check_items={"collection_name": collection_name, + "dim": default_dim}) + # 2. insert + rng = np.random.default_rng(seed=19530) + rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)] + client_w.insert(client, collection_name, rows) + client_w.flush(client, collection_name) + assert client_w.num_entities(client, collection_name)[0] == default_nb + # 3. search + vectors_to_search = rng.random((1, default_dim)) + insert_ids = [i for i in range(default_nb)] + client_w.search(client, collection_name, vectors_to_search, + check_task=CheckTasks.check_search_results, + check_items={"enable_high_level_api": True, + "nq": len(vectors_to_search), + "ids": insert_ids, + "limit": default_limit}) + # 4. query + client_w.query(client, collection_name, filter=default_search_exp, + check_task=CheckTasks.check_query_results, + check_items={exp_res: rows, + "with_vec": True, + "primary_field": default_primary_key_field_name}) + client_w.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.skip(reason="issue 25110") + def test_high_level_search_query_string(self): + """ + target: test search (high level api) for string primary key + method: create connection, collection, insert and search + expected: search/query successfully + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + client_w.create_collection(client, collection_name, default_dim, id_type="string", max_length=ct.default_length) + client_w.describe_collection(client, collection_name, + check_task=CheckTasks.check_describe_collection_property, + check_items={"collection_name": collection_name, + "dim": default_dim, + "auto_id": auto_id}) + # 2. insert + rng = np.random.default_rng(seed=19530) + rows = [{default_primary_key_field_name: str(i), default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)] + client_w.insert(client, collection_name, rows) + client_w.flush(client, collection_name) + assert client_w.num_entities(client, collection_name)[0] == default_nb + # 3. search + vectors_to_search = rng.random((1, default_dim)) + client_w.search(client, collection_name, vectors_to_search, + check_task=CheckTasks.check_search_results, + check_items={"enable_high_level_api": True, + "nq": len(vectors_to_search), + "limit": default_limit}) + # 4. query + client_w.query(client, collection_name, filter=default_search_exp, + check_task=CheckTasks.check_query_results, + check_items={exp_res: rows, + "with_vec": True, + "primary_field": default_primary_key_field_name}) + client_w.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L2) + def test_high_level_search_different_metric_types(self, metric_type, auto_id): + """ + target: test search (high level api) normal case + method: create connection, collection, insert and search + expected: search successfully with limit(topK) + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + client_w.create_collection(client, collection_name, default_dim, metric_type=metric_type, auto_id=auto_id) + # 2. insert + rng = np.random.default_rng(seed=19530) + rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)] + if auto_id: + for row in rows: + row.pop(default_primary_key_field_name) + client_w.insert(client, collection_name, rows) + client_w.flush(client, collection_name) + assert client_w.num_entities(client, collection_name)[0] == default_nb + # 3. search + vectors_to_search = rng.random((1, default_dim)) + search_params = {"metric_type": metric_type} + client_w.search(client, collection_name, vectors_to_search, limit=default_limit, + search_params=search_params, + output_fields=[default_primary_key_field_name], + check_task=CheckTasks.check_search_results, + check_items={"enable_high_level_api": True, + "nq": len(vectors_to_search), + "limit": default_limit}) + client_w.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L1) + def test_high_level_delete(self): + """ + target: test delete (high level api) + method: create connection, collection, insert delete, and search + expected: search/query successfully without deleted data + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + client_w.create_collection(client, collection_name, default_dim, consistency_level="Strong") + # 2. insert + default_nb = 1000 + rng = np.random.default_rng(seed=19530) + rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)] + pks = client_w.insert(client, collection_name, rows)[0] + client_w.flush(client, collection_name) + assert client_w.num_entities(client, collection_name)[0] == default_nb + # 3. get first primary key + first_pk_data = client_w.get(client, collection_name, pks[0:1]) + # 4. delete + delete_num = 3 + client_w.delete(client, collection_name, pks[0:delete_num]) + # 5. search + vectors_to_search = rng.random((1, default_dim)) + insert_ids = [i for i in range(default_nb)] + for insert_id in pks[0:delete_num]: + if insert_id in insert_ids: + insert_ids.remove(insert_id) + limit = default_nb - delete_num + client_w.search(client, collection_name, vectors_to_search, limit=default_nb, + check_task=CheckTasks.check_search_results, + check_items={"enable_high_level_api": True, + "nq": len(vectors_to_search), + "ids": insert_ids, + "limit": limit}) + # 6. query + client_w.query(client, collection_name, filter=default_search_exp, + check_task=CheckTasks.check_query_results, + check_items={exp_res: rows[delete_num:], + "with_vec": True, + "primary_field": default_primary_key_field_name}) + client_w.drop_collection(client, collection_name) diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 5c28188839..3779ebdbfa 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -3771,8 +3771,7 @@ class TestCollectionSearch(TestcaseBase): collection_w.search(vectors[:nq], default_search_field, default_search_params, limit, default_search_exp, _async=_async, - **kwargs - ) + **kwargs) @pytest.mark.tags(CaseLabel.L1) def test_search_with_consistency_session(self, nq, dim, auto_id, _async, enable_dynamic_field): @@ -5624,7 +5623,7 @@ class TestSearchDiskann(TestcaseBase): collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() search_list = 20 - default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}} + default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field,