From 75ce32dcad6bb81c9575b973d612bc326eacce14 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Sat, 18 Sep 2021 20:11:52 +0800 Subject: [PATCH] Merge part query cases of pymilvus and orm (#8253) Signed-off-by: ThreadDao --- .../testcases/entity/test_query.py | 124 ---------------- .../python_client/testcases/test_query_20.py | 133 ++++++++++++++++++ 2 files changed, 133 insertions(+), 124 deletions(-) diff --git a/tests/python_client/testcases/entity/test_query.py b/tests/python_client/testcases/entity/test_query.py index 088a7d8c5d..614e06539c 100644 --- a/tests/python_client/testcases/entity/test_query.py +++ b/tests/python_client/testcases/entity/test_query.py @@ -58,130 +58,6 @@ def init_binary_data(connect, collection, nb=3000, insert=True, partition_names= return insert_raw_vectors, insert_entities, ids -class TestQueryBase: - """ - test Query interface - query(collection_name, expr, output_fields=None, partition_names=None, timeout=None) - """ - - @pytest.fixture( - scope="function", - params=ut.gen_invalid_strs() - ) - def get_collection_name(self, request): - yield request.param - - @pytest.fixture( - scope="function", - params=ut.gen_simple_index() - ) - def get_simple_index(self, request, connect): - return request.param - - @pytest.mark.tags(CaseLabel.L0) - def test_query_invalid(self, connect, collection): - """ - target: test query - method: query with term expr - expected: verify query result - """ - entities, ids = init_data(connect, collection) - assert len(ids) == ut.default_nb - connect.load_collection(collection) - term_expr = f'{default_int_field_name} in {entities[:default_pos]}' - with pytest.raises(Exception): - res = connect.query(collection, term_expr) - - @pytest.mark.tags(CaseLabel.L0) - def test_query_valid(self, connect, collection): - """ - target: test query - method: query with term expr - expected: verify query result - """ - entities, ids = init_data(connect, collection) - assert len(ids) == ut.default_nb - connect.load_collection(collection) - term_expr = f'{default_int_field_name} in {ids[:default_pos]}' - res = connect.query(collection, term_expr, output_fields=["*", "%"]) - assert len(res) == default_pos - for _id, index in enumerate(ids[:default_pos]): - if res[index][default_int_field_name] == entities[0]["values"][index]: - assert res[index][default_float_field_name] == entities[1]["values"][index] - res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name]) - assert len(res) == default_pos - for _id, index in enumerate(ids[:default_pos]): - if res[index][default_int_field_name] == entities[0]["values"][index]: - ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index]) - - @pytest.mark.tags(CaseLabel.L0) - def test_query_collection_not_existed(self, connect): - """ - target: test query not existed collection - method: query not existed collection - expected: raise exception - """ - collection = "not_exist" - with pytest.raises(Exception): - connect.query(collection, default_term_expr) - - @pytest.mark.tags(CaseLabel.L0) - def test_query_invalid_collection_name(self, connect, get_collection_name): - """ - target: test query with invalid collection name - method: query with invalid collection name - expected: raise exception - """ - collection_name = get_collection_name - with pytest.raises(Exception): - connect.query(collection_name, default_term_expr) - - @pytest.mark.tags(CaseLabel.L0) - @pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()]) - def test_query_expr_invalid_string(self, connect, collection, expr): - """ - target: test query with non-string expr - method: query with non-string expr, eg 1, [] .. - expected: raise exception - """ - # entities, ids = init_data(connect, collection) - # assert len(ids) == ut.default_nb - connect.load_collection(collection) - with pytest.raises(Exception): - connect.query(collection, expr) - - @pytest.mark.xfail(reason="#6072") - @pytest.mark.tags(CaseLabel.L0) - def test_query_binary_expr_single_term_array(self, connect, binary_collection): - """ - target: test query with single array term expr - method: query with single array value - expected: query result is one entity - """ - _, binary_entities, ids = init_binary_data(connect, binary_collection) - assert len(ids) == ut.default_nb - connect.load_collection(binary_collection) - term_expr = f'{default_int_field_name} in [0]' - res = connect.query(binary_collection, term_expr, output_fields=["*", "%"]) - assert len(res) == 1 - assert res[0][default_int_field_name] == binary_entities[0]["values"][0] - assert res[1][default_float_field_name] == binary_entities[1]["values"][0] - assert res[2][ut.default_float_vec_field_name] == binary_entities[2]["values"][0] - - @pytest.mark.parametrize("fields", ut.gen_invalid_strs()) - @pytest.mark.tags(CaseLabel.L0) - def test_query_invalid_output_fields(self, connect, collection, fields): - """ - target: test query with invalid output fields - method: query with invalid field fields - expected: raise exception - """ - init_data(connect, collection) - connect.load_collection(collection) - with pytest.raises(Exception): - connect.query(collection, default_term_expr, output_fields=[fields]) - - class TestQueryPartition: """ test Query interface diff --git a/tests/python_client/testcases/test_query_20.py b/tests/python_client/testcases/test_query_20.py index 99384153ba..0a2436a500 100644 --- a/tests/python_client/testcases/test_query_20.py +++ b/tests/python_client/testcases/test_query_20.py @@ -11,6 +11,7 @@ from common import common_func as cf from common import common_type as ct from common.common_type import CaseLabel, CheckTasks from utils.util_log import test_log as log +import utils.utils as ut prefix = "query" exp_res = "exp_res" @@ -18,6 +19,11 @@ default_term_expr = f'{ct.default_int64_field_name} in [0, 1]' default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}} binary_index_params = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}} +default_entities = ut.gen_entities(ut.default_nb, is_normal=True) +default_pos = 5 +default_int_field_name = "int64" +default_float_field_name = "float" + class TestQueryBase(TestcaseBase): """ @@ -1056,3 +1062,130 @@ class TestQueryOperation(TestcaseBase): res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name]) assert len(res) == 1 assert res[0][ct.default_int64_field_name] == half + + """ + ****************************************************************** + The following classes are copied from pymilvus test + ****************************************************************** + """ + + +def init_data(connect, collection, nb=ut.default_nb, partition_names=None, auto_id=True): + """ + Generate entities and add it in collection + """ + if nb == 3000: + insert_entities = default_entities + else: + insert_entities = ut.gen_entities(nb, is_normal=True) + if partition_names is None: + if auto_id: + res = connect.insert(collection, insert_entities) + else: + res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)]) + else: + if auto_id: + res = connect.insert(collection, insert_entities, partition_name=partition_names) + else: + res = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], + partition_name=partition_names) + connect.flush([collection]) + ids = res.primary_keys + return insert_entities, ids + + +class TestQueryBase: + """ + test Query interface + query(collection_name, expr, output_fields=None, partition_names=None, timeout=None) + """ + + @pytest.fixture( + scope="function", + params=ut.gen_invalid_strs() + ) + def get_collection_name(self, request): + yield request.param + + @pytest.mark.tags(CaseLabel.L0) + def test_query_invalid(self, connect, collection): + """ + target: test query + method: query with term expr + expected: verify query result + """ + entities, ids = init_data(connect, collection) + assert len(ids) == ut.default_nb + connect.load_collection(collection) + term_expr = f'{default_int_field_name} in {entities[:default_pos]}' + with pytest.raises(Exception): + connect.query(collection, term_expr) + + @pytest.mark.tags(CaseLabel.L0) + def test_query_valid(self, connect, collection): + """ + target: test query + method: query with term expr + expected: verify query result + """ + entities, ids = init_data(connect, collection) + assert len(ids) == ut.default_nb + connect.load_collection(collection) + term_expr = f'{default_int_field_name} in {ids[:default_pos]}' + res = connect.query(collection, term_expr, output_fields=["*", "%"]) + assert len(res) == default_pos + for _id, index in enumerate(ids[:default_pos]): + if res[index][default_int_field_name] == entities[0]["values"][index]: + assert res[index][default_float_field_name] == entities[1]["values"][index] + res = connect.query(collection, term_expr, output_fields=[ut.default_float_vec_field_name]) + assert len(res) == default_pos + for _id, index in enumerate(ids[:default_pos]): + if res[index][default_int_field_name] == entities[0]["values"][index]: + ut.assert_equal_vector(res[index][ut.default_float_vec_field_name], entities[2]["values"][index]) + + @pytest.mark.tags(CaseLabel.L0) + def test_query_collection_not_existed(self, connect): + """ + target: test query not existed collection + method: query not existed collection + expected: raise exception + """ + collection = "not_exist" + with pytest.raises(Exception): + connect.query(collection, default_term_expr) + + @pytest.mark.tags(CaseLabel.L0) + def test_query_invalid_collection_name(self, connect, get_collection_name): + """ + target: test query with invalid collection name + method: query with invalid collection name + expected: raise exception + """ + collection_name = get_collection_name + with pytest.raises(Exception): + connect.query(collection_name, default_term_expr) + + @pytest.mark.tags(CaseLabel.L0) + @pytest.mark.parametrize("expr", [1, "1", "12-s", "中文", [], {}, ()]) + def test_query_expr_invalid_string(self, connect, collection, expr): + """ + target: test query with non-string expr + method: query with non-string expr, eg 1, [] .. + expected: raise exception + """ + connect.load_collection(collection) + with pytest.raises(Exception): + connect.query(collection, expr) + + @pytest.mark.parametrize("fields", ut.gen_invalid_strs()) + @pytest.mark.tags(CaseLabel.L0) + def test_query_invalid_output_fields(self, connect, collection, fields): + """ + target: test query with invalid output fields + method: query with invalid field fields + expected: raise exception + """ + init_data(connect, collection) + connect.load_collection(collection) + with pytest.raises(Exception): + connect.query(collection, default_term_expr, output_fields=[fields])