Fix term bug and add todo note

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
This commit is contained in:
FluorineDog 2020-12-10 20:13:37 +08:00 committed by yefu.chen
parent fc054874d3
commit 02f855359a
4 changed files with 104 additions and 14 deletions

View File

@ -105,10 +105,12 @@ std::unique_ptr<Expr>
ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
auto expr = std::make_unique<TermExprImpl<T>>();
auto data_type = schema[field_name].get_data_type();
Assert(body.is_array());
Assert(body.is_object());
auto values = body["values"];
expr->field_id_ = field_name;
expr->data_type_ = data_type;
for (auto& value : body) {
for (auto& value : values) {
if constexpr (std::is_same_v<T, bool>) {
Assert(value.is_boolean());
} else if constexpr (std::is_integral_v<T>) {

View File

@ -343,6 +343,7 @@ TEST(Expr, TestTerm) {
{R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }},
{R"([2000])", [](int v) { return v == 2000; }},
{R"([3000])", [](int v) { return v == 3000; }},
{R"([])", [](int v) { return false; }},
{vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }},
};
@ -352,7 +353,9 @@ TEST(Expr, TestTerm) {
"must": [
{
"term": {
"age": @@@@
"age": {
"values": @@@@
}
}
},
{
@ -406,4 +409,4 @@ TEST(Expr, TestTerm) {
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
}
}
}
}

View File

@ -259,6 +259,58 @@ TEST(Query, ExecWithPredicate) {
])");
ASSERT_EQ(json.dump(2), ref.dump(2));
}
TEST(Query, ExecTerm) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("age", DataType::FLOAT);
std::string dsl = R"({
"bool": {
"must": [
{
"term": {
"age": {
"values": []
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto segment = std::make_unique<SegmentSmallIndex>(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 3;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
QueryResult qr;
Timestamp time = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = QueryResultToJson(qr);
ASSERT_EQ(qr.num_queries_, num_queries);
ASSERT_EQ(qr.topK_, topk);
// for(auto x: )
}
TEST(Query, ExecWithoutPredicate) {
using namespace milvus::query;

View File

@ -283,6 +283,7 @@ class TestSearchBase:
assert res[0]._distances[0] < epsilon
assert check_id_result(res[0], ids[0])
# DOG: TODO INVALID TYPE UNKNOWN
@pytest.mark.skip("search_after_index_different_metric_type")
def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index):
'''
@ -454,6 +455,7 @@ class TestSearchBase:
# test for ip metric
#
# TODO: reopen after we supporting ip flat
# DOG: TODO REDUCE
@pytest.mark.skip("search_ip_flat")
@pytest.mark.level(2)
def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
@ -646,6 +648,7 @@ class TestSearchBase:
# TODO:
# assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
# DOG: TODO REDUCE
# TODO: reopen after we supporting ip flat
@pytest.mark.skip("search_distance_ip")
@pytest.mark.level(2)
@ -702,6 +705,7 @@ class TestSearchBase:
# TODO:
# assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_jaccard_flat_index")
def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
'''
@ -718,13 +722,14 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
# DOG: TODO INVALID TYPE
@pytest.mark.skip("search_distance_jaccard_flat_index_L2")
@pytest.mark.level(2)
def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection):
'''
target: search binary_collection, and check the result: distance
method: compare the return distance value with value computed with L2
expected: the return distance equals to the computed value
expected: throw error of mismatched metric type
'''
nq = 1
int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
@ -735,6 +740,7 @@ class TestSearchBase:
with pytest.raises(Exception) as e:
res = connect.search(binary_collection, query)
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_hamming_flat_index")
@pytest.mark.level(2)
def test_search_distance_hamming_flat_index(self, connect, binary_collection):
@ -752,6 +758,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_substructure_flat_index")
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index(self, connect, binary_collection):
@ -770,6 +777,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_substructure_flat_index_B")
@pytest.mark.level(2)
def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
@ -789,6 +797,7 @@ class TestSearchBase:
assert res[1][0].distance <= epsilon
assert res[1][0].id == ids[1]
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_superstructure_flat_index")
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
@ -807,6 +816,7 @@ class TestSearchBase:
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_superstructure_flat_index_B")
@pytest.mark.level(2)
def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
@ -828,6 +838,7 @@ class TestSearchBase:
assert res[1][0].id in ids
assert res[1][0].distance <= epsilon
# DOG: TODO BINARY
@pytest.mark.skip("search_distance_tanimoto_flat_index")
@pytest.mark.level(2)
def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
@ -966,6 +977,7 @@ class TestSearchDSL(object):
******************************************************************
"""
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_no_must")
def test_query_no_must(self, connect, collection):
'''
@ -977,6 +989,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_no_vector_term_only")
def test_query_no_vector_term_only(self, connect, collection):
'''
@ -1012,6 +1025,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_wrong_format")
def test_query_wrong_format(self, connect, collection):
'''
@ -1042,7 +1056,7 @@ class TestSearchDSL(object):
******************************************************************
"""
@pytest.mark.skip("query_term_value_not_in")
# PASS
@pytest.mark.level(2)
def test_query_term_value_not_in(self, connect, collection):
'''
@ -1058,8 +1072,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 0
# TODO:
# TODO:
@pytest.mark.skip("query_term_value_all_in")
# PASS
@pytest.mark.level(2)
def test_query_term_value_all_in(self, connect, collection):
'''
@ -1074,8 +1087,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 1
# TODO:
# TODO:
@pytest.mark.skip("query_term_values_not_in")
# PASS
@pytest.mark.level(2)
def test_query_term_values_not_in(self, connect, collection):
'''
@ -1091,7 +1103,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 0
# TODO:
@pytest.mark.skip("query_term_values_all_in")
# PASS
def test_query_term_values_all_in(self, connect, collection):
'''
method: build query with vector and term expr, with all term can be filtered
@ -1110,7 +1122,7 @@ class TestSearchDSL(object):
assert result.id in ids[:limit]
# TODO:
@pytest.mark.skip("query_term_values_parts_in")
# PASS
def test_query_term_values_parts_in(self, connect, collection):
'''
method: build query with vector and term expr, with parts of term can be filtered
@ -1126,8 +1138,7 @@ class TestSearchDSL(object):
assert len(res[0]) == default_top_k
# TODO:
# TODO:
@pytest.mark.skip("query_term_values_repeat")
# PASS
@pytest.mark.level(2)
def test_query_term_values_repeat(self, connect, collection):
'''
@ -1144,6 +1155,7 @@ class TestSearchDSL(object):
assert len(res[0]) == 1
# TODO:
# DOG: BUG, please fix
@pytest.mark.skip("query_term_value_empty")
def test_query_term_value_empty(self, connect, collection):
'''
@ -1156,6 +1168,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
@pytest.mark.skip("query_complex_dsl")
def test_query_complex_dsl(self, connect, collection):
'''
@ -1178,6 +1191,7 @@ class TestSearchDSL(object):
******************************************************************
"""
# DOG: TODO INVALID DSL
# TODO
@pytest.mark.skip("query_term_key_error")
@pytest.mark.level(2)
@ -1199,6 +1213,7 @@ class TestSearchDSL(object):
def get_invalid_term(self, request):
return request.param
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_term_wrong_format")
@pytest.mark.level(2)
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
@ -1213,6 +1228,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO UNKNOWN
# TODO
@pytest.mark.skip("query_term_field_named_term")
@pytest.mark.level(2)
@ -1239,6 +1255,7 @@ class TestSearchDSL(object):
assert len(res[0]) == default_top_k
connect.drop_collection(collection_term)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_term_one_field_not_existed")
@pytest.mark.level(2)
def test_query_term_one_field_not_existed(self, connect, collection):
@ -1349,6 +1366,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_range_one_field_not_existed")
def test_query_range_one_field_not_existed(self, connect, collection):
'''
@ -1369,6 +1387,7 @@ class TestSearchDSL(object):
************************************************************************
"""
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_term_has_common")
@pytest.mark.level(2)
@ -1386,6 +1405,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_term_no_common")
@pytest.mark.level(2)
@ -1403,6 +1423,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_term_different_fields")
def test_query_multi_term_different_fields(self, connect, collection):
@ -1420,6 +1441,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_term_multi_fields")
@pytest.mark.level(2)
@ -1437,6 +1459,7 @@ class TestSearchDSL(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_range_has_common")
@pytest.mark.level(2)
@ -1454,6 +1477,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_range_no_common")
@pytest.mark.level(2)
@ -1471,6 +1495,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_multi_range_different_fields")
@pytest.mark.level(2)
@ -1488,6 +1513,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == 0
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_range_multi_fields")
@pytest.mark.level(2)
@ -1511,6 +1537,7 @@ class TestSearchDSL(object):
******************************************************************
"""
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_term_range_has_common")
@pytest.mark.level(2)
@ -1528,6 +1555,7 @@ class TestSearchDSL(object):
assert len(res) == nq
assert len(res[0]) == default_top_k
# DOG: TODO TRC
# TODO
@pytest.mark.skip("query_single_term_range_no_common")
def test_query_single_term_range_no_common(self, connect, collection):
@ -1588,6 +1616,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_should_only_term")
def test_query_should_only_term(self, connect, collection):
'''
@ -1599,6 +1628,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_should_only_vector")
def test_query_should_only_vector(self, connect, collection):
'''
@ -1610,6 +1640,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_must_not_only_term")
def test_query_must_not_only_term(self, connect, collection):
'''
@ -1621,6 +1652,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_must_not_vector")
def test_query_must_not_vector(self, connect, collection):
'''
@ -1632,6 +1664,7 @@ class TestSearchDSLBools(object):
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
# DOG: TODO INVALID DSL
@pytest.mark.skip("query_must_should")
def test_query_must_should(self, connect, collection):
'''