From 02f855359a8e21e593f1601bfd45ce52147ed317 Mon Sep 17 00:00:00 2001 From: FluorineDog Date: Thu, 10 Dec 2020 20:13:37 +0800 Subject: [PATCH] Fix term bug and add todo note Signed-off-by: FluorineDog --- internal/core/src/query/Plan.cpp | 6 ++- internal/core/unittest/test_expr.cpp | 7 +++- internal/core/unittest/test_query.cpp | 52 ++++++++++++++++++++++++++ tests/python/test_search.py | 53 ++++++++++++++++++++++----- 4 files changed, 104 insertions(+), 14 deletions(-) diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 8e2a89cfb0..bcf1090712 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -105,10 +105,12 @@ std::unique_ptr ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) { auto expr = std::make_unique>(); auto data_type = schema[field_name].get_data_type(); - Assert(body.is_array()); + Assert(body.is_object()); + auto values = body["values"]; + expr->field_id_ = field_name; expr->data_type_ = data_type; - for (auto& value : body) { + for (auto& value : values) { if constexpr (std::is_same_v) { Assert(value.is_boolean()); } else if constexpr (std::is_integral_v) { diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 2d968f09cf..22471640f2 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -343,6 +343,7 @@ TEST(Expr, TestTerm) { {R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }}, {R"([2000])", [](int v) { return v == 2000; }}, {R"([3000])", [](int v) { return v == 3000; }}, + {R"([])", [](int v) { return false; }}, {vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }}, }; @@ -352,7 +353,9 @@ TEST(Expr, TestTerm) { "must": [ { "term": { - "age": @@@@ + "age": { + "values": @@@@ + } } }, { @@ -406,4 +409,4 @@ TEST(Expr, TestTerm) { ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } } -} \ No newline at end of file +} diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 5bb3a75a4d..c51935ff0d 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -259,6 +259,58 @@ TEST(Query, ExecWithPredicate) { ])"); ASSERT_EQ(json.dump(2), ref.dump(2)); } +TEST(Query, ExecTerm) { + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("age", DataType::FLOAT); + std::string dsl = R"({ + "bool": { + "must": [ + { + "term": { + "age": { + "values": [] + } + } + }, + { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 5 + } + } + } + ] + } + })"; + int64_t N = 1000 * 1000; + auto dataset = DataGen(schema, N); + auto segment = std::make_unique(schema); + segment->PreInsert(N); + segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_); + + auto plan = CreatePlan(*schema, dsl); + auto num_queries = 3; + auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); + auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + QueryResult qr; + Timestamp time = 1000000; + std::vector ph_group_arr = {ph_group.get()}; + segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr); + std::vector> results; + int topk = 5; + auto json = QueryResultToJson(qr); + ASSERT_EQ(qr.num_queries_, num_queries); + ASSERT_EQ(qr.topK_, topk); + // for(auto x: ) +} TEST(Query, ExecWithoutPredicate) { using namespace milvus::query; diff --git a/tests/python/test_search.py b/tests/python/test_search.py index 3b546d15f3..947de240c3 100644 --- a/tests/python/test_search.py +++ b/tests/python/test_search.py @@ -283,6 +283,7 @@ class TestSearchBase: assert res[0]._distances[0] < epsilon assert check_id_result(res[0], ids[0]) + # DOG: TODO INVALID TYPE UNKNOWN @pytest.mark.skip("search_after_index_different_metric_type") def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index): ''' @@ -454,6 +455,7 @@ class TestSearchBase: # test for ip metric # # TODO: reopen after we supporting ip flat + # DOG: TODO REDUCE @pytest.mark.skip("search_ip_flat") @pytest.mark.level(2) def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq): @@ -646,6 +648,7 @@ class TestSearchBase: # TODO: # assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon + # DOG: TODO REDUCE # TODO: reopen after we supporting ip flat @pytest.mark.skip("search_distance_ip") @pytest.mark.level(2) @@ -702,6 +705,7 @@ class TestSearchBase: # TODO: # assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon + # DOG: TODO BINARY @pytest.mark.skip("search_distance_jaccard_flat_index") def test_search_distance_jaccard_flat_index(self, connect, binary_collection): ''' @@ -718,13 +722,14 @@ class TestSearchBase: res = connect.search(binary_collection, query) assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon + # DOG: TODO INVALID TYPE @pytest.mark.skip("search_distance_jaccard_flat_index_L2") @pytest.mark.level(2) def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection): ''' target: search binary_collection, and check the result: distance method: compare the return distance value with value computed with L2 - expected: the return distance equals to the computed value + expected: throw error of mismatched metric type ''' nq = 1 int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2) @@ -735,6 +740,7 @@ class TestSearchBase: with pytest.raises(Exception) as e: res = connect.search(binary_collection, query) + # DOG: TODO BINARY @pytest.mark.skip("search_distance_hamming_flat_index") @pytest.mark.level(2) def test_search_distance_hamming_flat_index(self, connect, binary_collection): @@ -752,6 +758,7 @@ class TestSearchBase: res = connect.search(binary_collection, query) assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon + # DOG: TODO BINARY @pytest.mark.skip("search_distance_substructure_flat_index") @pytest.mark.level(2) def test_search_distance_substructure_flat_index(self, connect, binary_collection): @@ -770,6 +777,7 @@ class TestSearchBase: res = connect.search(binary_collection, query) assert len(res[0]) == 0 + # DOG: TODO BINARY @pytest.mark.skip("search_distance_substructure_flat_index_B") @pytest.mark.level(2) def test_search_distance_substructure_flat_index_B(self, connect, binary_collection): @@ -789,6 +797,7 @@ class TestSearchBase: assert res[1][0].distance <= epsilon assert res[1][0].id == ids[1] + # DOG: TODO BINARY @pytest.mark.skip("search_distance_superstructure_flat_index") @pytest.mark.level(2) def test_search_distance_superstructure_flat_index(self, connect, binary_collection): @@ -807,6 +816,7 @@ class TestSearchBase: res = connect.search(binary_collection, query) assert len(res[0]) == 0 + # DOG: TODO BINARY @pytest.mark.skip("search_distance_superstructure_flat_index_B") @pytest.mark.level(2) def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection): @@ -828,6 +838,7 @@ class TestSearchBase: assert res[1][0].id in ids assert res[1][0].distance <= epsilon + # DOG: TODO BINARY @pytest.mark.skip("search_distance_tanimoto_flat_index") @pytest.mark.level(2) def test_search_distance_tanimoto_flat_index(self, connect, binary_collection): @@ -966,6 +977,7 @@ class TestSearchDSL(object): ****************************************************************** """ + # DOG: TODO INVALID DSL @pytest.mark.skip("query_no_must") def test_query_no_must(self, connect, collection): ''' @@ -977,6 +989,7 @@ class TestSearchDSL(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO INVALID DSL @pytest.mark.skip("query_no_vector_term_only") def test_query_no_vector_term_only(self, connect, collection): ''' @@ -1012,6 +1025,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == default_top_k + # DOG: TODO INVALID DSL @pytest.mark.skip("query_wrong_format") def test_query_wrong_format(self, connect, collection): ''' @@ -1042,7 +1056,7 @@ class TestSearchDSL(object): ****************************************************************** """ - @pytest.mark.skip("query_term_value_not_in") + # PASS @pytest.mark.level(2) def test_query_term_value_not_in(self, connect, collection): ''' @@ -1058,8 +1072,7 @@ class TestSearchDSL(object): assert len(res[0]) == 0 # TODO: - # TODO: - @pytest.mark.skip("query_term_value_all_in") + # PASS @pytest.mark.level(2) def test_query_term_value_all_in(self, connect, collection): ''' @@ -1074,8 +1087,7 @@ class TestSearchDSL(object): assert len(res[0]) == 1 # TODO: - # TODO: - @pytest.mark.skip("query_term_values_not_in") + # PASS @pytest.mark.level(2) def test_query_term_values_not_in(self, connect, collection): ''' @@ -1091,7 +1103,7 @@ class TestSearchDSL(object): assert len(res[0]) == 0 # TODO: - @pytest.mark.skip("query_term_values_all_in") + # PASS def test_query_term_values_all_in(self, connect, collection): ''' method: build query with vector and term expr, with all term can be filtered @@ -1110,7 +1122,7 @@ class TestSearchDSL(object): assert result.id in ids[:limit] # TODO: - @pytest.mark.skip("query_term_values_parts_in") + # PASS def test_query_term_values_parts_in(self, connect, collection): ''' method: build query with vector and term expr, with parts of term can be filtered @@ -1126,8 +1138,7 @@ class TestSearchDSL(object): assert len(res[0]) == default_top_k # TODO: - # TODO: - @pytest.mark.skip("query_term_values_repeat") + # PASS @pytest.mark.level(2) def test_query_term_values_repeat(self, connect, collection): ''' @@ -1144,6 +1155,7 @@ class TestSearchDSL(object): assert len(res[0]) == 1 # TODO: + # DOG: BUG, please fix @pytest.mark.skip("query_term_value_empty") def test_query_term_value_empty(self, connect, collection): ''' @@ -1156,6 +1168,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == 0 + # DOG: TODO TRC @pytest.mark.skip("query_complex_dsl") def test_query_complex_dsl(self, connect, collection): ''' @@ -1178,6 +1191,7 @@ class TestSearchDSL(object): ****************************************************************** """ + # DOG: TODO INVALID DSL # TODO @pytest.mark.skip("query_term_key_error") @pytest.mark.level(2) @@ -1199,6 +1213,7 @@ class TestSearchDSL(object): def get_invalid_term(self, request): return request.param + # DOG: TODO INVALID DSL @pytest.mark.skip("query_term_wrong_format") @pytest.mark.level(2) def test_query_term_wrong_format(self, connect, collection, get_invalid_term): @@ -1213,6 +1228,7 @@ class TestSearchDSL(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO UNKNOWN # TODO @pytest.mark.skip("query_term_field_named_term") @pytest.mark.level(2) @@ -1239,6 +1255,7 @@ class TestSearchDSL(object): assert len(res[0]) == default_top_k connect.drop_collection(collection_term) + # DOG: TODO INVALID DSL @pytest.mark.skip("query_term_one_field_not_existed") @pytest.mark.level(2) def test_query_term_one_field_not_existed(self, connect, collection): @@ -1349,6 +1366,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == default_top_k + # DOG: TODO INVALID DSL @pytest.mark.skip("query_range_one_field_not_existed") def test_query_range_one_field_not_existed(self, connect, collection): ''' @@ -1369,6 +1387,7 @@ class TestSearchDSL(object): ************************************************************************ """ + # DOG: TODO TRC # TODO @pytest.mark.skip("query_multi_term_has_common") @pytest.mark.level(2) @@ -1386,6 +1405,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == default_top_k + # DOG: TODO TRC # TODO @pytest.mark.skip("query_multi_term_no_common") @pytest.mark.level(2) @@ -1403,6 +1423,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == 0 + # DOG: TODO TRC # TODO @pytest.mark.skip("query_multi_term_different_fields") def test_query_multi_term_different_fields(self, connect, collection): @@ -1420,6 +1441,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == 0 + # DOG: TODO TRC # TODO @pytest.mark.skip("query_single_term_multi_fields") @pytest.mark.level(2) @@ -1437,6 +1459,7 @@ class TestSearchDSL(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO TRC # TODO @pytest.mark.skip("query_multi_range_has_common") @pytest.mark.level(2) @@ -1454,6 +1477,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == default_top_k + # DOG: TODO TRC # TODO @pytest.mark.skip("query_multi_range_no_common") @pytest.mark.level(2) @@ -1471,6 +1495,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == 0 + # DOG: TODO TRC # TODO @pytest.mark.skip("query_multi_range_different_fields") @pytest.mark.level(2) @@ -1488,6 +1513,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == 0 + # DOG: TODO TRC # TODO @pytest.mark.skip("query_single_range_multi_fields") @pytest.mark.level(2) @@ -1511,6 +1537,7 @@ class TestSearchDSL(object): ****************************************************************** """ + # DOG: TODO TRC # TODO @pytest.mark.skip("query_single_term_range_has_common") @pytest.mark.level(2) @@ -1528,6 +1555,7 @@ class TestSearchDSL(object): assert len(res) == nq assert len(res[0]) == default_top_k + # DOG: TODO TRC # TODO @pytest.mark.skip("query_single_term_range_no_common") def test_query_single_term_range_no_common(self, connect, collection): @@ -1588,6 +1616,7 @@ class TestSearchDSLBools(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO INVALID DSL @pytest.mark.skip("query_should_only_term") def test_query_should_only_term(self, connect, collection): ''' @@ -1599,6 +1628,7 @@ class TestSearchDSLBools(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO INVALID DSL @pytest.mark.skip("query_should_only_vector") def test_query_should_only_vector(self, connect, collection): ''' @@ -1610,6 +1640,7 @@ class TestSearchDSLBools(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO INVALID DSL @pytest.mark.skip("query_must_not_only_term") def test_query_must_not_only_term(self, connect, collection): ''' @@ -1621,6 +1652,7 @@ class TestSearchDSLBools(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO INVALID DSL @pytest.mark.skip("query_must_not_vector") def test_query_must_not_vector(self, connect, collection): ''' @@ -1632,6 +1664,7 @@ class TestSearchDSLBools(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) + # DOG: TODO INVALID DSL @pytest.mark.skip("query_must_should") def test_query_must_should(self, connect, collection): '''