mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Fix term bug and add todo note
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
This commit is contained in:
parent
fc054874d3
commit
02f855359a
@ -105,10 +105,12 @@ std::unique_ptr<Expr>
|
||||
ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
|
||||
auto expr = std::make_unique<TermExprImpl<T>>();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(body.is_array());
|
||||
Assert(body.is_object());
|
||||
auto values = body["values"];
|
||||
|
||||
expr->field_id_ = field_name;
|
||||
expr->data_type_ = data_type;
|
||||
for (auto& value : body) {
|
||||
for (auto& value : values) {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(value.is_boolean());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
|
||||
@ -343,6 +343,7 @@ TEST(Expr, TestTerm) {
|
||||
{R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }},
|
||||
{R"([2000])", [](int v) { return v == 2000; }},
|
||||
{R"([3000])", [](int v) { return v == 3000; }},
|
||||
{R"([])", [](int v) { return false; }},
|
||||
{vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }},
|
||||
};
|
||||
|
||||
@ -352,7 +353,9 @@ TEST(Expr, TestTerm) {
|
||||
"must": [
|
||||
{
|
||||
"term": {
|
||||
"age": @@@@
|
||||
"age": {
|
||||
"values": @@@@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -406,4 +409,4 @@ TEST(Expr, TestTerm) {
|
||||
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,6 +259,58 @@ TEST(Query, ExecWithPredicate) {
|
||||
])");
|
||||
ASSERT_EQ(json.dump(2), ref.dump(2));
|
||||
}
|
||||
TEST(Query, ExecTerm) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("age", DataType::FLOAT);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"term": {
|
||||
"age": {
|
||||
"values": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
int64_t N = 1000 * 1000;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = std::make_unique<SegmentSmallIndex>(schema);
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto num_queries = 3;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
QueryResult qr;
|
||||
Timestamp time = 1000000;
|
||||
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
|
||||
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
|
||||
std::vector<std::vector<std::string>> results;
|
||||
int topk = 5;
|
||||
auto json = QueryResultToJson(qr);
|
||||
ASSERT_EQ(qr.num_queries_, num_queries);
|
||||
ASSERT_EQ(qr.topK_, topk);
|
||||
// for(auto x: )
|
||||
}
|
||||
|
||||
TEST(Query, ExecWithoutPredicate) {
|
||||
using namespace milvus::query;
|
||||
|
||||
@ -283,6 +283,7 @@ class TestSearchBase:
|
||||
assert res[0]._distances[0] < epsilon
|
||||
assert check_id_result(res[0], ids[0])
|
||||
|
||||
# DOG: TODO INVALID TYPE UNKNOWN
|
||||
@pytest.mark.skip("search_after_index_different_metric_type")
|
||||
def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
@ -454,6 +455,7 @@ class TestSearchBase:
|
||||
# test for ip metric
|
||||
#
|
||||
# TODO: reopen after we supporting ip flat
|
||||
# DOG: TODO REDUCE
|
||||
@pytest.mark.skip("search_ip_flat")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
@ -646,6 +648,7 @@ class TestSearchBase:
|
||||
# TODO:
|
||||
# assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
|
||||
|
||||
# DOG: TODO REDUCE
|
||||
# TODO: reopen after we supporting ip flat
|
||||
@pytest.mark.skip("search_distance_ip")
|
||||
@pytest.mark.level(2)
|
||||
@ -702,6 +705,7 @@ class TestSearchBase:
|
||||
# TODO:
|
||||
# assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
|
||||
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_jaccard_flat_index")
|
||||
def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
@ -718,13 +722,14 @@ class TestSearchBase:
|
||||
res = connect.search(binary_collection, query)
|
||||
assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
# DOG: TODO INVALID TYPE
|
||||
@pytest.mark.skip("search_distance_jaccard_flat_index_L2")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection):
|
||||
'''
|
||||
target: search binary_collection, and check the result: distance
|
||||
method: compare the return distance value with value computed with L2
|
||||
expected: the return distance equals to the computed value
|
||||
expected: throw error of mismatched metric type
|
||||
'''
|
||||
nq = 1
|
||||
int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
|
||||
@ -735,6 +740,7 @@ class TestSearchBase:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(binary_collection, query)
|
||||
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_hamming_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_hamming_flat_index(self, connect, binary_collection):
|
||||
@ -752,6 +758,7 @@ class TestSearchBase:
|
||||
res = connect.search(binary_collection, query)
|
||||
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
|
||||
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_substructure_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_substructure_flat_index(self, connect, binary_collection):
|
||||
@ -770,6 +777,7 @@ class TestSearchBase:
|
||||
res = connect.search(binary_collection, query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_substructure_flat_index_B")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
|
||||
@ -789,6 +797,7 @@ class TestSearchBase:
|
||||
assert res[1][0].distance <= epsilon
|
||||
assert res[1][0].id == ids[1]
|
||||
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_superstructure_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
|
||||
@ -807,6 +816,7 @@ class TestSearchBase:
|
||||
res = connect.search(binary_collection, query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_superstructure_flat_index_B")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
|
||||
@ -828,6 +838,7 @@ class TestSearchBase:
|
||||
assert res[1][0].id in ids
|
||||
assert res[1][0].distance <= epsilon
|
||||
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_tanimoto_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
|
||||
@ -966,6 +977,7 @@ class TestSearchDSL(object):
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_no_must")
|
||||
def test_query_no_must(self, connect, collection):
|
||||
'''
|
||||
@ -977,6 +989,7 @@ class TestSearchDSL(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_no_vector_term_only")
|
||||
def test_query_no_vector_term_only(self, connect, collection):
|
||||
'''
|
||||
@ -1012,6 +1025,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_wrong_format")
|
||||
def test_query_wrong_format(self, connect, collection):
|
||||
'''
|
||||
@ -1042,7 +1056,7 @@ class TestSearchDSL(object):
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("query_term_value_not_in")
|
||||
# PASS
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_value_not_in(self, connect, collection):
|
||||
'''
|
||||
@ -1058,8 +1072,7 @@ class TestSearchDSL(object):
|
||||
assert len(res[0]) == 0
|
||||
# TODO:
|
||||
|
||||
# TODO:
|
||||
@pytest.mark.skip("query_term_value_all_in")
|
||||
# PASS
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_value_all_in(self, connect, collection):
|
||||
'''
|
||||
@ -1074,8 +1087,7 @@ class TestSearchDSL(object):
|
||||
assert len(res[0]) == 1
|
||||
# TODO:
|
||||
|
||||
# TODO:
|
||||
@pytest.mark.skip("query_term_values_not_in")
|
||||
# PASS
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_values_not_in(self, connect, collection):
|
||||
'''
|
||||
@ -1091,7 +1103,7 @@ class TestSearchDSL(object):
|
||||
assert len(res[0]) == 0
|
||||
# TODO:
|
||||
|
||||
@pytest.mark.skip("query_term_values_all_in")
|
||||
# PASS
|
||||
def test_query_term_values_all_in(self, connect, collection):
|
||||
'''
|
||||
method: build query with vector and term expr, with all term can be filtered
|
||||
@ -1110,7 +1122,7 @@ class TestSearchDSL(object):
|
||||
assert result.id in ids[:limit]
|
||||
# TODO:
|
||||
|
||||
@pytest.mark.skip("query_term_values_parts_in")
|
||||
# PASS
|
||||
def test_query_term_values_parts_in(self, connect, collection):
|
||||
'''
|
||||
method: build query with vector and term expr, with parts of term can be filtered
|
||||
@ -1126,8 +1138,7 @@ class TestSearchDSL(object):
|
||||
assert len(res[0]) == default_top_k
|
||||
# TODO:
|
||||
|
||||
# TODO:
|
||||
@pytest.mark.skip("query_term_values_repeat")
|
||||
# PASS
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_values_repeat(self, connect, collection):
|
||||
'''
|
||||
@ -1144,6 +1155,7 @@ class TestSearchDSL(object):
|
||||
assert len(res[0]) == 1
|
||||
# TODO:
|
||||
|
||||
# DOG: BUG, please fix
|
||||
@pytest.mark.skip("query_term_value_empty")
|
||||
def test_query_term_value_empty(self, connect, collection):
|
||||
'''
|
||||
@ -1156,6 +1168,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# DOG: TODO TRC
|
||||
@pytest.mark.skip("query_complex_dsl")
|
||||
def test_query_complex_dsl(self, connect, collection):
|
||||
'''
|
||||
@ -1178,6 +1191,7 @@ class TestSearchDSL(object):
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
# TODO
|
||||
@pytest.mark.skip("query_term_key_error")
|
||||
@pytest.mark.level(2)
|
||||
@ -1199,6 +1213,7 @@ class TestSearchDSL(object):
|
||||
def get_invalid_term(self, request):
|
||||
return request.param
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_term_wrong_format")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
|
||||
@ -1213,6 +1228,7 @@ class TestSearchDSL(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO UNKNOWN
|
||||
# TODO
|
||||
@pytest.mark.skip("query_term_field_named_term")
|
||||
@pytest.mark.level(2)
|
||||
@ -1239,6 +1255,7 @@ class TestSearchDSL(object):
|
||||
assert len(res[0]) == default_top_k
|
||||
connect.drop_collection(collection_term)
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_term_one_field_not_existed")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_one_field_not_existed(self, connect, collection):
|
||||
@ -1349,6 +1366,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_range_one_field_not_existed")
|
||||
def test_query_range_one_field_not_existed(self, connect, collection):
|
||||
'''
|
||||
@ -1369,6 +1387,7 @@ class TestSearchDSL(object):
|
||||
************************************************************************
|
||||
"""
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_term_has_common")
|
||||
@pytest.mark.level(2)
|
||||
@ -1386,6 +1405,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_term_no_common")
|
||||
@pytest.mark.level(2)
|
||||
@ -1403,6 +1423,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_term_different_fields")
|
||||
def test_query_multi_term_different_fields(self, connect, collection):
|
||||
@ -1420,6 +1441,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_term_multi_fields")
|
||||
@pytest.mark.level(2)
|
||||
@ -1437,6 +1459,7 @@ class TestSearchDSL(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_range_has_common")
|
||||
@pytest.mark.level(2)
|
||||
@ -1454,6 +1477,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_range_no_common")
|
||||
@pytest.mark.level(2)
|
||||
@ -1471,6 +1495,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_range_different_fields")
|
||||
@pytest.mark.level(2)
|
||||
@ -1488,6 +1513,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_range_multi_fields")
|
||||
@pytest.mark.level(2)
|
||||
@ -1511,6 +1537,7 @@ class TestSearchDSL(object):
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_term_range_has_common")
|
||||
@pytest.mark.level(2)
|
||||
@ -1528,6 +1555,7 @@ class TestSearchDSL(object):
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_term_range_no_common")
|
||||
def test_query_single_term_range_no_common(self, connect, collection):
|
||||
@ -1588,6 +1616,7 @@ class TestSearchDSLBools(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_should_only_term")
|
||||
def test_query_should_only_term(self, connect, collection):
|
||||
'''
|
||||
@ -1599,6 +1628,7 @@ class TestSearchDSLBools(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_should_only_vector")
|
||||
def test_query_should_only_vector(self, connect, collection):
|
||||
'''
|
||||
@ -1610,6 +1640,7 @@ class TestSearchDSLBools(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_must_not_only_term")
|
||||
def test_query_must_not_only_term(self, connect, collection):
|
||||
'''
|
||||
@ -1621,6 +1652,7 @@ class TestSearchDSLBools(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_must_not_vector")
|
||||
def test_query_must_not_vector(self, connect, collection):
|
||||
'''
|
||||
@ -1632,6 +1664,7 @@ class TestSearchDSLBools(object):
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_must_should")
|
||||
def test_query_must_should(self, connect, collection):
|
||||
'''
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user