diff --git a/tests/python/__init__.py b/tests/python/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/python/conftest.py b/tests/python/conftest.py deleted file mode 100644 index 7f1fafb582..0000000000 --- a/tests/python/conftest.py +++ /dev/null @@ -1,255 +0,0 @@ -import socket -import pytest - -from .utils import * - -timeout = 60 -dimension = 128 -delete_timeout = 60 - - -def pytest_addoption(parser): - parser.addoption("--ip", action="store", default="localhost") - parser.addoption("--service", action="store", default="") - parser.addoption("--port", action="store", default=19530) - parser.addoption("--http-port", action="store", default=19121) - parser.addoption("--handler", action="store", default="GRPC") - parser.addoption("--tag", action="store", default="all", help="only run tests matching the tag.") - parser.addoption('--dry-run', action='store_true', default=False) - - -def pytest_configure(config): - # register an additional marker - config.addinivalue_line( - "markers", "tag(name): mark test to run only matching the tag" - ) - - -def pytest_runtest_setup(item): - tags = list() - for marker in item.iter_markers(name="tag"): - for tag in marker.args: - tags.append(tag) - if tags: - cmd_tag = item.config.getoption("--tag") - if cmd_tag != "all" and cmd_tag not in tags: - pytest.skip("test requires tag in {!r}".format(tags)) - - -def pytest_runtestloop(session): - if session.config.getoption('--dry-run'): - total_passed = 0 - total_skipped = 0 - test_file_to_items = {} - for item in session.items: - file_name, test_class, test_func = item.nodeid.split("::") - if test_file_to_items.get(file_name) is not None: - test_file_to_items[file_name].append(item) - else: - test_file_to_items[file_name] = [item] - for k, items in test_file_to_items.items(): - skip_case = [] - should_pass_but_skipped = [] - skipped_other_reason = [] - - level2_case = [] - for item in items: - if "pytestmark" in item.keywords.keys(): - markers = item.keywords["pytestmark"] - skip_case.extend([item.nodeid for marker in markers if marker.name == 'skip']) - should_pass_but_skipped.extend([item.nodeid for marker in markers if marker.name == 'skip' and len(marker.args) > 0 and marker.args[0] == "should pass"]) - skipped_other_reason.extend([item.nodeid for marker in markers if marker.name == 'skip' and (len(marker.args) < 1 or marker.args[0] != "should pass")]) - level2_case.extend([item.nodeid for marker in markers if marker.name == 'level' and marker.args[0] == 2]) - - print("") - print(f"[{k}]:") - print(f" Total : {len(items):13}") - print(f" Passed : {len(items) - len(skip_case):13}") - print(f" Skipped : {len(skip_case):13}") - print(f" - should pass: {len(should_pass_but_skipped):4}") - print(f" - not supported: {len(skipped_other_reason):4}") - print(f" Level2 : {len(level2_case):13}") - - print(f" ---------------------------------------") - print(f" should pass but skipped: ") - print("") - for nodeid in should_pass_but_skipped: - name, test_class, test_func = nodeid.split("::") - print(f" {name:8}: {test_class}.{test_func}") - print("") - print(f"===============================================") - total_passed += len(items) - len(skip_case) - total_skipped += len(skip_case) - - print("Total tests : ", len(session.items)) - print("Total passed: ", total_passed) - print("Total skiped: ", total_skipped) - return True - - -def check_server_connection(request): - ip = request.config.getoption("--ip") - port = request.config.getoption("--port") - - connected = True - if ip and (ip not in ['localhost', '127.0.0.1']): - try: - socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP) - except Exception as e: - print("Socket connnet failed: %s" % str(e)) - connected = False - return connected - - -@pytest.fixture(scope="module") -def connect(request): - ip = request.config.getoption("--ip") - service_name = request.config.getoption("--service") - port = request.config.getoption("--port") - http_port = request.config.getoption("--http-port") - handler = request.config.getoption("--handler") - if handler == "HTTP": - port = http_port - try: - milvus = get_milvus(host=ip, port=port, handler=handler) - # reset_build_index_threshold(milvus) - except Exception as e: - logging.getLogger().error(str(e)) - pytest.exit("Milvus server can not connected, exit pytest ...") - def fin(): - try: - milvus.close() - pass - except Exception as e: - logging.getLogger().info(str(e)) - request.addfinalizer(fin) - return milvus - - -@pytest.fixture(scope="module") -def dis_connect(request): - ip = request.config.getoption("--ip") - service_name = request.config.getoption("--service") - port = request.config.getoption("--port") - http_port = request.config.getoption("--http-port") - handler = request.config.getoption("--handler") - if handler == "HTTP": - port = http_port - milvus = get_milvus(host=ip, port=port, handler=handler) - milvus.close() - return milvus - - -@pytest.fixture(scope="module") -def args(request): - ip = request.config.getoption("--ip") - service_name = request.config.getoption("--service") - port = request.config.getoption("--port") - http_port = request.config.getoption("--http-port") - handler = request.config.getoption("--handler") - if handler == "HTTP": - port = http_port - args = {"ip": ip, "port": port, "handler": handler, "service_name": service_name} - return args - - -@pytest.fixture(scope="module") -def milvus(request): - ip = request.config.getoption("--ip") - port = request.config.getoption("--port") - http_port = request.config.getoption("--http-port") - handler = request.config.getoption("--handler") - if handler == "HTTP": - port = http_port - return get_milvus(host=ip, port=port, handler=handler) - - -@pytest.fixture(scope="function") -def collection(request, connect): - ori_collection_name = getattr(request.module, "collection_id", "test") - collection_name = gen_unique_str(ori_collection_name) - try: - default_fields = gen_default_fields() - connect.create_collection(collection_name, default_fields) - connect.load_collection(collection_name) - except Exception as e: - pytest.exit(str(e)) - def teardown(): - if connect.has_collection(collection_name): - connect.drop_collection(collection_name, timeout=delete_timeout) - request.addfinalizer(teardown) - assert connect.has_collection(collection_name) - return collection_name - -@pytest.fixture(scope="function") -def collection_without_loading(request, connect): - ori_collection_name = getattr(request.module, "collection_id", "test") - collection_name = gen_unique_str(ori_collection_name) - try: - default_fields = gen_default_fields() - connect.create_collection(collection_name, default_fields) - except Exception as e: - pytest.exit(str(e)) - def teardown(): - if connect.has_collection(collection_name): - connect.drop_collection(collection_name, timeout=delete_timeout) - request.addfinalizer(teardown) - assert connect.has_collection(collection_name) - return collection_name - - -# customised id -@pytest.fixture(scope="function") -def id_collection(request, connect): - ori_collection_name = getattr(request.module, "collection_id", "test") - collection_name = gen_unique_str(ori_collection_name) - try: - fields = gen_default_fields(auto_id=False) - connect.create_collection(collection_name, fields) - connect.load_collection(collection_name) - except Exception as e: - pytest.exit(str(e)) - def teardown(): - if connect.has_collection(collection_name): - connect.drop_collection(collection_name, timeout=delete_timeout) - request.addfinalizer(teardown) - assert connect.has_collection(collection_name) - return collection_name - - -@pytest.fixture(scope="function") -def binary_collection(request, connect): - ori_collection_name = getattr(request.module, "collection_id", "test") - collection_name = gen_unique_str(ori_collection_name) - try: - fields = gen_binary_default_fields() - connect.create_collection(collection_name, fields) - connect.load_collection(collection_name) - except Exception as e: - pytest.exit(str(e)) - def teardown(): - collection_names = connect.list_collections() - if connect.has_collection(collection_name): - connect.drop_collection(collection_name, timeout=delete_timeout) - request.addfinalizer(teardown) - assert connect.has_collection(collection_name) - return collection_name - - -# customised id -@pytest.fixture(scope="function") -def binary_id_collection(request, connect): - ori_collection_name = getattr(request.module, "collection_id", "test") - collection_name = gen_unique_str(ori_collection_name) - try: - fields = gen_binary_default_fields(auto_id=False) - connect.create_collection(collection_name, fields) - connect.load_collection(collection_name) - except Exception as e: - pytest.exit(str(e)) - def teardown(): - if connect.has_collection(collection_name): - connect.drop_collection(collection_name, timeout=delete_timeout) - request.addfinalizer(teardown) - assert connect.has_collection(collection_name) - return collection_name diff --git a/tests/python/constants.py b/tests/python/constants.py deleted file mode 100644 index 719ab2bc9b..0000000000 --- a/tests/python/constants.py +++ /dev/null @@ -1,22 +0,0 @@ -from . import utils - -default_fields = utils.gen_default_fields() -default_binary_fields = utils.gen_binary_default_fields() - -default_entity = utils.gen_entities(1) -default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1) - -default_entity_row = utils.gen_entities_rows(1) -default_raw_binary_vector_row, default_binary_entity_row = utils.gen_binary_entities_rows(1) - - -default_entities = utils.gen_entities(utils.default_nb) -default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb) - - -default_entities_new = utils.gen_entities_new(utils.default_nb) -default_raw_binary_vectors_new, default_binary_entities_new = utils.gen_binary_entities_new(utils.default_nb) - - -default_entities_rows = utils.gen_entities_rows(utils.default_nb) -default_raw_binary_vectors_rows, default_binary_entities_rows = utils.gen_binary_entities_rows(utils.default_nb) \ No newline at end of file diff --git a/tests/python/factorys.py b/tests/python/factorys.py deleted file mode 100644 index e0568a77c1..0000000000 --- a/tests/python/factorys.py +++ /dev/null @@ -1,127 +0,0 @@ -# STL imports -import random -import string -import time -import datetime -import random -import struct -import sys -import uuid -from functools import wraps - -sys.path.append('..') -# Third party imports -import numpy as np -import faker -from faker.providers import BaseProvider - -# local application imports -from milvus.client.types import IndexType, MetricType, DataType - -# grpc -from milvus.client.grpc_handler import Prepare as gPrepare -from milvus.grpc_gen import milvus_pb2 - - -def gen_vectors(num, dim): - return [[random.random() for _ in range(dim)] for _ in range(num)] - - -def gen_single_vector(dim): - return [[random.random() for _ in range(dim)]] - - -def gen_vector(nb, d, seed=np.random.RandomState(1234)): - xb = seed.rand(nb, d).astype("float32") - return xb.tolist() - - -def gen_unique_str(str=None): - prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) - return prefix if str is None else str + "_" + prefix - - -def get_current_day(): - return time.strftime('%Y-%m-%d', time.localtime()) - - -def get_last_day(day): - tmp = datetime.datetime.now() - datetime.timedelta(days=day) - return tmp.strftime('%Y-%m-%d') - - -def get_next_day(day): - tmp = datetime.datetime.now() + datetime.timedelta(days=day) - return tmp.strftime('%Y-%m-%d') - - -def gen_long_str(num): - string = '' - for _ in range(num): - char = random.choice('tomorrow') - string += char - - -def gen_one_binary(topk): - ids = [random.randrange(10000000, 99999999) for _ in range(topk)] - distances = [random.random() for _ in range(topk)] - return milvus_pb2.TopKQueryResult(struct.pack(str(topk) + 'l', *ids), struct.pack(str(topk) + 'd', *distances)) - - -def gen_nq_binaries(nq, topk): - return [gen_one_binary(topk) for _ in range(nq)] - - -def fake_query_bin_result(nq, topk): - return gen_nq_binaries(nq, topk) - - -class FakerProvider(BaseProvider): - - def collection_name(self): - return 'collection_names' + str(uuid.uuid4()).replace('-', '_') - - def normal_field_name(self): - return 'normal_field_names' + str(uuid.uuid4()).replace('-', '_') - - def vector_field_name(self): - return 'vector_field_names' + str(uuid.uuid4()).replace('-', '_') - - def name(self): - return 'name' + str(random.randint(1000, 9999)) - - def dim(self): - return random.randint(0, 999) - - -fake = faker.Faker() -fake.add_provider(FakerProvider) - -def collection_name_factory(): - return fake.collection_name() - -def collection_schema_factory(): - param = { - "fields": [ - {"name": fake.normal_field_name(),"type": DataType.INT32}, - {"name": fake.vector_field_name(),"type": DataType.FLOAT_VECTOR, "params": {"dim": random.randint(1, 999)}}, - ], - "auto_id": True, - } - return param - - -def records_factory(dimension, nq): - return [[random.random() for _ in range(dimension)] for _ in range(nq)] - - -def time_it(func): - @wraps(func) - def inner(*args, **kwrgs): - pref = time.perf_counter() - result = func(*args, **kwrgs) - delt = time.perf_counter() - pref - print(f"[{func.__name__}][{delt:.4}s]") - return result - - return inner diff --git a/tests/python/pytest.ini b/tests/python/pytest.ini deleted file mode 100644 index 2b18e97f6c..0000000000 --- a/tests/python/pytest.ini +++ /dev/null @@ -1,19 +0,0 @@ -[pytest] -log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s) -log_date_format = %Y-%m-%d %H:%M:%S - -# cli arguments. `-x`-stop test when error occurred; -addopts = -x - -testpaths = . - -log_cli = true -log_level = 10 - -timeout = 360 - -markers = - level: test level - serial - -; level = 1 diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt deleted file mode 100644 index 437bf67ede..0000000000 --- a/tests/python/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -grpcio==1.26.0 -grpcio-tools==1.26.0 -numpy==1.18.1 -pytest-cov==2.8.1 -pymilvus-distributed==0.0.35 -sklearn==0.0 -pytest==4.5.0 -pytest-timeout==1.3.3 -pytest-repeat==0.8.0 -allure-pytest==2.7.0 -pytest-print==0.1.2 -pytest-level==0.1.1 -pytest-xdist==1.23.2 -git+https://gitee.com/quicksilver/pytest-tags.git diff --git a/tests/python/test_create_collection.py b/tests/python/test_create_collection.py deleted file mode 100644 index 68e00e8cf2..0000000000 --- a/tests/python/test_create_collection.py +++ /dev/null @@ -1,314 +0,0 @@ -import pytest -from .utils import * -from .constants import * - -uid = "create_collection" - -class TestCreateCollection: - """ - ****************************************************************** - The following cases are used to test `create_collection` function - ****************************************************************** - """ - @pytest.fixture( - scope="function", - params=gen_single_filter_fields() - ) - def get_filter_field(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_single_vector_fields() - ) - def get_vector_field(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_segment_row_limits() - ) - def get_segment_row_limit(self, request): - yield request.param - - def test_create_collection_fields(self, connect, get_filter_field, get_vector_field): - ''' - target: test create normal collection with different fields - method: create collection with diff fields: metric/field_type/... - expected: no exception raised - ''' - filter_field = get_filter_field - # logging.getLogger().info(filter_field) - vector_field = get_vector_field - collection_name = gen_unique_str(uid) - fields = { - "fields": [filter_field, vector_field], - } - # logging.getLogger().info(fields) - connect.create_collection(collection_name, fields) - assert connect.has_collection(collection_name) - - def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field): - ''' - target: test create normal collection with different fields - method: create collection with diff fields: metric/field_type/... - expected: no exception raised - ''' - filter_field = get_filter_field - vector_field = get_vector_field - collection_name = gen_unique_str(uid) - fields = { - "fields": [filter_field, vector_field], - } - connect.create_collection(collection_name, fields) - assert connect.has_collection(collection_name) - - @pytest.mark.skip("no segment_row_limit") - def test_create_collection_segment_row_limit(self, connect): - ''' - target: test create normal collection with different fields - method: create collection with diff segment_row_limit - expected: no exception raised - ''' - collection_name = gen_unique_str(uid) - fields = copy.deepcopy(default_fields) - # fields["segment_row_limit"] = get_segment_row_limit - connect.create_collection(collection_name, fields) - assert connect.has_collection(collection_name) - - @pytest.mark.skip("no flush") - def _test_create_collection_auto_flush_disabled(self, connect): - ''' - target: test create normal collection, with large auto_flush_interval - method: create collection with corrent params - expected: create status return ok - ''' - disable_flush(connect) - collection_name = gen_unique_str(uid) - try: - connect.create_collection(collection_name, default_fields) - finally: - enable_flush(connect) - - def test_create_collection_after_insert(self, connect, collection): - ''' - target: test insert vector, then create collection again - method: insert vector and create collection - expected: error raised - ''' - # pdb.set_trace() - connect.insert(collection, default_entity) - - with pytest.raises(Exception) as e: - connect.create_collection(collection, default_fields) - - def test_create_collection_after_insert_flush(self, connect, collection): - ''' - target: test insert vector, then create collection again - method: insert vector and create collection - expected: error raised - ''' - connect.insert(collection, default_entity) - connect.flush([collection]) - with pytest.raises(Exception) as e: - connect.create_collection(collection, default_fields) - - def test_create_collection_without_connection(self, dis_connect): - ''' - target: test create collection, without connection - method: create collection with correct params, with a disconnected instance - expected: error raised - ''' - collection_name = gen_unique_str(uid) - with pytest.raises(Exception) as e: - dis_connect.create_collection(collection_name, default_fields) - - def test_create_collection_existed(self, connect): - ''' - target: test create collection but the collection name have already existed - method: create collection with the same collection_name - expected: error raised - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, default_fields) - - def test_create_after_drop_collection(self, connect, collection): - ''' - target: create with the same collection name after collection dropped - method: delete, then create - expected: create success - ''' - connect.drop_collection(collection) - time.sleep(2) - connect.create_collection(collection, default_fields) - - @pytest.mark.level(2) - def test_create_collection_multithread(self, connect): - ''' - target: test create collection with multithread - method: create collection using multithread, - expected: collections are created - ''' - threads_num = 8 - threads = [] - collection_names = [] - - def create(): - collection_name = gen_unique_str(uid) - collection_names.append(collection_name) - connect.create_collection(collection_name, default_fields) - for i in range(threads_num): - t = threading.Thread(target=create, args=()) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - for item in collection_names: - assert item in connect.list_collections() - connect.drop_collection(item) - - -class TestCreateCollectionInvalid(object): - """ - Test creating collections with invalid params - """ - @pytest.fixture( - scope="function", - params=gen_invalid_metric_types() - ) - def get_metric_type(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_ints() - ) - def get_segment_row_limit(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_ints() - ) - def get_dim(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_invalid_string(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_field_types() - ) - def get_field_type(self, request): - yield request.param - - @pytest.mark.level(2) - @pytest.mark.skip("no segment row limit") - def test_create_collection_with_invalid_segment_row_limit(self, connect, get_segment_row_limit): - collection_name = gen_unique_str() - fields = copy.deepcopy(default_fields) - fields["segment_row_limit"] = get_segment_row_limit - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, fields) - - @pytest.mark.level(2) - def test_create_collection_with_invalid_dimension(self, connect, get_dim): - dimension = get_dim - collection_name = gen_unique_str() - fields = copy.deepcopy(default_fields) - fields["fields"][-1]["params"]["dim"] = dimension - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, fields) - - @pytest.mark.level(2) - def test_create_collection_with_invalid_collectionname(self, connect, get_invalid_string): - collection_name = get_invalid_string - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, default_fields) - - @pytest.mark.level(2) - def test_create_collection_with_empty_collectionname(self, connect): - collection_name = '' - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, default_fields) - - @pytest.mark.level(2) - def test_create_collection_with_none_collectionname(self, connect): - collection_name = None - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, default_fields) - - def test_create_collection_None(self, connect): - ''' - target: test create collection but the collection name is None - method: create collection, param collection_name is None - expected: create raise error - ''' - with pytest.raises(Exception) as e: - connect.create_collection(None, default_fields) - - def test_create_collection_no_dimension(self, connect): - ''' - target: test create collection with no dimension params - method: create collection with corrent params - expected: create status return ok - ''' - collection_name = gen_unique_str(uid) - fields = copy.deepcopy(default_fields) - fields["fields"][-1]["params"].pop("dim") - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, fields) - - @pytest.mark.skip("no segment row limit") - def test_create_collection_no_segment_row_limit(self, connect): - ''' - target: test create collection with no segment_row_limit params - method: create collection with correct params - expected: use default default_segment_row_limit - ''' - collection_name = gen_unique_str(uid) - fields = copy.deepcopy(default_fields) - fields.pop("segment_row_limit") - connect.create_collection(collection_name, fields) - res = connect.get_collection_info(collection_name) - # logging.getLogger().info(res) - assert res["segment_row_limit"] == default_server_segment_row_limit - - def test_create_collection_limit_fields(self, connect): - collection_name = gen_unique_str(uid) - limit_num = 64 - fields = copy.deepcopy(default_fields) - for i in range(limit_num): - field_name = gen_unique_str("field_name") - field = {"name": field_name, "type": DataType.INT64} - fields["fields"].append(field) - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, fields) - - @pytest.mark.level(2) - def test_create_collection_invalid_field_name(self, connect, get_invalid_string): - collection_name = gen_unique_str(uid) - fields = copy.deepcopy(default_fields) - field_name = get_invalid_string - field = {"name": field_name, "type": DataType.INT64} - fields["fields"].append(field) - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, fields) - - def test_create_collection_invalid_field_type(self, connect, get_field_type): - collection_name = gen_unique_str(uid) - fields = copy.deepcopy(default_fields) - field_type = get_field_type - field = {"name": "test_field", "type": field_type} - fields["fields"].append(field) - with pytest.raises(Exception) as e: - connect.create_collection(collection_name, fields) diff --git a/tests/python/test_describe_collection.py b/tests/python/test_describe_collection.py deleted file mode 100644 index 13d50b7a66..0000000000 --- a/tests/python/test_describe_collection.py +++ /dev/null @@ -1,32 +0,0 @@ -import copy -from .utils import * -from .constants import * - -uid = "describe_collection" - - -class TestDescribeCollection: - """ - ****************************************************************** - The following cases are used to test `describe_collection` function - ****************************************************************** - """ - def test_describe_collection(self, connect): - ''' - target: test describe collection - method: create collection then describe the same collection - expected: returned value is the same - ''' - collection_name = gen_unique_str(uid) - df = copy.deepcopy(default_fields) - df["fields"].append({"name": "int32", "type": DataType.INT32}) - - connect.create_collection(collection_name, df) - info = connect.describe_collection(collection_name) - assert info.get("collection_name") == collection_name - assert len(info.get("fields")) == 4 - - for field in info.get("fields"): - assert field.get("name") in ["int32", "int64", "float", "float_vector"] - if field.get("name") == "float_vector": - assert field.get("params").get("dim") == str(default_dim) diff --git a/tests/python/test_drop_collection.py b/tests/python/test_drop_collection.py deleted file mode 100644 index efe376e339..0000000000 --- a/tests/python/test_drop_collection.py +++ /dev/null @@ -1,98 +0,0 @@ -import pytest -from .utils import * -from .constants import * - -uniq_id = "drop_collection" - -class TestDropCollection: - """ - ****************************************************************** - The following cases are used to test `drop_collection` function - ****************************************************************** - """ - def test_drop_collection(self, connect, collection): - ''' - target: test delete collection created with correct params - method: create collection and then delete, - assert the value returned by delete method - expected: status ok, and no collection in collections - ''' - connect.drop_collection(collection) - time.sleep(2) - assert not connect.has_collection(collection) - - def test_drop_collection_without_connection(self, collection, dis_connect): - ''' - target: test describe collection, without connection - method: drop collection with correct params, with a disconnected instance - expected: drop raise exception - ''' - with pytest.raises(Exception) as e: - dis_connect.drop_collection(collection) - - def test_drop_collection_not_existed(self, connect): - ''' - target: test if collection not created - method: random a collection name, which not existed in db, - assert the exception raised returned by drp_collection method - expected: False - ''' - collection_name = gen_unique_str(uniq_id) - with pytest.raises(Exception) as e: - connect.drop_collection(collection_name) - - @pytest.mark.level(2) - def test_create_drop_collection_multithread(self, connect): - ''' - target: test create and drop collection with multithread - method: create and drop collection using multithread, - expected: collections are created, and dropped - ''' - threads_num = 8 - threads = [] - collection_names = [] - - def create(): - collection_name = gen_unique_str(uniq_id) - collection_names.append(collection_name) - connect.create_collection(collection_name, default_fields) - connect.drop_collection(collection_name) - - for i in range(threads_num): - t = threading.Thread(target=create, args=()) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - for item in collection_names: - assert not connect.has_collection(item) - - -class TestDropCollectionInvalid(object): - """ - Test has collection with invalid params - """ - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.mark.level(2) - def test_drop_collection_with_invalid_collectionname(self, connect, get_collection_name): - collection_name = get_collection_name - with pytest.raises(Exception) as e: - connect.has_collection(collection_name) - - def test_drop_collection_with_empty_collectionname(self, connect): - collection_name = '' - with pytest.raises(Exception) as e: - connect.has_collection(collection_name) - - def test_drop_collection_with_none_collectionname(self, connect): - collection_name = None - with pytest.raises(Exception) as e: - connect.has_collection(collection_name) diff --git a/tests/python/test_get_collection_info.py b/tests/python/test_get_collection_info.py deleted file mode 100644 index a36e4dceed..0000000000 --- a/tests/python/test_get_collection_info.py +++ /dev/null @@ -1,233 +0,0 @@ -import pytest -from .utils import * -from .constants import * - -uid = "collection_info" - -class TestInfoBase: - - @pytest.fixture( - scope="function", - params=gen_single_filter_fields() - ) - def get_filter_field(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_single_vector_fields() - ) - def get_vector_field(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_segment_row_limits() - ) - def get_segment_row_limit(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - logging.getLogger().info(request.param) - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param - - """ - ****************************************************************** - The following cases are used to test `get_collection_info` function, no data in collection - ****************************************************************** - """ - - @pytest.mark.skip("no segment row limit and type") - def test_info_collection_fields(self, connect, get_filter_field, get_vector_field): - ''' - target: test create normal collection with different fields, check info returned - method: create collection with diff fields: metric/field_type/..., calling `get_collection_info` - expected: no exception raised, and value returned correct - ''' - filter_field = get_filter_field - vector_field = get_vector_field - collection_name = gen_unique_str(uid) - fields = { - "fields": [filter_field, vector_field], - "segment_row_limit": default_segment_row_limit - } - connect.create_collection(collection_name, fields) - res = connect.get_collection_info(collection_name) - assert res['auto_id'] == True - assert res['segment_row_limit'] == default_segment_row_limit - assert len(res["fields"]) == 2 - for field in res["fields"]: - if field["type"] == filter_field: - assert field["name"] == filter_field["name"] - elif field["type"] == vector_field: - assert field["name"] == vector_field["name"] - assert field["params"] == vector_field["params"] - - @pytest.mark.skip("no segment row limit and type") - def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit): - ''' - target: test create normal collection with different fields - method: create collection with diff segment_row_limit - expected: no exception raised - ''' - collection_name = gen_unique_str(uid) - fields = copy.deepcopy(default_fields) - fields["segment_row_limit"] = get_segment_row_limit - connect.create_collection(collection_name, fields) - # assert segment row count - res = connect.get_collection_info(collection_name) - assert res['segment_row_limit'] == get_segment_row_limit - - @pytest.mark.skip("no create Index") - def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index): - connect.create_index(collection, default_float_vec_field_name, get_simple_index) - info = connect.describe_index(collection, field_name) - assert info == get_simple_index - res = connect.get_collection_info(collection, default_float_vec_field_name) - assert index["index_type"] == get_simple_index["index_type"] - assert index["metric_type"] == get_simple_index["metric_type"] - - @pytest.mark.level(2) - def test_get_collection_info_without_connection(self, connect, collection, dis_connect): - ''' - target: test get collection info, without connection - method: calling get collection info with correct params, with a disconnected instance - expected: get collection info raise exception - ''' - with pytest.raises(Exception) as e: - assert connect.get_collection_info(dis_connect, collection) - - def test_get_collection_info_not_existed(self, connect): - ''' - target: test if collection not created - method: random a collection name, which not existed in db, - assert the value returned by get_collection_info method - expected: False - ''' - collection_name = gen_unique_str(uid) - with pytest.raises(Exception) as e: - res = connect.get_collection_info(connect, collection_name) - - @pytest.mark.level(2) - def test_get_collection_info_multithread(self, connect): - ''' - target: test create collection with multithread - method: create collection using multithread, - expected: collections are created - ''' - threads_num = 4 - threads = [] - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - - def get_info(): - res = connect.get_collection_info(connect, collection_name) - # assert - - for i in range(threads_num): - t = threading.Thread(target=get_info, args=()) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - """ - ****************************************************************** - The following cases are used to test `get_collection_info` function, and insert data in collection - ****************************************************************** - """ - - @pytest.mark.skip("no segment row limit and type") - def test_info_collection_fields_after_insert(self, connect, get_filter_field, get_vector_field): - ''' - target: test create normal collection with different fields, check info returned - method: create collection with diff fields: metric/field_type/..., calling `get_collection_info` - expected: no exception raised, and value returned correct - ''' - filter_field = get_filter_field - vector_field = get_vector_field - collection_name = gen_unique_str(uid) - fields = { - "fields": [filter_field, vector_field], - "segment_row_limit": default_segment_row_limit - } - connect.create_collection(collection_name, fields) - entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"]) - res_ids = connect.insert(collection_name, entities) - connect.flush([collection_name]) - res = connect.get_collection_info(collection_name) - assert res['auto_id'] == True - assert res['segment_row_limit'] == default_segment_row_limit - assert len(res["fields"]) == 2 - for field in res["fields"]: - if field["type"] == filter_field: - assert field["name"] == filter_field["name"] - elif field["type"] == vector_field: - assert field["name"] == vector_field["name"] - assert field["params"] == vector_field["params"] - - @pytest.mark.skip("not segment row limit") - def test_create_collection_segment_row_limit_after_insert(self, connect, get_segment_row_limit): - ''' - target: test create normal collection with different fields - method: create collection with diff segment_row_limit - expected: no exception raised - ''' - collection_name = gen_unique_str(uid) - fields = copy.deepcopy(default_fields) - fields["segment_row_limit"] = get_segment_row_limit - connect.create_collection(collection_name, fields) - entities = gen_entities_by_fields(fields["fields"], default_nb, fields["fields"][-1]["params"]["dim"]) - res_ids = connect.insert(collection_name, entities) - connect.flush([collection_name]) - res = connect.get_collection_info(collection_name) - assert res['auto_id'] == True - assert res['segment_row_limit'] == get_segment_row_limit - - -class TestInfoInvalid(object): - """ - Test get collection info with invalid params - """ - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - - @pytest.mark.level(2) - def test_get_collection_info_with_invalid_collectionname(self, connect, get_collection_name): - collection_name = get_collection_name - with pytest.raises(Exception) as e: - connect.get_collection_info(collection_name) - - @pytest.mark.level(2) - def test_get_collection_info_with_empty_collectionname(self, connect): - collection_name = '' - with pytest.raises(Exception) as e: - connect.get_collection_info(collection_name) - - @pytest.mark.level(2) - def test_get_collection_info_with_none_collectionname(self, connect): - collection_name = None - with pytest.raises(Exception) as e: - connect.get_collection_info(collection_name) - - def test_get_collection_info_None(self, connect): - ''' - target: test create collection but the collection name is None - method: create collection, param collection_name is None - expected: create raise error - ''' - with pytest.raises(Exception) as e: - connect.get_collection_info(None) diff --git a/tests/python/test_has_collection.py b/tests/python/test_has_collection.py deleted file mode 100644 index 0b5c740637..0000000000 --- a/tests/python/test_has_collection.py +++ /dev/null @@ -1,93 +0,0 @@ -import pytest -from .utils import * -from .constants import * - -uid = "has_collection" - -class TestHasCollection: - """ - ****************************************************************** - The following cases are used to test `has_collection` function - ****************************************************************** - """ - def test_has_collection(self, connect, collection): - ''' - target: test if the created collection existed - method: create collection, assert the value returned by has_collection method - expected: True - ''' - assert connect.has_collection(collection) - - @pytest.mark.level(2) - def test_has_collection_without_connection(self, collection, dis_connect): - ''' - target: test has collection, without connection - method: calling has collection with correct params, with a disconnected instance - expected: has collection raise exception - ''' - with pytest.raises(Exception) as e: - assert dis_connect.has_collection(collection) - - def test_has_collection_not_existed(self, connect): - ''' - target: test if collection not created - method: random a collection name, which not existed in db, - assert the value returned by has_collection method - expected: False - ''' - collection_name = gen_unique_str("test_collection") - assert not connect.has_collection(collection_name) - - - @pytest.mark.level(2) - def test_has_collection_multithread(self, connect): - ''' - target: test create collection with multithread - method: create collection using multithread, - expected: collections are created - ''' - threads_num = 4 - threads = [] - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - - def has(): - assert connect.has_collection(collection_name) - # assert not assert_collection(connect, collection_name) - for i in range(threads_num): - t = MilvusTestThread(target=has, args=()) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - -class TestHasCollectionInvalid(object): - """ - Test has collection with invalid params - """ - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.mark.level(2) - def test_has_collection_with_invalid_collectionname(self, connect, get_collection_name): - collection_name = get_collection_name - with pytest.raises(Exception) as e: - connect.has_collection(collection_name) - - @pytest.mark.level(2) - def test_has_collection_with_empty_collectionname(self, connect): - collection_name = '' - with pytest.raises(Exception) as e: - connect.has_collection(collection_name) - - @pytest.mark.level(2) - def test_has_collection_with_none_collectionname(self, connect): - collection_name = None - with pytest.raises(Exception) as e: - connect.has_collection(collection_name) diff --git a/tests/python/test_index.py b/tests/python/test_index.py deleted file mode 100644 index e24344bece..0000000000 --- a/tests/python/test_index.py +++ /dev/null @@ -1,841 +0,0 @@ -import logging -import time -import pdb -import threading -from multiprocessing import Pool, Process -import numpy -import pytest -import sklearn.preprocessing -from .utils import * -from .constants import * - -uid = "test_index" -BUILD_TIMEOUT = 300 -field_name = default_float_vec_field_name -binary_field_name = default_binary_vec_field_name -query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_k, 1) -default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} - - -# @pytest.mark.skip("wait for debugging...") -class TestIndexBase: - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - import copy - logging.getLogger().info(request.param) - #if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return copy.deepcopy(request.param) - - @pytest.fixture( - scope="function", - params=[ - 1, - 10, - 1111 - ], - ) - def get_nq(self, request): - yield request.param - - """ - ****************************************************************** - The following cases are used to test `create_index` function - ****************************************************************** - """ - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - ids = connect.insert(collection, default_entities) - connect.create_index(collection, field_name, get_simple_index) - - def test_create_index_on_field_not_existed(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index on field not existed - expected: error raised - ''' - tmp_field_name = gen_unique_str() - ids = connect.insert(collection, default_entities) - with pytest.raises(Exception) as e: - connect.create_index(collection, tmp_field_name, get_simple_index) - - @pytest.mark.level(2) - def test_create_index_on_field(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index on other field - expected: error raised - ''' - tmp_field_name = "int64" - ids = connect.insert(collection, default_entities) - with pytest.raises(Exception) as e: - connect.create_index(collection, tmp_field_name, get_simple_index) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_no_vectors(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_partition(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection, create partition, and add entities in it, create index - expected: return search success - ''' - connect.create_partition(collection, default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - connect.flush([collection]) - connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_partition_flush(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection, create partition, and add entities in it, create index - expected: return search success - ''' - connect.create_partition(collection, default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - connect.flush() - connect.create_index(collection, field_name, get_simple_index) - - def test_create_index_without_connect(self, dis_connect, collection): - ''' - target: test create index without connection - method: create collection and add entities in it, check if added successfully - expected: raise exception - ''' - with pytest.raises(Exception) as e: - dis_connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_search_with_query_vectors(self, connect, collection, get_simple_index, get_nq): - ''' - target: test create index interface, search with more query vectors - method: create collection and add entities in it, create index - expected: return search success - ''' - ids = connect.insert(collection, default_entities) - connect.create_index(collection, field_name, get_simple_index) - # logging.getLogger().info(connect.get_collection_stats(collection)) - nq = get_nq - index_type = get_simple_index["index_type"] - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, search_params=search_param) - res = connect.search(collection, query) - assert len(res) == nq - - @pytest.mark.skip("can't_pass_ci") - @pytest.mark.timeout(BUILD_TIMEOUT) - @pytest.mark.level(2) - def test_create_index_multithread(self, connect, collection, args): - ''' - target: test create index interface with multiprocess - method: create collection and add entities in it, create index - expected: return search success - ''' - connect.insert(collection, default_entities) - - def build(connect): - connect.create_index(collection, field_name, default_index) - - threads_num = 8 - threads = [] - for i in range(threads_num): - m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"]) - t = MilvusTestThread(target=build, args=(m,)) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - def test_create_index_collection_not_existed(self, connect): - ''' - target: test create index interface when collection name not existed - method: create collection and add entities in it, create index - , make sure the collection name not in index - expected: create index failed - ''' - collection_name = gen_unique_str(uid) - with pytest.raises(Exception) as e: - connect.create_index(collection_name, field_name, default_index) - - @pytest.mark.skip("count_entries") - @pytest.mark.level(2) - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_insert_flush(self, connect, collection, get_simple_index): - ''' - target: test create index - method: create collection and create index, add entities in it - expected: create index ok, and count correct - ''' - connect.create_index(collection, field_name, get_simple_index) - ids = connect.insert(collection, default_entities) - connect.flush([collection]) - count = connect.count_entities(collection) - assert count == default_nb - - @pytest.mark.level(2) - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_same_index_repeatedly(self, connect, collection, get_simple_index): - ''' - target: check if index can be created repeatedly, with the same create_index params - method: create index after index have been built - expected: return code success, and search ok - ''' - connect.create_index(collection, field_name, get_simple_index) - connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.level(2) - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_different_index_repeatedly(self, connect, collection): - ''' - target: check if index can be created repeatedly, with the different create_index params - method: create another index with different index_params after index have been built - expected: return code 0, and describe index result equals with the second index params - ''' - ids = connect.insert(collection, default_entities) - indexs = [default_index, {"metric_type":"L2", "index_type": "FLAT", "params":{"nlist": 1024}}] - for index in indexs: - connect.create_index(collection, field_name, index) - stats = connect.get_collection_stats(collection) - # assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"] - assert stats["row_count"] == str(default_nb) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_ip(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - ids = connect.insert(collection, default_entities) - get_simple_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_no_vectors_ip(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - get_simple_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_partition_ip(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection, create partition, and add entities in it, create index - expected: return search success - ''' - connect.create_partition(collection, default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - connect.flush([collection]) - get_simple_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_partition_flush_ip(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection, create partition, and add entities in it, create index - expected: return search success - ''' - connect.create_partition(collection, default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - connect.flush() - get_simple_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_search_with_query_vectors_ip(self, connect, collection, get_simple_index, get_nq): - ''' - target: test create index interface, search with more query vectors - method: create collection and add entities in it, create index - expected: return search success - ''' - metric_type = "IP" - ids = connect.insert(collection, default_entities) - get_simple_index["metric_type"] = metric_type - connect.create_index(collection, field_name, get_simple_index) - # logging.getLogger().info(connect.get_collection_stats(collection)) - nq = get_nq - index_type = get_simple_index["index_type"] - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, metric_type=metric_type, search_params=search_param) - res = connect.search(collection, query) - assert len(res) == nq - - @pytest.mark.skip("test_create_index_multithread_ip") - @pytest.mark.timeout(BUILD_TIMEOUT) - @pytest.mark.level(2) - def test_create_index_multithread_ip(self, connect, collection, args): - ''' - target: test create index interface with multiprocess - method: create collection and add entities in it, create index - expected: return search success - ''' - connect.insert(collection, default_entities) - - def build(connect): - default_index["metric_type"] = "IP" - connect.create_index(collection, field_name, default_index) - - threads_num = 8 - threads = [] - for i in range(threads_num): - m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"]) - t = MilvusTestThread(target=build, args=(m,)) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - def test_create_index_collection_not_existed_ip(self, connect, collection): - ''' - target: test create index interface when collection name not existed - method: create collection and add entities in it, create index - , make sure the collection name not in index - expected: return code not equals to 0, create index failed - ''' - collection_name = gen_unique_str(uid) - default_index["metric_type"] = "IP" - with pytest.raises(Exception) as e: - connect.create_index(collection_name, field_name, default_index) - - @pytest.mark.skip("count_entries") - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_no_vectors_insert_ip(self, connect, collection, get_simple_index): - ''' - target: test create index interface when there is no vectors in collection, and does not affect the subsequent process - method: create collection and add no vectors in it, and then create index, add entities in it - expected: return code equals to 0 - ''' - default_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - ids = connect.insert(collection, default_entities) - connect.flush([collection]) - count = connect.count_entities(collection) - assert count == default_nb - - @pytest.mark.level(2) - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_same_index_repeatedly_ip(self, connect, collection, get_simple_index): - ''' - target: check if index can be created repeatedly, with the same create_index params - method: create index after index have been built - expected: return code success, and search ok - ''' - default_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - connect.create_index(collection, field_name, get_simple_index) - - # TODO: - - @pytest.mark.level(2) - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_different_index_repeatedly_ip(self, connect, collection): - ''' - target: check if index can be created repeatedly, with the different create_index params - method: create another index with different index_params after index have been built - expected: return code 0, and describe index result equals with the second index params - ''' - ids = connect.insert(collection, default_entities) - indexs = [default_index, {"index_type": "FLAT", "params": {"nlist": 1024}, "metric_type": "IP"}] - for index in indexs: - connect.create_index(collection, field_name, index) - stats = connect.get_collection_stats(collection) - # assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"] - assert stats["row_count"] == str(default_nb) - - """ - ****************************************************************** - The following cases are used to test `drop_index` function - ****************************************************************** - """ - - @pytest.mark.skip("get_collection_stats") - def test_drop_index(self, connect, collection, get_simple_index): - ''' - target: test drop index interface - method: create collection and add entities in it, create index, call drop index - expected: return code 0, and default index param - ''' - # ids = connect.insert(collection, entities) - connect.create_index(collection, field_name, get_simple_index) - connect.drop_index(collection, field_name) - stats = connect.get_collection_stats(collection) - # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type - assert not stats["partitions"][0]["segments"] - - @pytest.mark.skip("get_collection_stats") - @pytest.mark.skip("drop_index raise exception") - @pytest.mark.level(2) - def test_drop_index_repeatly(self, connect, collection, get_simple_index): - ''' - target: test drop index repeatly - method: create index, call drop index, and drop again - expected: return code 0 - ''' - connect.create_index(collection, field_name, get_simple_index) - stats = connect.get_collection_stats(collection) - connect.drop_index(collection, field_name) - connect.drop_index(collection, field_name) - stats = connect.get_collection_stats(collection) - logging.getLogger().info(stats) - # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type - assert not stats["partitions"][0]["segments"] - - @pytest.mark.level(2) - def test_drop_index_without_connect(self, dis_connect, collection): - ''' - target: test drop index without connection - method: drop index, and check if drop successfully - expected: raise exception - ''' - with pytest.raises(Exception) as e: - dis_connect.drop_index(collection, field_name) - - def test_drop_index_collection_not_existed(self, connect): - ''' - target: test drop index interface when collection name not existed - method: create collection and add entities in it, create index - , make sure the collection name not in index, and then drop it - expected: return code not equals to 0, drop index failed - ''' - collection_name = gen_unique_str(uid) - with pytest.raises(Exception) as e: - connect.drop_index(collection_name, field_name) - - def test_drop_index_collection_not_create(self, connect, collection): - ''' - target: test drop index interface when index not created - method: create collection and add entities in it, create index - expected: return code not equals to 0, drop index failed - ''' - # ids = connect.insert(collection, entities) - # no create index - connect.drop_index(collection, field_name) - - @pytest.mark.skip("drop_index") - @pytest.mark.level(2) - def test_create_drop_index_repeatly(self, connect, collection, get_simple_index): - ''' - target: test create / drop index repeatly, use the same index params - method: create index, drop index, four times - expected: return code 0 - ''' - for i in range(4): - connect.create_index(collection, field_name, get_simple_index) - connect.drop_index(collection, field_name) - - @pytest.mark.skip("get_collection_stats") - def test_drop_index_ip(self, connect, collection, get_simple_index): - ''' - target: test drop index interface - method: create collection and add entities in it, create index, call drop index - expected: return code 0, and default index param - ''' - # ids = connect.insert(collection, entities) - get_simple_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - connect.drop_index(collection, field_name) - stats = connect.get_collection_stats(collection) - # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type - assert not stats["partitions"][0]["segments"] - - @pytest.mark.skip("get_collection_stats") - @pytest.mark.level(2) - def test_drop_index_repeatly_ip(self, connect, collection, get_simple_index): - ''' - target: test drop index repeatly - method: create index, call drop index, and drop again - expected: return code 0 - ''' - get_simple_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - stats = connect.get_collection_stats(collection) - connect.drop_index(collection, field_name) - connect.drop_index(collection, field_name) - stats = connect.get_collection_stats(collection) - logging.getLogger().info(stats) - # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type - assert not stats["partitions"][0]["segments"] - - @pytest.mark.level(2) - def test_drop_index_without_connect_ip(self, dis_connect, collection): - ''' - target: test drop index without connection - method: drop index, and check if drop successfully - expected: raise exception - ''' - with pytest.raises(Exception) as e: - dis_connect.drop_index(collection, field_name) - - def test_drop_index_collection_not_create_ip(self, connect, collection): - ''' - target: test drop index interface when index not created - method: create collection and add entities in it, create index - expected: return code not equals to 0, drop index failed - ''' - # ids = connect.insert(collection, entities) - # no create index - connect.drop_index(collection, field_name) - - @pytest.mark.skip("drop_index") - @pytest.mark.skip("can't create and drop") - @pytest.mark.level(2) - def test_create_drop_index_repeatly_ip(self, connect, collection, get_simple_index): - ''' - target: test create / drop index repeatly, use the same index params - method: create index, drop index, four times - expected: return code 0 - ''' - get_simple_index["metric_type"] = "IP" - for i in range(4): - connect.create_index(collection, field_name, get_simple_index) - connect.drop_index(collection, field_name) - - -class TestIndexBinary: - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - # TODO: Determine the service mode - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param - - @pytest.fixture( - scope="function", - params=gen_binary_index() - ) - def get_jaccard_index(self, request, connect): - if request.param["index_type"] in binary_support(): - request.param["metric_type"] = "JACCARD" - return request.param - else: - pytest.skip("Skip index") - - @pytest.fixture( - scope="function", - params=gen_binary_index() - ) - def get_l2_index(self, request, connect): - request.param["metric_type"] = "L2" - return request.param - - @pytest.fixture( - scope="function", - params=[ - 1, - 10, - 1111 - ], - ) - def get_nq(self, request): - yield request.param - - """ - ****************************************************************** - The following cases are used to test `create_index` function - ****************************************************************** - """ - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index(self, connect, binary_collection, get_jaccard_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - ids = connect.insert(binary_collection, default_binary_entities) - connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_partition(self, connect, binary_collection, get_jaccard_index): - ''' - target: test create index interface - method: create collection, create partition, and add entities in it, create index - expected: return search success - ''' - connect.create_partition(binary_collection, default_tag) - ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag) - connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq): - ''' - target: test create index interface, search with more query vectors - method: create collection and add entities in it, create index - expected: return search success - ''' - nq = get_nq - ids = connect.insert(binary_collection, default_binary_entities) - connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD") - search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD") - logging.getLogger().info(search_param) - res = connect.search(binary_collection, query, search_params=search_param) - assert len(res) == nq - - @pytest.mark.skip("get status for build index failed") - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_invalid_metric_type_binary(self, connect, binary_collection, get_l2_index): - ''' - target: test create index interface with invalid metric type - method: add entitys into binary connection, flash, create index with L2 metric type. - expected: return create_index failure - ''' - # insert 6000 vectors - ids = connect.insert(binary_collection, default_binary_entities) - connect.flush([binary_collection]) - - if get_l2_index["index_type"] == "BIN_FLAT": - res = connect.create_index(binary_collection, binary_field_name, get_l2_index) - else: - with pytest.raises(Exception) as e: - res = connect.create_index(binary_collection, binary_field_name, get_l2_index) - - """ - ****************************************************************** - The following cases are used to test `get_index_info` function - ****************************************************************** - """ - - @pytest.mark.skip("get_collection_stats does not impl") - def test_get_index_info(self, connect, binary_collection, get_jaccard_index): - ''' - target: test describe index interface - method: create collection and add entities in it, create index, call describe index - expected: return code 0, and index instructure - ''' - ids = connect.insert(binary_collection, default_binary_entities) - connect.flush([binary_collection]) - connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - stats = connect.get_collection_stats(binary_collection) - assert stats["row_count"] == default_nb - for partition in stats["partitions"]: - segments = partition["segments"] - if segments: - for segment in segments: - for file in segment["files"]: - if "index_type" in file: - assert file["index_type"] == get_jaccard_index["index_type"] - - @pytest.mark.skip("get_collection_stats does not impl") - def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index): - ''' - target: test describe index interface - method: create collection, create partition and add entities in it, create index, call describe index - expected: return code 0, and index instructure - ''' - connect.create_partition(binary_collection, default_tag) - ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag) - connect.flush([binary_collection]) - connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - stats = connect.get_collection_stats(binary_collection) - logging.getLogger().info(stats) - assert stats["row_count"] == default_nb - assert len(stats["partitions"]) == 2 - for partition in stats["partitions"]: - segments = partition["segments"] - if segments: - for segment in segments: - for file in segment["files"]: - if "index_type" in file: - assert file["index_type"] == get_jaccard_index["index_type"] - - """ - ****************************************************************** - The following cases are used to test `drop_index` function - ****************************************************************** - """ - - @pytest.mark.skip("get_collection_stats") - def test_drop_index(self, connect, binary_collection, get_jaccard_index): - ''' - target: test drop index interface - method: create collection and add entities in it, create index, call drop index - expected: return code 0, and default index param - ''' - connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - stats = connect.get_collection_stats(binary_collection) - logging.getLogger().info(stats) - connect.drop_index(binary_collection, binary_field_name) - stats = connect.get_collection_stats(binary_collection) - # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type - assert not stats["partitions"][0]["segments"] - - @pytest.mark.skip("get_collection_stats does not impl") - def test_drop_index_partition(self, connect, binary_collection, get_jaccard_index): - ''' - target: test drop index interface - method: create collection, create partition and add entities in it, create index on collection, call drop collection index - expected: return code 0, and default index param - ''' - connect.create_partition(binary_collection, default_tag) - ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag) - connect.flush([binary_collection]) - connect.create_index(binary_collection, binary_field_name, get_jaccard_index) - stats = connect.get_collection_stats(binary_collection) - connect.drop_index(binary_collection, binary_field_name) - stats = connect.get_collection_stats(binary_collection) - assert stats["row_count"] == default_nb - for partition in stats["partitions"]: - segments = partition["segments"] - if segments: - for segment in segments: - for file in segment["files"]: - if "index_type" not in file: - continue - if file["index_type"] == get_jaccard_index["index_type"]: - assert False - - -class TestIndexInvalid(object): - """ - Test create / describe / drop index interfaces with invalid collection names - """ - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.mark.level(1) - def test_create_index_with_invalid_collectionname(self, connect, get_collection_name): - collection_name = get_collection_name - with pytest.raises(Exception) as e: - connect.create_index(collection_name, field_name, default_index) - - @pytest.mark.level(1) - def test_drop_index_with_invalid_collectionname(self, connect, get_collection_name): - collection_name = get_collection_name - with pytest.raises(Exception) as e: - connect.drop_index(collection_name) - - @pytest.fixture( - scope="function", - params=gen_invalid_index() - ) - def get_index(self, request): - yield request.param - - @pytest.mark.level(2) - def test_create_index_with_invalid_index_params(self, connect, collection, get_index): - logging.getLogger().info(get_index) - with pytest.raises(Exception) as e: - connect.create_index(collection, field_name, get_simple_index) - - -class TestIndexAsync: - @pytest.fixture(scope="function", autouse=True) - def skip_http_check(self, args): - if args["handler"] == "HTTP": - pytest.skip("skip in http mode") - - """ - ****************************************************************** - The following cases are used to test `create_index` function - ****************************************************************** - """ - - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - # TODO: Determine the service mode - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param - - def check_result(self, res): - logging.getLogger().info("In callback check search result") - logging.getLogger().info(res) - - """ - ****************************************************************** - The following cases are used to test `create_index` function - ****************************************************************** - """ - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - ids = connect.insert(collection, default_entities) - logging.getLogger().info("start index") - future = connect.create_index(collection, field_name, get_simple_index, _async=True) - logging.getLogger().info("before result") - res = future.result() - # TODO: - logging.getLogger().info(res) - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_drop(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - ids = connect.insert(collection, default_entities) - logging.getLogger().info("start index") - future = connect.create_index(collection, field_name, get_simple_index, _async=True) - logging.getLogger().info("DROP") - connect.drop_collection(collection) - - @pytest.mark.level(2) - def test_create_index_with_invalid_collectionname(self, connect): - collection_name = " " - with pytest.raises(Exception) as e: - future = connect.create_index(collection_name, field_name, default_index, _async=True) - res = future.result() - - @pytest.mark.timeout(BUILD_TIMEOUT) - def test_create_index_callback(self, connect, collection, get_simple_index): - ''' - target: test create index interface - method: create collection and add entities in it, create index - expected: return search success - ''' - ids = connect.insert(collection, default_entities) - logging.getLogger().info("start index") - future = connect.create_index(collection, field_name, get_simple_index, _async=True, - _callback=self.check_result) - logging.getLogger().info("before result") - res = future.result() - # TODO: - logging.getLogger().info(res) diff --git a/tests/python/test_insert.py b/tests/python/test_insert.py deleted file mode 100644 index 9b3276f57b..0000000000 --- a/tests/python/test_insert.py +++ /dev/null @@ -1,1148 +0,0 @@ -import pytest -from .utils import * -from .constants import * - -ADD_TIMEOUT = 600 -uid = "test_insert" -field_name = default_float_vec_field_name -binary_field_name = default_binary_vec_field_name -default_single_query = { - "bool": { - "must": [ - {"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type": "L2", - "params": {"nprobe": 10}}}} - ] - } -} - - -class TestInsertBase: - """ - ****************************************************************** - The following cases are used to test `insert` function - ****************************************************************** - """ - - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("CPU not support index_type: ivf_sq8h") - return request.param - - @pytest.fixture( - scope="function", - params=gen_single_filter_fields() - ) - def get_filter_field(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_single_vector_fields() - ) - def get_vector_field(self, request): - yield request.param - - def test_add_vector_with_empty_vector(self, connect, collection): - ''' - target: test add vectors with empty vectors list - method: set empty vectors list as add method params - expected: raises a Exception - ''' - vector = [] - with pytest.raises(Exception) as e: - status, ids = connect.insert(collection, vector) - - def test_add_vector_with_None(self, connect, collection): - ''' - target: test add vectors with None - method: set None as add method params - expected: raises a Exception - ''' - vector = None - with pytest.raises(Exception) as e: - status, ids = connect.insert(collection, vector) - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_collection_not_existed(self, connect): - ''' - target: test insert, with collection not existed - method: insert entity into a random named collection - expected: error raised - ''' - collection_name = gen_unique_str(uid) - with pytest.raises(Exception) as e: - connect.insert(collection_name, default_entities) - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_drop_collection(self, connect, collection): - ''' - target: test delete collection after insert vector - method: insert vector and delete collection - expected: no error raised - ''' - ids = connect.insert(collection, default_entity) - assert len(ids) == 1 - connect.drop_collection(collection) - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_sleep_drop_collection(self, connect, collection): - ''' - target: test delete collection after insert vector for a while - method: insert vector, sleep, and delete collection - expected: no error raised - ''' - ids = connect.insert(collection, default_entity) - assert len(ids) == 1 - connect.flush([collection]) - connect.drop_collection(collection) - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_create_index(self, connect, collection, get_simple_index): - ''' - target: test build index insert after vector - method: insert vector and build index - expected: no error raised - ''' - ids = connect.insert(collection, default_entities) - assert len(ids) == default_nb - connect.flush([collection]) - connect.create_index(collection, field_name, get_simple_index) - info = connect.describe_index(collection, field_name) - assert info == get_simple_index - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_create_index_new(self, connect, collection, get_simple_index): - ''' - target: test build index insert after vector - method: insert vector and build index - expected: no error raised - ''' - ids = connect.insert(collection, default_entities_new) - assert len(ids) == default_nb - connect.flush([collection]) - connect.create_index(collection, field_name, get_simple_index) - info = connect.describe_index(collection, field_name) - assert info == get_simple_index - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_after_create_index(self, connect, collection, get_simple_index): - ''' - target: test build index insert after vector - method: insert vector and build index - expected: no error raised - ''' - connect.create_index(collection, field_name, get_simple_index) - ids = connect.insert(collection, default_entities) - assert len(ids) == default_nb - info = connect.describe_index(collection, field_name) - assert info == get_simple_index - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_search(self, connect, collection): - ''' - target: test search vector after insert vector after a while - method: insert vector, sleep, and search collection - expected: no error raised - ''' - ids = connect.insert(collection, default_entities) - connect.flush([collection]) - res = connect.search(collection, default_single_query) - logging.getLogger().debug(res) - assert res - - @pytest.mark.skip("segment row count") - def test_insert_segment_row_count(self, connect, collection): - nb = default_segment_row_limit + 1 - res_ids = connect.insert(collection, gen_entities(nb)) - connect.flush([collection]) - assert len(res_ids) == nb - stats = connect.get_collection_stats(collection) - assert len(stats['partitions'][0]['segments']) == 2 - for segment in stats['partitions'][0]['segments']: - assert segment['row_count'] in [default_segment_row_limit, 1] - - @pytest.fixture( - scope="function", - params=[ - 1, - 2000 - ], - ) - def insert_count(self, request): - yield request.param - - @pytest.mark.skip(" todo support count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_ids(self, connect, id_collection, insert_count): - ''' - target: test insert vectors in collection, use customize ids - method: create collection and insert vectors in it, check the ids returned and the collection length after vectors inserted - expected: the length of ids and the collection row count - ''' - nb = insert_count - ids = [i for i in range(nb)] - res_ids = connect.insert(id_collection, gen_entities(nb), ids) - connect.flush([id_collection]) - assert len(res_ids) == nb - assert res_ids == ids - res_count = connect.count_entities(id_collection) - assert res_count == nb - - @pytest.mark.skip(" todo support count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_the_same_ids(self, connect, id_collection, insert_count): - ''' - target: test insert vectors in collection, use customize the same ids - method: create collection and insert vectors in it, check the ids returned and the collection length after vectors inserted - expected: the length of ids and the collection row count - ''' - nb = insert_count - ids = [1 for i in range(nb)] - res_ids = connect.insert(id_collection, gen_entities(nb), ids) - connect.flush([id_collection]) - assert len(res_ids) == nb - assert res_ids == ids - res_count = connect.count_entities(id_collection) - assert res_count == nb - - @pytest.mark.skip(" todo support count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field): - ''' - target: test create normal collection with different fields, insert entities into id with ids - method: create collection with diff fields: metric/field_type/..., insert, and count - expected: row count correct - ''' - nb = 5 - filter_field = get_filter_field - vector_field = get_vector_field - collection_name = gen_unique_str("test_collection") - fields = { - "fields": [filter_field, vector_field], - "segment_row_limit": default_segment_row_limit, - "auto_id": True - } - connect.create_collection(collection_name, fields) - ids = [i for i in range(nb)] - entities = gen_entities_by_fields(fields["fields"], nb, default_dim) - res_ids = connect.insert(collection_name, entities, ids) - assert res_ids == ids - connect.flush([collection_name]) - res_count = connect.count_entities(collection_name) - assert res_count == nb - - # TODO: assert exception && enable - @pytest.mark.level(2) - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_twice_ids_no_ids(self, connect, id_collection): - ''' - target: check the result of insert, with params ids and no ids - method: test insert vectors twice, use customize ids first, and then use no ids - expected: error raised - ''' - ids = [i for i in range(default_nb)] - res_ids = connect.insert(id_collection, default_entities, ids) - with pytest.raises(Exception) as e: - res_ids_new = connect.insert(id_collection, default_entities) - - # TODO: assert exception && enable - @pytest.mark.level(2) - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_twice_not_ids_ids(self, connect, id_collection): - ''' - target: check the result of insert, with params ids and no ids - method: test insert vectors twice, use not ids first, and then use customize ids - expected: error raised - ''' - with pytest.raises(Exception) as e: - res_ids = connect.insert(id_collection, default_entities) - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_ids_length_not_match_batch(self, connect, id_collection): - ''' - target: test insert vectors in collection, use customize ids, len(ids) != len(vectors) - method: create collection and insert vectors in it - expected: raise an exception - ''' - ids = [i for i in range(1, default_nb)] - logging.getLogger().info(len(ids)) - with pytest.raises(Exception) as e: - res_ids = connect.insert(id_collection, default_entities, ids) - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_ids_length_not_match_single(self, connect, collection): - ''' - target: test insert vectors in collection, use customize ids, len(ids) != len(vectors) - method: create collection and insert vectors in it - expected: raise an exception - ''' - ids = [i for i in range(1, default_nb)] - logging.getLogger().info(len(ids)) - with pytest.raises(Exception) as e: - res_ids = connect.insert(collection, default_entity, ids) - - @pytest.mark.skip(" todo support count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field): - ''' - target: test create normal collection with different fields, insert entities into id without ids - method: create collection with diff fields: metric/field_type/..., insert, and count - expected: row count correct - ''' - nb = 5 - filter_field = get_filter_field - vector_field = get_vector_field - collection_name = gen_unique_str("test_collection") - fields = { - "fields": [filter_field, vector_field], - "segment_row_limit": default_segment_row_limit - } - connect.create_collection(collection_name, fields) - entities = gen_entities_by_fields(fields["fields"], nb, default_dim) - res_ids = connect.insert(collection_name, entities) - connect.flush([collection_name]) - res_count = connect.count_entities(collection_name) - assert res_count == nb - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_tag(self, connect, collection): - ''' - target: test insert entities in collection created before - method: create collection and insert entities in it, with the partition_tag param - expected: the collection row count equals to nq - ''' - connect.create_partition(collection, default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - assert len(ids) == default_nb - assert connect.has_partition(collection, default_tag) - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_tag_with_ids(self, connect, id_collection): - ''' - target: test insert entities in collection created before, insert with ids - method: create collection and insert entities in it, with the partition_tag param - expected: the collection row count equals to nq - ''' - connect.create_partition(id_collection, default_tag) - ids = [i for i in range(default_nb)] - res_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag) - assert res_ids == ids - - - @pytest.mark.skip(" todo support count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_default_tag(self, connect, collection): - ''' - target: test insert entities into default partition - method: create partition and insert info collection without tag params - expected: the collection row count equals to nb - ''' - connect.create_partition(collection, default_tag) - ids = connect.insert(collection, default_entities) - connect.flush([collection]) - assert len(ids) == default_nb - res_count = connect.count_entities(collection) - assert res_count == default_nb - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_tag_not_existed(self, connect, collection): - ''' - target: test insert entities in collection created before - method: create collection and insert entities in it, with the not existed partition_tag param - expected: error raised - ''' - tag = gen_unique_str() - with pytest.raises(Exception) as e: - ids = connect.insert(collection, default_entities, partition_tag=tag) - - @pytest.mark.skip(" not support count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_tag_existed(self, connect, collection): - ''' - target: test insert entities in collection created before - method: create collection and insert entities in it repeatly, with the partition_tag param - expected: the collection row count equals to nq - ''' - connect.create_partition(collection, default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - connect.flush([collection]) - res_count = connect.count_entities(collection) - assert res_count == 2 * default_nb - - @pytest.mark.level(2) - def test_insert_without_connect(self, dis_connect, collection): - ''' - target: test insert entities without connection - method: create collection and insert entities in it, check if inserted successfully - expected: raise exception - ''' - with pytest.raises(Exception) as e: - ids = dis_connect.insert(collection, default_entities) - - def test_insert_collection_not_existed(self, connect): - ''' - target: test insert entities in collection, which not existed before - method: insert entities collection not existed, check the status - expected: error raised - ''' - with pytest.raises(Exception) as e: - ids = connect.insert(gen_unique_str("not_exist_collection"), default_entities) - - @pytest.mark.skip("to do add dim check ") - def test_insert_dim_not_matched(self, connect, collection): - ''' - target: test insert entities, the vector dimension is not equal to the collection dimension - method: the entities dimension is half of the collection dimension, check the status - expected: error raised - ''' - vectors = gen_vectors(default_nb, int(default_dim) // 2) - insert_entities = copy.deepcopy(default_entities) - insert_entities[-1]["values"] = vectors - with pytest.raises(Exception) as e: - ids = connect.insert(collection, insert_entities) - - - def test_insert_with_field_name_not_match(self, connect, collection): - ''' - target: test insert entities, with the entity field name updated - method: update entity field name - expected: error raised - ''' - tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", "int64new") - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - # @pytest.mark.skip(" todo support type check") - def test_insert_with_field_type_not_match(self, connect, collection): - ''' - target: test insert entities, with the entity field type updated - method: update entity field type - expected: error raised - ''' - tmp_entity = update_field_type(copy.deepcopy(default_entity), "int64", DataType.FLOAT) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - @pytest.mark.skip("to do add field_type check ") - @pytest.mark.level(2) - def test_insert_with_field_type_not_match_B(self, connect, collection): - ''' - target: test insert entities, with the entity field type updated - method: update entity field type - expected: error raised - ''' - tmp_entity = update_field_type(copy.deepcopy(default_entity), "int64", DataType.DOUBLE) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - @pytest.mark.level(2) - def test_insert_with_field_value_not_match(self, connect, collection): - ''' - target: test insert entities, with the entity field value updated - method: update entity field value - expected: error raised - ''' - tmp_entity = update_field_value(copy.deepcopy(default_entity), DataType.FLOAT, 's') - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_field_more(self, connect, collection): - ''' - target: test insert entities, with more fields than collection schema - method: add entity field - expected: error raised - ''' - tmp_entity = add_field(copy.deepcopy(default_entity)) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_field_vector_more(self, connect, collection): - ''' - target: test insert entities, with more fields than collection schema - method: add entity vector field - expected: error raised - ''' - tmp_entity = add_vector_field(default_nb, default_dim) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_field_less(self, connect, collection): - ''' - target: test insert entities, with less fields than collection schema - method: remove entity field - expected: error raised - ''' - tmp_entity = remove_field(copy.deepcopy(default_entity)) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_field_vector_less(self, connect, collection): - ''' - target: test insert entities, with less fields than collection schema - method: remove entity vector field - expected: error raised - ''' - tmp_entity = remove_vector_field(copy.deepcopy(default_entity)) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_no_field_vector_value(self, connect, collection): - ''' - target: test insert entities, with no vector field value - method: remove entity vector field - expected: error raised - ''' - tmp_entity = copy.deepcopy(default_entity) - del tmp_entity[-1]["values"] - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_no_field_vector_type(self, connect, collection): - ''' - target: test insert entities, with no vector field type - method: remove entity vector field - expected: error raised - ''' - tmp_entity = copy.deepcopy(default_entity) - del tmp_entity[-1]["type"] - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_no_field_vector_name(self, connect, collection): - ''' - target: test insert entities, with no vector field name - method: remove entity vector field - expected: error raised - ''' - tmp_entity = copy.deepcopy(default_entity) - del tmp_entity[-1]["name"] - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - @pytest.mark.skip("support count entities") - @pytest.mark.level(2) - @pytest.mark.timeout(30) - def test_collection_insert_rows_count_multi_threading(self, args, collection): - ''' - target: test collection rows_count is correct or not with multi threading - method: create collection and insert entities in it(idmap), - assert the value returned by count_entities method is equal to length of entities - expected: the count is equal to the length of entities - ''' - if args["handler"] == "HTTP": - pytest.skip("Skip test in http mode") - thread_num = 8 - threads = [] - milvus = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"], try_connect=False) - - def insert(thread_i): - logging.getLogger().info("In thread-%d" % thread_i) - milvus.insert(collection, default_entities) - milvus.flush([collection]) - - for i in range(thread_num): - t = MilvusTestThread(target=insert, args=(i,)) - threads.append(t) - t.start() - for t in threads: - t.join() - res_count = milvus.count_entities(collection) - assert res_count == thread_num * default_nb - - # TODO: unable to set config - @pytest.mark.skip("get entity by id") - @pytest.mark.level(2) - def _test_insert_disable_auto_flush(self, connect, collection): - ''' - target: test insert entities, with disable autoflush - method: disable autoflush and insert, get entity - expected: the count is equal to 0 - ''' - delete_nums = 500 - disable_flush(connect) - ids = connect.insert(collection, default_entities) - res = connect.get_entity_by_id(collection, ids[:delete_nums]) - assert len(res) == delete_nums - assert res[0] is None - - - -class TestInsertBinary: - @pytest.fixture( - scope="function", - params=gen_binary_index() - ) - def get_binary_index(self, request): - request.param["metric_type"] = "JACCARD" - return request.param - - @pytest.mark.skip("count entities") - def test_insert_binary_entities(self, connect, binary_collection): - ''' - target: test insert entities in binary collection - method: create collection and insert binary entities in it - expected: the collection row count equals to nb - ''' - ids = connect.insert(binary_collection, default_binary_entities) - assert len(ids) == default_nb - connect.flush() - assert connect.count_entities(binary_collection) == default_nb - - @pytest.mark.skip("count entities") - def test_insert_binary_entities_new(self, connect, binary_collection): - ''' - target: test insert entities in binary collection - method: create collection and insert binary entities in it - expected: the collection row count equals to nb - ''' - ids = connect.insert(binary_collection, default_binary_entities_new) - assert len(ids) == default_nb - connect.flush() - assert connect.count_entities(binary_collection) == default_nb - - # @pytest.mark.skip - def test_insert_binary_tag(self, connect, binary_collection): - ''' - target: test insert entities and create partition tag - method: create collection and insert binary entities in it, with the partition_tag param - expected: the collection row count equals to nb - ''' - connect.create_partition(binary_collection, default_tag) - ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag) - assert len(ids) == default_nb - assert connect.has_partition(binary_collection, default_tag) - - @pytest.mark.skip("count entities") - @pytest.mark.level(2) - def test_insert_binary_multi_times(self, connect, binary_collection): - ''' - target: test insert entities multi times and final flush - method: create collection and insert binary entity multi and final flush - expected: the collection row count equals to nb - ''' - for i in range(default_nb): - ids = connect.insert(binary_collection, default_binary_entity) - assert len(ids) == 1 - connect.flush([binary_collection]) - assert connect.count_entities(binary_collection) == default_nb - - def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index): - ''' - target: test insert binary entities after build index - method: build index and insert entities - expected: no error raised - ''' - connect.create_index(binary_collection, binary_field_name, get_binary_index) - ids = connect.insert(binary_collection, default_binary_entities) - assert len(ids) == default_nb - connect.flush([binary_collection]) - info = connect.describe_index(binary_collection, binary_field_name) - assert info == get_binary_index - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index): - ''' - target: test build index insert after vector - method: insert vector and build index - expected: no error raised - ''' - ids = connect.insert(binary_collection, default_binary_entities) - assert len(ids) == default_nb - connect.flush([binary_collection]) - connect.create_index(binary_collection, binary_field_name, get_binary_index) - info = connect.describe_index(binary_collection, binary_field_name) - assert info == get_binary_index - - @pytest.mark.skip("binary search") - def test_insert_binary_search(self, connect, binary_collection): - ''' - target: test search vector after insert vector after a while - method: insert vector, sleep, and search collection - expected: no error raised - ''' - ids = connect.insert(binary_collection, default_binary_entities) - connect.flush([binary_collection]) - query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD") - res = connect.search(binary_collection, query) - logging.getLogger().debug(res) - assert res - - -class TestInsertAsync: - @pytest.fixture(scope="function", autouse=True) - def skip_http_check(self, args): - if args["handler"] == "HTTP": - pytest.skip("skip in http mode") - - @pytest.fixture( - scope="function", - params=[ - 1, - 1000 - ], - ) - def insert_count(self, request): - yield request.param - - def check_status(self, result): - logging.getLogger().info("In callback check status") - assert not result - - def check_result(self, result): - logging.getLogger().info("In callback check status") - assert result - - def test_insert_async(self, connect, collection, insert_count): - ''' - target: test insert vectors with different length of vectors - method: set different vectors as insert method params - expected: length of ids is equal to the length of vectors - ''' - nb = insert_count - future = connect.insert(collection, gen_entities(nb), _async=True) - ids = future.result() - connect.flush([collection]) - assert len(ids) == nb - - @pytest.mark.level(2) - def test_insert_async_false(self, connect, collection, insert_count): - ''' - target: test insert vectors with different length of vectors - method: set different vectors as insert method params - expected: length of ids is equal to the length of vectors - ''' - nb = insert_count - ids = connect.insert(collection, gen_entities(nb), _async=False) - # ids = future.result() - connect.flush([collection]) - assert len(ids) == nb - - def test_insert_async_callback(self, connect, collection, insert_count): - ''' - target: test insert vectors with different length of vectors - method: set different vectors as insert method params - expected: length of ids is equal to the length of vectors - ''' - nb = insert_count - future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result) - future.done() - ids = future.result() - assert len(ids) == nb - - - @pytest.mark.skip("count entities") - @pytest.mark.level(2) - def test_insert_async_long(self, connect, collection): - ''' - target: test insert vectors with different length of vectors - method: set different vectors as insert method params - expected: length of ids is equal to the length of vectors - ''' - nb = 50000 - future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result) - result = future.result() - assert len(result) == nb - connect.flush([collection]) - count = connect.count_entities(collection) - logging.getLogger().info(count) - assert count == nb - - @pytest.mark.skip("count entities") - @pytest.mark.level(2) - def test_insert_async_callback_timeout(self, connect, collection): - ''' - target: test insert vectors with different length of vectors - method: set different vectors as insert method params - expected: length of ids is equal to the length of vectors - ''' - nb = 100000 - future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status, timeout=1) - with pytest.raises(Exception) as e: - result = future.result() - count = connect.count_entities(collection) - assert count == 0 - - def test_insert_async_invalid_params(self, connect): - ''' - target: test insert vectors with different length of vectors - method: set different vectors as insert method params - expected: raise exception - ''' - collection_new = gen_unique_str() - with pytest.raises(Exception) as e: - future = connect.insert(collection_new, default_entities, _async=True) - result = future.result() - - def test_insert_async_invalid_params_raise_exception(self, connect, collection): - ''' - target: test insert vectors with different length of vectors - method: set different vectors as insert method params - expected: raise exception - ''' - entities = [] - with pytest.raises(Exception) as e: - future = connect.insert(collection, entities, _async=True) - future.result() - - -class TestInsertMultiCollections: - """ - ****************************************************************** - The following cases are used to test `insert` function - ****************************************************************** - """ - - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - logging.getLogger().info(request.param) - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param - - @pytest.mark.skip("count entities") - def test_insert_vector_multi_collections(self, connect): - ''' - target: test insert entities - method: create 10 collections and insert entities into them in turn - expected: row count - ''' - collection_num = 10 - collection_list = [] - for i in range(collection_num): - collection_name = gen_unique_str(uid) - collection_list.append(collection_name) - connect.create_collection(collection_name, default_fields) - ids = connect.insert(collection_name, default_entities) - connect.flush([collection_name]) - assert len(ids) == default_nb - count = connect.count_entities(collection_name) - assert count == default_nb - - @pytest.mark.timeout(ADD_TIMEOUT) - def test_drop_collection_insert_vector_another(self, connect, collection): - ''' - target: test insert vector to collection_1 after collection_2 deleted - method: delete collection_2 and insert vector to collection_1 - expected: row count equals the length of entities inserted - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - connect.drop_collection(collection) - ids = connect.insert(collection_name, default_entity) - connect.flush([collection_name]) - assert len(ids) == 1 - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_create_index_insert_vector_another(self, connect, collection, get_simple_index): - ''' - target: test insert vector to collection_2 after build index for collection_1 - method: build index and insert vector - expected: status ok - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - connect.create_index(collection, field_name, get_simple_index) - ids = connect.insert(collection, default_entity) - connect.drop_collection(collection_name) - - @pytest.mark.skip("count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_vector_create_index_another(self, connect, collection, get_simple_index): - ''' - target: test insert vector to collection_2 after build index for collection_1 - method: build index and insert vector - expected: status ok - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - ids = connect.insert(collection, default_entity) - connect.create_index(collection, field_name, get_simple_index) - count = connect.count_entities(collection_name) - assert count == 0 - - @pytest.mark.skip("count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_vector_sleep_create_index_another(self, connect, collection, get_simple_index): - ''' - target: test insert vector to collection_2 after build index for collection_1 for a while - method: build index and insert vector - expected: status ok - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - ids = connect.insert(collection, default_entity) - connect.flush([collection]) - connect.create_index(collection, field_name, get_simple_index) - count = connect.count_entities(collection) - assert count == 1 - - @pytest.mark.skip("count entities") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_search_vector_insert_vector_another(self, connect, collection): - ''' - target: test insert vector to collection_1 after search collection_2 - method: search collection and insert vector - expected: status ok - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - res = connect.search(collection, default_single_query) - logging.getLogger().debug(res) - ids = connect.insert(collection_name, default_entity) - connect.flush() - count = connect.count_entities(collection_name) - assert count == 1 - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_vector_search_vector_another(self, connect, collection): - ''' - target: test insert vector to collection_1 after search collection_2 - method: search collection and insert vector - expected: status ok - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - ids = connect.insert(collection, default_entity) - result = connect.search(collection_name, default_single_query) - - @pytest.mark.skip("r0.3-test") - @pytest.mark.timeout(ADD_TIMEOUT) - def test_insert_vector_sleep_search_vector_another(self, connect, collection): - ''' - target: test insert vector to collection_1 after search collection_2 a while - method: search collection , sleep, and insert vector - expected: status ok - ''' - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - ids = connect.insert(collection, default_entity) - connect.flush([collection]) - result = connect.search(collection_name, default_single_query) - - -class TestInsertInvalid(object): - """ - Test inserting vectors with invalid collection names - """ - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_tag_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_field_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_field_type(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_field_int_value(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_ints() - ) - def get_entity_id(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_vectors() - ) - def get_field_vectors_value(self, request): - yield request.param - - def test_insert_ids_invalid(self, connect, id_collection, get_entity_id): - ''' - target: test insert, with using customize ids, which are not int64 - method: create collection and insert entities in it - expected: raise an exception - ''' - entity_id = get_entity_id - ids = [entity_id for _ in range(default_nb)] - with pytest.raises(Exception): - connect.insert(id_collection, default_entities, ids) - - def test_insert_with_invalid_collection_name(self, connect, get_collection_name): - collection_name = get_collection_name - with pytest.raises(Exception): - connect.insert(collection_name, default_entity) - - def test_insert_with_invalid_tag_name(self, connect, collection, get_tag_name): - tag_name = get_tag_name - connect.create_partition(collection, default_tag) - if tag_name is not None: - with pytest.raises(Exception): - connect.insert(collection, default_entity, partition_tag=tag_name) - else: - connect.insert(collection, default_entity, partition_tag=tag_name) - - def test_insert_with_invalid_field_name(self, connect, collection, get_field_name): - field_name = get_field_name - tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - @pytest.mark.skip("laster add check of field type") - def test_insert_with_invalid_field_type(self, connect, collection, get_field_type): - field_type = get_field_type - tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - @pytest.mark.skip("laster add check of field value") - def test_insert_with_invalid_field_value(self, connect, collection, get_field_int_value): - field_value = get_field_int_value - tmp_entity = update_field_type(copy.deepcopy(default_entity), 'int64', field_value) - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - def test_insert_with_invalid_field_vector_value(self, connect, collection, get_field_vectors_value): - tmp_entity = copy.deepcopy(default_entity) - src_vector = tmp_entity[-1]["values"] - src_vector[0][1] = get_field_vectors_value - with pytest.raises(Exception): - connect.insert(collection, tmp_entity) - - -class TestInsertInvalidBinary(object): - """ - Test inserting vectors with invalid collection names - """ - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_tag_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_field_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_field_type(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_field_int_value(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_ints() - ) - def get_entity_id(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_vectors() - ) - def get_field_vectors_value(self, request): - yield request.param - - @pytest.mark.level(2) - def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name): - tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name) - with pytest.raises(Exception): - connect.insert(binary_collection, tmp_entity) - - @pytest.mark.skip("todo support row data check") - @pytest.mark.level(2) - def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value): - tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value) - with pytest.raises(Exception): - connect.insert(binary_collection, tmp_entity) - - @pytest.mark.skip("todo support row data check") - @pytest.mark.level(2) - def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value): - tmp_entity = copy.deepcopy(default_binary_entity) - src_vector = tmp_entity[-1]["values"] - src_vector[0][1] = get_field_vectors_value - with pytest.raises(Exception): - connect.insert(binary_collection, tmp_entity) - - @pytest.mark.level(2) - def test_insert_ids_invalid(self, connect, binary_id_collection, get_entity_id): - ''' - target: test insert, with using customize ids, which are not int64 - method: create collection and insert entities in it - expected: raise an exception - ''' - entity_id = get_entity_id - ids = [entity_id for _ in range(default_nb)] - with pytest.raises(Exception): - connect.insert(binary_id_collection, default_binary_entities, ids) - - @pytest.mark.skip("check filed") - @pytest.mark.level(2) - def test_insert_with_invalid_field_type(self, connect, binary_collection, get_field_type): - field_type = get_field_type - tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_type) - with pytest.raises(Exception): - connect.insert(binary_collection, tmp_entity) - - @pytest.mark.skip("check field") - @pytest.mark.level(2) - def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value): - tmp_entity = copy.deepcopy(default_binary_entities) - src_vector = tmp_entity[-1]["values"] - src_vector[1] = get_field_vectors_value - with pytest.raises(Exception): - connect.insert(binary_collection, tmp_entity) diff --git a/tests/python/test_list_collections.py b/tests/python/test_list_collections.py deleted file mode 100644 index 2c363595c9..0000000000 --- a/tests/python/test_list_collections.py +++ /dev/null @@ -1,88 +0,0 @@ -import pytest -from .utils import * -from .constants import * - -uid = "list_collections" - -class TestListCollections: - """ - ****************************************************************** - The following cases are used to test `list_collections` function - ****************************************************************** - """ - def test_list_collections(self, connect, collection): - ''' - target: test list collections - method: create collection, assert the value returned by list_collections method - expected: True - ''' - assert collection in connect.list_collections() - - def test_list_collections_multi_collections(self, connect): - ''' - target: test list collections - method: create collection, assert the value returned by list_collections method - expected: True - ''' - collection_num = 50 - for i in range(collection_num): - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - assert collection_name in connect.list_collections() - - @pytest.mark.level(2) - def test_list_collections_without_connection(self, dis_connect): - ''' - target: test list collections, without connection - method: calling list collections with correct params, with a disconnected instance - expected: list collections raise exception - ''' - with pytest.raises(Exception) as e: - dis_connect.list_collections() - - def test_list_collections_not_existed(self, connect): - ''' - target: test if collection not created - method: random a collection name, which not existed in db, - assert the value returned by list_collections method - expected: False - ''' - collection_name = gen_unique_str(uid) - assert collection_name not in connect.list_collections() - - - @pytest.mark.level(2) - @pytest.mark.skip("can't run in parallel") - def test_list_collections_no_collection(self, connect): - ''' - target: test show collections is correct or not, if no collection in db - method: delete all collections, - assert the value returned by list_collections method is equal to [] - expected: the status is ok, and the result is equal to [] - ''' - result = connect.list_collections() - if result: - for collection_name in result: - assert connect.has_collection(collection_name) - - @pytest.mark.level(2) - def test_list_collections_multithread(self, connect): - ''' - target: test create collection with multithread - method: create collection using multithread, - expected: collections are created - ''' - threads_num = 4 - threads = [] - collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - - def _list(): - assert collection_name in connect.list_collections() - for i in range(threads_num): - t = threading.Thread(target=_list, args=()) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() diff --git a/tests/python/test_load_collection.py b/tests/python/test_load_collection.py deleted file mode 100644 index e0f2e8b52b..0000000000 --- a/tests/python/test_load_collection.py +++ /dev/null @@ -1,22 +0,0 @@ -from tests.utils import * -from tests.constants import * - -uniq_id = "load_collection" - -class TestLoadCollection: - """ - ****************************************************************** - The following cases are used to test `load_collection` function - ****************************************************************** - """ - def test_load_collection(self, connect, collection_without_loading): - ''' - target: test load collection and wait for loading collection - method: insert then flush, when flushed, try load collection - expected: no errors - ''' - collection = collection_without_loading - ids = connect.insert(collection, default_entities) - ids = connect.insert(collection, default_entity) - connect.flush([collection]) - connect.load_collection(collection) \ No newline at end of file diff --git a/tests/python/test_load_partitions.py b/tests/python/test_load_partitions.py deleted file mode 100644 index 128b04139c..0000000000 --- a/tests/python/test_load_partitions.py +++ /dev/null @@ -1,26 +0,0 @@ -from tests.utils import * -from tests.constants import * - -uniq_id = "load_partitions" - -class TestLoadPartitions: - """ - ****************************************************************** - The following cases are used to test `load_partitions` function - ****************************************************************** - """ - def test_load_partitions(self, connect, collection): - ''' - target: test load collection and wait for loading collection - method: insert then flush, when flushed, try load collection - expected: no errors - ''' - partition_tag = "lvn9pq34u8rasjk" - connect.create_partition(collection, partition_tag + "1") - ids = connect.insert(collection, default_entities, partition_tag=partition_tag + "1") - - connect.create_partition(collection, partition_tag + "2") - ids = connect.insert(collection, default_entity, partition_tag=partition_tag + "2") - - connect.flush([collection]) - connect.load_partitions(collection, [partition_tag + "2"]) diff --git a/tests/python/test_partition.py b/tests/python/test_partition.py deleted file mode 100644 index 8a412dcdc2..0000000000 --- a/tests/python/test_partition.py +++ /dev/null @@ -1,396 +0,0 @@ -import pytest -from .utils import * -from .constants import * - -TIMEOUT = 120 - -class TestCreateBase: - """ - ****************************************************************** - The following cases are used to test `create_partition` function - ****************************************************************** - """ - def test_create_partition(self, connect, collection): - ''' - target: test create partition, check status returned - method: call function: create_partition - expected: status ok - ''' - connect.create_partition(collection, default_tag) - - @pytest.mark.level(2) - @pytest.mark.timeout(600) - @pytest.mark.skip - def test_create_partition_limit(self, connect, collection, args): - ''' - target: test create partitions, check status returned - method: call function: create_partition for 4097 times - expected: exception raised - ''' - threads_num = 8 - threads = [] - if args["handler"] == "HTTP": - pytest.skip("skip in http mode") - - def create(connect, threads_num): - for i in range(max_partition_num // threads_num): - tag_tmp = gen_unique_str() - connect.create_partition(collection, tag_tmp) - - for i in range(threads_num): - m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"]) - t = threading.Thread(target=create, args=(m, threads_num, )) - threads.append(t) - t.start() - for t in threads: - t.join() - tag_tmp = gen_unique_str() - with pytest.raises(Exception) as e: - connect.create_partition(collection, tag_tmp) - - def test_create_partition_repeat(self, connect, collection): - ''' - target: test create partition, check status returned - method: call function: create_partition - expected: status ok - ''' - connect.create_partition(collection, default_tag) - with pytest.raises(Exception) as e: - connect.create_partition(collection, default_tag) - - def test_create_partition_collection_not_existed(self, connect): - ''' - target: test create partition, its owner collection name not existed in db, check status returned - method: call function: create_partition - expected: status not ok - ''' - collection_name = gen_unique_str() - with pytest.raises(Exception) as e: - connect.create_partition(collection_name, default_tag) - - def test_create_partition_tag_name_None(self, connect, collection): - ''' - target: test create partition, tag name set None, check status returned - method: call function: create_partition - expected: status ok - ''' - tag_name = None - with pytest.raises(Exception) as e: - connect.create_partition(collection, tag_name) - - def test_create_different_partition_tags(self, connect, collection): - ''' - target: test create partition twice with different names - method: call function: create_partition, and again - expected: status ok - ''' - connect.create_partition(collection, default_tag) - tag_name = gen_unique_str() - connect.create_partition(collection, tag_name) - tag_list = connect.list_partitions(collection) - assert default_tag in tag_list - assert tag_name in tag_list - assert "_default" in tag_list - - @pytest.mark.skip("not support custom id") - def test_create_partition_insert_default(self, connect, id_collection): - ''' - target: test create partition, and insert vectors, check status returned - method: call function: create_partition - expected: status ok - ''' - connect.create_partition(id_collection, default_tag) - ids = [i for i in range(default_nb)] - insert_ids = connect.insert(id_collection, default_entities, ids) - assert len(insert_ids) == len(ids) - - @pytest.mark.skip("not support custom id") - def test_create_partition_insert_with_tag(self, connect, id_collection): - ''' - target: test create partition, and insert vectors, check status returned - method: call function: create_partition - expected: status ok - ''' - connect.create_partition(id_collection, default_tag) - ids = [i for i in range(default_nb)] - insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag) - assert len(insert_ids) == len(ids) - - def test_create_partition_insert_with_tag_not_existed(self, connect, collection): - ''' - target: test create partition, and insert vectors, check status returned - method: call function: create_partition - expected: status not ok - ''' - tag_new = "tag_new" - connect.create_partition(collection, default_tag) - ids = [i for i in range(default_nb)] - with pytest.raises(Exception) as e: - insert_ids = connect.insert(collection, default_entities, ids, partition_tag=tag_new) - - @pytest.mark.skip("not support custom id") - def test_create_partition_insert_same_tags(self, connect, id_collection): - ''' - target: test create partition, and insert vectors, check status returned - method: call function: create_partition - expected: status ok - ''' - connect.create_partition(id_collection, default_tag) - ids = [i for i in range(default_nb)] - insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag) - ids = [(i+default_nb) for i in range(default_nb)] - new_insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag) - connect.flush([id_collection]) - res = connect.count_entities(id_collection) - assert res == default_nb * 2 - - @pytest.mark.level(2) - @pytest.mark.skip("not support count entities") - def test_create_partition_insert_same_tags_two_collections(self, connect, collection): - ''' - target: test create two partitions, and insert vectors with the same tag to each collection, check status returned - method: call function: create_partition - expected: status ok, collection length is correct - ''' - connect.create_partition(collection, default_tag) - collection_new = gen_unique_str() - connect.create_collection(collection_new, default_fields) - connect.create_partition(collection_new, default_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) - ids = connect.insert(collection_new, default_entities, partition_tag=default_tag) - connect.flush([collection, collection_new]) - res = connect.count_entities(collection) - assert res == default_nb - res = connect.count_entities(collection_new) - assert res == default_nb - - -class TestShowBase: - - """ - ****************************************************************** - The following cases are used to test `list_partitions` function - ****************************************************************** - """ - def test_list_partitions(self, connect, collection): - ''' - target: test show partitions, check status and partitions returned - method: create partition first, then call function: list_partitions - expected: status ok, partition correct - ''' - connect.create_partition(collection, default_tag) - res = connect.list_partitions(collection) - assert default_tag in res - - def test_list_partitions_no_partition(self, connect, collection): - ''' - target: test show partitions with collection name, check status and partitions returned - method: call function: list_partitions - expected: status ok, partitions correct - ''' - res = connect.list_partitions(collection) - assert len(res) == 1 - - def test_show_multi_partitions(self, connect, collection): - ''' - target: test show partitions, check status and partitions returned - method: create partitions first, then call function: list_partitions - expected: status ok, partitions correct - ''' - tag_new = gen_unique_str() - connect.create_partition(collection, default_tag) - connect.create_partition(collection, tag_new) - res = connect.list_partitions(collection) - assert default_tag in res - assert tag_new in res - - -class TestHasBase: - - """ - ****************************************************************** - The following cases are used to test `has_partition` function - ****************************************************************** - """ - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_tag_name(self, request): - yield request.param - - def test_has_partition(self, connect, collection): - ''' - target: test has_partition, check status and result - method: create partition first, then call function: has_partition - expected: status ok, result true - ''' - connect.create_partition(collection, default_tag) - res = connect.has_partition(collection, default_tag) - logging.getLogger().info(res) - assert res - - def test_has_partition_multi_partitions(self, connect, collection): - ''' - target: test has_partition, check status and result - method: create partition first, then call function: has_partition - expected: status ok, result true - ''' - for tag_name in [default_tag, "tag_new", "tag_new_new"]: - connect.create_partition(collection, tag_name) - for tag_name in [default_tag, "tag_new", "tag_new_new"]: - res = connect.has_partition(collection, tag_name) - assert res - - def test_has_partition_tag_not_existed(self, connect, collection): - ''' - target: test has_partition, check status and result - method: then call function: has_partition, with tag not existed - expected: status ok, result empty - ''' - res = connect.has_partition(collection, default_tag) - logging.getLogger().info(res) - assert not res - - def test_has_partition_collection_not_existed(self, connect, collection): - ''' - target: test has_partition, check status and result - method: then call function: has_partition, with collection not existed - expected: status not ok - ''' - with pytest.raises(Exception) as e: - res = connect.has_partition("not_existed_collection", default_tag) - - @pytest.mark.level(2) - def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name): - ''' - target: test has partition, with invalid tag name, check status returned - method: call function: has_partition - expected: status ok - ''' - tag_name = get_tag_name - connect.create_partition(collection, default_tag) - with pytest.raises(Exception) as e: - res = connect.has_partition(collection, tag_name) - - -class TestDropBase: - - """ - ****************************************************************** - The following cases are used to test `drop_partition` function - ****************************************************************** - """ - def test_drop_partition(self, connect, collection): - ''' - target: test drop partition, check status and partition if existed - method: create partitions first, then call function: drop_partition - expected: status ok, no partitions in db - ''' - connect.create_partition(collection, default_tag) - connect.drop_partition(collection, default_tag) - res = connect.list_partitions(collection) - tag_list = [] - assert default_tag not in tag_list - - def test_drop_partition_tag_not_existed(self, connect, collection): - ''' - target: test drop partition, but tag not existed - method: create partitions first, then call function: drop_partition - expected: status not ok - ''' - connect.create_partition(collection, default_tag) - new_tag = "new_tag" - with pytest.raises(Exception) as e: - connect.drop_partition(collection, new_tag) - - def test_drop_partition_tag_not_existed_A(self, connect, collection): - ''' - target: test drop partition, but collection not existed - method: create partitions first, then call function: drop_partition - expected: status not ok - ''' - connect.create_partition(collection, default_tag) - new_collection = gen_unique_str() - with pytest.raises(Exception) as e: - connect.drop_partition(new_collection, default_tag) - - @pytest.mark.level(2) - def test_drop_partition_repeatedly(self, connect, collection): - ''' - target: test drop partition twice, check status and partition if existed - method: create partitions first, then call function: drop_partition - expected: status not ok, no partitions in db - ''' - connect.create_partition(collection, default_tag) - connect.drop_partition(collection, default_tag) - time.sleep(2) - with pytest.raises(Exception) as e: - connect.drop_partition(collection, default_tag) - tag_list = connect.list_partitions(collection) - assert default_tag not in tag_list - - def test_drop_partition_create(self, connect, collection): - ''' - target: test drop partition, and create again, check status - method: create partitions first, then call function: drop_partition, create_partition - expected: status not ok, partition in db - ''' - connect.create_partition(collection, default_tag) - connect.drop_partition(collection, default_tag) - time.sleep(2) - connect.create_partition(collection, default_tag) - tag_list = connect.list_partitions(collection) - assert default_tag in tag_list - - -class TestNameInvalid(object): - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_tag_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.mark.level(2) - def test_drop_partition_with_invalid_collection_name(self, connect, collection, get_collection_name): - ''' - target: test drop partition, with invalid collection name, check status returned - method: call function: drop_partition - expected: status not ok - ''' - collection_name = get_collection_name - connect.create_partition(collection, default_tag) - with pytest.raises(Exception) as e: - connect.drop_partition(collection_name, default_tag) - - @pytest.mark.level(2) - def test_drop_partition_with_invalid_tag_name(self, connect, collection, get_tag_name): - ''' - target: test drop partition, with invalid tag name, check status returned - method: call function: drop_partition - expected: status not ok - ''' - tag_name = get_tag_name - connect.create_partition(collection, default_tag) - with pytest.raises(Exception) as e: - connect.drop_partition(collection, tag_name) - - @pytest.mark.level(2) - def test_list_partitions_with_invalid_collection_name(self, connect, collection, get_collection_name): - ''' - target: test show partitions, with invalid collection name, check status returned - method: call function: list_partitions - expected: status not ok - ''' - collection_name = get_collection_name - connect.create_partition(collection, default_tag) - with pytest.raises(Exception) as e: - res = connect.list_partitions(collection_name) diff --git a/tests/python/test_search.py b/tests/python/test_search.py deleted file mode 100644 index 78a07cd77b..0000000000 --- a/tests/python/test_search.py +++ /dev/null @@ -1,1818 +0,0 @@ -import time -import pdb -import copy -import logging -from multiprocessing import Pool, Process -import pytest -import numpy as np - -from milvus import DataType -from .utils import * -from .constants import * - -uid = "test_search" -nq = 1 -epsilon = 0.001 -field_name = default_float_vec_field_name -binary_field_name = default_binary_vec_field_name -search_param = {"nprobe": 1} - -entity = gen_entities(1, is_normal=True) -entities = gen_entities(default_nb, is_normal=True) -raw_vectors, binary_entities = gen_binary_entities(default_nb) -default_query, default_query_vecs = gen_query_vectors(field_name, entities, default_top_k, nq) -default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, default_top_k, - nq) - - -def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True): - ''' - Generate entities and add it in collection - ''' - global entities - if nb == 1200: - insert_entities = entities - else: - insert_entities = gen_entities(nb, is_normal=True) - if partition_tags is None: - if auto_id: - ids = connect.insert(collection, insert_entities) - else: - ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)]) - else: - if auto_id: - ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) - else: - ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags) - # connect.flush([collection]) - return insert_entities, ids - - -def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=None): - ''' - Generate entities and add it in collection - ''' - ids = [] - global binary_entities - global raw_vectors - if nb == 1200: - insert_entities = binary_entities - insert_raw_vectors = raw_vectors - else: - insert_raw_vectors, insert_entities = gen_binary_entities(nb) - if insert is True: - if partition_tags is None: - ids = connect.insert(collection, insert_entities) - else: - ids = connect.insert(collection, insert_entities, partition_tag=partition_tags) - connect.flush([collection]) - return insert_raw_vectors, insert_entities, ids - - -class TestSearchBase: - """ - generate valid create_index params - """ - - @pytest.fixture( - scope="function", - params=gen_index() - ) - def get_index(self, request, connect): - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param - - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - import copy - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return copy.deepcopy(request.param) - - @pytest.fixture( - scope="function", - params=gen_binary_index() - ) - def get_jaccard_index(self, request, connect): - logging.getLogger().info(request.param) - if request.param["index_type"] in binary_support(): - return request.param - else: - pytest.skip("Skip index Temporary") - - @pytest.fixture( - scope="function", - params=gen_binary_index() - ) - def get_hamming_index(self, request, connect): - logging.getLogger().info(request.param) - if request.param["index_type"] in binary_support(): - return request.param - else: - pytest.skip("Skip index Temporary") - - @pytest.fixture( - scope="function", - params=gen_binary_index() - ) - def get_structure_index(self, request, connect): - logging.getLogger().info(request.param) - if request.param["index_type"] == "FLAT": - return request.param - else: - pytest.skip("Skip index Temporary") - - """ - generate top-k params - """ - - @pytest.fixture( - scope="function", - params=[1, 10] - ) - def get_top_k(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=[1, 10, 1100] - ) - def get_nq(self, request): - yield request.param - - # PASS - def test_search_flat(self, connect, collection, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, change top-k value - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = get_nq - entities, ids = init_data(connect, collection) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq) - if top_k <= max_top_k: - res = connect.search(collection, query) - assert len(res[0]) == top_k - assert res[0]._distances[0] <= epsilon - assert check_id_result(res[0], ids[0]) - else: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # milvus-distributed dose not have the limitation of top_k - def test_search_flat_top_k(self, connect, collection, get_nq): - ''' - target: test basic search function, all the search params is corrent, change top-k value - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' - top_k = 16385 - nq = get_nq - entities, ids = init_data(connect, collection) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq) - if top_k <= max_top_k: - res = connect.search(collection, query) - assert len(res[0]) == top_k - assert res[0]._distances[0] <= epsilon - assert check_id_result(res[0], ids[0]) - else: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # TODO: reopen after we supporting targetEntry - @pytest.mark.skip("search_field") - def test_search_field(self, connect, collection, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, change top-k value - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = get_nq - entities, ids = init_data(connect, collection) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq) - if top_k <= max_top_k: - res = connect.search(collection, query, fields=["float_vector"]) - assert len(res[0]) == top_k - assert res[0]._distances[0] <= epsilon - assert check_id_result(res[0], ids[0]) - res = connect.search(collection, query, fields=["float"]) - for i in range(nq): - assert entities[1]["values"][:nq][i] in [r.entity.get('float') for r in res[i]] - else: - with pytest.raises(Exception): - connect.search(collection, query) - - @pytest.mark.skip("search_after_delete") - def test_search_after_delete(self, connect, collection, get_top_k, get_nq): - ''' - target: test basic search function before and after deletion, all the search params is - corrent, change top-k value. - check issue #4200 - method: search with the given vectors, check the result - expected: the deleted entities do not exist in the result. - ''' - top_k = get_top_k - nq = get_nq - - entities, ids = init_data(connect, collection, nb=10000) - first_int64_value = entities[0]["values"][0] - first_vector = entities[2]["values"][0] - - search_param = get_search_param("FLAT") - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) - vecs[:] = [] - vecs.append(first_vector) - - res = None - if top_k > max_top_k: - with pytest.raises(Exception): - connect.search(collection, query, fields=['int64']) - pytest.skip("top_k value is larger than max_topp_k") - else: - res = connect.search(collection, query, fields=['int64']) - assert len(res) == 1 - assert len(res[0]) >= top_k - assert res[0][0].id == ids[0] - assert res[0][0].entity.get("int64") == first_int64_value - assert res[0]._distances[0] < epsilon - assert check_id_result(res[0], ids[0]) - - connect.delete_entity_by_id(collection, ids[:1]) - connect.flush([collection]) - - res2 = connect.search(collection, query, fields=['int64']) - assert len(res2) == 1 - assert len(res2[0]) >= top_k - assert res2[0][0].id != ids[0] - if top_k > 1: - assert res2[0][0].id == res[0][1].id - assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") - - # Pass - @pytest.mark.level(2) - def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = get_nq - - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - entities, ids = init_data(connect, collection) - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - else: - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) >= top_k - assert res[0]._distances[0] < epsilon - assert check_id_result(res[0], ids[0]) - - # DOG: TODO INVALID TYPE UNKNOWN - @pytest.mark.skip("search_after_index_different_metric_type") - def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index): - ''' - target: test search with different metric_type - method: build index with L2, and search using IP - expected: search ok - ''' - search_metric_type = "IP" - index_type = get_simple_index["index_type"] - entities, ids = init_data(connect, collection) - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, metric_type=search_metric_type, - search_params=search_param) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - - # pass - @pytest.mark.level(2) - @pytest.mark.skip("r0.3-test") - def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: add vectors into collection, search with the given vectors, check the result - expected: the length of the result is top_k, search collection with partition tag return empty - ''' - top_k = get_top_k - nq = get_nq - - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - connect.create_partition(collection, default_tag) - entities, ids = init_data(connect, collection) - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - else: - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) >= top_k - assert res[0]._distances[0] < epsilon - assert check_id_result(res[0], ids[0]) - res = connect.search(collection, query, partition_tags=[default_tag]) - assert len(res) == nq - - # PASS - @pytest.mark.level(2) - def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is correct, test all index params, and build - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = get_nq - - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - connect.create_partition(collection, default_tag) - entities, ids = init_data(connect, collection, partition_tags=default_tag) - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) - for tags in [[default_tag], [default_tag, "new_tag"]]: - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query, partition_tags=tags) - else: - res = connect.search(collection, query, partition_tags=tags) - assert len(res) == nq - assert len(res[0]) >= top_k - assert res[0]._distances[0] < epsilon - assert check_id_result(res[0], ids[0]) - - @pytest.mark.skip("search_index_partition_C") - @pytest.mark.level(2) - def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: search with the given vectors and tag (tag name not existed in collection), check the result - expected: error raised - ''' - top_k = get_top_k - nq = get_nq - entities, ids = init_data(connect, collection) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query, partition_tags=["new_tag"]) - else: - res = connect.search(collection, query, partition_tags=["new_tag"]) - assert len(res) == nq - assert len(res[0]) == 0 - - # PASS - @pytest.mark.skip("r0.3-test") - @pytest.mark.level(2) - def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: search collection with the given vectors and tags, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = 2 - new_tag = "new_tag" - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - connect.create_partition(collection, default_tag) - connect.create_partition(collection, new_tag) - entities, ids = init_data(connect, collection, partition_tags=default_tag) - new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - else: - res = connect.search(collection, query) - assert check_id_result(res[0], ids[0]) - assert not check_id_result(res[1], new_ids[0]) - assert res[0]._distances[0] < epsilon - assert res[1]._distances[0] < epsilon - res = connect.search(collection, query, partition_tags=["new_tag"]) - assert res[0]._distances[0] > epsilon - assert res[1]._distances[0] > epsilon - - # Pass - @pytest.mark.level(2) - def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: search collection with the given vectors and tags, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = 2 - tag = "tag" - new_tag = "new_tag" - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - connect.create_partition(collection, tag) - connect.create_partition(collection, new_tag) - entities, ids = init_data(connect, collection, partition_tags=tag) - new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - else: - res = connect.search(collection, query, partition_tags=["(.*)tag"]) - assert not check_id_result(res[0], ids[0]) - assert res[0]._distances[0] < epsilon - assert res[1]._distances[0] < epsilon - res = connect.search(collection, query, partition_tags=["new(.*)"]) - assert res[0]._distances[0] < epsilon - assert res[1]._distances[0] < epsilon - - # pass - # test for ip metric - # - # TODO: reopen after we supporting ip flat - # DOG: TODO REDUCE - @pytest.mark.skip("search_ip_flat") - @pytest.mark.level(2) - def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, change top-k value - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = get_nq - entities, ids = init_data(connect, collection) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP") - if top_k <= max_top_k: - res = connect.search(collection, query) - assert len(res[0]) == top_k - assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) - assert check_id_result(res[0], ids[0]) - else: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - @pytest.mark.skip("r0.3-test") - @pytest.mark.level(2) - def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = get_nq - - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - entities, ids = init_data(connect, collection) - get_simple_index["metric_type"] = "IP" - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - else: - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) >= top_k - assert check_id_result(res[0], ids[0]) - assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) - - @pytest.mark.level(2) - def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: add vectors into collection, search with the given vectors, check the result - expected: the length of the result is top_k, search collection with partition tag return empty - ''' - top_k = get_top_k - nq = get_nq - metric_type = "IP" - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - connect.create_partition(collection, default_tag) - entities, ids = init_data(connect, collection) - get_simple_index["metric_type"] = metric_type - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, - search_params=search_param) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - else: - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) >= top_k - assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) - assert check_id_result(res[0], ids[0]) - res = connect.search(collection, query, partition_tags=[default_tag]) - assert len(res) == nq - - # PASS - @pytest.mark.level(2) - def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): - ''' - target: test basic search function, all the search params is corrent, test all index params, and build - method: search collection with the given vectors and tags, check the result - expected: the length of the result is top_k - ''' - top_k = get_top_k - nq = 2 - metric_type = "IP" - new_tag = "new_tag" - index_type = get_simple_index["index_type"] - if index_type in skip_pq(): - pytest.skip("Skip PQ") - connect.create_partition(collection, default_tag) - connect.create_partition(collection, new_tag) - entities, ids = init_data(connect, collection, partition_tags=default_tag) - new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag) - get_simple_index["metric_type"] = metric_type - connect.create_index(collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) - if top_k > max_top_k: - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - else: - res = connect.search(collection, query) - assert check_id_result(res[0], ids[0]) - assert not check_id_result(res[1], new_ids[0]) - assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) - assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) - res = connect.search(collection, query, partition_tags=["new_tag"]) - assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0]) - # TODO: - # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) - - # PASS - @pytest.mark.level(2) - def test_search_without_connect(self, dis_connect, collection): - ''' - target: test search vectors without connection - method: use dis connected instance, call search method and check if search successfully - expected: raise exception - ''' - with pytest.raises(Exception) as e: - res = dis_connect.search(collection, default_query) - - # PASS - # TODO: proxy or SDK checks if collection exists - def test_search_collection_name_not_existed(self, connect): - ''' - target: search collection not existed - method: search with the random collection_name, which is not in db - expected: status not ok - ''' - collection_name = gen_unique_str(uid) - with pytest.raises(Exception) as e: - res = connect.search(collection_name, default_query) - - # PASS - def test_search_distance_l2(self, connect, collection): - ''' - target: search collection, and check the result: distance - method: compare the return distance value with value computed with Euclidean - expected: the return distance equals to the computed value - ''' - nq = 2 - search_param = {"nprobe": 1} - entities, ids = init_data(connect, collection, nb=nq) - query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, - search_params=search_param) - inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, - search_params=search_param) - distance_0 = l2(vecs[0], inside_vecs[0]) - distance_1 = l2(vecs[0], inside_vecs[1]) - res = connect.search(collection, query) - assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) - - # Pass - def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): - ''' - target: search collection, and check the result: distance - method: compare the return distance value with value computed with Inner product - expected: the return distance equals to the computed value - ''' - index_type = get_simple_index["index_type"] - nq = 2 - entities, ids = init_data(connect, id_collection, auto_id=False) - connect.create_index(id_collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, - search_params=search_param) - inside_vecs = entities[-1]["values"] - min_distance = 1.0 - min_id = None - for i in range(default_nb): - tmp_dis = l2(vecs[0], inside_vecs[i]) - if min_distance > tmp_dis: - min_distance = tmp_dis - min_id = ids[i] - res = connect.search(id_collection, query) - tmp_epsilon = epsilon - check_id_result(res[0], min_id) - # if index_type in ["ANNOY", "IVF_PQ"]: - # tmp_epsilon = 0.1 - # TODO: - # assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon - - # DOG: TODO REDUCE - # TODO: reopen after we supporting ip flat - @pytest.mark.skip("search_distance_ip") - @pytest.mark.level(2) - def test_search_distance_ip(self, connect, collection): - ''' - target: search collection, and check the result: distance - method: compare the return distance value with value computed with Inner product - expected: the return distance equals to the computed value - ''' - nq = 2 - metirc_type = "IP" - search_param = {"nprobe": 1} - entities, ids = init_data(connect, collection, nb=nq) - query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, - metric_type=metirc_type, - search_params=search_param) - inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, - search_params=search_param) - distance_0 = ip(vecs[0], inside_vecs[0]) - distance_1 = ip(vecs[0], inside_vecs[1]) - res = connect.search(collection, query) - assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon - - # Pass - def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): - ''' - target: search collection, and check the result: distance - method: compare the return distance value with value computed with Inner product - expected: the return distance equals to the computed value - ''' - index_type = get_simple_index["index_type"] - nq = 2 - metirc_type = "IP" - entities, ids = init_data(connect, id_collection, auto_id=False) - get_simple_index["metric_type"] = metirc_type - connect.create_index(id_collection, field_name, get_simple_index) - search_param = get_search_param(index_type) - query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, - metric_type=metirc_type, - search_params=search_param) - inside_vecs = entities[-1]["values"] - max_distance = 0 - max_id = None - for i in range(default_nb): - tmp_dis = ip(vecs[0], inside_vecs[i]) - if max_distance < tmp_dis: - max_distance = tmp_dis - max_id = ids[i] - res = connect.search(id_collection, query) - tmp_epsilon = epsilon - check_id_result(res[0], max_id) - # if index_type in ["ANNOY", "IVF_PQ"]: - # tmp_epsilon = 0.1 - # TODO: - # assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon - - # PASS - @pytest.mark.skip("r0.3-test") - def test_search_distance_jaccard_flat_index(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with L2 - expected: the return distance equals to the computed value - ''' - nq = 1 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) - distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) - distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) - query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="JACCARD") - res = connect.search(binary_collection, query) - assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon - - # DOG: TODO INVALID TYPE - @pytest.mark.skip("search_distance_jaccard_flat_index_L2") - @pytest.mark.level(2) - def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with L2 - expected: throw error of mismatched metric type - ''' - nq = 1 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) - distance_0 = jaccard(query_int_vectors[0], int_vectors[0]) - distance_1 = jaccard(query_int_vectors[0], int_vectors[1]) - query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="L2") - with pytest.raises(Exception) as e: - res = connect.search(binary_collection, query) - - # PASS - @pytest.mark.skip("r0.3-test") - @pytest.mark.level(2) - def test_search_distance_hamming_flat_index(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with Inner product - expected: the return distance equals to the computed value - ''' - nq = 1 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) - distance_0 = hamming(query_int_vectors[0], int_vectors[0]) - distance_1 = hamming(query_int_vectors[0], int_vectors[1]) - query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="HAMMING") - res = connect.search(binary_collection, query) - assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon - - # PASS - @pytest.mark.level(2) - def test_search_distance_substructure_flat_index(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with Inner product - expected: the return distance equals to the computed value - ''' - nq = 1 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) - distance_0 = substructure(query_int_vectors[0], int_vectors[0]) - distance_1 = substructure(query_int_vectors[0], int_vectors[1]) - query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, - metric_type="SUBSTRUCTURE") - res = connect.search(binary_collection, query) - assert len(res[0]) == 0 - - # PASS - @pytest.mark.level(2) - def test_search_distance_substructure_flat_index_B(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with SUB - expected: the return distance equals to the computed value - ''' - top_k = 3 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2) - query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUBSTRUCTURE", - replace_vecs=query_vecs) - res = connect.search(binary_collection, query) - assert res[0][0].distance <= epsilon - assert res[0][0].id == ids[0] - assert res[1][0].distance <= epsilon - assert res[1][0].id == ids[1] - - # PASS - @pytest.mark.level(2) - def test_search_distance_superstructure_flat_index(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with Inner product - expected: the return distance equals to the computed value - ''' - nq = 1 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) - distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) - distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) - query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, - metric_type="SUPERSTRUCTURE") - res = connect.search(binary_collection, query) - assert len(res[0]) == 0 - - # PASS - @pytest.mark.level(2) - def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with SUPER - expected: the return distance equals to the computed value - ''' - top_k = 3 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2) - query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUPERSTRUCTURE", - replace_vecs=query_vecs) - res = connect.search(binary_collection, query) - assert len(res[0]) == 2 - assert len(res[1]) == 2 - assert res[0][0].id in ids - assert res[0][0].distance <= epsilon - assert res[1][0].id in ids - assert res[1][0].distance <= epsilon - - # PASS - @pytest.mark.level(2) - def test_search_distance_tanimoto_flat_index(self, connect, binary_collection): - ''' - target: search binary_collection, and check the result: distance - method: compare the return distance value with value computed with Inner product - expected: the return distance equals to the computed value - ''' - nq = 1 - int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) - query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) - distance_0 = tanimoto(query_int_vectors[0], int_vectors[0]) - distance_1 = tanimoto(query_int_vectors[0], int_vectors[1]) - query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="TANIMOTO") - res = connect.search(binary_collection, query) - assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon - - # PASS - @pytest.mark.skip("r0.3-test") - @pytest.mark.level(2) - @pytest.mark.timeout(30) - def test_search_concurrent_multithreads(self, connect, args): - ''' - target: test concurrent search with multiprocessess - method: search with 10 processes, each process uses dependent connection - expected: status ok and the returned vectors should be query_records - ''' - nb = 100 - top_k = 10 - threads_num = 4 - threads = [] - collection = gen_unique_str(uid) - uri = "tcp://%s:%s" % (args["ip"], args["port"]) - # create collection - milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) - milvus.create_collection(collection, default_fields) - entities, ids = init_data(milvus, collection) - - def search(milvus): - res = milvus.search(collection, default_query) - assert len(res) == 1 - assert res[0]._entities[0].id in ids - assert res[0]._distances[0] < epsilon - - for i in range(threads_num): - milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) - t = MilvusTestThread(target=search, args=(milvus,)) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - # PASS - @pytest.mark.skip("r0.3-test") - @pytest.mark.level(2) - @pytest.mark.timeout(30) - def test_search_concurrent_multithreads_single_connection(self, connect, args): - ''' - target: test concurrent search with multiprocessess - method: search with 10 processes, each process uses dependent connection - expected: status ok and the returned vectors should be query_records - ''' - nb = 100 - top_k = 10 - threads_num = 4 - threads = [] - collection = gen_unique_str(uid) - uri = "tcp://%s:%s" % (args["ip"], args["port"]) - # create collection - milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) - milvus.create_collection(collection, default_fields) - entities, ids = init_data(milvus, collection) - - def search(milvus): - res = milvus.search(collection, default_query) - assert len(res) == 1 - assert res[0]._entities[0].id in ids - assert res[0]._distances[0] < epsilon - - for i in range(threads_num): - t = MilvusTestThread(target=search, args=(milvus,)) - threads.append(t) - t.start() - time.sleep(0.2) - for t in threads: - t.join() - - # PASS - @pytest.mark.skip("r0.3-test") - @pytest.mark.level(2) - def test_search_multi_collections(self, connect, args): - ''' - target: test search multi collections of L2 - method: add vectors into 10 collections, and search - expected: search status ok, the length of result - ''' - num = 10 - top_k = 10 - nq = 20 - for i in range(num): - collection = gen_unique_str(uid + str(i)) - connect.create_collection(collection, default_fields) - entities, ids = init_data(connect, collection) - assert len(ids) == default_nb - query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param) - res = connect.search(collection, query) - assert len(res) == nq - for i in range(nq): - assert check_id_result(res[i], ids[i]) - assert res[i]._distances[0] < epsilon - assert res[i]._distances[1] > epsilon - - @pytest.mark.skip("test_query_entities_with_field_less_than_top_k") - def test_query_entities_with_field_less_than_top_k(self, connect, id_collection): - """ - target: test search with field, and let return entities less than topk - method: insert entities and build ivf_ index, and search with field, n_probe=1 - expected: - """ - entities, ids = init_data(connect, id_collection, auto_id=False) - simple_index = {"index_type": "IVF_FLAT", "params": {"nlist": 200}, "metric_type": "L2"} - connect.create_index(id_collection, field_name, simple_index) - # logging.getLogger().info(connect.get_collection_info(id_collection)) - top_k = 300 - default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 1}) - expr = {"must": [gen_default_vector_expr(default_query)]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(id_collection, query, fields=["int64"]) - assert len(res) == nq - for r in res[0]: - assert getattr(r.entity, "int64") == getattr(r.entity, "id") - - -class TestSearchDSL(object): - """ - ****************************************************************** - # The following cases are used to build invalid query expr - ****************************************************************** - """ - - # PASS - def test_query_no_must(self, connect, collection): - ''' - method: build query without must expr - expected: error raised - ''' - # entities, ids = init_data(connect, collection) - query = update_query_expr(default_query, keep_old=False) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_no_vector_term_only(self, connect, collection): - ''' - method: build query without vector only term - expected: error raised - ''' - # entities, ids = init_data(connect, collection) - expr = { - "must": [gen_default_term_expr] - } - query = update_query_expr(default_query, keep_old=False, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_no_vector_range_only(self, connect, collection): - ''' - method: build query without vector only range - expected: error raised - ''' - # entities, ids = init_data(connect, collection) - expr = { - "must": [gen_default_range_expr] - } - query = update_query_expr(default_query, keep_old=False, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_vector_only(self, connect, collection): - entities, ids = init_data(connect, collection) - res = connect.search(collection, default_query) - assert len(res) == nq - assert len(res[0]) == default_top_k - - # PASS - def test_query_wrong_format(self, connect, collection): - ''' - method: build query without must expr, with wrong expr name - expected: error raised - ''' - # entities, ids = init_data(connect, collection) - expr = { - "must1": [gen_default_term_expr] - } - query = update_query_expr(default_query, keep_old=False, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_empty(self, connect, collection): - ''' - method: search with empty query - expected: error raised - ''' - query = {} - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - """ - ****************************************************************** - # The following cases are used to build valid query expr - ****************************************************************** - """ - - # PASS - @pytest.mark.level(2) - def test_query_term_value_not_in(self, connect, collection): - ''' - method: build query with vector and term expr, with no term can be filtered - expected: filter pass - ''' - entities, ids = init_data(connect, collection) - expr = { - "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - # TODO: - - # PASS - @pytest.mark.level(2) - def test_query_term_value_all_in(self, connect, collection): - ''' - method: build query with vector and term expr, with all term can be filtered - expected: filter pass - ''' - entities, ids = init_data(connect, collection) - expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 1 - # TODO: - - # PASS - @pytest.mark.level(2) - def test_query_term_values_not_in(self, connect, collection): - ''' - method: build query with vector and term expr, with no term can be filtered - expected: filter pass - ''' - entities, ids = init_data(connect, collection) - expr = {"must": [gen_default_vector_expr(default_query), - gen_default_term_expr(values=[i for i in range(100000, 100010)])]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - # TODO: - - # PASS - def test_query_term_values_all_in(self, connect, collection): - ''' - method: build query with vector and term expr, with all term can be filtered - expected: filter pass - ''' - entities, ids = init_data(connect, collection) - expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - limit = default_nb // 2 - for i in range(nq): - for result in res[i]: - logging.getLogger().info(result.id) - assert result.id in ids[:limit] - # TODO: - - # PASS - def test_query_term_values_parts_in(self, connect, collection): - ''' - method: build query with vector and term expr, with parts of term can be filtered - expected: filter pass - ''' - entities, ids = init_data(connect, collection) - expr = {"must": [gen_default_vector_expr(default_query), - gen_default_term_expr( - values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - # TODO: - - # PASS - @pytest.mark.level(2) - def test_query_term_values_repeat(self, connect, collection): - ''' - method: build query with vector and term expr, with the same values - expected: filter pass - ''' - entities, ids = init_data(connect, collection) - expr = { - "must": [gen_default_vector_expr(default_query), - gen_default_term_expr(values=[1 for i in range(1, default_nb)])]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 1 - # TODO: - - # DOG: BUG, please fix - @pytest.mark.skip("query_term_value_empty") - def test_query_term_value_empty(self, connect, collection): - ''' - method: build query with term value empty - expected: return null - ''' - expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - - # PASS - def test_query_complex_dsl(self, connect, collection): - ''' - method: query with complicated dsl - expected: no error raised - ''' - expr = {"must": [ - {"must": [{"should": [gen_default_term_expr(values=[1]), gen_default_range_expr()]}]}, - {"must": [gen_default_vector_expr(default_query)]} - ]} - logging.getLogger().info(expr) - query = update_query_expr(default_query, expr=expr) - logging.getLogger().info(query) - res = connect.search(collection, query) - logging.getLogger().info(res) - - """ - ****************************************************************** - # The following cases are used to build invalid term query expr - ****************************************************************** - """ - - # PASS - @pytest.mark.level(2) - def test_query_term_key_error(self, connect, collection): - ''' - method: build query with term key error - expected: Exception raised - ''' - expr = {"must": [gen_default_vector_expr(default_query), - gen_default_term_expr(keyword="terrm", values=[i for i in range(default_nb // 2)])]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - @pytest.fixture( - scope="function", - params=gen_invalid_term() - ) - def get_invalid_term(self, request): - return request.param - - # PASS - @pytest.mark.level(2) - def test_query_term_wrong_format(self, connect, collection, get_invalid_term): - ''' - method: build query with wrong format term - expected: Exception raised - ''' - entities, ids = init_data(connect, collection) - term = get_invalid_term - expr = {"must": [gen_default_vector_expr(default_query), term]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # DOG: PLEASE IMPLEMENT connect.count_entities - # TODO - @pytest.mark.skip("query_term_field_named_term") - @pytest.mark.level(2) - def test_query_term_field_named_term(self, connect, collection): - ''' - method: build query with field named "term" - expected: error raised - ''' - term_fields = add_field_default(default_fields, field_name="term") - collection_term = gen_unique_str("term") - connect.create_collection(collection_term, term_fields) - term_entities = add_field(entities, field_name="term") - ids = connect.insert(collection_term, term_entities) - assert len(ids) == default_nb - connect.flush([collection_term]) - count = connect.count_entities(collection_term) # count_entities is not impelmented - assert count == default_nb # removing these two lines, this test passed - term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}} - expr = {"must": [gen_default_vector_expr(default_query), - term_param]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection_term, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - connect.drop_collection(collection_term) - - # PASS - @pytest.mark.level(2) - def test_query_term_one_field_not_existed(self, connect, collection): - ''' - method: build query with two fields term, one of it not existed - expected: exception raised - ''' - entities, ids = init_data(connect, collection) - term = gen_default_term_expr() - term["term"].update({"a": [0]}) - expr = {"must": [gen_default_vector_expr(default_query), term]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - """ - ****************************************************************** - # The following cases are used to build valid range query expr - ****************************************************************** - """ - - # PASS - def test_query_range_key_error(self, connect, collection): - ''' - method: build query with range key error - expected: Exception raised - ''' - range = gen_default_range_expr(keyword="ranges") - expr = {"must": [gen_default_vector_expr(default_query), range]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - @pytest.fixture( - scope="function", - params=gen_invalid_range() - ) - def get_invalid_range(self, request): - return request.param - - # PASS - @pytest.mark.level(2) - def test_query_range_wrong_format(self, connect, collection, get_invalid_range): - ''' - method: build query with wrong format range - expected: Exception raised - ''' - entities, ids = init_data(connect, collection) - range = get_invalid_range - expr = {"must": [gen_default_vector_expr(default_query), range]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - @pytest.mark.level(2) - def test_query_range_string_ranges(self, connect, collection): - ''' - method: build query with invalid ranges - expected: raise Exception - ''' - entities, ids = init_data(connect, collection) - ranges = {"GT": "0", "LT": "1000"} - range = gen_default_range_expr(ranges=ranges) - expr = {"must": [gen_default_vector_expr(default_query), range]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - @pytest.mark.level(2) - def test_query_range_invalid_ranges(self, connect, collection): - ''' - method: build query with invalid ranges - expected: 0 - ''' - entities, ids = init_data(connect, collection) - ranges = {"GT": default_nb, "LT": 0} - range = gen_default_range_expr(ranges=ranges) - expr = {"must": [gen_default_vector_expr(default_query), range]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception): - res = connect.search(collection, query) - assert len(res[0]) == 0 - - @pytest.fixture( - scope="function", - params=gen_valid_ranges() - ) - def get_valid_ranges(self, request): - return request.param - - # PASS - @pytest.mark.level(2) - def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges): - ''' - method: build query with valid ranges - expected: pass - ''' - entities, ids = init_data(connect, collection) - ranges = get_valid_ranges - range = gen_default_range_expr(ranges=ranges) - expr = {"must": [gen_default_vector_expr(default_query), range]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - - # PASS - def test_query_range_one_field_not_existed(self, connect, collection): - ''' - method: build query with two fields ranges, one of fields not existed - expected: exception raised - ''' - entities, ids = init_data(connect, collection) - range = gen_default_range_expr() - range["range"].update({"a": {"GT": 1, "LT": default_nb // 2}}) - expr = {"must": [gen_default_vector_expr(default_query), range]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - """ - ************************************************************************ - # The following cases are used to build query expr multi range and term - ************************************************************************ - """ - - # PASS - def test_query_multi_term_has_common(self, connect, collection): - ''' - method: build query with multi term with same field, and values has common - expected: pass - ''' - entities, ids = init_data(connect, collection) - term_first = gen_default_term_expr() - term_second = gen_default_term_expr(values=[i for i in range(default_nb // 3)]) - expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - - # PASS - @pytest.mark.level(2) - def test_query_multi_term_no_common(self, connect, collection): - ''' - method: build query with multi range with same field, and ranges no common - expected: pass - ''' - entities, ids = init_data(connect, collection) - term_first = gen_default_term_expr() - term_second = gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)]) - expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - - # PASS - def test_query_multi_term_different_fields(self, connect, collection): - ''' - method: build query with multi range with same field, and ranges no common - expected: pass - ''' - entities, ids = init_data(connect, collection) - term_first = gen_default_term_expr() - term_second = gen_default_term_expr(field="float", - values=[float(i) for i in range(default_nb // 2, default_nb)]) - expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - - # PASS - @pytest.mark.level(2) - def test_query_single_term_multi_fields(self, connect, collection): - ''' - method: build query with multi term, different field each term - expected: pass - ''' - entities, ids = init_data(connect, collection) - term_first = {"int64": {"values": [i for i in range(default_nb // 2)]}} - term_second = {"float": {"values": [float(i) for i in range(default_nb // 2, default_nb)]}} - term = update_term_expr({"term": {}}, [term_first, term_second]) - expr = {"must": [gen_default_vector_expr(default_query), term]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - @pytest.mark.level(2) - def test_query_multi_range_has_common(self, connect, collection): - ''' - method: build query with multi range with same field, and ranges has common - expected: pass - ''' - entities, ids = init_data(connect, collection) - range_one = gen_default_range_expr() - range_two = gen_default_range_expr(ranges={"GT": 1, "LT": default_nb // 3}) - expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - - # PASS - @pytest.mark.level(2) - def test_query_multi_range_no_common(self, connect, collection): - ''' - method: build query with multi range with same field, and ranges no common - expected: pass - ''' - entities, ids = init_data(connect, collection) - range_one = gen_default_range_expr() - range_two = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb}) - expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - - # PASS - @pytest.mark.level(2) - def test_query_multi_range_different_fields(self, connect, collection): - ''' - method: build query with multi range, different field each range - expected: pass - ''' - entities, ids = init_data(connect, collection) - range_first = gen_default_range_expr() - range_second = gen_default_range_expr(field="float", ranges={"GT": default_nb // 2, "LT": default_nb}) - expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - - # PASS - @pytest.mark.level(2) - def test_query_single_range_multi_fields(self, connect, collection): - ''' - method: build query with multi range, different field each range - expected: pass - ''' - entities, ids = init_data(connect, collection) - range_first = {"int64": {"GT": 0, "LT": default_nb // 2}} - range_second = {"float": {"GT": default_nb / 2, "LT": float(default_nb)}} - range = update_range_expr({"range": {}}, [range_first, range_second]) - expr = {"must": [gen_default_vector_expr(default_query), range]} - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - """ - ****************************************************************** - # The following cases are used to build query expr both term and range - ****************************************************************** - """ - - # PASS - @pytest.mark.level(2) - def test_query_single_term_range_has_common(self, connect, collection): - ''' - method: build query with single term single range - expected: pass - ''' - entities, ids = init_data(connect, collection) - term = gen_default_term_expr() - range = gen_default_range_expr(ranges={"GT": -1, "LT": default_nb // 2}) - expr = {"must": [gen_default_vector_expr(default_query), term, range]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == default_top_k - - # PASS - def test_query_single_term_range_no_common(self, connect, collection): - ''' - method: build query with single term single range - expected: pass - ''' - entities, ids = init_data(connect, collection) - term = gen_default_term_expr() - range = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb}) - expr = {"must": [gen_default_vector_expr(default_query), term, range]} - query = update_query_expr(default_query, expr=expr) - res = connect.search(collection, query) - assert len(res) == nq - assert len(res[0]) == 0 - - """ - ****************************************************************** - # The following cases are used to build multi vectors query expr - ****************************************************************** - """ - - # PASS - def test_query_multi_vectors_same_field(self, connect, collection): - ''' - method: build query with two vectors same field - expected: error raised - ''' - entities, ids = init_data(connect, collection) - vector1 = default_query - vector2 = gen_query_vectors(field_name, entities, default_top_k, nq=2) - expr = { - "must": [vector1, vector2] - } - query = update_query_expr(default_query, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - -class TestSearchDSLBools(object): - """ - ****************************************************************** - # The following cases are used to build invalid query expr - ****************************************************************** - """ - - # PASS - @pytest.mark.level(2) - def test_query_no_bool(self, connect, collection): - ''' - method: build query without bool expr - expected: error raised - ''' - entities, ids = init_data(connect, collection) - expr = {"bool1": {}} - query = expr - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_should_only_term(self, connect, collection): - ''' - method: build query without must, with should.term instead - expected: error raised - ''' - expr = {"should": gen_default_term_expr} - query = update_query_expr(default_query, keep_old=False, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_should_only_vector(self, connect, collection): - ''' - method: build query without must, with should.vector instead - expected: error raised - ''' - expr = {"should": default_query["bool"]["must"]} - query = update_query_expr(default_query, keep_old=False, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_must_not_only_term(self, connect, collection): - ''' - method: build query without must, with must_not.term instead - expected: error raised - ''' - expr = {"must_not": gen_default_term_expr} - query = update_query_expr(default_query, keep_old=False, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_must_not_vector(self, connect, collection): - ''' - method: build query without must, with must_not.vector instead - expected: error raised - ''' - expr = {"must_not": default_query["bool"]["must"]} - query = update_query_expr(default_query, keep_old=False, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # PASS - def test_query_must_should(self, connect, collection): - ''' - method: build query must, and with should.term - expected: error raised - ''' - expr = {"should": gen_default_term_expr} - query = update_query_expr(default_query, keep_old=True, expr=expr) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - -""" -****************************************************************** -# The following cases are used to test `search` function -# with invalid collection_name, or invalid query expr -****************************************************************** -""" - - -class TestSearchInvalid(object): - """ - Test search collection with invalid collection names - """ - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_invalid_tag(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_invalid_strs() - ) - def get_invalid_field(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=gen_simple_index() - ) - def get_simple_index(self, request, connect): - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param - - # PASS - @pytest.mark.level(2) - def test_search_with_invalid_collection(self, connect, get_collection_name): - collection_name = get_collection_name - with pytest.raises(Exception) as e: - res = connect.search(collection_name, default_query) - - # PASS - # TODO(yukun) - @pytest.mark.level(2) - def test_search_with_invalid_tag(self, connect, collection): - tag = " " - with pytest.raises(Exception) as e: - res = connect.search(collection, default_query, partition_tags=tag) - - # TODO: reopen after we supporting targetEntry - @pytest.mark.skip("search_with_invalid_field_name") - @pytest.mark.level(2) - def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field): - fields = [get_invalid_field] - with pytest.raises(Exception) as e: - res = connect.search(collection, default_query, fields=fields) - - # TODO: reopen after we supporting targetEntry - @pytest.mark.skip("search_with_not_existed_field_name") - @pytest.mark.level(1) - def test_search_with_not_existed_field_name(self, connect, collection): - fields = [gen_unique_str("field_name")] - with pytest.raises(Exception) as e: - res = connect.search(collection, default_query, fields=fields) - - """ - Test search collection with invalid query - """ - - @pytest.fixture( - scope="function", - params=gen_invalid_ints() - ) - def get_top_k(self, request): - yield request.param - - @pytest.mark.level(1) - def test_search_with_invalid_top_k(self, connect, collection, get_top_k): - ''' - target: test search function, with the wrong top_k - method: search with top_k - expected: raise an error, and the connection is normal - ''' - top_k = get_top_k - default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k - with pytest.raises(Exception) as e: - res = connect.search(collection, default_query) - - """ - Test search collection with invalid search params - """ - - @pytest.fixture( - scope="function", - params=gen_invaild_search_params() - ) - def get_search_params(self, request): - yield request.param - - # Pass - @pytest.mark.level(2) - def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): - ''' - target: test search function, with the wrong nprobe - method: search with nprobe - expected: raise an error, and the connection is normal - ''' - search_params = get_search_params - index_type = get_simple_index["index_type"] - if index_type in ["FLAT"]: - pytest.skip("skip in FLAT index") - if index_type != search_params["index_type"]: - pytest.skip("skip if index_type not matched") - entities, ids = init_data(connect, collection) - connect.create_index(collection, field_name, get_simple_index) - query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, - search_params=search_params["search_params"]) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - # pass - @pytest.mark.level(2) - def test_search_with_invalid_params_binary(self, connect, binary_collection): - ''' - target: test search function, with the wrong nprobe - method: search with nprobe - expected: raise an error, and the connection is normal - ''' - nq = 1 - index_type = "BIN_IVF_FLAT" - int_vectors, entities, ids = init_binary_data(connect, binary_collection) - query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False) - connect.create_index(binary_collection, binary_field_name, - {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 128}}) - query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, - search_params={"nprobe": 0}, metric_type="JACCARD") - with pytest.raises(Exception) as e: - res = connect.search(binary_collection, query) - - # Pass - @pytest.mark.level(2) - def test_search_with_empty_params(self, connect, collection, args, get_simple_index): - ''' - target: test search function, with empty search params - method: search with params - expected: raise an error, and the connection is normal - ''' - index_type = get_simple_index["index_type"] - if args["handler"] == "HTTP": - pytest.skip("skip in http mode") - if index_type == "FLAT": - pytest.skip("skip in FLAT index") - entities, ids = init_data(connect, collection) - connect.create_index(collection, field_name, get_simple_index) - query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params={}) - with pytest.raises(Exception) as e: - res = connect.search(collection, query) - - -def check_id_result(result, id): - limit_in = 5 - ids = [entity.id for entity in result] - if len(result) >= limit_in: - return id in ids[:limit_in] - else: - return id in ids diff --git a/tests/python/utils.py b/tests/python/utils.py deleted file mode 100644 index 2d5f72507d..0000000000 --- a/tests/python/utils.py +++ /dev/null @@ -1,1004 +0,0 @@ -import grpc - -import os -import sys -import random -import pdb -import string -import struct -import logging -import threading -import time -import copy -import numpy as np -from sklearn import preprocessing -from milvus import Milvus, DataType - -port = 19530 -epsilon = 0.000001 -namespace = "milvus" - -default_flush_interval = 1 -big_flush_interval = 1000 -default_drop_interval = 3 -default_dim = 128 -default_nb = 1200 -default_top_k = 10 -max_top_k = 16384 -max_partition_num = 4096 -default_segment_row_limit = 1000 -default_server_segment_row_limit = 1024 * 512 -default_float_vec_field_name = "float_vector" -default_binary_vec_field_name = "binary_vector" -default_partition_name = "_default" -default_tag = "1970_01_01" - -# TODO: -# TODO: disable RHNSW_SQ/PQ in 0.11.0 -all_index_types = [ - "FLAT", - "IVF_FLAT", - "IVF_SQ8", - "IVF_SQ8_HYBRID", - "IVF_PQ", - "HNSW", - # "NSG", - "ANNOY", - "RHNSW_PQ", - "RHNSW_SQ", - "BIN_FLAT", - "BIN_IVF_FLAT" -] - -default_index_params = [ - {"nlist": 128}, - {"nlist": 128}, - {"nlist": 128}, - {"nlist": 128}, - {"nlist": 128, "m": 16, "nbits": 8}, - {"M": 48, "efConstruction": 500}, - # {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50}, - {"n_trees": 50}, - {"M": 48, "efConstruction": 500, "PQM": 64}, - {"M": 48, "efConstruction": 500}, - {"nlist": 128}, - {"nlist": 128} -] - - -def index_cpu_not_support(): - return ["IVF_SQ8_HYBRID"] - - -def binary_support(): - return ["BIN_FLAT", "BIN_IVF_FLAT"] - - -def delete_support(): - return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"] - - -def ivf(): - return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"] - - -def skip_pq(): - return ["IVF_PQ", "RHNSW_PQ", "RHNSW_SQ"] - - -def binary_metrics(): - return ["JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", "SUPERSTRUCTURE"] - - -def structure_metrics(): - return ["SUBSTRUCTURE", "SUPERSTRUCTURE"] - - -def l2(x, y): - return np.linalg.norm(np.array(x) - np.array(y)) - - -def ip(x, y): - return np.inner(np.array(x), np.array(y)) - - -def jaccard(x, y): - x = np.asarray(x, np.bool) - y = np.asarray(y, np.bool) - return 1 - np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum()) - - -def hamming(x, y): - x = np.asarray(x, np.bool) - y = np.asarray(y, np.bool) - return np.bitwise_xor(x, y).sum() - - -def tanimoto(x, y): - x = np.asarray(x, np.bool) - y = np.asarray(y, np.bool) - return -np.log2(np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum())) - - -def substructure(x, y): - x = np.asarray(x, np.bool) - y = np.asarray(y, np.bool) - return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(y) - - -def superstructure(x, y): - x = np.asarray(x, np.bool) - y = np.asarray(y, np.bool) - return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x) - - -def get_milvus(host, port, uri=None, handler=None, **kwargs): - if handler is None: - handler = "GRPC" - try_connect = kwargs.get("try_connect", True) - if uri is not None: - milvus = Milvus(uri=uri, handler=handler, try_connect=try_connect) - else: - milvus = Milvus(host=host, port=port, handler=handler, try_connect=try_connect) - return milvus - - -def reset_build_index_threshold(connect): - connect.set_config("engine", "build_index_threshold", 1024) - - -def disable_flush(connect): - connect.set_config("storage", "auto_flush_interval", big_flush_interval) - - -def enable_flush(connect): - # reset auto_flush_interval=1 - connect.set_config("storage", "auto_flush_interval", default_flush_interval) - config_value = connect.get_config("storage", "auto_flush_interval") - assert config_value == str(default_flush_interval) - - -def gen_inaccuracy(num): - return num / 255.0 - - -def gen_vectors(num, dim, is_normal=True): - vectors = [[random.random() for _ in range(dim)] for _ in range(num)] - vectors = preprocessing.normalize(vectors, axis=1, norm='l2') - return vectors.tolist() - - -# def gen_vectors(num, dim, seed=np.random.RandomState(1234), is_normal=False): -# xb = seed.rand(num, dim).astype("float32") -# xb = preprocessing.normalize(xb, axis=1, norm='l2') -# return xb.tolist() - - -def gen_binary_vectors(num, dim): - raw_vectors = [] - binary_vectors = [] - for i in range(num): - raw_vector = [random.randint(0, 1) for i in range(dim)] - raw_vectors.append(raw_vector) - binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) - return raw_vectors, binary_vectors - - -def gen_binary_sub_vectors(vectors, length): - raw_vectors = [] - binary_vectors = [] - dim = len(vectors[0]) - for i in range(length): - raw_vector = [0 for i in range(dim)] - vector = vectors[i] - for index, j in enumerate(vector): - if j == 1: - raw_vector[index] = 1 - raw_vectors.append(raw_vector) - binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) - return raw_vectors, binary_vectors - - -def gen_binary_super_vectors(vectors, length): - raw_vectors = [] - binary_vectors = [] - dim = len(vectors[0]) - for i in range(length): - cnt_1 = np.count_nonzero(vectors[i]) - raw_vector = [1 for i in range(dim)] - raw_vectors.append(raw_vector) - binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) - return raw_vectors, binary_vectors - - -def gen_int_attr(row_num): - return [random.randint(0, 255) for _ in range(row_num)] - - -def gen_float_attr(row_num): - return [random.uniform(0, 255) for _ in range(row_num)] - - -def gen_unique_str(str_value=None): - prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) - return "test_" + prefix if str_value is None else str_value + "_" + prefix - - -def gen_single_filter_fields(): - fields = [] - for data_type in DataType: - if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]: - fields.append({"name": data_type.name, "type": data_type}) - return fields - - -def gen_single_vector_fields(): - fields = [] - for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: - field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}} - fields.append(field) - return fields - - -def gen_default_fields(auto_id=True): - default_fields = { - "fields": [ - {"name": "int64", "type": DataType.INT64}, - {"name": "float", "type": DataType.FLOAT}, - {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, - "params": {"dim": default_dim}}, - ], - "segment_row_limit": default_segment_row_limit, - "auto_id": auto_id - } - return default_fields - - -def gen_binary_default_fields(auto_id=True): - default_fields = { - "fields": [ - {"name": "int64", "type": DataType.INT64}, - {"name": "float", "type": DataType.FLOAT}, - {"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": default_dim}} - ], - "segment_row_limit": default_segment_row_limit, - "auto_id": auto_id - } - return default_fields - - -def gen_entities(nb, is_normal=False): - vectors = gen_vectors(nb, default_dim, is_normal) - entities = [ - {"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]}, - {"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]}, - {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors} - ] - return entities - - -def gen_entities_new(nb, is_normal=False): - vectors = gen_vectors(nb, default_dim, is_normal) - entities = [ - {"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]}, - {"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]}, - {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors} - ] - return entities - - -def gen_entities_rows(nb, is_normal=False, _id=True): - vectors = gen_vectors(nb, default_dim, is_normal) - entities = [] - if not _id: - for i in range(nb): - entity = { - "_id": i, - "int64": i, - "float": float(i), - default_float_vec_field_name: vectors[i] - } - entities.append(entity) - else: - for i in range(nb): - entity = { - "int64": i, - "float": float(i), - default_float_vec_field_name: vectors[i] - } - entities.append(entity) - return entities - - -def gen_binary_entities(nb): - raw_vectors, vectors = gen_binary_vectors(nb, default_dim) - entities = [ - {"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]}, - {"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]}, - {"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors} - ] - return raw_vectors, entities - - -def gen_binary_entities_new(nb): - raw_vectors, vectors = gen_binary_vectors(nb, default_dim) - entities = [ - {"name": "int64", "values": [i for i in range(nb)]}, - {"name": "float", "values": [float(i) for i in range(nb)]}, - {"name": default_binary_vec_field_name, "values": vectors} - ] - return raw_vectors, entities - - -def gen_binary_entities_rows(nb, _id=True): - raw_vectors, vectors = gen_binary_vectors(nb, default_dim) - entities = [] - if not _id: - for i in range(nb): - entity = { - "_id": i, - "int64": i, - "float": float(i), - default_binary_vec_field_name: vectors[i] - } - entities.append(entity) - else: - for i in range(nb): - entity = { - "int64": i, - "float": float(i), - default_binary_vec_field_name: vectors[i] - } - entities.append(entity) - return raw_vectors, entities - - -def gen_entities_by_fields(fields, nb, dim): - entities = [] - for field in fields: - if field["type"] in [DataType.INT32, DataType.INT64]: - field_value = [1 for i in range(nb)] - elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]: - field_value = [3.0 for i in range(nb)] - elif field["type"] == DataType.BINARY_VECTOR: - field_value = gen_binary_vectors(nb, dim)[1] - elif field["type"] == DataType.FLOAT_VECTOR: - field_value = gen_vectors(nb, dim) - field.update({"values": field_value}) - entities.append(field) - return entities - - -def assert_equal_entity(a, b): - pass - - -def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False, - metric_type="L2", replace_vecs=None): - if rand_vector is True: - dimension = len(entities[-1]["values"][0]) - query_vectors = gen_vectors(nq, dimension) - else: - query_vectors = entities[-1]["values"][:nq] - if replace_vecs: - query_vectors = replace_vecs - must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}} - must_param["vector"][field_name]["metric_type"] = metric_type - query = { - "bool": { - "must": [must_param] - } - } - return query, query_vectors - - -def update_query_expr(src_query, keep_old=True, expr=None): - tmp_query = copy.deepcopy(src_query) - if expr is not None: - tmp_query["bool"].update(expr) - if keep_old is not True: - tmp_query["bool"].pop("must") - return tmp_query - - -def gen_default_vector_expr(default_query): - return default_query["bool"]["must"][0] - - -def gen_default_term_expr(keyword="term", field="int64", values=None): - if values is None: - values = [i for i in range(default_nb // 2)] - expr = {keyword: {field: {"values": values}}} - return expr - - -def update_term_expr(src_term, terms): - tmp_term = copy.deepcopy(src_term) - for term in terms: - tmp_term["term"].update(term) - return tmp_term - - -def gen_default_range_expr(keyword="range", field="int64", ranges=None): - if ranges is None: - ranges = {"GT": 1, "LT": default_nb // 2} - expr = {keyword: {field: ranges}} - return expr - - -def update_range_expr(src_range, ranges): - tmp_range = copy.deepcopy(src_range) - for range in ranges: - tmp_range["range"].update(range) - return tmp_range - - -def gen_invalid_range(): - range = [ - {"range": 1}, - {"range": {}}, - {"range": []}, - {"range": {"range": {"int64": {"GT": 0, "LT": default_nb // 2}}}} - ] - return range - - -def gen_valid_ranges(): - ranges = [ - {"GT": 0, "LT": default_nb // 2}, - {"GT": default_nb // 2, "LT": default_nb * 2}, - {"GT": 0}, - {"LT": default_nb}, - {"GT": -1, "LT": default_top_k}, - ] - return ranges - - -def gen_invalid_term(): - terms = [ - {"term": 1}, - {"term": []}, - {"term": {}}, - {"term": {"term": {"int64": {"values": [i for i in range(default_nb // 2)]}}}} - ] - return terms - - -def add_field_default(default_fields, type=DataType.INT64, field_name=None): - tmp_fields = copy.deepcopy(default_fields) - if field_name is None: - field_name = gen_unique_str() - field = { - "name": field_name, - "type": type - } - tmp_fields["fields"].append(field) - return tmp_fields - - -def add_field(entities, field_name=None): - nb = len(entities[0]["values"]) - tmp_entities = copy.deepcopy(entities) - if field_name is None: - field_name = gen_unique_str() - field = { - "name": field_name, - "type": DataType.INT64, - "values": [i for i in range(nb)] - } - tmp_entities.append(field) - return tmp_entities - - -def add_vector_field(entities, is_normal=False): - nb = len(entities[0]["values"]) - vectors = gen_vectors(nb, default_dim, is_normal) - field = { - "name": gen_unique_str(), - "type": DataType.FLOAT_VECTOR, - "values": vectors - } - entities.append(field) - return entities - - -# def update_fields_metric_type(fields, metric_type): -# tmp_fields = copy.deepcopy(fields) -# if metric_type in ["L2", "IP"]: -# tmp_fields["fields"][-1]["type"] = DataType.FLOAT_VECTOR -# else: -# tmp_fields["fields"][-1]["type"] = DataType.BINARY_VECTOR -# tmp_fields["fields"][-1]["params"]["metric_type"] = metric_type -# return tmp_fields - - -def remove_field(entities): - del entities[0] - return entities - - -def remove_vector_field(entities): - del entities[-1] - return entities - - -def update_field_name(entities, old_name, new_name): - tmp_entities = copy.deepcopy(entities) - for item in tmp_entities: - if item["name"] == old_name: - item["name"] = new_name - return tmp_entities - - -def update_field_type(entities, old_name, new_name): - tmp_entities = copy.deepcopy(entities) - for item in tmp_entities: - if item["name"] == old_name: - item["type"] = new_name - return tmp_entities - - -def update_field_value(entities, old_type, new_value): - tmp_entities = copy.deepcopy(entities) - for item in tmp_entities: - if item["type"] == old_type: - for index, value in enumerate(item["values"]): - item["values"][index] = new_value - return tmp_entities - - -def update_field_name_row(entities, old_name, new_name): - tmp_entities = copy.deepcopy(entities) - for item in tmp_entities: - if old_name in item: - item[new_name] = item[old_name] - item.pop(old_name) - else: - raise Exception("Field %s not in field" % old_name) - return tmp_entities - - -def update_field_type_row(entities, old_name, new_name): - tmp_entities = copy.deepcopy(entities) - for item in tmp_entities: - if old_name in item: - item["type"] = new_name - return tmp_entities - - -def add_vector_field(nb, dimension=default_dim): - field_name = gen_unique_str() - field = { - "name": field_name, - "type": DataType.FLOAT_VECTOR, - "values": gen_vectors(nb, dimension) - } - return field_name - - -def gen_segment_row_limits(): - sizes = [ - 1024, - 4096 - ] - return sizes - - -def gen_invalid_ips(): - ips = [ - # "255.0.0.0", - # "255.255.0.0", - # "255.255.255.0", - # "255.255.255.255", - "127.0.0", - # "123.0.0.2", - "12-s", - " ", - "12 s", - "BB。A", - " siede ", - "(mn)", - "中文", - "a".join("a" for _ in range(256)) - ] - return ips - - -def gen_invalid_uris(): - ip = None - uris = [ - " ", - "中文", - # invalid protocol - # "tc://%s:%s" % (ip, port), - # "tcp%s:%s" % (ip, port), - - # # invalid port - # "tcp://%s:100000" % ip, - # "tcp://%s: " % ip, - # "tcp://%s:19540" % ip, - # "tcp://%s:-1" % ip, - # "tcp://%s:string" % ip, - - # invalid ip - "tcp:// :19530", - # "tcp://123.0.0.1:%s" % port, - "tcp://127.0.0:19530", - # "tcp://255.0.0.0:%s" % port, - # "tcp://255.255.0.0:%s" % port, - # "tcp://255.255.255.0:%s" % port, - # "tcp://255.255.255.255:%s" % port, - "tcp://\n:19530", - ] - return uris - - -def gen_invalid_strs(): - strings = [ - 1, - [1], - None, - "12-s", - # " ", - # "", - # None, - "12 s", - "(mn)", - "中文", - "a".join("a" for i in range(256)) - ] - return strings - - -def gen_invalid_field_types(): - field_types = [ - # 1, - "=c", - # 0, - None, - "", - "a".join("a" for i in range(256)) - ] - return field_types - - -def gen_invalid_metric_types(): - metric_types = [ - 1, - "=c", - 0, - None, - "", - "a".join("a" for i in range(256)) - ] - return metric_types - - -# TODO: -def gen_invalid_ints(): - int_values = [ - # 1.0, - None, - [1, 2, 3], - " ", - "", - -1, - "String", - "=c", - "中文", - "a".join("a" for i in range(256)) - ] - return int_values - - -def gen_invalid_params(): - params = [ - 9999999999, - -1, - # None, - [1, 2, 3], - " ", - "", - "String", - "中文" - ] - return params - - -def gen_invalid_vectors(): - invalid_vectors = [ - "1*2", - [], - [1], - [1, 2], - [" "], - ['a'], - [None], - None, - (1, 2), - {"a": 1}, - " ", - "", - "String", - " siede ", - "中文", - "a".join("a" for i in range(256)) - ] - return invalid_vectors - - -def gen_invaild_search_params(): - invalid_search_key = 100 - search_params = [] - for index_type in all_index_types: - if index_type == "FLAT": - continue - search_params.append({"index_type": index_type, "search_params": {"invalid_key": invalid_search_key}}) - if index_type in delete_support(): - for nprobe in gen_invalid_params(): - ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}} - search_params.append(ivf_search_params) - elif index_type in ["HNSW", "RHNSW_PQ", "RHNSW_SQ"]: - for ef in gen_invalid_params(): - hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}} - search_params.append(hnsw_search_param) - elif index_type == "NSG": - for search_length in gen_invalid_params(): - nsg_search_param = {"index_type": index_type, "search_params": {"search_length": search_length}} - search_params.append(nsg_search_param) - search_params.append({"index_type": index_type, "search_params": {"invalid_key": 100}}) - elif index_type == "ANNOY": - for search_k in gen_invalid_params(): - if isinstance(search_k, int): - continue - annoy_search_param = {"index_type": index_type, "search_params": {"search_k": search_k}} - search_params.append(annoy_search_param) - return search_params - - -def gen_invalid_index(): - index_params = [] - for index_type in gen_invalid_strs(): - index_param = {"index_type": index_type, "params": {"nlist": 1024}} - index_params.append(index_param) - for nlist in gen_invalid_params(): - index_param = {"index_type": "IVF_FLAT", "params": {"nlist": nlist}} - index_params.append(index_param) - for M in gen_invalid_params(): - index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}} - index_param = {"index_type": "RHNSW_PQ", "params": {"M": M, "efConstruction": 100}} - index_param = {"index_type": "RHNSW_SQ", "params": {"M": M, "efConstruction": 100}} - index_params.append(index_param) - for efConstruction in gen_invalid_params(): - index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}} - index_param = {"index_type": "RHNSW_PQ", "params": {"M": 16, "efConstruction": efConstruction}} - index_param = {"index_type": "RHNSW_SQ", "params": {"M": 16, "efConstruction": efConstruction}} - index_params.append(index_param) - for search_length in gen_invalid_params(): - index_param = {"index_type": "NSG", - "params": {"search_length": search_length, "out_degree": 40, "candidate_pool_size": 50, - "knng": 100}} - index_params.append(index_param) - for out_degree in gen_invalid_params(): - index_param = {"index_type": "NSG", - "params": {"search_length": 100, "out_degree": out_degree, "candidate_pool_size": 50, - "knng": 100}} - index_params.append(index_param) - for candidate_pool_size in gen_invalid_params(): - index_param = {"index_type": "NSG", "params": {"search_length": 100, "out_degree": 40, - "candidate_pool_size": candidate_pool_size, - "knng": 100}} - index_params.append(index_param) - index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}}) - index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}}) - index_params.append({"index_type": "RHNSW_PQ", "params": {"invalid_key": 16, "efConstruction": 100}}) - index_params.append({"index_type": "RHNSW_SQ", "params": {"invalid_key": 16, "efConstruction": 100}}) - index_params.append({"index_type": "NSG", - "params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300, - "knng": 100}}) - for invalid_n_trees in gen_invalid_params(): - index_params.append({"index_type": "ANNOY", "params": {"n_trees": invalid_n_trees}}) - - return index_params - - -def gen_index(): - nlists = [1, 1024, 16384] - pq_ms = [128, 64, 32, 16, 8, 4] - Ms = [5, 24, 48] - efConstructions = [100, 300, 500] - search_lengths = [10, 100, 300] - out_degrees = [5, 40, 300] - candidate_pool_sizes = [50, 100, 300] - knngs = [5, 100, 300] - - index_params = [] - for index_type in all_index_types: - if index_type in ["FLAT", "BIN_FLAT", "BIN_IVF_FLAT"]: - index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}}) - elif index_type in ["IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID"]: - ivf_params = [{"index_type": index_type, "index_param": {"nlist": nlist}} \ - for nlist in nlists] - index_params.extend(ivf_params) - elif index_type == "IVF_PQ": - IVFPQ_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \ - for nlist in nlists \ - for m in pq_ms] - index_params.extend(IVFPQ_params) - elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]: - hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \ - for M in Ms \ - for efConstruction in efConstructions] - index_params.extend(hnsw_params) - elif index_type == "NSG": - nsg_params = [{"index_type": index_type, - "index_param": {"search_length": search_length, "out_degree": out_degree, - "candidate_pool_size": candidate_pool_size, "knng": knng}} \ - for search_length in search_lengths \ - for out_degree in out_degrees \ - for candidate_pool_size in candidate_pool_sizes \ - for knng in knngs] - index_params.extend(nsg_params) - - return index_params - - -def gen_simple_index(): - index_params = [] - for i in range(len(all_index_types)): - if all_index_types[i] in binary_support(): - continue - dic = {"index_type": all_index_types[i], "metric_type": "L2"} - dic.update({"params": default_index_params[i]}) - index_params.append(dic) - return index_params - - -def gen_binary_index(): - index_params = [] - for i in range(len(all_index_types)): - if all_index_types[i] in binary_support(): - dic = {"index_type": all_index_types[i]} - dic.update({"params": default_index_params[i]}) - index_params.append(dic) - return index_params - - -def get_search_param(index_type, metric_type="L2"): - search_params = {"metric_type": metric_type} - if index_type in ivf() or index_type in binary_support(): - search_params.update({"nprobe": 64}) - elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]: - search_params.update({"ef": 64}) - elif index_type == "NSG": - search_params.update({"search_length": 100}) - elif index_type == "ANNOY": - search_params.update({"search_k": 1000}) - else: - logging.getLogger().error("Invalid index_type.") - raise Exception("Invalid index_type.") - return search_params - - -def assert_equal_vector(v1, v2): - if len(v1) != len(v2): - assert False - for i in range(len(v1)): - assert abs(v1[i] - v2[i]) < epsilon - - -def restart_server(helm_release_name): - res = True - timeout = 120 - from kubernetes import client, config - client.rest.logger.setLevel(logging.WARNING) - - # service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace) - config.load_kube_config() - v1 = client.CoreV1Api() - pod_name = None - # config_map_names = v1.list_namespaced_config_map(namespace, pretty='true') - # body = {"replicas": 0} - pods = v1.list_namespaced_pod(namespace) - for i in pods.items: - if i.metadata.name.find(helm_release_name) != -1 and i.metadata.name.find("mysql") == -1: - pod_name = i.metadata.name - break - # v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true') - # status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true') - logging.getLogger().debug("Pod name: %s" % pod_name) - if pod_name is not None: - try: - v1.delete_namespaced_pod(pod_name, namespace) - except Exception as e: - logging.error(str(e)) - logging.error("Exception when calling CoreV1Api->delete_namespaced_pod") - res = False - return res - logging.error("Sleep 10s after pod deleted") - time.sleep(10) - # check if restart successfully - pods = v1.list_namespaced_pod(namespace) - for i in pods.items: - pod_name_tmp = i.metadata.name - logging.error(pod_name_tmp) - if pod_name_tmp == pod_name: - continue - elif pod_name_tmp.find(helm_release_name) == -1 or pod_name_tmp.find("mysql") != -1: - continue - else: - status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true') - logging.error(status_res.status.phase) - start_time = time.time() - ready_break = False - while time.time() - start_time <= timeout: - logging.error(time.time()) - status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true') - if status_res.status.phase == "Running": - logging.error("Already running") - ready_break = True - time.sleep(10) - break - else: - time.sleep(1) - if time.time() - start_time > timeout: - logging.error("Restart pod: %s timeout" % pod_name_tmp) - res = False - return res - if ready_break: - break - else: - raise Exception("Pod: %s not found" % pod_name) - follow = True - pretty = True - previous = True # bool | Return previous terminated container logs. Defaults to false. (optional) - since_seconds = 56 # int | A relative time in seconds before the current time from which to show logs. If this value precedes the time a pod was started, only logs since the pod start will be returned. If this value is in the future, no logs will be returned. Only one of sinceSeconds or sinceTime may be specified. (optional) - timestamps = True # bool | If true, add an RFC3339 or RFC3339Nano timestamp at the beginning of every line of log output. Defaults to false. (optional) - container = "milvus" - # start_time = time.time() - # while time.time() - start_time <= timeout: - # try: - # api_response = v1.read_namespaced_pod_log(pod_name_tmp, namespace, container=container, follow=follow, - # pretty=pretty, previous=previous, since_seconds=since_seconds, - # timestamps=timestamps) - # logging.error(api_response) - # return res - # except Exception as e: - # logging.error("Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e) - # # waiting for server start - # time.sleep(5) - # # res = False - # # return res - # if time.time() - start_time > timeout: - # logging.error("Restart pod: %s timeout" % pod_name_tmp) - # res = False - return res - - -class MilvusTestThread(threading.Thread): - def __init__(self, target, args=()): - threading.Thread.__init__(self, target=target, args=args) - - def run(self): - self.exc = None - try: - super(MilvusTestThread, self).run() - except BaseException as e: - self.exc = e - - def join(self): - super(MilvusTestThread, self).join() - if self.exc: - raise self.exc - - -class MockGrpcError(grpc.RpcError): - def __init__(self, code=1, details="error"): - self._code = code - self._details = details - - def code(self): - return self._code - - def details(self): - return self._details