diff --git a/tests/milvus_python_test/entity/test_insert.py b/tests/milvus_python_test/entity/test_insert.py index b673b17ff8..c736377a62 100644 --- a/tests/milvus_python_test/entity/test_insert.py +++ b/tests/milvus_python_test/entity/test_insert.py @@ -1,3 +1,4 @@ +import logging import time import pdb import copy @@ -11,6 +12,7 @@ from constants import * ADD_TIMEOUT = 60 uid = "test_insert" field_name = default_float_vec_field_name +binary_field_name = default_binary_vec_field_name default_single_query = { "bool": { "must": [ @@ -20,6 +22,7 @@ default_single_query = { } } + class TestInsertBase: """ ****************************************************************** @@ -116,6 +119,11 @@ class TestInsertBase: assert len(ids) == default_nb connect.flush([collection]) connect.create_index(collection, field_name, get_simple_index) + info = connect.get_collection_info(collection) + fields = info["fields"] + for field in fields: + if field["field"] == field_name: + assert field["indexes"][0] == get_simple_index @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_after_create_index(self, connect, collection, get_simple_index): @@ -127,6 +135,11 @@ class TestInsertBase: connect.create_index(collection, field_name, get_simple_index) ids = connect.insert(collection, default_entities) assert len(ids) == default_nb + info = connect.get_collection_info(collection) + fields = info["fields"] + for field in fields: + if field["field"] == field_name: + assert field["indexes"][0] == get_simple_index @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_search(self, connect, collection): @@ -300,6 +313,7 @@ class TestInsertBase: connect.create_partition(collection, default_tag) ids = connect.insert(collection, default_entities, partition_tag=default_tag) assert len(ids) == default_nb + assert connect.has_partition(collection, default_tag) @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_tag_with_ids(self, connect, id_collection): @@ -533,6 +547,96 @@ class TestInsertBase: assert res[0] is None +class TestInsertBinary: + @pytest.fixture( + scope="function", + params=gen_binary_index() + ) + def get_binary_index(self, request): + request.param["metric_type"] = "JACCARD" + return request.param + + def test_insert_binary_entities(self, connect, binary_collection): + ''' + target: test insert entities in binary collection + method: create collection and insert binary entities in it + expected: the collection row count equals to nb + ''' + ids = connect.insert(binary_collection, default_binary_entities) + assert len(ids) == default_nb + connect.flush() + assert connect.count_entities(binary_collection) == default_nb + + def test_insert_binary_tag(self, connect, binary_collection): + ''' + target: test insert entities and create partition tag + method: create collection and insert binary entities in it, with the partition_tag param + expected: the collection row count equals to nb + ''' + connect.create_partition(binary_collection, default_tag) + ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag) + assert len(ids) == default_nb + assert connect.has_partition(binary_collection, default_tag) + + def test_insert_binary_multi_times(self, connect, binary_collection): + ''' + target: test insert entities multi times and final flush + method: create collection and insert binary entity multi and final flush + expected: the collection row count equals to nb + ''' + for i in range(default_nb): + ids = connect.insert(binary_collection, default_binary_entity) + assert len(ids) == 1 + connect.flush([binary_collection]) + assert connect.count_entities(binary_collection) == default_nb + + def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index): + ''' + target: test insert binary entities after build index + method: build index and insert entities + expected: no error raised + ''' + connect.create_index(binary_collection, binary_field_name, get_binary_index) + ids = connect.insert(binary_collection, default_binary_entities) + assert len(ids) == default_nb + connect.flush([binary_collection]) + info = connect.get_collection_info(binary_collection) + fields = info["fields"] + for field in fields: + if field["field"] == binary_field_name: + assert field["indexes"][0] == get_binary_index + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index): + ''' + target: test build index insert after vector + method: insert vector and build index + expected: no error raised + ''' + ids = connect.insert(binary_collection, default_binary_entities) + assert len(ids) == default_nb + connect.flush([binary_collection]) + connect.create_index(binary_collection, binary_field_name, get_binary_index) + info = connect.get_collection_info(binary_collection) + fields = info["fields"] + for field in fields: + if field["field"] == binary_field_name: + assert field["indexes"][0] == get_binary_index + + def test_insert_binary_search(self, connect, binary_collection): + ''' + target: test search vector after insert vector after a while + method: insert vector, sleep, and search collection + expected: no error raised + ''' + ids = connect.insert(binary_collection, default_binary_entities) + connect.flush([binary_collection]) + query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD") + res = connect.search(binary_collection, query) + logging.getLogger().debug(res) + assert res + + class TestInsertAsync: @pytest.fixture(scope="function", autouse=True) def skip_http_check(self, args): @@ -940,15 +1044,13 @@ class TestInsertInvalidBinary(object): @pytest.mark.level(2) def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name): - field_name = get_field_name tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name) with pytest.raises(Exception): connect.insert(binary_collection, tmp_entity) @pytest.mark.level(2) def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value): - field_value = get_field_int_value - tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_value) + tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value) with pytest.raises(Exception): connect.insert(binary_collection, tmp_entity) @@ -972,13 +1074,6 @@ class TestInsertInvalidBinary(object): with pytest.raises(Exception): connect.insert(binary_id_collection, default_binary_entities, ids) - @pytest.mark.level(2) - def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name): - field_name = get_field_name - tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name) - with pytest.raises(Exception): - connect.insert(binary_collection, tmp_entity) - @pytest.mark.level(2) def test_insert_with_invalid_field_type(self, connect, binary_collection, get_field_type): field_type = get_field_type @@ -986,13 +1081,6 @@ class TestInsertInvalidBinary(object): with pytest.raises(Exception): connect.insert(binary_collection, tmp_entity) - @pytest.mark.level(2) - def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value): - field_value = get_field_int_value - tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_value) - with pytest.raises(Exception): - connect.insert(binary_collection, tmp_entity) - @pytest.mark.level(2) def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value): tmp_entity = copy.deepcopy(default_binary_entities)