From bebbc841e10e49fd404eece0da0ab8df9fa2c849 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Wed, 16 Feb 2022 19:35:49 +0800 Subject: [PATCH] [test]Fix testcase assertion (#15595) Signed-off-by: zhuwenxing --- .../python_client/testcases/test_search_20.py | 23 ++++++++++++------- tests/python_client/utils/util_pymilvus.py | 4 ++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/python_client/testcases/test_search_20.py b/tests/python_client/testcases/test_search_20.py index e42694b2b3..b65774440f 100644 --- a/tests/python_client/testcases/test_search_20.py +++ b/tests/python_client/testcases/test_search_20.py @@ -2071,7 +2071,7 @@ class TestCollectionSearch(TestcaseBase): """ -def init_data(connect, collection, nb=3000, partition_names=None, auto_id=True): +def init_data(connect, collection, start=0, nb=3000, partition_names=None, auto_id=True): """ Generate entities and add it in collection """ @@ -2079,7 +2079,7 @@ def init_data(connect, collection, nb=3000, partition_names=None, auto_id=True): if nb == 3000: insert_entities = entities else: - insert_entities = gen_entities(nb, is_normal=True) + insert_entities = gen_entities(nb, start=start, is_normal=True) if partition_names is None: res = connect.insert(collection, insert_entities) else: @@ -2336,20 +2336,21 @@ class TestSearchBase: connect.create_partition(collection, default_tag) connect.create_partition(collection, new_tag) entities, ids = init_data(connect, collection, partition_names=default_tag) - new_entities, new_ids = init_data(connect, collection, nb=6001, partition_names=new_tag) + start = max(ids) + 1 + new_entities, new_ids = init_data(connect, collection, start=start, nb=6001, partition_names=new_tag) connect.create_index(collection, field_name, get_simple_index) search_param = get_search_param(index_type) query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, search_params=search_param) if top_k > max_top_k: with pytest.raises(Exception) as e: - res = connect.search(collection, **query) + res = connect.search(collection, partition_names=[default_tag], **query) else: connect.load_collection(collection) - res = connect.search(collection, **query) + res = connect.search(collection, partition_names=[default_tag], **query) assert check_id_result(res[0], ids[0]) assert res[0]._distances[0] < epsilon assert res[1]._distances[0] < epsilon - res = connect.search(collection, **query, partition_names=[new_tag]) + res = connect.search(collection, partition_names=[new_tag], **query) assert res[0]._distances[0] > epsilon assert res[1]._distances[0] > epsilon connect.release_collection(collection) @@ -2447,18 +2448,24 @@ class TestSearchBase: connect.create_partition(collection, default_tag) connect.create_partition(collection, new_tag) entities, ids = init_data(connect, collection, partition_names=default_tag) - new_entities, new_ids = init_data(connect, collection, nb=6001, partition_names=new_tag) + start = max(ids) + 1 + new_entities, new_ids = init_data(connect, collection, start=start, nb=6001, partition_names=new_tag) get_simple_index["metric_type"] = metric_type connect.create_index(collection, field_name, get_simple_index) search_param = get_search_param(index_type) + # query vectors are selected from default partition query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) connect.load_collection(collection) - res = connect.search(collection, **query) + # do search in default partition, so the results's id should be in default partition + res = connect.search(collection, partition_names=[default_tag], **query) assert check_id_result(res[0], ids[0]) + assert not check_id_result(res[1], new_ids[0]) + # the top_1 of res[0] and res[1] are themselfs, so the distance is 1 (when metric_type is IP, the distance more closer to 1 means more similar) assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) res = connect.search(collection, **query, partition_names=["new_tag"]) + # the query vector is selected from default partition, so the top 1 can't be itself when searching in new_tag partition, which means the distance less than 1 assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0]) # TODO: # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0]) diff --git a/tests/python_client/utils/util_pymilvus.py b/tests/python_client/utils/util_pymilvus.py index fe83ff190a..0ff1e5661b 100644 --- a/tests/python_client/utils/util_pymilvus.py +++ b/tests/python_client/utils/util_pymilvus.py @@ -274,10 +274,10 @@ def gen_binary_default_fields(auto_id=True): return default_fields -def gen_entities(nb, is_normal=False): +def gen_entities(nb, start=0, is_normal=False): vectors = gen_vectors(nb, default_dim, is_normal) entities = [ - {"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]}, + {"name": "int64", "type": DataType.INT64, "values": [i for i in range(start, nb+start)]}, {"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]}, {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors} ]