diff --git a/tests/python/test_search.py b/tests/python/test_search.py index 605bd8507d..26a09d6348 100644 --- a/tests/python/test_search.py +++ b/tests/python/test_search.py @@ -147,7 +147,6 @@ class TestSearchBase: yield request.param # PASS - @pytest.mark.skip("r0.3-test") def test_search_flat(self, connect, collection, get_top_k, get_nq): ''' target: test basic search function, all the search params is corrent, change top-k value @@ -258,7 +257,6 @@ class TestSearchBase: assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") # Pass - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -306,8 +304,8 @@ class TestSearchBase: assert len(res[0]) == default_top_k # pass - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) + @pytest.mark.skip("r0.3-test") def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' target: test basic search function, all the search params is corrent, test all index params, and build @@ -338,7 +336,6 @@ class TestSearchBase: assert len(res) == nq # PASS - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -424,7 +421,6 @@ class TestSearchBase: assert res[1]._distances[0] > epsilon # Pass - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): ''' @@ -514,7 +510,6 @@ class TestSearchBase: assert check_id_result(res[0], ids[0]) assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -548,7 +543,6 @@ class TestSearchBase: assert len(res) == nq # PASS - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): ''' @@ -609,7 +603,6 @@ class TestSearchBase: res = connect.search(collection_name, default_query) # PASS - @pytest.mark.skip("r0.3-test") def test_search_distance_l2(self, connect, collection): ''' target: search collection, and check the result: distance @@ -629,7 +622,6 @@ class TestSearchBase: assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) # Pass - @pytest.mark.skip("r0.3-test") def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): ''' target: search collection, and check the result: distance @@ -684,7 +676,6 @@ class TestSearchBase: assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon # Pass - @pytest.mark.skip("r0.3-test") def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): ''' target: search collection, and check the result: distance @@ -771,7 +762,6 @@ class TestSearchBase: assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon # PASS - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_distance_substructure_flat_index(self, connect, binary_collection): ''' @@ -790,7 +780,6 @@ class TestSearchBase: assert len(res[0]) == 0 # PASS - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_distance_substructure_flat_index_B(self, connect, binary_collection): ''' @@ -810,7 +799,6 @@ class TestSearchBase: assert res[1][0].id == ids[1] # PASS - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_distance_superstructure_flat_index(self, connect, binary_collection): ''' @@ -829,7 +817,6 @@ class TestSearchBase: assert len(res[0]) == 0 # PASS - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): ''' @@ -851,7 +838,6 @@ class TestSearchBase: assert res[1][0].distance <= epsilon # PASS - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_distance_tanimoto_flat_index(self, connect, binary_collection): ''' @@ -985,7 +971,6 @@ class TestSearchBase: assert getattr(r.entity, "int64") == getattr(r.entity, "id") -@pytest.mark.skip("r0.3-test") class TestSearchDSL(object): """ ****************************************************************** @@ -1579,7 +1564,6 @@ class TestSearchDSL(object): res = connect.search(collection, query) -@pytest.mark.skip("r0.3-test") class TestSearchDSLBools(object): """ ****************************************************************** @@ -1766,7 +1750,6 @@ class TestSearchInvalid(object): yield request.param # Pass - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): ''' @@ -1788,7 +1771,6 @@ class TestSearchInvalid(object): res = connect.search(collection, query) # pass - @pytest.mark.skip("r0.3-test") @pytest.mark.level(2) def test_search_with_invalid_params_binary(self, connect, binary_collection): ''' diff --git a/tests/python_test/collection/test_collection_stats.py b/tests/python_test/collection/test_collection_stats.py index 8c16aaaa54..c13b391c7c 100644 --- a/tests/python_test/collection/test_collection_stats.py +++ b/tests/python_test/collection/test_collection_stats.py @@ -10,18 +10,19 @@ from constants import * uid = "get_collection_stats" + class TestGetCollectionStats: """ ****************************************************************** The following cases are used to test `collection_stats` function ****************************************************************** """ - + @pytest.fixture( scope="function", params=gen_invalid_strs() ) - def get_collection_name(self, request): + def get_invalid_collection_name(self, request): yield request.param @pytest.fixture( @@ -46,6 +47,17 @@ class TestGetCollectionStats: else: pytest.skip("Skip index Temporary") + @pytest.fixture( + scope="function", + params=[ + 1, + 1000, + 2001 + ], + ) + def insert_count(self, request): + yield request.param + def test_get_collection_stats_name_not_existed(self, connect, collection): ''' target: get collection stats where collection name does not exist @@ -53,22 +65,19 @@ class TestGetCollectionStats: expected: status not ok ''' collection_name = gen_unique_str(uid) - connect.create_collection(collection_name, default_fields) - connect.get_collection_stats(collection_name) - connect.drop_collection(collection_name) with pytest.raises(Exception) as e: connect.get_collection_stats(collection_name) @pytest.mark.level(2) - def test_get_collection_stats_name_invalid(self, connect, get_collection_name): + def test_get_collection_stats_name_invalid(self, connect, get_invalid_collection_name): ''' target: get collection stats where collection name is invalid method: call collection_stats with invalid collection_name expected: status not ok ''' - collection_name = get_collection_name + collection_name = get_invalid_collection_name with pytest.raises(Exception) as e: - stats = connect.get_collection_stats(collection_name) + connect.get_collection_stats(collection_name) def test_get_collection_stats_empty(self, connect, collection): ''' @@ -77,10 +86,17 @@ class TestGetCollectionStats: expected: segment = [] ''' stats = connect.get_collection_stats(collection) - assert stats["row_count"] == 0 - # assert len(stats["partitions"]) == 1 - # assert stats["partitions"][0]["tag"] == default_partition_name - # assert stats["partitions"][0]["row_count"] == 0 + connect.flush([collection]) + assert stats[row_count] == 0 + + def test_get_collection_stats_without_connection(self, collection, dis_connect): + ''' + target: test count_entities, without connection + method: calling count_entities with correct params, with a disconnected instance + expected: count_entities raise exception + ''' + with pytest.raises(Exception) as e: + dis_connect.get_collection_stats(collection) def test_get_collection_stats_batch(self, connect, collection): ''' @@ -89,12 +105,10 @@ class TestGetCollectionStats: expected: count as expected ''' ids = connect.insert(collection, default_entities) + assert len(ids) == default_nb connect.flush([collection]) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == default_nb - # assert len(stats["partitions"]) == 1 - # assert stats["partitions"][0]["tag"] == default_partition_name - # assert stats["partitions"][0]["row_count"] == default_nb + assert int(stats[row_count]) == default_nb def test_get_collection_stats_single(self, connect, collection): ''' @@ -104,13 +118,10 @@ class TestGetCollectionStats: ''' nb = 10 for i in range(nb): - ids = connect.insert(collection, default_entity) + connect.insert(collection, default_entity) connect.flush([collection]) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == nb - # assert len(stats["partitions"]) == 1 - # assert stats["partitions"][0]["tag"] == default_partition_name - # assert stats["partitions"][0]["row_count"] == nb + assert stats[row_count] == nb @pytest.mark.skip("delete_by_id not support yet") def test_get_collection_stats_after_delete(self, connect, collection): @@ -184,12 +195,10 @@ class TestGetCollectionStats: ''' connect.create_partition(collection, default_tag) ids = connect.insert(collection, default_entities, partition_tag=default_tag) + assert len(ids) == default_nb connect.flush([collection]) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == default_nb - # assert len(stats["partitions"]) == 2 - # assert stats["partitions"][1]["tag"] == default_tag - # assert stats["partitions"][1]["row_count"] == default_nb + assert stats[row_count] == default_nb def test_get_collection_stats_partitions(self, connect, collection): ''' @@ -200,26 +209,88 @@ class TestGetCollectionStats: new_tag = "new_tag" connect.create_partition(collection, default_tag) connect.create_partition(collection, new_tag) - ids = connect.insert(collection, default_entities, partition_tag=default_tag) + connect.insert(collection, default_entities, partition_tag=default_tag) connect.flush([collection]) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == default_nb - # for partition in stats["partitions"]: - # if partition["tag"] == default_tag: - # assert partition["row_count"] == default_nb - # else: - # assert partition["row_count"] == 0 - ids = connect.insert(collection, default_entities, partition_tag=new_tag) + assert stats[row_count] == default_nb + connect.insert(collection, default_entities, partition_tag=new_tag) connect.flush([collection]) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == default_nb * 2 - # for partition in stats["partitions"]: - # if partition["tag"] in [default_tag, new_tag]: - # assert partition["row_count"] == default_nb - ids = connect.insert(collection, default_entities) + assert stats[row_count] == default_nb * 2 + connect.insert(collection, default_entities) connect.flush([collection]) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == default_nb * 3 + assert stats[row_count] == default_nb * 3 + + # @pytest.mark.tags("0331") + def test_get_collection_stats_partitions_A(self, connect, collection, insert_count): + ''' + target: test collection rows_count is correct or not + method: create collection, create partitions and add entities in it, + assert the value returned by count_entities method is equal to length of entities + expected: the count is equal to the length of entities + ''' + new_tag = "new_tag" + entities = gen_entities(insert_count) + connect.create_partition(collection, default_tag) + connect.create_partition(collection, new_tag) + connect.insert(collection, entities) + connect.flush([collection]) + stats = connect.get_collection_stats(collection) + assert stats[row_count] == insert_count + + # @pytest.mark.tags("0331") + def test_get_collection_stats_partitions_B(self, connect, collection, insert_count): + ''' + target: test collection rows_count is correct or not + method: create collection, create partitions and add entities in one of the partitions, + assert the value returned by count_entities method is equal to length of entities + expected: the count is equal to the length of entities + ''' + new_tag = "new_tag" + entities = gen_entities(insert_count) + connect.create_partition(collection, default_tag) + connect.create_partition(collection, new_tag) + connect.insert(collection, entities, partition_tag=default_tag) + connect.flush([collection]) + stats = connect.get_collection_stats(collection) + assert stats[row_count] == insert_count + + # @pytest.mark.tags("0331") + def test_get_collection_stats_partitions_C(self, connect, collection, insert_count): + ''' + target: test collection rows_count is correct or not + method: create collection, create partitions and add entities in one of the partitions, + assert the value returned by count_entities method is equal to length of entities + expected: the count is equal to the length of vectors + ''' + new_tag = "new_tag" + entities = gen_entities(insert_count) + connect.create_partition(collection, default_tag) + connect.create_partition(collection, new_tag) + connect.insert(collection, entities) + connect.insert(collection, entities, partition_tag=default_tag) + connect.flush([collection]) + stats = connect.get_collection_stats(collection) + assert stats[row_count] == insert_count*2 + + # @pytest.mark.tags("0331") + def test_get_collection_stats_partitions_D(self, connect, collection, insert_count): + ''' + target: test collection rows_count is correct or not + method: create collection, create partitions and add entities in one of the partitions, + assert the value returned by count_entities method is equal to length of entities + expected: the collection count is equal to the length of entities + ''' + new_tag = "new_tag" + entities = gen_entities(insert_count) + connect.create_partition(collection, default_tag) + connect.create_partition(collection, new_tag) + connect.insert(collection, entities, partition_tag=default_tag) + connect.insert(collection, entities, partition_tag=new_tag) + connect.flush([collection]) + stats = connect.get_collection_stats(collection) + assert stats[row_count] == insert_count*2 # TODO: assert metric type in stats response def test_get_collection_stats_after_index_created(self, connect, collection, get_simple_index): @@ -228,17 +299,11 @@ class TestGetCollectionStats: method: create collection, add vectors, create index and call collection_stats expected: status ok, index created and shown in segments ''' - ids = connect.insert(collection, default_entities) + connect.insert(collection, default_entities) connect.flush([collection]) connect.create_index(collection, default_float_vec_field_name, get_simple_index) stats = connect.get_collection_stats(collection) - logging.getLogger().info(stats) - assert stats["row_count"] == default_nb - # for file in stats["partitions"][0]["segments"][0]["files"]: - # if file["name"] == default_float_vec_field_name and "index_type" in file: - # assert file["data_size"] > 0 - # assert file["index_type"] == get_simple_index["index_type"] - # break + assert stats[row_count] == default_nb # TODO: assert metric type in stats response def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index): @@ -249,16 +314,12 @@ class TestGetCollectionStats: ''' get_simple_index["metric_type"] = "IP" ids = connect.insert(collection, default_entities) + assert len(ids) == default_nb connect.flush([collection]) get_simple_index.update({"metric_type": "IP"}) connect.create_index(collection, default_float_vec_field_name, get_simple_index) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == default_nb - # for file in stats["partitions"][0]["segments"][0]["files"]: - # if file["name"] == default_float_vec_field_name and "index_type" in file: - # assert file["data_size"] > 0 - # assert file["index_type"] == get_simple_index["index_type"] - # break + assert stats[row_count] == default_nb # TODO: assert metric type in stats response def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index): @@ -269,14 +330,9 @@ class TestGetCollectionStats: ''' ids = connect.insert(binary_collection, default_binary_entities) connect.flush([binary_collection]) - connect.create_index(binary_collection, "binary_vector", get_jaccard_index) + connect.create_index(binary_collection, default_binary_vec_field_name, get_jaccard_index) stats = connect.get_collection_stats(binary_collection) - assert stats["row_count"] == default_nb - # for file in stats["partitions"][0]["segments"][0]["files"]: - # if file["name"] == default_float_vec_field_name and "index_type" in file: - # assert file["data_size"] > 0 - # assert file["index_type"] == get_simple_index["index_type"] - # break + assert stats[row_count] == default_nb def test_get_collection_stats_after_create_different_index(self, connect, collection): ''' @@ -288,14 +344,9 @@ class TestGetCollectionStats: connect.flush([collection]) for index_type in ["IVF_FLAT", "IVF_SQ8"]: connect.create_index(collection, default_float_vec_field_name, - {"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"}) + {"index_type": index_type, "params": {"nlist": 1024}, "metric_type": "L2"}) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == default_nb - # for file in stats["partitions"][0]["segments"][0]["files"]: - # if file["name"] == default_float_vec_field_name and "index_type" in file: - # assert file["data_size"] > 0 - # assert file["index_type"] == index_type - # break + assert stats[row_count] == default_nb def test_collection_count_multi_collections(self, connect): ''' @@ -310,12 +361,11 @@ class TestGetCollectionStats: collection_name = gen_unique_str(uid) collection_list.append(collection_name) connect.create_collection(collection_name, default_fields) - res = connect.insert(collection_name, default_entities) + ids = connect.insert(collection_name, default_entities) connect.flush(collection_list) for i in range(collection_num): stats = connect.get_collection_stats(collection_list[i]) - # assert stats["partitions"][0]["row_count"] == default_nb - assert stats["row_count"] == default_nb + assert stats[row_count] == default_nb connect.drop_collection(collection_list[i]) @pytest.mark.level(2) @@ -334,23 +384,19 @@ class TestGetCollectionStats: connect.create_collection(collection_name, default_fields) res = connect.insert(collection_name, default_entities) connect.flush(collection_list) + index_1 = {"index_type": "IVF_SQ8", "params": {"nlist": 1024}, "metric_type": "L2"} + index_2 = {"index_type": "IVF_FLAT", "params": {"nlist": 1024}, "metric_type": "L2"} if i % 2: - connect.create_index(collection_name, default_float_vec_field_name, - {"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"}) + connect.create_index(collection_name, default_float_vec_field_name, index_1) else: - connect.create_index(collection_name, default_float_vec_field_name, - {"index_type": "IVF_FLAT","params":{"nlist": 1024}, "metric_type": "L2"}) + connect.create_index(collection_name, default_float_vec_field_name, index_2) for i in range(collection_num): stats = connect.get_collection_stats(collection_list[i]) - assert stats["row_count"] == default_nb - # if i % 2: - # for file in stats["partitions"][0]["segments"][0]["files"]: - # if file["name"] == default_float_vec_field_name and "index_type" in file: - # assert file["index_type"] == "IVF_SQ8" - # break - # else: - # for file in stats["partitions"][0]["segments"][0]["files"]: - # if file["name"] == default_float_vec_field_name and "index_type" in file: - # assert file["index_type"] == "IVF_FLAT" - # break + assert stats[row_count] == default_nb + index = connect.describe_index(collection_list[i], default_float_vec_field_name) + if i % 2: + assert index == index_1 + else: + assert index == index_2 + # break connect.drop_collection(collection_list[i]) diff --git a/tests/python_test/collection/test_create_collection.py b/tests/python_test/collection/test_create_collection.py index e958f59688..cd812de008 100644 --- a/tests/python_test/collection/test_create_collection.py +++ b/tests/python_test/collection/test_create_collection.py @@ -6,19 +6,20 @@ import time import threading from multiprocessing import Process import sklearn.preprocessing - import pytest from utils import * from constants import * uid = "create_collection" + class TestCreateCollection: """ ****************************************************************** The following cases are used to test `create_collection` function ****************************************************************** """ + @pytest.fixture( scope="function", params=gen_single_filter_fields() @@ -52,8 +53,8 @@ class TestCreateCollection: vector_field = get_vector_field collection_name = gen_unique_str(uid) fields = { - "fields": [filter_field, vector_field], - # "segment_row_limit": default_segment_row_limit + "fields": [filter_field, vector_field], + # "segment_row_limit": default_segment_row_limit } logging.getLogger().info(fields) connect.create_collection(collection_name, fields) @@ -93,7 +94,7 @@ class TestCreateCollection: expected: error raised ''' connect.insert(collection, default_entity) - connect.flush([collection]) + # connect.flush([collection]) with pytest.raises(Exception) as e: connect.create_collection(collection, default_fields) @@ -140,7 +141,7 @@ class TestCreateCollection: method: create collection using multithread, expected: collections are created ''' - threads_num = 8 + threads_num = 8 threads = [] collection_names = [] @@ -148,6 +149,7 @@ class TestCreateCollection: collection_name = gen_unique_str(uid) collection_names.append(collection_name) connect.create_collection(collection_name, default_fields) + for i in range(threads_num): t = TestThread(target=create, args=()) threads.append(t) @@ -155,7 +157,7 @@ class TestCreateCollection: time.sleep(0.2) for t in threads: t.join() - + for item in collection_names: assert item in connect.list_collections() connect.drop_collection(item) @@ -165,6 +167,7 @@ class TestCreateCollectionInvalid(object): """ Test creating collections with invalid params """ + @pytest.fixture( scope="function", params=gen_invalid_metric_types() @@ -217,7 +220,7 @@ class TestCreateCollectionInvalid(object): fields = copy.deepcopy(default_fields) fields["fields"][-1]["params"]["dim"] = dimension with pytest.raises(Exception) as e: - connect.create_collection(collection_name, fields) + connect.create_collection(collection_name, fields) @pytest.mark.level(2) @pytest.mark.tags("0331") diff --git a/tests/python_test/entity/test_insert.py b/tests/python_test/entity/test_insert.py index bbbc478a0d..cf11fc07de 100644 --- a/tests/python_test/entity/test_insert.py +++ b/tests/python_test/entity/test_insert.py @@ -5,7 +5,7 @@ import copy import threading from multiprocessing import Pool, Process import pytest -from milvus import DataType +from milvus import DataType, ParamError, BaseException from utils import * from constants import * @@ -36,8 +36,9 @@ class TestInsertBase: ) def get_simple_index(self, request, connect): # if str(connect._cmd("mode")) == "CPU": - # if request.param["index_type"] in index_cpu_not_support(): - # pytest.skip("CPU not support index_type: ivf_sq8h") + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("CPU not support index_type: ivf_sq8h") + logging.getLogger().info(request.param) return request.param @pytest.fixture( @@ -54,6 +55,7 @@ class TestInsertBase: def get_vector_field(self, request): yield request.param + @pytest.mark.tags("0331") def test_insert_with_empty_entity(self, connect, collection): ''' target: test insert with empty entity list @@ -62,7 +64,7 @@ class TestInsertBase: ''' entities = [] with pytest.raises(ParamError) as e: - status, ids = connect.insert(collection, entities) + connect.insert(collection, entities) def test_insert_with_None(self, connect, collection): ''' @@ -71,10 +73,11 @@ class TestInsertBase: expected: raises a ParamError ''' entity = None - with pytest.raises(Exception) as e: - status, ids = connect.insert(collection, entity) + with pytest.raises(ParamError) as e: + connect.insert(collection, entity) @pytest.mark.timeout(ADD_TIMEOUT) + @pytest.mark.tags("0331") def test_insert_collection_not_existed(self, connect): ''' target: test insert, with collection not existed @@ -82,10 +85,11 @@ class TestInsertBase: expected: raise a BaseException ''' collection_name = gen_unique_str(uid) - with pytest.raises(Exception) as e: + with pytest.raises(BaseException) as e: connect.insert(collection_name, default_entities) @pytest.mark.level(2) + @pytest.mark.tags("0331") def test_insert_without_connect(self, dis_connect, collection): ''' target: test insert entities without connection @@ -93,28 +97,30 @@ class TestInsertBase: expected: raise exception ''' with pytest.raises(Exception) as e: - ids = dis_connect.insert(collection, default_entities) + dis_connect.insert(collection, default_entities) @pytest.mark.timeout(ADD_TIMEOUT) + @pytest.mark.tags("0331") def test_insert_drop_collection(self, connect, collection): ''' target: test delete collection after insert entities method: insert entities and drop collection expected: has_collection false ''' - ids = connect.insert(collection, default_entity_row) + ids = connect.insert(collection, default_entity) assert len(ids) == 1 connect.drop_collection(collection) assert connect.has_collection(collection) == False @pytest.mark.timeout(ADD_TIMEOUT) + @pytest.mark.tags("0331") def test_insert_flush_drop_collection(self, connect, collection): ''' target: test drop collection after insert entities for a while method: insert entities, sleep, and delete collection expected: has_collection false ''' - ids = connect.insert(collection, default_entity_row) + ids = connect.insert(collection, default_entity) assert len(ids) == 1 connect.flush([collection]) connect.drop_collection(collection) @@ -133,10 +139,6 @@ class TestInsertBase: connect.create_index(collection, field_name, get_simple_index) index = connect.describe_index(collection, field_name) assert index == get_simple_index - # fields = info["fields"] - # for field in fields: - # if field["name"] == field_name: - # assert field["indexes"][0] == get_simple_index @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_after_create_index(self, connect, collection, get_simple_index): @@ -200,9 +202,10 @@ class TestInsertBase: assert len(res_ids) == nb assert res_ids == ids stats = connect.get_collection_stats(id_collection) - assert stats["row_count"] == nb + assert stats[row_count] == nb @pytest.mark.timeout(ADD_TIMEOUT) + # @pytest.mark.tags("0331") def test_insert_the_same_ids(self, connect, id_collection, insert_count): ''' target: test insert vectors in collection, use customize the same ids @@ -216,7 +219,7 @@ class TestInsertBase: assert len(res_ids) == nb assert res_ids == ids stats = connect.get_collection_stats(id_collection) - assert stats["row_count"] == nb + assert stats[row_count] == nb @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field): @@ -231,16 +234,17 @@ class TestInsertBase: collection_name = gen_unique_str("test_collection") fields = { "fields": [filter_field, vector_field], - "auto_id": True + "auto_id": False } connect.create_collection(collection_name, fields) ids = [i for i in range(nb)] - entities = gen_entities_by_fields(fields["fields"], nb, dim) + entities = gen_entities_by_fields(fields["fields"], nb, default_dim) + logging.getLogger().info(entities) res_ids = connect.insert(collection_name, entities, ids) assert res_ids == ids connect.flush([collection_name]) - stats = connect.get_collection_stats(id_collection) - assert stats["row_count"] == nb + stats = connect.get_collection_stats(collection_name) + assert stats[row_count] == nb @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_ids_not_match(self, connect, id_collection, insert_count): @@ -250,7 +254,7 @@ class TestInsertBase: expected: exception raised ''' nb = insert_count - with pytest.raises(Exception) as e: + with pytest.raises(BaseException) as e: connect.insert(id_collection, gen_entities(nb)) # TODO @@ -262,9 +266,9 @@ class TestInsertBase: expected: BaseException raised ''' ids = [i for i in range(default_nb)] - res_ids = connect.insert(id_collection, default_entities, ids) - with pytest.raises(Exception) as e: - res_ids_new = connect.insert(id_collection, default_entities) + connect.insert(collection, default_entities, ids) + with pytest.raises(BaseException) as e: + connect.insert(collection, default_entities) # TODO: assert exception && enable @pytest.mark.level(2) @@ -275,8 +279,8 @@ class TestInsertBase: method: test insert vectors twice, use not ids first, and then use customize ids expected: error raised ''' - with pytest.raises(Exception) as e: - res_ids = connect.insert(id_collection, default_entities) + with pytest.raises(BaseException) as e: + connect.insert(id_collection, default_entities) @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_ids_length_not_match_batch(self, connect, id_collection): @@ -287,8 +291,8 @@ class TestInsertBase: ''' ids = [i for i in range(1, default_nb)] logging.getLogger().info(len(ids)) - with pytest.raises(Exception) as e: - res_ids = connect.insert(id_collection, default_entities, ids) + with pytest.raises(BaseException) as e: + connect.insert(id_collection, default_entities, ids) @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_ids_length_not_match_single(self, connect, id_collection): @@ -299,8 +303,8 @@ class TestInsertBase: ''' ids = [i for i in range(1, default_nb)] logging.getLogger().info(len(ids)) - with pytest.raises(Exception) as e: - res_ids = connect.insert(id_collection, default_entity, ids) + with pytest.raises(BaseException) as e: + connect.insert(id_collection, default_entity, ids) @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_partition(self, connect, collection): @@ -313,9 +317,9 @@ class TestInsertBase: ids = connect.insert(collection, default_entities, partition_tag=default_tag) assert len(ids) == default_nb assert connect.has_partition(collection, default_tag) - connect.flush([collection_name]) - stats = connect.get_collection_stats(id_collection) - assert stats["row_count"] == default_nb + connect.flush([collection]) + stats = connect.get_collection_stats(collection) + assert stats[row_count] == default_nb # TODO @pytest.mark.timeout(ADD_TIMEOUT) @@ -331,17 +335,18 @@ class TestInsertBase: assert res_ids == ids @pytest.mark.timeout(ADD_TIMEOUT) + @pytest.mark.tags("0331") def test_insert_default_partition(self, connect, collection): ''' target: test insert entities into default partition method: create partition and insert info collection without tag params expected: the collection row count equals to nb ''' - default_tag = "_default" - with pytest.raises(Exception) as e: - connect.create_partition(collection, default_tag) + with pytest.raises(BaseException) as e: + connect.create_partition(collection, default_partition_name) @pytest.mark.timeout(ADD_TIMEOUT) + @pytest.mark.tags("0331") def test_insert_partition_not_existed(self, connect, collection): ''' target: test insert entities in collection created before @@ -350,7 +355,7 @@ class TestInsertBase: ''' tag = gen_unique_str() with pytest.raises(Exception) as e: - ids = connect.insert(collection, default_entities, partition_tag=tag) + connect.insert(collection, default_entities, partition_tag=tag) @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_partition_repeatedly(self, connect, collection): @@ -364,7 +369,7 @@ class TestInsertBase: ids = connect.insert(collection, default_entities, partition_tag=default_tag) connect.flush([collection]) res = connect.get_collection_stats(collection) - assert res["row_count"] == 2 * default_nb + assert res[row_count] == 2 * default_nb def test_insert_dim_not_matched(self, connect, collection): ''' @@ -375,9 +380,11 @@ class TestInsertBase: vectors = gen_vectors(default_nb, int(default_dim) // 2) insert_entities = copy.deepcopy(default_entities) insert_entities[-1][default_float_vec_field_name] = vectors + # logging.getLogger().info(len(insert_entities[-1][default_float_vec_field_name][0])) with pytest.raises(Exception) as e: - ids = connect.insert(collection, insert_entities) + connect.insert(collection, insert_entities) + @pytest.mark.tags("0331") def test_insert_with_field_name_not_match(self, connect, collection): ''' target: test insert entities, with the entity field name updated @@ -400,6 +407,7 @@ class TestInsertBase: connect.insert(collection, tmp_entity) @pytest.mark.level(2) + @pytest.mark.tags("0331") def test_insert_with_field_value_not_match(self, connect, collection): ''' target: test insert entities, with the entity field value updated @@ -410,6 +418,7 @@ class TestInsertBase: with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_field_more(self, connect, collection): ''' target: test insert entities, with more fields than collection schema @@ -420,6 +429,7 @@ class TestInsertBase: with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_field_vector_more(self, connect, collection): ''' target: test insert entities, with more fields than collection schema @@ -430,6 +440,7 @@ class TestInsertBase: with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_field_less(self, connect, collection): ''' target: test insert entities, with less fields than collection schema @@ -440,6 +451,7 @@ class TestInsertBase: with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_field_vector_less(self, connect, collection): ''' target: test insert entities, with less fields than collection schema @@ -450,6 +462,7 @@ class TestInsertBase: with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_no_field_vector_value(self, connect, collection): ''' target: test insert entities, with no vector field value @@ -461,6 +474,7 @@ class TestInsertBase: with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_no_field_vector_type(self, connect, collection): ''' target: test insert entities, with no vector field type @@ -472,6 +486,7 @@ class TestInsertBase: with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_no_field_vector_name(self, connect, collection): ''' target: test insert entities, with no vector field name @@ -537,6 +552,7 @@ class TestInsertBinary: request.param["metric_type"] = "JACCARD" return request.param + # @pytest.mark.tags("0331") def test_insert_binary_entities(self, connect, binary_collection): ''' target: test insert entities in binary collection @@ -545,10 +561,11 @@ class TestInsertBinary: ''' ids = connect.insert(binary_collection, default_binary_entities) assert len(ids) == default_nb - connect.flush() + connect.flush([binary_collection]) stats = connect.get_collection_stats(binary_collection) - assert stats["row_count"] == default_nb + assert stats[row_count] == default_nb + # @pytest.mark.tags("0331") def test_insert_binary_partition(self, connect, binary_collection): ''' target: test insert entities and create partition tag @@ -559,8 +576,9 @@ class TestInsertBinary: ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag) assert len(ids) == default_nb assert connect.has_partition(binary_collection, default_tag) + connect.flush([binary_collection]) stats = connect.get_collection_stats(binary_collection) - assert stats["row_count"] == default_nb + assert stats[row_count] == default_nb def test_insert_binary_multi_times(self, connect, binary_collection): ''' @@ -573,7 +591,7 @@ class TestInsertBinary: assert len(ids) == 1 connect.flush([binary_collection]) stats = connect.get_collection_stats(binary_collection) - assert stats["row_count"] == default_nb + assert stats[row_count] == default_nb def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index): ''' @@ -610,7 +628,8 @@ class TestInsertBinary: ''' ids = connect.insert(binary_collection, default_binary_entities) connect.flush([binary_collection]) - query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD") + query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, + metric_type="JACCARD") connect.load_collection(binary_collection) res = connect.search(binary_collection, query) logging.getLogger().debug(res) @@ -638,9 +657,10 @@ class TestInsertAsync: assert not result def check_result(self, result): - logging.getLogger().info("In callback check status") + logging.getLogger().info("In callback check results") assert result + @pytest.mark.tags("0331") def test_insert_async(self, connect, collection, insert_count): ''' target: test insert vectors with different length of vectors @@ -654,6 +674,7 @@ class TestInsertAsync: assert len(ids) == nb @pytest.mark.level(2) + @pytest.mark.tags("0331") def test_insert_async_false(self, connect, collection, insert_count): ''' target: test insert vectors with different length of vectors @@ -666,6 +687,7 @@ class TestInsertAsync: connect.flush([collection]) assert len(ids) == nb + # @pytest.mark.tags("0331") def test_insert_async_callback(self, connect, collection, insert_count): ''' target: test insert vectors with different length of vectors @@ -675,6 +697,8 @@ class TestInsertAsync: nb = insert_count future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status) future.done() + ids = future.result() + assert len(ids) == nb @pytest.mark.level(2) def test_insert_async_long(self, connect, collection): @@ -685,14 +709,15 @@ class TestInsertAsync: ''' nb = 50000 future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_result) - result = future.result() - assert len(result) == nb + ids = future.result() + assert len(ids) == nb connect.flush([collection]) stats = connect.get_collection_stats(collection) logging.getLogger().info(stats) - assert stats["row_count"] == nb + assert stats[row_count] == nb @pytest.mark.level(2) + # @pytest.mark.tags("0331") def test_insert_async_callback_timeout(self, connect, collection): ''' target: test insert vectors with different length of vectors @@ -704,7 +729,7 @@ class TestInsertAsync: with pytest.raises(Exception) as e: result = future.result() stats = connect.get_collection_stats(collection) - assert stats["row_count"] == 0 + assert stats[row_count] == 0 def test_insert_async_invalid_params(self, connect): ''' @@ -714,8 +739,9 @@ class TestInsertAsync: ''' collection_new = gen_unique_str() future = connect.insert(collection_new, default_entities, _async=True) + future.done() with pytest.raises(Exception) as e: - result = future.result() + ids = future.result() def test_insert_async_invalid_params_raise_exception(self, connect, collection): ''' @@ -747,6 +773,7 @@ class TestInsertMultiCollections: # pytest.skip("sq8h not support in CPU mode") return request.param + # @pytest.mark.tags("0331") def test_insert_entity_multi_collections(self, connect): ''' target: test insert entities @@ -763,9 +790,10 @@ class TestInsertMultiCollections: connect.flush([collection_name]) assert len(ids) == default_nb stats = connect.get_collection_stats(collection_name) - assert stats["row_count"] == default_nb + assert stats[row_count] == default_nb @pytest.mark.timeout(ADD_TIMEOUT) + @pytest.mark.tags("0331") def test_drop_collection_insert_entity_another(self, connect, collection): ''' target: test insert vector to collection_1 after collection_2 deleted @@ -780,6 +808,7 @@ class TestInsertMultiCollections: assert len(ids) == 1 @pytest.mark.timeout(ADD_TIMEOUT) + @pytest.mark.tags("0331") def test_create_index_insert_entity_another(self, connect, collection, get_simple_index): ''' target: test insert vector to collection_2 after build index for collection_1 @@ -807,7 +836,7 @@ class TestInsertMultiCollections: index = connect.describe_index(collection_name, field_name) assert index == get_simple_index stats = connect.get_collection_stats(collection) - assert stats["row_count"] == 1 + assert stats[row_count] == 1 @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_entity_sleep_create_index_another(self, connect, collection, get_simple_index): @@ -822,10 +851,10 @@ class TestInsertMultiCollections: connect.flush([collection]) connect.create_index(collection_name, field_name, get_simple_index) stats = connect.get_collection_stats(collection) - assert stats["row_count"] == 1 + assert stats[row_count] == 1 @pytest.mark.timeout(ADD_TIMEOUT) - def test_search_entity_insert_vector_another(self, connect, collection): + def test_search_entity_insert_entity_another(self, connect, collection): ''' target: test insert entity to collection_1 after search collection_2 method: search collection and insert entity @@ -838,7 +867,7 @@ class TestInsertMultiCollections: ids = connect.insert(collection_name, default_entity) connect.flush() stats = connect.get_collection_stats(collection_name) - assert stats["row_count"] == 1 + assert stats[row_count] == 1 @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_entity_search_entity_another(self, connect, collection): @@ -876,9 +905,11 @@ class TestInsertMultiCollections: connect.insert(collection, default_entities) connect.flush([collection]) connect.load_collection(collection) + def release(): connect.release_collection(collection) - t = threading.Thread(target=release, args=()) + + t = threading.Thread(target=release, args=(collection,)) t.start() ids = connect.insert(collection, default_entities) assert len(ids) == default_nb @@ -938,6 +969,7 @@ class TestInsertInvalid(object): def get_field_vectors_value(self, request): yield request.param + @pytest.mark.tags("0331") def test_insert_ids_invalid(self, connect, id_collection, get_entity_id): ''' target: test insert, with using customize ids, which are not int64 @@ -949,11 +981,13 @@ class TestInsertInvalid(object): with pytest.raises(Exception): connect.insert(id_collection, default_entities, ids) + @pytest.mark.tags("0331") def test_insert_with_invalid_collection_name(self, connect, get_collection_name): collection_name = get_collection_name with pytest.raises(Exception): connect.insert(collection_name, default_entity) + @pytest.mark.tags("0331") def test_insert_with_invalid_partition_name(self, connect, collection, get_tag_name): tag_name = get_tag_name connect.create_partition(collection, default_tag) @@ -963,11 +997,13 @@ class TestInsertInvalid(object): else: connect.insert(collection, default_entity, partition_tag=tag_name) + @pytest.mark.tags("0331") def test_insert_with_invalid_field_name(self, connect, collection, get_field_name): tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name) with pytest.raises(Exception): connect.insert(collection, tmp_entity) + # @pytest.mark.tags("0331") def test_insert_with_invalid_field_type(self, connect, collection, get_field_type): field_type = get_field_type tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type) @@ -980,6 +1016,7 @@ class TestInsertInvalid(object): with pytest.raises(Exception): connect.insert(collection, tmp_entity) + @pytest.mark.tags("0331") def test_insert_with_invalid_field_entity_value(self, connect, collection, get_field_vectors_value): tmp_entity = copy.deepcopy(default_entity) src_vector = tmp_entity[-1]["values"] @@ -1043,12 +1080,14 @@ class TestInsertInvalidBinary(object): yield request.param @pytest.mark.level(2) + @pytest.mark.tags("0331") def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name): tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name) with pytest.raises(Exception): connect.insert(binary_collection, tmp_entity) @pytest.mark.level(2) + # @pytest.mark.tags("0331") def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value): tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value) with pytest.raises(Exception): @@ -1063,6 +1102,7 @@ class TestInsertInvalidBinary(object): connect.insert(binary_collection, tmp_entity) @pytest.mark.level(2) + @pytest.mark.tags("0331") def test_insert_ids_invalid(self, connect, binary_id_collection, get_entity_id): ''' target: test insert, with using customize ids, which are not int64 diff --git a/tests/python_test/test_connect.py b/tests/python_test/test_connect.py index 4aed94f7f4..eb7459f149 100644 --- a/tests/python_test/test_connect.py +++ b/tests/python_test/test_connect.py @@ -38,7 +38,7 @@ class TestConnect: expected: raise an error after disconnected ''' with pytest.raises(Exception) as e: - connect.close() + dis_connect.close() @pytest.mark.tags("0331") def test_connect_correct_ip_port(self, args): diff --git a/tests/python_test/utils.py b/tests/python_test/utils.py index 7ad9893f3c..c3a99de079 100644 --- a/tests/python_test/utils.py +++ b/tests/python_test/utils.py @@ -31,6 +31,7 @@ default_float_vec_field_name = "float_vector" default_binary_vec_field_name = "binary_vector" default_partition_name = "_default" default_tag = "1970_01_01" +row_count = "row_count" # TODO: # TODO: disable RHNSW_SQ/PQ in 0.11.0 @@ -43,6 +44,7 @@ all_index_types = [ "HNSW", # "NSG", "ANNOY", + "RHNSW_FLAT", "RHNSW_PQ", "RHNSW_SQ", "BIN_FLAT", @@ -54,10 +56,11 @@ default_index_params = [ {"nlist": 128}, {"nlist": 128}, # {"nlist": 128}, - {"nlist": 128, "m": 16}, + {"nlist": 128, "m": 16, "nbits": 8}, {"M": 48, "efConstruction": 500}, # {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50}, {"n_trees": 50}, + {"M": 48, "efConstruction": 500}, {"M": 48, "efConstruction": 500, "PQM": 64}, {"M": 48, "efConstruction": 500}, {"nlist": 128},