From e379b1f0f4317d1599ec3dea653324356ec75ee6 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Wed, 24 Dec 2025 00:39:19 +0800 Subject: [PATCH] enhance: moved query optimization to proxy, added various optimizations (#45526) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit issue: https://github.com/milvus-io/milvus/issues/45525 see added README.md for added optimizations ## Summary by CodeRabbit * **New Features** * Added query expression optimization feature with a new `optimizeExpr` configuration flag to enable automatic simplification of filter predicates, including range predicate optimization, merging of IN/NOT IN conditions, and flattening of nested logical operators. * **Bug Fixes** * Adjusted delete operation behavior to correctly handle expression evaluation. ✏️ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Buqian Zheng --- internal/core/src/exec/expression/Expr.cpp | 226 +---- internal/core/unittest/test_exec.cpp | 701 ------------- .../parser/planparserv2/plan_parser_v2.go | 2 + .../parser/planparserv2/rewriter/README.md | 138 +++ .../parser/planparserv2/rewriter/entry.go | 200 ++++ .../parser/planparserv2/rewriter/range.go | 923 ++++++++++++++++++ .../rewriter/range_binary_test.go | 412 ++++++++ .../planparserv2/rewriter/range_json_test.go | 294 ++++++ .../planparserv2/rewriter/range_test.go | 492 ++++++++++ .../parser/planparserv2/rewriter/term_in.go | 666 +++++++++++++ .../planparserv2/rewriter/term_in_test.go | 356 +++++++ .../planparserv2/rewriter/text_match.go | 78 ++ .../planparserv2/rewriter/text_match_test.go | 122 +++ internal/parser/planparserv2/rewriter/util.go | 290 ++++++ internal/proxy/task_search_test.go | 2 +- internal/util/exprutil/expr_checker_test.go | 8 +- tests/go_client/testcases/delete_test.go | 30 - tests/python_client/testcases/test_delete.py | 20 +- 18 files changed, 3982 insertions(+), 978 deletions(-) create mode 100644 internal/parser/planparserv2/rewriter/README.md create mode 100644 internal/parser/planparserv2/rewriter/entry.go create mode 100644 internal/parser/planparserv2/rewriter/range.go create mode 100644 internal/parser/planparserv2/rewriter/range_binary_test.go create mode 100644 internal/parser/planparserv2/rewriter/range_json_test.go create mode 100644 internal/parser/planparserv2/rewriter/range_test.go create mode 100644 internal/parser/planparserv2/rewriter/term_in.go create mode 100644 internal/parser/planparserv2/rewriter/term_in_test.go create mode 100644 internal/parser/planparserv2/rewriter/text_match.go create mode 100644 internal/parser/planparserv2/rewriter/text_match_test.go create mode 100644 internal/parser/planparserv2/rewriter/util.go diff --git a/internal/core/src/exec/expression/Expr.cpp b/internal/core/src/exec/expression/Expr.cpp index 83f12116f2..15d7da1918 100644 --- a/internal/core/src/exec/expression/Expr.cpp +++ b/internal/core/src/exec/expression/Expr.cpp @@ -18,7 +18,6 @@ #include "common/EasyAssert.h" #include "common/Tracer.h" -#include "fmt/format.h" #include "exec/expression/AlwaysTrueExpr.h" #include "exec/expression/BinaryArithOpEvalRangeExpr.h" #include "exec/expression/BinaryRangeExpr.h" @@ -84,75 +83,19 @@ CompileExpressions(const std::vector& sources, return exprs; } -static std::optional -ShouldFlatten(const expr::TypedExprPtr& expr, - const std::unordered_set& flat_candidates = {}) { - if (auto call = - std::dynamic_pointer_cast(expr)) { - if (call->op_type_ == expr::LogicalBinaryExpr::OpType::And || - call->op_type_ == expr::LogicalBinaryExpr::OpType::Or) { - return call->name(); - } - } - return std::nullopt; -} - -static bool -IsCall(const expr::TypedExprPtr& expr, const std::string& name) { - if (auto call = - std::dynamic_pointer_cast(expr)) { - return call->name() == name; - } - return false; -} - -static bool -AllInputTypeEqual(const expr::TypedExprPtr& expr) { - const auto& inputs = expr->inputs(); - for (int i = 1; i < inputs.size(); i++) { - if (inputs[0]->type() != inputs[i]->type()) { - return false; - } - } - return true; -} - -static void -FlattenInput(const expr::TypedExprPtr& input, - const std::string& flatten_call, - std::vector& flat) { - if (IsCall(input, flatten_call) && AllInputTypeEqual(input)) { - for (auto& child : input->inputs()) { - FlattenInput(child, flatten_call, flat); - } - } else { - flat.emplace_back(input); - } -} - std::vector CompileInputs(const expr::TypedExprPtr& expr, QueryContext* context, const std::unordered_set& flatten_cadidates) { std::vector compiled_inputs; - auto flatten = ShouldFlatten(expr); for (auto& input : expr->inputs()) { if (dynamic_cast(input.get())) { AssertInfo( dynamic_cast(expr.get()), "An InputReference can only occur under a FieldReference"); } else { - if (flatten.has_value()) { - std::vector flat_exprs; - FlattenInput(input, flatten.value(), flat_exprs); - for (auto& input : flat_exprs) { - compiled_inputs.push_back(CompileExpression( - input, context, flatten_cadidates, false)); - } - } else { - compiled_inputs.push_back(CompileExpression( - input, context, flatten_cadidates, false)); - } + compiled_inputs.push_back( + CompileExpression(input, context, flatten_cadidates, false)); } } return compiled_inputs; @@ -507,170 +450,6 @@ ReorderConjunctExpr(std::shared_ptr& expr, expr->Reorder(reorder); } -inline std::shared_ptr -ConvertMultiNotEqualToNotInExpr(std::vector>& exprs, - std::vector indices, - ExecContext* context) { - std::vector values; - auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET; - for (auto& i : indices) { - auto expr = std::static_pointer_cast(exprs[i]) - ->GetLogicalExpr(); - if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) { - type = expr->val_.val_case(); - } - if (type != expr->val_.val_case()) { - return nullptr; - } - values.push_back(expr->val_); - } - auto logical_expr = std::make_shared( - exprs[indices[0]]->GetColumnInfo().value(), values); - auto query_context = context->get_query_context(); - auto term_expr = std::make_shared( - std::vector>{}, - logical_expr, - "PhyTermFilterExpr", - query_context->get_op_context(), - query_context->get_segment(), - query_context->get_active_count(), - query_context->get_query_timestamp(), - query_context->query_config()->get_expr_batch_size(), - query_context->get_consistency_level()); - return std::make_shared( - std::vector>{term_expr}, - std::make_shared( - milvus::expr::LogicalUnaryExpr::OpType::LogicalNot, logical_expr), - "PhyLogicalUnaryExpr", - query_context->get_op_context()); -} - -inline std::shared_ptr -ConvertMultiOrToInExpr(std::vector>& exprs, - std::vector indices, - ExecContext* context) { - std::vector values; - auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET; - for (auto& i : indices) { - auto expr = std::static_pointer_cast(exprs[i]) - ->GetLogicalExpr(); - if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) { - type = expr->val_.val_case(); - } - if (type != expr->val_.val_case()) { - return nullptr; - } - values.push_back(expr->val_); - } - auto logical_expr = std::make_shared( - exprs[indices[0]]->GetColumnInfo().value(), values); - auto query_context = context->get_query_context(); - return std::make_shared( - std::vector>{}, - logical_expr, - "PhyTermFilterExpr", - query_context->get_op_context(), - query_context->get_segment(), - query_context->get_active_count(), - query_context->get_query_timestamp(), - query_context->query_config()->get_expr_batch_size(), - query_context->get_consistency_level()); -} - -inline void -RewriteConjunctExpr(std::shared_ptr& expr, - ExecContext* context) { - // covert A = .. or A = .. or A = .. to A in (.., .., ..) - if (expr->IsOr()) { - auto& inputs = expr->GetInputsRef(); - std::map> expr_indices; - for (size_t i = 0; i < inputs.size(); i++) { - auto input = inputs[i]; - if (input->name() == "PhyUnaryRangeFilterExpr") { - auto phy_expr = - std::static_pointer_cast(input); - if (phy_expr->GetOpType() == proto::plan::OpType::Equal) { - auto column = phy_expr->GetColumnInfo().value(); - if (expr_indices.find(column) != expr_indices.end()) { - expr_indices[column].push_back(i); - } else { - expr_indices[column] = {i}; - } - } - } - - if (input->name() == "PhyConjunctFilterExpr") { - auto expr = std::static_pointer_cast< - milvus::exec::PhyConjunctFilterExpr>(input); - RewriteConjunctExpr(expr, context); - } - } - - for (auto& [column, indices] : expr_indices) { - // For numeric type, if or column greater than 150, then using in expr replace. - // For other type, all convert to in expr. - if ((IsNumericDataType(column.data_type_) && - indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) || - (!IsNumericDataType(column.data_type_) && indices.size() > 1)) { - auto new_expr = - ConvertMultiOrToInExpr(inputs, indices, context); - if (new_expr) { - inputs[indices[0]] = new_expr; - for (size_t j = 1; j < indices.size(); j++) { - inputs[indices[j]] = nullptr; - } - } - } - } - inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr), - inputs.end()); - } - - // convert A != .. and A != .. and A != .. to not A in (.., .., ..) - if (expr->IsAnd()) { - auto& inputs = expr->GetInputsRef(); - std::map> expr_indices; - for (size_t i = 0; i < inputs.size(); i++) { - auto input = inputs[i]; - if (input->name() == "PhyUnaryRangeFilterExpr") { - auto phy_expr = - std::static_pointer_cast(input); - if (phy_expr->GetOpType() == proto::plan::OpType::NotEqual) { - auto column = phy_expr->GetColumnInfo().value(); - if (expr_indices.find(column) != expr_indices.end()) { - expr_indices[column].push_back(i); - } else { - expr_indices[column] = {i}; - } - } - - if (input->name() == "PhyConjunctFilterExpr") { - auto expr = std::static_pointer_cast< - milvus::exec::PhyConjunctFilterExpr>(input); - RewriteConjunctExpr(expr, context); - } - } - } - - for (auto& [column, indices] : expr_indices) { - if ((IsNumericDataType(column.data_type_) && - indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) || - (!IsNumericDataType(column.data_type_) && indices.size() > 1)) { - auto new_expr = - ConvertMultiNotEqualToNotInExpr(inputs, indices, context); - if (new_expr) { - inputs[indices[0]] = new_expr; - for (size_t j = 1; j < indices.size(); j++) { - inputs[indices[j]] = nullptr; - } - } - } - } - inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr), - inputs.end()); - } -} - inline void SetNamespaceSkipIndex(std::shared_ptr conjunct_expr, ExecContext* context) { @@ -727,7 +506,6 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector& exprs) { LOG_DEBUG("before reoder filter expression: {}", expr->ToString()); auto conjunct_expr = std::static_pointer_cast(expr); - RewriteConjunctExpr(conjunct_expr, context); bool has_heavy_operation = false; ReorderConjunctExpr(conjunct_expr, context, has_heavy_operation); LOG_DEBUG("after reorder filter expression: {}", expr->ToString()); diff --git a/internal/core/unittest/test_exec.cpp b/internal/core/unittest/test_exec.cpp index a4cf6ce583..c7100209d8 100644 --- a/internal/core/unittest/test_exec.cpp +++ b/internal/core/unittest/test_exec.cpp @@ -256,183 +256,6 @@ TEST_P(TaskTest, LogicalExpr) { EXPECT_EQ(num_rows, num_rows_); } -TEST_P(TaskTest, CompileInputs_and) { - using namespace milvus; - using namespace milvus::query; - using namespace milvus::segcore; - auto schema = std::make_shared(); - auto vec_fid = - schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - proto::plan::GenericValue val; - val.set_int64_val(10); - // expr: (int64_fid < 10 and int64_fid < 10) and (int64_fid < 10 and int64_fid < 10) - auto expr1 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr2 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - auto expr4 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr5 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr6 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - auto expr7 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr3, expr6); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - auto exprs = milvus::exec::CompileInputs(expr7, query_context.get(), {}); - EXPECT_EQ(exprs.size(), 4); - for (int i = 0; i < exprs.size(); ++i) { - std::cout << exprs[i]->name() << std::endl; - EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr"); - } -} - -TEST_P(TaskTest, CompileInputs_or_with_and) { - using namespace milvus; - using namespace milvus::query; - using namespace milvus::segcore; - auto schema = std::make_shared(); - auto vec_fid = - schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - proto::plan::GenericValue val; - val.set_int64_val(10); - { - // expr: (int64_fid > 10 and int64_fid > 10) or (int64_fid > 10 and int64_fid > 10) - auto expr1 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr2 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - auto expr4 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr5 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr6 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - auto expr7 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr3, expr6); - auto exprs = - milvus::exec::CompileInputs(expr7, query_context.get(), {}); - EXPECT_EQ(exprs.size(), 2); - for (int i = 0; i < exprs.size(); ++i) { - std::cout << exprs[i]->name() << std::endl; - EXPECT_STREQ(exprs[i]->name().c_str(), "PhyConjunctFilterExpr"); - } - } - { - // expr: (int64_fid < 10 or int64_fid < 10) or (int64_fid > 10 and int64_fid > 10) - auto expr1 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr2 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - auto expr4 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr5 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr6 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - auto expr7 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr3, expr6); - auto exprs = - milvus::exec::CompileInputs(expr7, query_context.get(), {}); - std::cout << exprs.size() << std::endl; - EXPECT_EQ(exprs.size(), 3); - for (int i = 0; i < exprs.size() - 1; ++i) { - std::cout << exprs[i]->name() << std::endl; - EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr"); - } - EXPECT_STREQ(exprs[2]->name().c_str(), "PhyConjunctFilterExpr"); - } - { - // expr: (int64_fid > 10 or int64_fid > 10) and (int64_fid > 10 and int64_fid > 10) - auto expr1 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr2 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - auto expr4 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr5 = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val, - std::vector{}); - auto expr6 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - auto expr7 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr3, expr6); - auto exprs = - milvus::exec::CompileInputs(expr7, query_context.get(), {}); - std::cout << exprs.size() << std::endl; - EXPECT_EQ(exprs.size(), 3); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - for (int i = 1; i < exprs.size(); ++i) { - std::cout << exprs[i]->name() << std::endl; - EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr"); - } - } -} - TEST_P(TaskTest, Test_reorder) { using namespace milvus; using namespace milvus::query; @@ -698,530 +521,6 @@ TEST_P(TaskTest, Test_reorder) { } } -TEST_P(TaskTest, Test_MultiNotEqualConvert) { - using namespace milvus; - using namespace milvus::query; - using namespace milvus::segcore; - using namespace milvus::exec; - - { - // expr: string2 != "111" and string2 != "222" and string2 != "333" - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::NotEqual, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::NotEqual, - val2, - std::vector{}); - proto::plan::GenericValue val3; - val3.set_string_val("333"); - auto expr3 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::NotEqual, - val3, - std::vector{}); - auto expr4 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr4, expr3); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 1); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyLogicalUnaryExpr"); - EXPECT_EQ(inputs[0]->GetInputsRef().size(), 1); - EXPECT_STREQ(inputs[0]->GetInputsRef()[0]->name().c_str(), - "PhyTermFilterExpr"); - } - - { - // expr: int64 != 111 and int64 != 222 and int64 != 333 - proto::plan::GenericValue val1; - val1.set_int64_val(111); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::NotEqual, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_int64_val(222); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::NotEqual, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - proto::plan::GenericValue val3; - val3.set_int64_val(333); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::NotEqual, - val3, - std::vector{}); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr3, expr4); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 3); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyUnaryRangeFilterExpr"); - EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr"); - EXPECT_STREQ(inputs[2]->name().c_str(), "PhyUnaryRangeFilterExpr"); - } - - { - // expr: string2 != "111" and string2 != "222" and (int64 > 10 && int64 < 100) and string2 != "333" - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::NotEqual, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::NotEqual, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - - proto::plan::GenericValue val3; - val3.set_int64_val(10); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::GreaterThan, - val3, - std::vector{}); - proto::plan::GenericValue val4; - val4.set_int64_val(100); - auto expr5 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::LessThan, - val4, - std::vector{}); - auto expr6 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr4, expr5); - - auto expr7 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr6, expr3); - - proto::plan::GenericValue val5; - val5.set_string_val("333"); - auto expr8 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::NotEqual, - val5, - std::vector{}); - auto expr9 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr7, expr8); - - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr9}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 2); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr"); - EXPECT_STREQ(inputs[1]->name().c_str(), "PhyLogicalUnaryExpr"); - auto phy_expr1 = - std::static_pointer_cast( - inputs[1]); - EXPECT_EQ(phy_expr1->GetInputsRef().size(), 1); - EXPECT_STREQ(phy_expr1->GetInputsRef()[0]->name().c_str(), - "PhyTermFilterExpr"); - phy_expr = - std::static_pointer_cast( - inputs[0]); - inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 2); - } - - { - // expr: json['a'] != "111" and json['a'] != "222" and json['a'] != "333" - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::NotEqual, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::NotEqual, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr1, expr2); - proto::plan::GenericValue val3; - val3.set_string_val("333"); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::NotEqual, - val3, - std::vector{}); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr3, expr4); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 1); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyLogicalUnaryExpr"); - auto phy_expr1 = - std::static_pointer_cast( - inputs[0]); - EXPECT_EQ(phy_expr1->GetInputsRef().size(), 1); - EXPECT_STREQ(phy_expr1->GetInputsRef()[0]->name().c_str(), - "PhyTermFilterExpr"); - } -} - -TEST_P(TaskTest, Test_MultiInConvert) { - using namespace milvus; - using namespace milvus::query; - using namespace milvus::segcore; - using namespace milvus::exec; - - { - // expr: string2 == '111' or string2 == '222' or string2 == "333" - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::Equal, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::Equal, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - proto::plan::GenericValue val3; - val3.set_string_val("333"); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::Equal, - val3, - std::vector{}); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 1); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); - } - - { - // expr: string2 == '111' or string2 == '222' or (int64 > 10 && int64 < 100) or string2 == "333" - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::Equal, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::Equal, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - - proto::plan::GenericValue val3; - val3.set_int64_val(10); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::GreaterThan, - val3, - std::vector{}); - proto::plan::GenericValue val4; - val4.set_int64_val(100); - auto expr5 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::LessThan, - val4, - std::vector{}); - auto expr6 = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, expr4, expr5); - - auto expr7 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr6, expr3); - - proto::plan::GenericValue val5; - val5.set_string_val("333"); - auto expr8 = std::make_shared( - expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), - proto::plan::OpType::Equal, - val5, - std::vector{}); - auto expr9 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr7, expr8); - - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr9}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 2); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr"); - EXPECT_STREQ(inputs[1]->name().c_str(), "PhyTermFilterExpr"); - } - { - // expr: json['a'] == "111" or json['a'] == "222" or json['3'] = "333" - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::Equal, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::Equal, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - proto::plan::GenericValue val3; - val3.set_string_val("333"); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::Equal, - val3, - std::vector{}); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 1); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); - } - - { - // expr: json['a'] == "111" or json['b'] == "222" or json['a'] == "333" - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::Equal, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'b'}), - proto::plan::OpType::Equal, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - proto::plan::GenericValue val3; - val3.set_string_val("333"); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::Equal, - val3, - std::vector{}); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 2); - EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); - EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr"); - } - - { - // expr: json['a'] == "111" or json['b'] == "222" or json['a'] == 1 - proto::plan::GenericValue val1; - val1.set_string_val("111"); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::Equal, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_string_val("222"); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'b'}), - proto::plan::OpType::Equal, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - proto::plan::GenericValue val3; - val3.set_int64_val(1); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["json"], - DataType::JSON, - std::vector{'a'}), - proto::plan::OpType::Equal, - val3, - std::vector{}); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 3); - } - - { - // expr: int1 == 11 or int1 == 22 or int3 == 33 - proto::plan::GenericValue val1; - val1.set_int64_val(11); - auto expr1 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::Equal, - val1, - std::vector{}); - proto::plan::GenericValue val2; - val2.set_int64_val(222); - auto expr2 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::Equal, - val2, - std::vector{}); - auto expr3 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); - proto::plan::GenericValue val3; - val3.set_int64_val(1); - auto expr4 = std::make_shared( - expr::ColumnInfo(field_map_["int64"], DataType::INT64), - proto::plan::OpType::Equal, - val3, - std::vector{}); - auto expr5 = std::make_shared( - expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); - ExecContext context(query_context.get()); - auto exprs = - milvus::exec::CompileExpressions({expr5}, &context, {}, false); - EXPECT_EQ(exprs.size(), 1); - EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); - auto phy_expr = - std::static_pointer_cast( - exprs[0]); - auto inputs = phy_expr->GetInputsRef(); - EXPECT_EQ(inputs.size(), 3); - } -} - // This test verifies the fix for https://github.com/milvus-io/milvus/issues/46053. // // Bug scenario: diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index cf897b908c..5c67efbb35 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -15,6 +15,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated" + "github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter" "github.com/milvus-io/milvus/internal/util/function/rerank" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" @@ -149,6 +150,7 @@ func parseExprInner(schema *typeutil.SchemaHelper, exprStr string, exprTemplateV return nil, err } + predicate.expr = rewriter.RewriteExpr(predicate.expr) return predicate.expr, nil } diff --git a/internal/parser/planparserv2/rewriter/README.md b/internal/parser/planparserv2/rewriter/README.md new file mode 100644 index 0000000000..3bfa4cce70 --- /dev/null +++ b/internal/parser/planparserv2/rewriter/README.md @@ -0,0 +1,138 @@ +## Expression Rewriter (planparserv2/rewriter) + +This module performs rule-based logical rewrites on parsed `planpb.Expr` trees right after template value filling and before planning/execution. + +### Entry +- `RewriteExpr(*planpb.Expr) *planpb.Expr` (in `entry.go`) + - Recursively visits the expression tree and applies a set of composable, side-effect-free rewrite rules. + - Uses global configuration from `paramtable.Get().CommonCfg.EnabledOptimizeExpr` +- `RewriteExprWithConfig(*planpb.Expr, bool) *planpb.Expr` (in `entry.go`) + - Same as `RewriteExpr` but allows custom configuration for testing or special cases. + +### Configuration + +The rewriter can be configured via the following parameter (refreshable at runtime): + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `common.enabledOptimizeExpr` | `true` | Enable query expression optimization including range simplification, IN/NOT IN merge, TEXT_MATCH merge, and all other optimizations | + +**IMPORTANT**: IN/NOT IN value list sorting and deduplication **always** runs regardless of this configuration setting, because the execution engine depends on sorted value lists. + +### Implemented Rules + +1) IN / NOT IN normalization and merges (`term_in.go`) +- OR-equals to IN (same column): + - `a == v1 OR a == v2 ...` → `a IN (v1, v2, ...)` + - Numeric columns only merge when count > threshold (default 150); others when count > 1. +- AND-not-equals to NOT IN (same column): + - `a != v1 AND a != v2 ...` → `NOT (a IN (v1, v2, ...))` + - Same thresholds as above. +- IN vs Equal redundancy elimination (same column): + - AND: `(a ∈ S) AND (a = v)`: + - if `v ∈ S` → `a = v` + - if `v ∉ S` → contradiction → constant `false` + - OR: `(a ∈ S) OR (a = v)` → `a ∈ (S ∪ {v})` (always union) +- IN with IN union: + - OR: `(a ∈ S1) OR (a ∈ S2)` → `a ∈ (S1 ∪ S2)` with sorting/dedup + - AND: `(a ∈ S1) AND (a ∈ S2)` → `a ∈ (S1 ∩ S2)`; empty intersection → constant `false` +- Sort and deduplicate `IN` / `NOT IN` value lists (supported types: bool, int64, float64, string). + +2) TEXT_MATCH OR merge (`text_match.go`) +- Merge ORs of `TEXT_MATCH(field, "literal")` on the same column (no options): + - Concatenate literals with a single space in the order they appear; no tokenization, deduplication, or sorting is performed. + - Example: `TEXT_MATCH(f, "A C") OR TEXT_MATCH(f, "B D")` → `TEXT_MATCH(f, "A C B D")` +- If any `TEXT_MATCH` in the group has options (e.g., `minimum_should_match`), this optimization is skipped for that group. + +3) Range predicate simplification (`range.go`) +- AND tighten (same column): + - Lower bounds: `a > 10 AND a > 20` → `a > 20` (pick strongest lower) + - Upper bounds: `a < 50 AND a < 60` → `a < 50` (pick strongest upper) + - Mixed lower and upper: `a > 10 AND a < 50` → `10 < a < 50` (BinaryRangeExpr) + - Inclusion respected (>, >=, <, <=). On ties, exclusive is considered stronger than inclusive for tightening. +- OR weaken (same column, same direction): + - Lower bounds: `a > 10 OR a > 20` → `a > 10` (pick weakest lower) + - Upper bounds: `a < 10 OR a < 20` → `a < 20` (pick weakest upper) + - Inclusion respected, preferring inclusive for weakening in ties. +- Mixed-direction OR (lower vs upper) is not merged. +- Equivalent-bound collapses (same column, same value): + - AND: `a ≥ x AND a > x` → `a > x`; `a ≤ y AND a < y` → `a < y` + - OR: `a ≥ x OR a > x` → `a ≥ x`; `a ≤ y OR a < y` → `a ≤ y` + - Symmetric dedup: `a > 10 AND a ≥ 10` → `a > 10`; `a < 5 OR a ≤ 5` → `a ≤ 5` +- IN ∩ range filtering: + - AND: `(a ∈ {…}) AND (range)` → keep only values in the set that satisfy the range + - e.g., `{1,3,5} AND a > 3` → `{5}` +- Supported columns for range optimization: + - Scalar: Int8/Int16/Int32/Int64, Float/Double, VarChar + - Array element access: when indexing an element (e.g., `ArrayInt[0]`), the element type above applies + - JSON/dynamic fields with nested paths (e.g., `JSONField["price"]`, `$meta["age"]`) are range-optimized + - Type determined from literal value (int, float, string) + - Numeric types (int and float) are compatible and normalized to Double for merging + - Different type categories are not merged (e.g., `json["a"] > 10` and `json["a"] > "hello"` remain separate) + - Bool literals are not optimized (no meaningful ranges) +- Literal compatibility: + - Integer columns require integer literals (e.g., `Int64Field > 10`) + - Float/Double columns accept both integer and float literals (e.g., `FloatField > 10` or `> 10.5`) +- Column identity: + - Merges only happen within the same `ColumnInfo` (including nested path and element index). For example, `ArrayInt[0]` and `ArrayInt[1]` are different columns and are not merged with each other. +- BinaryRangeExpr merging: + - AND: Merge multiple `BinaryRangeExpr` nodes on the same column to compute intersection (max lower, min upper) + - `(10 < x < 50) AND (20 < x < 40)` → `(20 < x < 40)` + - Empty intersection → constant `false` + - AND with UnaryRangeExpr: Update appropriate bound of `BinaryRangeExpr` + - `(10 < x < 50) AND (x > 30)` → `(30 < x < 50)` + - OR: Merge overlapping or adjacent `BinaryRangeExpr` nodes into wider interval + - `(10 < x < 25) OR (20 < x < 40)` → `(10 < x < 40)` (overlapping) + - `(10 < x <= 20) OR (20 <= x < 30)` → `(10 < x < 30)` (adjacent with inclusive) + - Disjoint intervals remain separate: `(10 < x < 20) OR (30 < x < 40)` → remains as OR + - Inclusivity handling: AND prefers exclusive on equal bounds (stronger), OR prefers inclusive (weaker) + +### General Notes +- All merges require operands to target the same column (same `ColumnInfo`, including nested path/element type). +- Rewrite runs after template value filling; template placeholders do not appear here. +- Sorting/dedup for IN/NOT IN is deterministic; duplicates are removed post-sort. +- Numeric-threshold for OR→IN / AND≠→NOT IN is defined in `util.go` (`defaultConvertOrToInNumericLimit`, default 150). + +### Pass Ordering (current) +- OR branch: + 1. Flatten + 2. OR `==` → IN + 3. TEXT_MATCH merge (no options) + 4. Range weaken (same-direction bounds) + 5. BinaryRangeExpr merge (overlapping/adjacent intervals) + 6. IN with `!=` short-circuiting + 7. IN ∪ IN union + 8. IN vs Equal redundancy elimination + 9. Fold back to BinaryExpr +- AND branch: + 1. Flatten + 2. Range tighten / interval construction + 3. BinaryRangeExpr merge (intersection, also with UnaryRangeExpr) + 4. IN ∪ IN intersection (if any) + 5. IN with `!=` filtering + 6. IN ∩ range filtering + 7. IN vs Equal redundancy elimination + 8. AND `!=` → NOT IN + 9. Fold back to BinaryExpr + +Each construction of IN will be normalized (sorted and deduplicated). TEXT_MATCH OR merge concatenates literals with a single space; no tokenization, deduplication, or sorting is performed. + +### File Structure +- `entry.go` — rewrite entry and visitor orchestration +- `util.go` — shared helpers (column keying, value classification, sorting/dedup, constructors) +- `term_in.go` — IN/NOT IN normalization and conversions +- `text_match.go` — TEXT_MATCH OR merge (no options) +- `range.go` — range tightening/weakening and interval construction + +### Future Extensions +- More IN-range algebra (e.g., `IN` vs exact equality propagation across subtrees). +- Merging phrase_match or other string ops with clearly-defined token rules. +- More algebraic simplifications around equality and null checks: + - Contradiction detection: `(a == 1) AND (a == 2)` → `false`; `(a > 10) AND (a == 5)` → `false` + - Tautology detection: `(a > 10) OR (a <= 10)` → `true` (for non-NULL values) + - Absorption laws: `(a > 10) OR ((a > 10) AND (b > 20))` → `a > 10` +- Advanced BinaryRangeExpr merging: + - OR with 3+ intervals: Currently limited to 2 intervals. Full interval merging algorithm needed for `(10 < x < 20) OR (15 < x < 25) OR (22 < x < 30)` → `(10 < x < 30)`. + - OR with unbounded + bounded: Currently skipped. Could optimize `(x > 10) OR (5 < x < 15)` → `x > 5`. + + diff --git a/internal/parser/planparserv2/rewriter/entry.go b/internal/parser/planparserv2/rewriter/entry.go new file mode 100644 index 0000000000..d6bc73a4c1 --- /dev/null +++ b/internal/parser/planparserv2/rewriter/entry.go @@ -0,0 +1,200 @@ +package rewriter + +import ( + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" +) + +func RewriteExpr(e *planpb.Expr) *planpb.Expr { + optimizeEnabled := paramtable.Get().CommonCfg.EnabledOptimizeExpr.GetAsBool() + return RewriteExprWithConfig(e, optimizeEnabled) +} + +func RewriteExprWithConfig(e *planpb.Expr, optimizeEnabled bool) *planpb.Expr { + if e == nil { + return nil + } + v := &visitor{optimizeEnabled: optimizeEnabled} + res := v.visitExpr(e) + if out, ok := res.(*planpb.Expr); ok && out != nil { + return out + } + return e +} + +type visitor struct { + optimizeEnabled bool +} + +func (v *visitor) visitExpr(expr *planpb.Expr) interface{} { + switch real := expr.GetExpr().(type) { + case *planpb.Expr_BinaryExpr: + return v.visitBinaryExpr(real.BinaryExpr) + case *planpb.Expr_UnaryExpr: + return v.visitUnaryExpr(real.UnaryExpr) + case *planpb.Expr_TermExpr: + return v.visitTermExpr(real.TermExpr) + // no optimization for other types + default: + return expr + } +} + +func (v *visitor) visitBinaryExpr(expr *planpb.BinaryExpr) interface{} { + left := v.visitExpr(expr.GetLeft()).(*planpb.Expr) + right := v.visitExpr(expr.GetRight()).(*planpb.Expr) + switch expr.GetOp() { + case planpb.BinaryExpr_LogicalOr: + parts := flattenOr(left, right) + if v.optimizeEnabled { + parts = v.combineOrEqualsToIn(parts) + parts = v.combineOrTextMatchToMerged(parts) + parts = v.combineOrRangePredicates(parts) + parts = v.combineOrBinaryRanges(parts) + parts = v.combineOrInWithNotEqual(parts) + parts = v.combineOrInWithIn(parts) + parts = v.combineOrInWithEqual(parts) + } + return foldBinary(planpb.BinaryExpr_LogicalOr, parts) + case planpb.BinaryExpr_LogicalAnd: + parts := flattenAnd(left, right) + if v.optimizeEnabled { + parts = v.combineAndRangePredicates(parts) + parts = v.combineAndBinaryRanges(parts) + parts = v.combineAndInWithIn(parts) + parts = v.combineAndInWithNotEqual(parts) + parts = v.combineAndInWithRange(parts) + parts = v.combineAndInWithEqual(parts) + parts = v.combineAndNotEqualsToNotIn(parts) + } + return foldBinary(planpb.BinaryExpr_LogicalAnd, parts) + default: + return &planpb.Expr{ + Expr: &planpb.Expr_BinaryExpr{ + BinaryExpr: &planpb.BinaryExpr{ + Left: left, + Right: right, + Op: expr.GetOp(), + }, + }, + } + } +} + +func (v *visitor) visitUnaryExpr(expr *planpb.UnaryExpr) interface{} { + child := v.visitExpr(expr.GetChild()).(*planpb.Expr) + + // Optimize double negation: NOT (NOT AlwaysTrue) → AlwaysTrue + if expr.GetOp() == planpb.UnaryExpr_Not { + if IsAlwaysFalseExpr(child) { + return newAlwaysTrueExpr() + } + } + + return &planpb.Expr{ + Expr: &planpb.Expr_UnaryExpr{ + UnaryExpr: &planpb.UnaryExpr{ + Op: expr.GetOp(), + Child: child, + }, + }, + } +} + +func (v *visitor) visitTermExpr(expr *planpb.TermExpr) interface{} { + sortTermValues(expr) + return &planpb.Expr{Expr: &planpb.Expr_TermExpr{TermExpr: expr}} +} + +func flattenOr(a, b *planpb.Expr) []*planpb.Expr { + out := make([]*planpb.Expr, 0, 4) + collectOr(a, &out) + collectOr(b, &out) + return out +} + +func collectOr(e *planpb.Expr, out *[]*planpb.Expr) { + if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalOr { + collectOr(be.GetLeft(), out) + collectOr(be.GetRight(), out) + return + } + *out = append(*out, e) +} + +func flattenAnd(a, b *planpb.Expr) []*planpb.Expr { + out := make([]*planpb.Expr, 0, 4) + collectAnd(a, &out) + collectAnd(b, &out) + return out +} + +func collectAnd(e *planpb.Expr, out *[]*planpb.Expr) { + if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd { + collectAnd(be.GetLeft(), out) + collectAnd(be.GetRight(), out) + return + } + *out = append(*out, e) +} + +func foldBinary(op planpb.BinaryExpr_BinaryOp, exprs []*planpb.Expr) *planpb.Expr { + if len(exprs) == 0 { + return nil + } + + // Handle AlwaysTrue and AlwaysFalse optimizations (single-pass) + switch op { + case planpb.BinaryExpr_LogicalAnd: + filtered := make([]*planpb.Expr, 0, len(exprs)) + for _, e := range exprs { + if IsAlwaysFalseExpr(e) { + // AND: any AlwaysFalse → entire expression is AlwaysFalse + return newAlwaysFalseExpr() + } + if !IsAlwaysTrueExpr(e) { + // Filter out AlwaysTrue (since AlwaysTrue AND X = X) + filtered = append(filtered, e) + } + } + exprs = filtered + // If all were AlwaysTrue, return AlwaysTrue + if len(exprs) == 0 { + return newAlwaysTrueExpr() + } + case planpb.BinaryExpr_LogicalOr: + filtered := make([]*planpb.Expr, 0, len(exprs)) + for _, e := range exprs { + if IsAlwaysTrueExpr(e) { + // OR: any AlwaysTrue → entire expression is AlwaysTrue + return newAlwaysTrueExpr() + } + if !IsAlwaysFalseExpr(e) { + // Filter out AlwaysFalse (since AlwaysFalse OR X = X) + filtered = append(filtered, e) + } + } + exprs = filtered + // If all were AlwaysFalse, return AlwaysFalse + if len(exprs) == 0 { + return newAlwaysFalseExpr() + } + } + + if len(exprs) == 1 { + return exprs[0] + } + cur := exprs[0] + for i := 1; i < len(exprs); i++ { + cur = &planpb.Expr{ + Expr: &planpb.Expr_BinaryExpr{ + BinaryExpr: &planpb.BinaryExpr{ + Left: cur, + Right: exprs[i], + Op: op, + }, + }, + } + } + return cur +} diff --git a/internal/parser/planparserv2/rewriter/range.go b/internal/parser/planparserv2/rewriter/range.go new file mode 100644 index 0000000000..67ee66cacb --- /dev/null +++ b/internal/parser/planparserv2/rewriter/range.go @@ -0,0 +1,923 @@ +package rewriter + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" +) + +type bound struct { + value *planpb.GenericValue + inclusive bool + isLower bool + exprIndex int +} + +func isSupportedScalarForRange(dt schemapb.DataType) bool { + switch dt { + case schemapb.DataType_Int8, + schemapb.DataType_Int16, + schemapb.DataType_Int32, + schemapb.DataType_Int64, + schemapb.DataType_Float, + schemapb.DataType_Double, + schemapb.DataType_VarChar: + return true + default: + return false + } +} + +// resolveEffectiveType returns (dt, ok) where ok indicates this column is eligible for range optimization. +// Eligible when the column is a supported scalar, or an array whose element type is a supported scalar, +// or a JSON field with a nested path (type will be determined from literal values). +func resolveEffectiveType(col *planpb.ColumnInfo) (schemapb.DataType, bool) { + if col == nil { + return schemapb.DataType_None, false + } + dt := col.GetDataType() + if isSupportedScalarForRange(dt) { + return dt, true + } + if dt == schemapb.DataType_Array { + et := col.GetElementType() + if isSupportedScalarForRange(et) { + return et, true + } + } + // JSON fields with nested paths are eligible; the effective type will be + // determined from the literal value in the comparison. + if dt == schemapb.DataType_JSON && len(col.GetNestedPath()) > 0 { + // Return a placeholder type; actual type checking happens in resolveJSONEffectiveType + return schemapb.DataType_JSON, true + } + return schemapb.DataType_None, false +} + +// resolveJSONEffectiveType returns the effective type for a JSON field based on the literal value. +// Returns (type, ok) where ok is false if the value is not suitable for range optimization. +// For numeric types (int and float), we normalize to Double to allow mixing, similar to scalar Float/Double columns. +func resolveJSONEffectiveType(v *planpb.GenericValue) (schemapb.DataType, bool) { + if v == nil || v.GetVal() == nil { + return schemapb.DataType_None, false + } + switch v.GetVal().(type) { + case *planpb.GenericValue_Int64Val: + // Normalize int to Double for JSON to allow mixing with float literals + return schemapb.DataType_Double, true + case *planpb.GenericValue_FloatVal: + return schemapb.DataType_Double, true + case *planpb.GenericValue_StringVal: + return schemapb.DataType_VarChar, true + case *planpb.GenericValue_BoolVal: + // Boolean comparisons don't have meaningful ranges + return schemapb.DataType_None, false + default: + return schemapb.DataType_None, false + } +} + +func valueMatchesType(dt schemapb.DataType, v *planpb.GenericValue) bool { + if v == nil || v.GetVal() == nil { + return false + } + switch dt { + case schemapb.DataType_Int8, + schemapb.DataType_Int16, + schemapb.DataType_Int32, + schemapb.DataType_Int64: + _, ok := v.GetVal().(*planpb.GenericValue_Int64Val) + return ok + case schemapb.DataType_Float, schemapb.DataType_Double: + // For float columns, accept both float and int literal values + switch v.GetVal().(type) { + case *planpb.GenericValue_FloatVal, *planpb.GenericValue_Int64Val: + return true + default: + return false + } + case schemapb.DataType_VarChar: + _, ok := v.GetVal().(*planpb.GenericValue_StringVal) + return ok + case schemapb.DataType_JSON: + // For JSON, check if we can determine a valid type from the value + _, ok := resolveJSONEffectiveType(v) + return ok + default: + return false + } +} + +func (v *visitor) combineAndRangePredicates(parts []*planpb.Expr) []*planpb.Expr { + type group struct { + col *planpb.ColumnInfo + effDt schemapb.DataType // effective type for comparison + lowers []bound + uppers []bound + } + groups := map[string]*group{} + // exprs not eligible for range optimization + others := []int{} + isRangeOp := func(op planpb.OpType) bool { + return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual || + op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual + } + for idx, e := range parts { + u := e.GetUnaryRangeExpr() + if u == nil || !isRangeOp(u.GetOp()) || u.GetValue() == nil { + others = append(others, idx) + continue + } + col := u.GetColumnInfo() + if col == nil { + others = append(others, idx) + continue + } + // Only optimize for supported types and matching value type + effDt, ok := resolveEffectiveType(col) + if !ok || !valueMatchesType(effDt, u.GetValue()) { + others = append(others, idx) + continue + } + // For JSON fields, determine the actual effective type from the literal value + if effDt == schemapb.DataType_JSON { + var typeOk bool + effDt, typeOk = resolveJSONEffectiveType(u.GetValue()) + if !typeOk { + others = append(others, idx) + continue + } + } + // Group by column + effective type (for JSON, type depends on literal) + key := columnKey(col) + fmt.Sprintf("|%d", effDt) + g, ok := groups[key] + if !ok { + g = &group{col: col, effDt: effDt} + groups[key] = g + } + b := bound{ + value: u.GetValue(), + inclusive: u.GetOp() == planpb.OpType_GreaterEqual || u.GetOp() == planpb.OpType_LessEqual, + isLower: u.GetOp() == planpb.OpType_GreaterThan || u.GetOp() == planpb.OpType_GreaterEqual, + exprIndex: idx, + } + if b.isLower { + g.lowers = append(g.lowers, b) + } else { + g.uppers = append(g.uppers, b) + } + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + for _, g := range groups { + // Use the effective type stored in the group + var bestLower *bound + for i := range g.lowers { + if bestLower == nil || cmpGeneric(g.effDt, g.lowers[i].value, bestLower.value) > 0 || + (cmpGeneric(g.effDt, g.lowers[i].value, bestLower.value) == 0 && !g.lowers[i].inclusive && bestLower.inclusive) { + b := g.lowers[i] + bestLower = &b + } + } + var bestUpper *bound + for i := range g.uppers { + if bestUpper == nil || cmpGeneric(g.effDt, g.uppers[i].value, bestUpper.value) < 0 || + (cmpGeneric(g.effDt, g.uppers[i].value, bestUpper.value) == 0 && !g.uppers[i].inclusive && bestUpper.inclusive) { + b := g.uppers[i] + bestUpper = &b + } + } + if bestLower != nil && bestUpper != nil { + // Check if the interval is valid (non-empty) + c := cmpGeneric(g.effDt, bestLower.value, bestUpper.value) + isEmpty := false + if c > 0 { + // lower > upper: always empty + isEmpty = true + } else if c == 0 { + // lower == upper: only valid if both bounds are inclusive + if !bestLower.inclusive || !bestUpper.inclusive { + isEmpty = true + } + } + + for _, b := range g.lowers { + used[b.exprIndex] = true + } + for _, b := range g.uppers { + used[b.exprIndex] = true + } + + if isEmpty { + // Empty interval → constant false + out = append(out, newAlwaysFalseExpr()) + } else { + out = append(out, newBinaryRangeExpr(g.col, bestLower.inclusive, bestUpper.inclusive, bestLower.value, bestUpper.value)) + } + } else if bestLower != nil { + for _, b := range g.lowers { + used[b.exprIndex] = true + } + op := planpb.OpType_GreaterThan + if bestLower.inclusive { + op = planpb.OpType_GreaterEqual + } + out = append(out, newUnaryRangeExpr(g.col, op, bestLower.value)) + } else if bestUpper != nil { + for _, b := range g.uppers { + used[b.exprIndex] = true + } + op := planpb.OpType_LessThan + if bestUpper.inclusive { + op = planpb.OpType_LessEqual + } + out = append(out, newUnaryRangeExpr(g.col, op, bestUpper.value)) + } + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +func (v *visitor) combineOrRangePredicates(parts []*planpb.Expr) []*planpb.Expr { + type key struct { + colKey string + isLower bool + effDt schemapb.DataType // effective type for JSON fields + } + type group struct { + col *planpb.ColumnInfo + effDt schemapb.DataType + dirLower bool + bounds []bound + } + groups := map[key]*group{} + others := []int{} + isRangeOp := func(op planpb.OpType) bool { + return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual || + op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual + } + for idx, e := range parts { + u := e.GetUnaryRangeExpr() + if u == nil || !isRangeOp(u.GetOp()) || u.GetValue() == nil { + others = append(others, idx) + continue + } + col := u.GetColumnInfo() + if col == nil { + others = append(others, idx) + continue + } + effDt, ok := resolveEffectiveType(col) + if !ok || !valueMatchesType(effDt, u.GetValue()) { + others = append(others, idx) + continue + } + // For JSON fields, determine the actual effective type from the literal value + if effDt == schemapb.DataType_JSON { + var typeOk bool + effDt, typeOk = resolveJSONEffectiveType(u.GetValue()) + if !typeOk { + others = append(others, idx) + continue + } + } + isLower := u.GetOp() == planpb.OpType_GreaterThan || u.GetOp() == planpb.OpType_GreaterEqual + k := key{colKey: columnKey(col), isLower: isLower, effDt: effDt} + g, ok := groups[k] + if !ok { + g = &group{col: col, effDt: effDt, dirLower: isLower} + groups[k] = g + } + g.bounds = append(g.bounds, bound{ + value: u.GetValue(), + inclusive: u.GetOp() == planpb.OpType_GreaterEqual || u.GetOp() == planpb.OpType_LessEqual, + isLower: isLower, + exprIndex: idx, + }) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + for _, g := range groups { + if len(g.bounds) <= 1 { + continue + } + if g.dirLower { + var best *bound + for i := range g.bounds { + // Use the effective type stored in the group + if best == nil || cmpGeneric(g.effDt, g.bounds[i].value, best.value) < 0 || + (cmpGeneric(g.effDt, g.bounds[i].value, best.value) == 0 && g.bounds[i].inclusive && !best.inclusive) { + b := g.bounds[i] + best = &b + } + } + for _, b := range g.bounds { + used[b.exprIndex] = true + } + op := planpb.OpType_GreaterThan + if best.inclusive { + op = planpb.OpType_GreaterEqual + } + out = append(out, newUnaryRangeExpr(g.col, op, best.value)) + } else { + var best *bound + for i := range g.bounds { + // Use the effective type stored in the group + if best == nil || cmpGeneric(g.effDt, g.bounds[i].value, best.value) > 0 || + (cmpGeneric(g.effDt, g.bounds[i].value, best.value) == 0 && g.bounds[i].inclusive && !best.inclusive) { + b := g.bounds[i] + best = &b + } + } + for _, b := range g.bounds { + used[b.exprIndex] = true + } + op := planpb.OpType_LessThan + if best.inclusive { + op = planpb.OpType_LessEqual + } + out = append(out, newUnaryRangeExpr(g.col, op, best.value)) + } + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +func newBinaryRangeExpr(col *planpb.ColumnInfo, lowerInclusive bool, upperInclusive bool, lower *planpb.GenericValue, upper *planpb.GenericValue) *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_BinaryRangeExpr{ + BinaryRangeExpr: &planpb.BinaryRangeExpr{ + ColumnInfo: col, + LowerInclusive: lowerInclusive, + UpperInclusive: upperInclusive, + LowerValue: lower, + UpperValue: upper, + }, + }, + } +} + +// -1 means a < b, 0 means a == b, 1 means a > b +func cmpGeneric(dt schemapb.DataType, a, b *planpb.GenericValue) int { + switch dt { + case schemapb.DataType_Int8, + schemapb.DataType_Int16, + schemapb.DataType_Int32, + schemapb.DataType_Int64: + ai, bi := a.GetInt64Val(), b.GetInt64Val() + if ai < bi { + return -1 + } + if ai > bi { + return 1 + } + return 0 + case schemapb.DataType_Float, schemapb.DataType_Double: + // Allow comparing int and float by promoting to float + toFloat := func(g *planpb.GenericValue) float64 { + switch g.GetVal().(type) { + case *planpb.GenericValue_FloatVal: + return g.GetFloatVal() + case *planpb.GenericValue_Int64Val: + return float64(g.GetInt64Val()) + default: + // Should not happen due to gate; treat as 0 deterministically + return 0 + } + } + af, bf := toFloat(a), toFloat(b) + if af < bf { + return -1 + } + if af > bf { + return 1 + } + return 0 + case schemapb.DataType_String, + schemapb.DataType_VarChar: + as, bs := a.GetStringVal(), b.GetStringVal() + if as < bs { + return -1 + } + if as > bs { + return 1 + } + return 0 + default: + // Unsupported types are not optimized; callers gate with resolveEffectiveType. + return 0 + } +} + +// combineAndBinaryRanges merges BinaryRangeExpr nodes with AND semantics (intersection). +// Also handles mixing BinaryRangeExpr with UnaryRangeExpr. +func (v *visitor) combineAndBinaryRanges(parts []*planpb.Expr) []*planpb.Expr { + type interval struct { + lower *planpb.GenericValue + lowerInc bool + upper *planpb.GenericValue + upperInc bool + exprIndex int + isBinaryRange bool + } + type group struct { + col *planpb.ColumnInfo + effDt schemapb.DataType + intervals []interval + } + groups := map[string]*group{} + others := []int{} + + for idx, e := range parts { + // Try BinaryRangeExpr + if bre := e.GetBinaryRangeExpr(); bre != nil { + col := bre.GetColumnInfo() + if col == nil { + others = append(others, idx) + continue + } + effDt, ok := resolveEffectiveType(col) + if !ok { + others = append(others, idx) + continue + } + // For JSON, determine actual type from lower value + if effDt == schemapb.DataType_JSON { + var typeOk bool + effDt, typeOk = resolveJSONEffectiveType(bre.GetLowerValue()) + if !typeOk { + others = append(others, idx) + continue + } + } + key := columnKey(col) + fmt.Sprintf("|%d", effDt) + g, exists := groups[key] + if !exists { + g = &group{col: col, effDt: effDt} + groups[key] = g + } + g.intervals = append(g.intervals, interval{ + lower: bre.GetLowerValue(), + lowerInc: bre.GetLowerInclusive(), + upper: bre.GetUpperValue(), + upperInc: bre.GetUpperInclusive(), + exprIndex: idx, + isBinaryRange: true, + }) + continue + } + + // Try UnaryRangeExpr (range ops only) + if ure := e.GetUnaryRangeExpr(); ure != nil { + op := ure.GetOp() + if op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual || + op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual { + col := ure.GetColumnInfo() + if col == nil { + others = append(others, idx) + continue + } + effDt, ok := resolveEffectiveType(col) + if !ok || !valueMatchesType(effDt, ure.GetValue()) { + others = append(others, idx) + continue + } + if effDt == schemapb.DataType_JSON { + var typeOk bool + effDt, typeOk = resolveJSONEffectiveType(ure.GetValue()) + if !typeOk { + others = append(others, idx) + continue + } + } + key := columnKey(col) + fmt.Sprintf("|%d", effDt) + g, exists := groups[key] + if !exists { + g = &group{col: col, effDt: effDt} + groups[key] = g + } + isLower := op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual + inc := op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessEqual + if isLower { + g.intervals = append(g.intervals, interval{ + lower: ure.GetValue(), + lowerInc: inc, + upper: nil, + upperInc: false, + exprIndex: idx, + isBinaryRange: false, + }) + } else { + g.intervals = append(g.intervals, interval{ + lower: nil, + lowerInc: false, + upper: ure.GetValue(), + upperInc: inc, + exprIndex: idx, + isBinaryRange: false, + }) + } + continue + } + } + + // Not a range expr we can optimize + others = append(others, idx) + } + + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + + for _, g := range groups { + if len(g.intervals) == 0 { + continue + } + if len(g.intervals) == 1 { + // Single interval, keep as is + continue + } + + // Compute intersection: max lower, min upper + var finalLower *planpb.GenericValue + var finalLowerInc bool + var finalUpper *planpb.GenericValue + var finalUpperInc bool + + for _, iv := range g.intervals { + if iv.lower != nil { + if finalLower == nil { + finalLower = iv.lower + finalLowerInc = iv.lowerInc + } else { + c := cmpGeneric(g.effDt, iv.lower, finalLower) + if c > 0 || (c == 0 && !iv.lowerInc) { + finalLower = iv.lower + finalLowerInc = iv.lowerInc + } + } + } + if iv.upper != nil { + if finalUpper == nil { + finalUpper = iv.upper + finalUpperInc = iv.upperInc + } else { + c := cmpGeneric(g.effDt, iv.upper, finalUpper) + if c < 0 || (c == 0 && !iv.upperInc) { + finalUpper = iv.upper + finalUpperInc = iv.upperInc + } + } + } + } + + // Check if intersection is empty + if finalLower != nil && finalUpper != nil { + c := cmpGeneric(g.effDt, finalLower, finalUpper) + isEmpty := false + if c > 0 { + isEmpty = true + } else if c == 0 { + // Equal bounds: only valid if both inclusive + if !finalLowerInc || !finalUpperInc { + isEmpty = true + } + } + if isEmpty { + // Empty intersection → constant false + for _, iv := range g.intervals { + used[iv.exprIndex] = true + } + out = append(out, newAlwaysFalseExpr()) + continue + } + } + + // Mark all intervals as used + for _, iv := range g.intervals { + used[iv.exprIndex] = true + } + + // Emit the merged interval + if finalLower != nil && finalUpper != nil { + out = append(out, newBinaryRangeExpr(g.col, finalLowerInc, finalUpperInc, finalLower, finalUpper)) + } else if finalLower != nil { + op := planpb.OpType_GreaterThan + if finalLowerInc { + op = planpb.OpType_GreaterEqual + } + out = append(out, newUnaryRangeExpr(g.col, op, finalLower)) + } else if finalUpper != nil { + op := planpb.OpType_LessThan + if finalUpperInc { + op = planpb.OpType_LessEqual + } + out = append(out, newUnaryRangeExpr(g.col, op, finalUpper)) + } + } + + // Add unused parts + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + + return out +} + +// combineOrBinaryRanges merges BinaryRangeExpr nodes with OR semantics (union if overlapping/adjacent). +// Also handles mixing BinaryRangeExpr with UnaryRangeExpr. +func (v *visitor) combineOrBinaryRanges(parts []*planpb.Expr) []*planpb.Expr { + type interval struct { + lower *planpb.GenericValue + lowerInc bool + upper *planpb.GenericValue + upperInc bool + exprIndex int + isBinaryRange bool + } + type group struct { + col *planpb.ColumnInfo + effDt schemapb.DataType + intervals []interval + } + groups := map[string]*group{} + others := []int{} + + for idx, e := range parts { + // Try BinaryRangeExpr + if bre := e.GetBinaryRangeExpr(); bre != nil { + col := bre.GetColumnInfo() + if col == nil { + others = append(others, idx) + continue + } + effDt, ok := resolveEffectiveType(col) + if !ok { + others = append(others, idx) + continue + } + if effDt == schemapb.DataType_JSON { + var typeOk bool + effDt, typeOk = resolveJSONEffectiveType(bre.GetLowerValue()) + if !typeOk { + others = append(others, idx) + continue + } + } + key := columnKey(col) + fmt.Sprintf("|%d", effDt) + g, exists := groups[key] + if !exists { + g = &group{col: col, effDt: effDt} + groups[key] = g + } + g.intervals = append(g.intervals, interval{ + lower: bre.GetLowerValue(), + lowerInc: bre.GetLowerInclusive(), + upper: bre.GetUpperValue(), + upperInc: bre.GetUpperInclusive(), + exprIndex: idx, + isBinaryRange: true, + }) + continue + } + + // Try UnaryRangeExpr + if ure := e.GetUnaryRangeExpr(); ure != nil { + op := ure.GetOp() + if op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual || + op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual { + col := ure.GetColumnInfo() + if col == nil { + others = append(others, idx) + continue + } + effDt, ok := resolveEffectiveType(col) + if !ok || !valueMatchesType(effDt, ure.GetValue()) { + others = append(others, idx) + continue + } + if effDt == schemapb.DataType_JSON { + var typeOk bool + effDt, typeOk = resolveJSONEffectiveType(ure.GetValue()) + if !typeOk { + others = append(others, idx) + continue + } + } + key := columnKey(col) + fmt.Sprintf("|%d", effDt) + g, exists := groups[key] + if !exists { + g = &group{col: col, effDt: effDt} + groups[key] = g + } + isLower := op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual + inc := op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessEqual + if isLower { + g.intervals = append(g.intervals, interval{ + lower: ure.GetValue(), + lowerInc: inc, + upper: nil, + upperInc: false, + exprIndex: idx, + isBinaryRange: false, + }) + } else { + g.intervals = append(g.intervals, interval{ + lower: nil, + lowerInc: false, + upper: ure.GetValue(), + upperInc: inc, + exprIndex: idx, + isBinaryRange: false, + }) + } + continue + } + } + + others = append(others, idx) + } + + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + + for _, g := range groups { + if len(g.intervals) == 0 { + continue + } + if len(g.intervals) == 1 { + // Single interval, keep as is + continue + } + + // For OR, try to merge overlapping/adjacent intervals + // If any interval is unbounded on one side, check if it subsumes others + // For simplicity, we'll handle the common cases: + // 1. All bounded intervals: try to merge if overlapping/adjacent + // 2. Mix of bounded/unbounded: merge unbounded with compatible bounds + + var hasUnboundedLower, hasUnboundedUpper bool + var unboundedLowerVal *planpb.GenericValue + var unboundedLowerInc bool + var unboundedUpperVal *planpb.GenericValue + var unboundedUpperInc bool + + // Check for unbounded intervals + for _, iv := range g.intervals { + if iv.lower != nil && iv.upper == nil { + // Lower bound only (x > a) + if !hasUnboundedLower { + hasUnboundedLower = true + unboundedLowerVal = iv.lower + unboundedLowerInc = iv.lowerInc + } else { + // Multiple lower-only bounds: take weakest (minimum) + c := cmpGeneric(g.effDt, iv.lower, unboundedLowerVal) + if c < 0 || (c == 0 && iv.lowerInc && !unboundedLowerInc) { + unboundedLowerVal = iv.lower + unboundedLowerInc = iv.lowerInc + } + } + } + if iv.lower == nil && iv.upper != nil { + // Upper bound only (x < b) + if !hasUnboundedUpper { + hasUnboundedUpper = true + unboundedUpperVal = iv.upper + unboundedUpperInc = iv.upperInc + } else { + // Multiple upper-only bounds: take weakest (maximum) + c := cmpGeneric(g.effDt, iv.upper, unboundedUpperVal) + if c > 0 || (c == 0 && iv.upperInc && !unboundedUpperInc) { + unboundedUpperVal = iv.upper + unboundedUpperInc = iv.upperInc + } + } + } + } + + // Case 1: Have both unbounded lower and upper → entire domain (always true, but we can't express that simply) + // For now, keep them separate + // Case 2: Have one unbounded side → merge with compatible bounded intervals + // Case 3: All bounded → try to merge overlapping/adjacent + + if hasUnboundedLower && hasUnboundedUpper { + // Both unbounded sides - this likely covers most values + // Keep as separate predicates for now (more advanced merging could be done) + continue + } + + if hasUnboundedLower || hasUnboundedUpper { + // Merge unbounded with bounded intervals where applicable + // For unbounded lower (x > a): can merge with binary ranges that have compatible upper bounds + // For unbounded upper (x < b): can merge with binary ranges that have compatible lower bounds + // This is complex, so for now we'll keep it simple and just skip merging + // In practice, unbounded intervals often dominate + continue + } + + // All bounded intervals: try to merge overlapping/adjacent ones + // This requires sorting and checking overlap + // For simplicity in this initial implementation, we'll check if there are exactly 2 intervals + // and try to merge them if they overlap or are adjacent + + if len(g.intervals) == 2 { + iv1, iv2 := g.intervals[0], g.intervals[1] + if iv1.lower == nil || iv1.upper == nil || iv2.lower == nil || iv2.upper == nil { + // One is not fully bounded, skip + continue + } + + // Check if they overlap or are adjacent + // They overlap if: iv1.lower <= iv2.upper AND iv2.lower <= iv1.upper + // They are adjacent if: iv1.upper == iv2.lower (or vice versa) with at least one inclusive + + // Determine order: which has smaller lower bound + var first, second interval + c := cmpGeneric(g.effDt, iv1.lower, iv2.lower) + if c <= 0 { + first, second = iv1, iv2 + } else { + first, second = iv2, iv1 + } + + // Check if they can be merged + // Overlap: first.upper >= second.lower + cmpUpperLower := cmpGeneric(g.effDt, first.upper, second.lower) + canMerge := false + if cmpUpperLower > 0 { + // Overlap + canMerge = true + } else if cmpUpperLower == 0 { + // Adjacent: at least one bound must be inclusive + if first.upperInc || second.lowerInc { + canMerge = true + } + } + + if canMerge { + // Merge: take min lower and max upper + mergedLower := first.lower + mergedLowerInc := first.lowerInc + mergedUpper := second.upper + mergedUpperInc := second.upperInc + + // Upper bound: take maximum + cmpUppers := cmpGeneric(g.effDt, first.upper, second.upper) + if cmpUppers > 0 { + mergedUpper = first.upper + mergedUpperInc = first.upperInc + } else if cmpUppers == 0 { + // Same value: prefer inclusive + if first.upperInc { + mergedUpperInc = true + } + } + + // Mark both as used + used[first.exprIndex] = true + used[second.exprIndex] = true + + // Emit merged interval + out = append(out, newBinaryRangeExpr(g.col, mergedLowerInc, mergedUpperInc, mergedLower, mergedUpper)) + } + } + // For more than 2 intervals, we'd need more sophisticated merging logic + // For now, we'll leave them separate + } + + // Add unused parts + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + + return out +} diff --git a/internal/parser/planparserv2/rewriter/range_binary_test.go b/internal/parser/planparserv2/rewriter/range_binary_test.go new file mode 100644 index 0000000000..b106af1863 --- /dev/null +++ b/internal/parser/planparserv2/rewriter/range_binary_test.go @@ -0,0 +1,412 @@ +package rewriter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + parser "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" +) + +// Test BinaryRangeExpr AND BinaryRangeExpr - intersection +func TestRewrite_BinaryRange_AND_BinaryRange_Intersection(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 50) AND (20 < x < 40) → (20 < x < 40) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 20 and Int64Field < 40)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre, "should merge to single binary range") + require.Equal(t, false, bre.GetLowerInclusive()) + require.Equal(t, false, bre.GetUpperInclusive()) + require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr AND BinaryRangeExpr - tighter lower +func TestRewrite_BinaryRange_AND_BinaryRange_TighterLower(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 50) AND (5 < x < 40) → (10 < x < 40) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 5 and Int64Field < 40)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr AND BinaryRangeExpr - tighter upper +func TestRewrite_BinaryRange_AND_BinaryRange_TighterUpper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 50) AND (15 < x < 60) → (15 < x < 50) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 15 and Int64Field < 60)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(15), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr AND BinaryRangeExpr - empty intersection +func TestRewrite_BinaryRange_AND_BinaryRange_EmptyIntersection(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 20) AND (30 < x < 40) → false + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) and (Int64Field > 30 and Int64Field < 40)`, nil) + require.NoError(t, err) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +// Test BinaryRangeExpr AND BinaryRangeExpr - equal bounds, both inclusive +func TestRewrite_BinaryRange_AND_BinaryRange_EqualBounds_BothInclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 <= x <= 10) AND (10 <= x <= 20) → (x == 10) which is (10 <= x <= 10) + expr, err := parser.ParseExpr(helper, `(Int64Field >= 10 and Int64Field <= 10) and (Int64Field >= 10 and Int64Field <= 20)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, true, bre.GetLowerInclusive()) + require.Equal(t, true, bre.GetUpperInclusive()) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(10), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr AND BinaryRangeExpr - equal bounds, one exclusive → false +func TestRewrite_BinaryRange_AND_BinaryRange_EqualBounds_Exclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x <= 20) AND (5 <= x < 10) → false (bounds meet at 10 but exclusive) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) and (Int64Field >= 5 and Int64Field < 10)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +// Test BinaryRangeExpr AND UnaryRangeExpr - tighten lower bound +func TestRewrite_BinaryRange_AND_UnaryRange_TightenLower(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 50) AND (x > 30) → (30 < x < 50) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and Int64Field > 30`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(30), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr AND UnaryRangeExpr - tighten upper bound +func TestRewrite_BinaryRange_AND_UnaryRange_TightenUpper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 50) AND (x < 25) → (10 < x < 25) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and Int64Field < 25`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(25), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr AND UnaryRangeExpr - weaker bound (no change) +func TestRewrite_BinaryRange_AND_UnaryRange_WeakerBound(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (20 < x < 30) AND (x > 10) → (20 < x < 30) (10 is weaker than 20) + expr, err := parser.ParseExpr(helper, `(Int64Field > 20 and Int64Field < 30) and Int64Field > 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr OR BinaryRangeExpr - overlapping → union +func TestRewrite_BinaryRange_OR_BinaryRange_Overlapping(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 25) OR (20 < x < 40) → (10 < x < 40) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 25) or (Int64Field > 20 and Int64Field < 40)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre, "overlapping intervals should merge") + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr OR BinaryRangeExpr - adjacent (inclusive) → union +func TestRewrite_BinaryRange_OR_BinaryRange_Adjacent(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x <= 20) OR (20 <= x < 30) → (10 < x < 30) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) or (Int64Field >= 20 and Int64Field < 30)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre, "adjacent intervals should merge") + require.Equal(t, false, bre.GetLowerInclusive()) + require.Equal(t, false, bre.GetUpperInclusive()) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRangeExpr OR BinaryRangeExpr - disjoint (no merge) +func TestRewrite_BinaryRange_OR_BinaryRange_Disjoint(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 20) OR (30 < x < 40) → remains as OR + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 30 and Int64Field < 40)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + // Should remain as BinaryExpr OR since intervals are disjoint + be := expr.GetBinaryExpr() + require.NotNil(t, be, "disjoint intervals should not merge") + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) +} + +// Test BinaryRangeExpr OR BinaryRangeExpr - adjacent but both exclusive (no merge) +func TestRewrite_BinaryRange_OR_BinaryRange_AdjacentExclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 20) OR (20 < x < 30) → remains as OR (gap at 20) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 20 and Int64Field < 30)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be, "exclusive adjacent intervals should not merge") + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) +} + +// Test BinaryRangeExpr OR BinaryRangeExpr - prefer inclusive when merging +func TestRewrite_BinaryRange_OR_BinaryRange_PreferInclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 <= x < 25) OR (15 <= x <= 30) → (10 <= x <= 30) + expr, err := parser.ParseExpr(helper, `(Int64Field >= 10 and Int64Field < 25) or (Int64Field >= 15 and Int64Field <= 30)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + // Lower should be 10 (from first), upper should be 30 (from second) + // Both should prefer inclusive where available + require.Equal(t, true, bre.GetLowerInclusive()) + require.Equal(t, true, bre.GetUpperInclusive()) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val()) +} + +// Test with Float fields +func TestRewrite_BinaryRange_AND_Float(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (1.0 < x < 5.0) AND (2.0 < x < 4.0) → (2.0 < x < 4.0) + expr, err := parser.ParseExpr(helper, `(FloatField > 1.0 and FloatField < 5.0) and (FloatField > 2.0 and FloatField < 4.0)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.InDelta(t, 2.0, bre.GetLowerValue().GetFloatVal(), 1e-9) + require.InDelta(t, 4.0, bre.GetUpperValue().GetFloatVal(), 1e-9) +} + +// Test with VarChar fields +func TestRewrite_BinaryRange_OR_VarChar_Overlapping(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // ("b" < x < "m") OR ("j" < x < "z") → ("b" < x < "z") + expr, err := parser.ParseExpr(helper, `(VarCharField > "b" and VarCharField < "m") or (VarCharField > "j" and VarCharField < "z")`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, "b", bre.GetLowerValue().GetStringVal()) + require.Equal(t, "z", bre.GetUpperValue().GetStringVal()) +} + +// Test with JSON fields +func TestRewrite_BinaryRange_AND_JSON(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // (10 < json["price"] < 100) AND (20 < json["price"] < 80) → (20 < json["price"] < 80) + expr, err := parser.ParseExpr(helper, `(JSONField["price"] > 10 and JSONField["price"] < 100) and (JSONField["price"] > 20 and JSONField["price"] < 80)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val()) +} + +// Test with JSON fields - OR overlapping +func TestRewrite_BinaryRange_OR_JSON_Overlapping(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // (1.0 < json["score"] < 3.0) OR (2.5 < json["score"] < 5.0) → (1.0 < json["score"] < 5.0) + expr, err := parser.ParseExpr(helper, `(JSONField["score"] > 1.0 and JSONField["score"] < 3.0) or (JSONField["score"] > 2.5 and JSONField["score"] < 5.0)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.InDelta(t, 1.0, bre.GetLowerValue().GetFloatVal(), 1e-9) + require.InDelta(t, 5.0, bre.GetUpperValue().GetFloatVal(), 1e-9) +} + +// Test mixing BinaryRange with multiple UnaryRanges +func TestRewrite_BinaryRange_AND_MultipleUnary(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 100) AND (x > 20) AND (x < 80) → (20 < x < 80) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 100) and Int64Field > 20 and Int64Field < 80`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val()) +} + +// Test three BinaryRanges with AND +func TestRewrite_BinaryRange_AND_ThreeBinaryRanges(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (5 < x < 100) AND (10 < x < 90) AND (15 < x < 80) → (15 < x < 80) + expr, err := parser.ParseExpr(helper, `(Int64Field > 5 and Int64Field < 100) and (Int64Field > 10 and Int64Field < 90) and (Int64Field > 15 and Int64Field < 80)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, int64(15), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val()) +} + +// Test BinaryRange on different columns should not merge +func TestRewrite_BinaryRange_AND_DifferentColumns(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < Int64Field < 50) AND (20 < FloatField < 40) → both remain + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (FloatField > 20 and FloatField < 40)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be, "different columns should not merge") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) +} + +// Test BinaryRange with JSON different paths should not merge +func TestRewrite_BinaryRange_AND_JSON_DifferentPaths(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // (10 < json["a"] < 50) AND (20 < json["b"] < 40) → both remain + expr, err := parser.ParseExpr(helper, `(JSONField["a"] > 10 and JSONField["a"] < 50) and (JSONField["b"] > 20 and JSONField["b"] < 40)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be, "different JSON paths should not merge") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) +} + +// Test BinaryRangeExpr OR with 3 overlapping intervals +// NOTE: Current implementation limitation - only merges 2 intervals at a time +func TestRewrite_BinaryRange_OR_ThreeOverlapping_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 20) OR (15 < x < 25) OR (22 < x < 30) + // Ideally should merge to (10 < x < 30) + // Currently: may only partially merge due to limitation + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 15 and Int64Field < 25) or (Int64Field > 22 and Int64Field < 30)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Due to current limitation, result may vary depending on tree structure + // This test documents the current behavior rather than ideal behavior + // When enhancement is implemented, this test should be updated to verify (10 < x < 30) + + // For now, just verify it doesn't crash and produces valid output + require.NotNil(t, expr) + // Could be BinaryRangeExpr (if some merged) or BinaryExpr OR (if not merged) + // We document that 3+ intervals are NOT fully optimized yet +} + +// Test BinaryRangeExpr OR with 3 fully overlapping intervals +func TestRewrite_BinaryRange_OR_ThreeFullyOverlapping(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x < 30) OR (12 < x < 28) OR (15 < x < 25) + // The second and third are fully contained in the first + // Ideally should merge to (10 < x < 30) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 30) or (Int64Field > 12 and Int64Field < 28) or (Int64Field > 15 and Int64Field < 25)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: may not fully optimize + // This test documents that 3+ interval merging is not yet complete + require.NotNil(t, expr) +} + +// Test BinaryRangeExpr OR with 4 adjacent intervals +func TestRewrite_BinaryRange_OR_FourAdjacent_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (10 < x <= 20) OR (20 < x <= 30) OR (30 < x <= 40) OR (40 < x <= 50) + // Ideally should merge to (10 < x <= 50) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) or (Int64Field > 20 and Int64Field <= 30) or (Int64Field > 30 and Int64Field <= 40) or (Int64Field > 40 and Int64Field <= 50)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: only pairs of adjacent intervals may merge + // Full chain merging not implemented + require.NotNil(t, expr) +} + +// Test OR with unbounded lower + bounded interval +// NOTE: Current implementation limitation - unbounded intervals not merged with bounded +func TestRewrite_BinaryRange_OR_UnboundedLower_Bounded_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (x > 10) OR (5 < x < 15) + // Ideally should merge to (x > 5) + // Currently: both predicates remain separate + expr, err := parser.ParseExpr(helper, `Int64Field > 10 or (Int64Field > 5 and Int64Field < 15)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: unbounded + bounded intervals not optimized + // Result should be a BinaryExpr OR with both predicates + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as OR (not merged)") + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) +} + +// Test OR with unbounded upper + bounded interval +func TestRewrite_BinaryRange_OR_UnboundedUpper_Bounded_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (x < 20) OR (10 < x < 30) + // Ideally should merge to (x < 30) + // Currently: both predicates remain separate + expr, err := parser.ParseExpr(helper, `Int64Field < 20 or (Int64Field > 10 and Int64Field < 30)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: unbounded + bounded intervals not optimized + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as OR (not merged)") + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) +} + +// Test OR with unbounded lower + unbounded upper +func TestRewrite_BinaryRange_OR_UnboundedBoth_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (x > 10) OR (x < 20) + // This covers most values (gap only between 10 and 20 if both exclusive) + // Currently: both predicates remain separate + expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field < 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: unbounded intervals in OR not merged + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as OR") + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) +} + +// Test OR with multiple unbounded lower bounds - these DO get optimized (weakening) +func TestRewrite_BinaryRange_OR_MultipleUnboundedLower(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (x > 10) OR (x > 20) → (x > 10) [weaker bound] + expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // This SHOULD be optimized (weakening works for same direction) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure, "should merge to single unary range") + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} diff --git a/internal/parser/planparserv2/rewriter/range_json_test.go b/internal/parser/planparserv2/rewriter/range_json_test.go new file mode 100644 index 0000000000..863de8ca9a --- /dev/null +++ b/internal/parser/planparserv2/rewriter/range_json_test.go @@ -0,0 +1,294 @@ +package rewriter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + parser "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +func buildSchemaHelperWithJSON(t *testing.T) *typeutil.SchemaHelper { + fields := []*schemapb.FieldSchema{ + {FieldID: 101, Name: "Int64Field", DataType: schemapb.DataType_Int64}, + {FieldID: 102, Name: "JSONField", DataType: schemapb.DataType_JSON}, + {FieldID: 103, Name: "$meta", DataType: schemapb.DataType_JSON, IsDynamic: true}, + } + schema := &schemapb.CollectionSchema{ + Name: "rewrite_json_test", + AutoID: false, + Fields: fields, + EnableDynamicField: true, + } + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + return helper +} + +// Test JSON field with int comparison - AND tightening +func TestRewrite_JSON_Int_AND_Strengthen(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["price"] > 10 and JSONField["price"] > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure, "should merge to single unary range") + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(20), ure.GetValue().GetInt64Val()) + // Verify column info + require.Equal(t, schemapb.DataType_JSON, ure.GetColumnInfo().GetDataType()) + require.Equal(t, []string{"price"}, ure.GetColumnInfo().GetNestedPath()) +} + +// Test JSON field with int comparison - AND to BinaryRange +func TestRewrite_JSON_Int_AND_ToBinaryRange(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["price"] > 10 and JSONField["price"] < 50`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre, "should create binary range") + require.Equal(t, false, bre.GetLowerInclusive()) + require.Equal(t, false, bre.GetUpperInclusive()) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val()) + // Verify column info + require.Equal(t, schemapb.DataType_JSON, bre.GetColumnInfo().GetDataType()) + require.Equal(t, []string{"price"}, bre.GetColumnInfo().GetNestedPath()) +} + +// Test JSON field with float comparison - AND tightening with mixed int/float +func TestRewrite_JSON_Float_AND_Strengthen_Mixed(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["score"] > 10 and JSONField["score"] > 15.5`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.InDelta(t, 15.5, ure.GetValue().GetFloatVal(), 1e-9) +} + +// Test JSON field with string comparison - AND tightening +func TestRewrite_JSON_String_AND_Strengthen(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["name"] > "alice" and JSONField["name"] > "bob"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, "bob", ure.GetValue().GetStringVal()) +} + +// Test JSON field with string comparison - AND to BinaryRange +func TestRewrite_JSON_String_AND_ToBinaryRange(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["name"] >= "alice" and JSONField["name"] <= "zebra"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, true, bre.GetLowerInclusive()) + require.Equal(t, true, bre.GetUpperInclusive()) + require.Equal(t, "alice", bre.GetLowerValue().GetStringVal()) + require.Equal(t, "zebra", bre.GetUpperValue().GetStringVal()) +} + +// Test JSON field with OR - weakening lower bounds +func TestRewrite_JSON_Int_OR_Weaken(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["age"] > 10 or JSONField["age"] > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} + +// Test JSON field with OR - weakening upper bounds +func TestRewrite_JSON_String_OR_Weaken_Upper(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["category"] < "electronics" or JSONField["category"] < "sports"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_LessThan, ure.GetOp()) + require.Equal(t, "sports", ure.GetValue().GetStringVal()) +} + +// Test JSON field with numeric types (int and float) - SHOULD merge +func TestRewrite_JSON_NumericTypes_Merge(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // JSONField["value"] > 10.5 (float) and JSONField["value"] > 20 (int) + // Numeric types should merge - both treated as Double + expr, err := parser.ParseExpr(helper, `JSONField["value"] > 10.5 and JSONField["value"] > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure, "numeric types should merge to single predicate") + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + // Should pick the stronger bound (20) + switch ure.GetValue().GetVal().(type) { + case *planpb.GenericValue_Int64Val: + require.Equal(t, int64(20), ure.GetValue().GetInt64Val()) + case *planpb.GenericValue_FloatVal: + require.InDelta(t, 20.0, ure.GetValue().GetFloatVal(), 1e-9) + default: + t.Fatalf("unexpected value type") + } +} + +// Test JSON field with mixed type categories - should NOT merge +// Note: Numeric types (int and float) ARE compatible, but numeric and string are not +func TestRewrite_JSON_MixedTypes_NoMerge(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // JSONField["value"] > 10 (numeric) and JSONField["value"] > "hello" (string) + // These should remain separate as they have different type categories + expr, err := parser.ParseExpr(helper, `JSONField["value"] > 10 and JSONField["value"] > "hello"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as binary expr (AND)") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) + // Both sides should be UnaryRangeExpr + require.NotNil(t, be.GetLeft().GetUnaryRangeExpr()) + require.NotNil(t, be.GetRight().GetUnaryRangeExpr()) +} + +// Test JSON field with mixed type categories in OR - should NOT merge +func TestRewrite_JSON_MixedTypes_OR_NoMerge(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["data"] > 100 or JSONField["data"] > "text"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as binary expr (OR)") + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) +} + +// Test dynamic field ($meta) - same as JSON field +func TestRewrite_DynamicField_Int_AND_Strengthen(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // $meta is the dynamic field + expr, err := parser.ParseExpr(helper, `$meta["count"] > 5 and $meta["count"] > 15`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(15), ure.GetValue().GetInt64Val()) +} + +// Test dynamic field - AND to BinaryRange +func TestRewrite_DynamicField_Float_AND_ToBinaryRange(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `$meta["rating"] >= 1.0 and $meta["rating"] <= 5.0`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, true, bre.GetLowerInclusive()) + require.Equal(t, true, bre.GetUpperInclusive()) + require.InDelta(t, 1.0, bre.GetLowerValue().GetFloatVal(), 1e-9) + require.InDelta(t, 5.0, bre.GetUpperValue().GetFloatVal(), 1e-9) +} + +// Test different nested paths - should NOT merge +func TestRewrite_JSON_DifferentPaths_NoMerge(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["a"] > 10 and JSONField["b"] > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be, "different paths should not merge") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) +} + +// Test nested JSON path +func TestRewrite_JSON_NestedPath_AND_Strengthen(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `JSONField["user"]["age"] > 18 and JSONField["user"]["age"] > 21`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(21), ure.GetValue().GetInt64Val()) + require.Equal(t, []string{"user", "age"}, ure.GetColumnInfo().GetNestedPath()) +} + +// Test inclusive vs exclusive bounds +func TestRewrite_JSON_AND_EquivalentBounds(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // JSONField["x"] >= 10 AND JSONField["x"] > 10 → JSONField["x"] > 10 (exclusive is stronger) + expr, err := parser.ParseExpr(helper, `JSONField["x"] >= 10 and JSONField["x"] > 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} + +// Test OR with inclusive bounds preference +func TestRewrite_JSON_OR_EquivalentBounds(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + // JSONField["x"] >= 10 OR JSONField["x"] > 10 → JSONField["x"] >= 10 (inclusive is weaker) + expr, err := parser.ParseExpr(helper, `JSONField["x"] >= 10 or JSONField["x"] > 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} + +// Test that scalar field and JSON field don't interfere +func TestRewrite_JSON_And_Scalar_Independent(t *testing.T) { + helper := buildSchemaHelperWithJSON(t) + expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 20 and JSONField["price"] > 5 and JSONField["price"] > 15`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + // Should result in AND of two optimized predicates + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) + + // Collect all predicates + var predicates []*planpb.Expr + var collect func(*planpb.Expr) + collect = func(e *planpb.Expr) { + if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd { + collect(be.GetLeft()) + collect(be.GetRight()) + } else { + predicates = append(predicates, e) + } + } + collect(expr) + + require.Equal(t, 2, len(predicates), "should have two optimized predicates") + + // Check that both are optimized to the stronger bounds + var hasInt64, hasJSON bool + for _, p := range predicates { + if ure := p.GetUnaryRangeExpr(); ure != nil { + col := ure.GetColumnInfo() + if col.GetDataType() == schemapb.DataType_Int64 { + hasInt64 = true + require.Equal(t, int64(20), ure.GetValue().GetInt64Val()) + } else if col.GetDataType() == schemapb.DataType_JSON { + hasJSON = true + require.Equal(t, int64(15), ure.GetValue().GetInt64Val()) + } + } + } + require.True(t, hasInt64, "should have optimized Int64Field predicate") + require.True(t, hasJSON, "should have optimized JSONField predicate") +} diff --git a/internal/parser/planparserv2/rewriter/range_test.go b/internal/parser/planparserv2/rewriter/range_test.go new file mode 100644 index 0000000000..4813dd196a --- /dev/null +++ b/internal/parser/planparserv2/rewriter/range_test.go @@ -0,0 +1,492 @@ +package rewriter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + parser "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +func TestRewrite_Range_AND_Strengthen(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(20), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_Range_AND_Strengthen_Upper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field < 50 and Int64Field < 60`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_LessThan, ure.GetOp()) + require.Equal(t, int64(50), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_Range_AND_EquivalentBounds(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // a ≥ x AND a > x → a > x + expr, err := parser.ParseExpr(helper, `Int64Field >= 10 and Int64Field > 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) + // a ≤ y AND a < y → a < y + expr, err = parser.ParseExpr(helper, `Int64Field <= 10 and Int64Field < 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure = expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_LessThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_Range_OR_Weaken(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_Range_OR_Weaken_Upper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field < 10 or Int64Field < 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_LessThan, ure.GetOp()) + require.Equal(t, int64(20), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_Range_OR_EquivalentBounds(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // a ≥ x OR a > x → a ≥ x + expr, err := parser.ParseExpr(helper, `Int64Field >= 10 or Int64Field > 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) + // a ≤ y OR a < y → a ≤ y + expr, err = parser.ParseExpr(helper, `Int64Field <= 10 or Int64Field < 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure = expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_LessEqual, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_Range_AND_ToBinaryRange(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field < 50`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, false, bre.GetLowerInclusive()) + require.Equal(t, false, bre.GetUpperInclusive()) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val()) +} + +func TestRewrite_Range_AND_ToBinaryRange_Inclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field >= 10 and Int64Field <= 50`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.Equal(t, true, bre.GetLowerInclusive()) + require.Equal(t, true, bre.GetUpperInclusive()) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val()) +} + +func TestRewrite_Range_OR_MixedDirection_NoMerge(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field < 5`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) + require.NotNil(t, be.GetLeft().GetUnaryRangeExpr()) + require.NotNil(t, be.GetRight().GetUnaryRangeExpr()) +} + +// Edge cases for Float/Double columns: allow mixing int and float literals. +func TestRewrite_Range_AND_Strengthen_Float_Mixed(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `FloatField > 10 and FloatField > 15.0`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.InDelta(t, 15.0, ure.GetValue().GetFloatVal(), 1e-9) +} + +func TestRewrite_Range_AND_ToBinaryRange_Float_Mixed(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `FloatField > 10 and FloatField < 20.5`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.False(t, bre.GetLowerInclusive()) + require.False(t, bre.GetUpperInclusive()) + // lower may be encoded as int or float; assert either encoding equals 10 + lv := bre.GetLowerValue() + switch lv.GetVal().(type) { + case *planpb.GenericValue_Int64Val: + require.Equal(t, int64(10), lv.GetInt64Val()) + case *planpb.GenericValue_FloatVal: + require.InDelta(t, 10.0, lv.GetFloatVal(), 1e-9) + default: + t.Fatalf("unexpected lower value type") + } + // upper is float literal 20.5 + require.InDelta(t, 20.5, bre.GetUpperValue().GetFloatVal(), 1e-9) +} + +func TestRewrite_Range_OR_Weaken_Float_Mixed(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `FloatField > 10 or FloatField > 20.5`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + // weakest lower is 10; value may be encoded as int or float + switch ure.GetValue().GetVal().(type) { + case *planpb.GenericValue_Int64Val: + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) + case *planpb.GenericValue_FloatVal: + require.InDelta(t, 10.0, ure.GetValue().GetFloatVal(), 1e-9) + default: + t.Fatalf("unexpected value type") + } +} + +func TestRewrite_Range_AND_NoMerge_Int_WithFloatLiteral(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + _, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 15.0`, nil) + // keeping this test so that we know the parser will not accept this expression, so we + // don't need to optimize it. + require.Error(t, err) +} + +func TestRewrite_Range_Tie_Inclusive_Float(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // OR: >=10 or >10 -> >=10 (weaken prefers inclusive) + expr, err := parser.ParseExpr(helper, `FloatField >= 10 or FloatField > 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp()) + // AND: >=10 and >10 -> >10 (tighten prefers strict) + expr, err = parser.ParseExpr(helper, `FloatField >= 10 and FloatField > 10`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure = expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) +} + +// VarChar range optimization tests +func TestRewrite_Range_VarChar_AND_Strengthen(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `VarCharField > "a" and VarCharField > "b"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, "b", ure.GetValue().GetStringVal()) +} + +func TestRewrite_Range_VarChar_OR_Weaken_Upper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `VarCharField < "m" or VarCharField < "z"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_LessThan, ure.GetOp()) + require.Equal(t, "z", ure.GetValue().GetStringVal()) +} + +// Array fields: ensure parser rejects direct range comparison on arrays for different element types. +// If in the future parser supports range on arrays, these tests can be updated accordingly. +func buildSchemaHelperWithArraysT(t *testing.T) *typeutil.SchemaHelper { + fields := []*schemapb.FieldSchema{ + {FieldID: 201, Name: "ArrayInt", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}, + {FieldID: 202, Name: "ArrayFloat", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}, + {FieldID: 203, Name: "ArrayVarchar", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_VarChar}, + } + schema := &schemapb.CollectionSchema{ + Name: "rewrite_array_test", + AutoID: false, + Fields: fields, + } + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + return helper +} + +func TestRewrite_Range_Array_Int_NotSupported(t *testing.T) { + helper := buildSchemaHelperWithArraysT(t) + _, err := parser.ParseExpr(helper, `ArrayInt > 10`, nil) + require.Error(t, err) +} + +func TestRewrite_Range_Array_Float_NotSupported(t *testing.T) { + helper := buildSchemaHelperWithArraysT(t) + _, err := parser.ParseExpr(helper, `ArrayFloat > 10.5`, nil) + require.Error(t, err) +} + +func TestRewrite_Range_Array_VarChar_NotSupported(t *testing.T) { + helper := buildSchemaHelperWithArraysT(t) + _, err := parser.ParseExpr(helper, `ArrayVarchar > "a"`, nil) + require.Error(t, err) +} + +// Array index access optimizations +func TestRewrite_Range_ArrayInt_Index_AND_Strengthen(t *testing.T) { + helper := buildSchemaHelperWithArraysT(t) + expr, err := parser.ParseExpr(helper, `ArrayInt[0] > 10 and ArrayInt[0] > 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(20), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_Range_ArrayFloat_Index_OR_Weaken_Mixed(t *testing.T) { + helper := buildSchemaHelperWithArraysT(t) + expr, err := parser.ParseExpr(helper, `ArrayFloat[0] > 10 or ArrayFloat[0] > 20.5`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + switch ure.GetValue().GetVal().(type) { + case *planpb.GenericValue_Int64Val: + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) + case *planpb.GenericValue_FloatVal: + require.InDelta(t, 10.0, ure.GetValue().GetFloatVal(), 1e-9) + default: + t.Fatalf("unexpected value type") + } +} + +func TestRewrite_Range_ArrayVarChar_Index_ToBinaryRange(t *testing.T) { + helper := buildSchemaHelperWithArraysT(t) + expr, err := parser.ParseExpr(helper, `ArrayVarchar[0] > "a" and ArrayVarchar[0] < "m"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre) + require.False(t, bre.GetLowerInclusive()) + require.False(t, bre.GetUpperInclusive()) + require.Equal(t, "a", bre.GetLowerValue().GetStringVal()) + require.Equal(t, "m", bre.GetUpperValue().GetStringVal()) +} + +// helper to flatten AND tree into list of exprs +func collectAndExprs(e *planpb.Expr, out *[]*planpb.Expr) { + if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd { + collectAndExprs(be.GetLeft(), out) + collectAndExprs(be.GetRight(), out) + return + } + *out = append(*out, e) +} + +func TestRewrite_Range_Array_Index_Different_NoMerge(t *testing.T) { + helper := buildSchemaHelperWithArraysT(t) + expr, err := parser.ParseExpr(helper, `ArrayInt[0] > 10 and ArrayInt[1] > 20 and ArrayInt[0] < 20`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) + parts := []*planpb.Expr{} + collectAndExprs(expr, &parts) + // expect exactly two parts after rewrite: interval on index 0, and lower bound on index 1 + require.Equal(t, 2, len(parts)) + var seenInterval, seenLower bool + for _, p := range parts { + if bre := p.GetBinaryRangeExpr(); bre != nil { + seenInterval = true + // 10 < ArrayInt[0] < 20 + require.False(t, bre.GetLowerInclusive()) + require.False(t, bre.GetUpperInclusive()) + require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(20), bre.GetUpperValue().GetInt64Val()) + continue + } + if ure := p.GetUnaryRangeExpr(); ure != nil { + seenLower = true + require.True(t, ure.GetOp() == planpb.OpType_GreaterThan || ure.GetOp() == planpb.OpType_GreaterEqual) + // bound value 20 on index 1 lower side + require.Equal(t, int64(20), ure.GetValue().GetInt64Val()) + continue + } + // should not reach here: only BinaryRangeExpr and UnaryRangeExpr expected + t.Fatalf("unexpected expr kind in AND parts") + } + require.True(t, seenInterval) + require.True(t, seenLower) +} + +// Test invalid BinaryRangeExpr: lower > upper → false +func TestRewrite_Range_AND_InvalidRange_LowerGreaterThanUpper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field > 100 AND Int64Field < 50 → false (impossible range) + expr, err := parser.ParseExpr(helper, `Int64Field > 100 and Int64Field < 50`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +// Test invalid BinaryRangeExpr: lower == upper with exclusive bounds → false +func TestRewrite_Range_AND_InvalidRange_EqualBoundsExclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field > 50 AND Int64Field < 50 → false (exclusive on equal bounds) + expr, err := parser.ParseExpr(helper, `Int64Field > 50 and Int64Field < 50`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +// Test invalid BinaryRangeExpr: lower == upper with one exclusive → false +func TestRewrite_Range_AND_InvalidRange_EqualBoundsOneExclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field >= 50 AND Int64Field < 50 → false (one exclusive on equal bounds) + expr, err := parser.ParseExpr(helper, `Int64Field >= 50 and Int64Field < 50`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +// Test valid BinaryRangeExpr: lower == upper with both inclusive → valid +func TestRewrite_Range_AND_ValidRange_EqualBoundsBothInclusive(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field >= 50 AND Int64Field <= 50 → (50 <= x <= 50), which is valid (x == 50) + expr, err := parser.ParseExpr(helper, `Int64Field >= 50 and Int64Field <= 50`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + bre := expr.GetBinaryRangeExpr() + require.NotNil(t, bre, "should create valid binary range for x == 50") + require.Equal(t, true, bre.GetLowerInclusive()) + require.Equal(t, true, bre.GetUpperInclusive()) + require.Equal(t, int64(50), bre.GetLowerValue().GetInt64Val()) + require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val()) +} + +// Test invalid BinaryRangeExpr with float: lower > upper → false +func TestRewrite_Range_AND_InvalidRange_Float_LowerGreaterThanUpper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // FloatField > 99.9 AND FloatField < 10.5 → false + expr, err := parser.ParseExpr(helper, `FloatField > 99.9 and FloatField < 10.5`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +// Test invalid BinaryRangeExpr with string: lower > upper → false +func TestRewrite_Range_AND_InvalidRange_String_LowerGreaterThanUpper(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // VarCharField > "zebra" AND VarCharField < "apple" → false + expr, err := parser.ParseExpr(helper, `VarCharField > "zebra" and VarCharField < "apple"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +// Test AlwaysFalse propagation through nested AND expressions +func TestRewrite_AlwaysFalse_Propagation_DeepNesting(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Deep nesting: (Int64Field > 10) AND ((Int64Field > 20) AND (Int64Field > 100 AND Int64Field < 50)) + // The innermost (Int64Field > 100 AND Int64Field < 50) should become AlwaysFalse + // This AlwaysFalse should propagate up through all ANDs, making the entire expression AlwaysFalse + expr, err := parser.ParseExpr(helper, `(Int64Field > 10) and ((Int64Field > 20) and (Int64Field > 100 and Int64Field < 50))`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + // Should propagate to become AlwaysFalse at top level + require.True(t, rewriter.IsAlwaysFalseExpr(expr), "AlwaysFalse should propagate to top level") +} + +// Test AlwaysFalse elimination in OR expressions +func TestRewrite_AlwaysFalse_Elimination_InOR(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (Int64Field > 10) OR ((Int64Field > 20) OR (Int64Field > 100 AND Int64Field < 50)) + // The innermost becomes AlwaysFalse, should be eliminated from OR + // Result should be: Int64Field > 10 OR Int64Field > 20 → Int64Field > 10 (weaker bound) + expr, err := parser.ParseExpr(helper, `(Int64Field > 10) or ((Int64Field > 20) or (Int64Field > 100 and Int64Field < 50))`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + // Should simplify to single range condition: Int64Field > 10 + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure, "AlwaysFalse should be eliminated, leaving simplified range") + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} + +// Test complex double negation: NOT NOT AlwaysTrue → AlwaysTrue +func TestRewrite_DoubleNegation_ToAlwaysTrue(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `not (Int64Field > 100 and Int64Field < 50)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysTrueExpr(expr)) +} + +// Test complex nested double negation with multiple layers +func TestRewrite_ComplexDoubleNegation_MultiLayer(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `not ((Int64Field > 100 and Int64Field < 50) or (FloatField > 99.9 and FloatField < 10.5))`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysTrueExpr(expr)) +} + +// Test AlwaysTrue in AND with normal conditions gets eliminated +func TestRewrite_AlwaysTrue_Elimination_InAND(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // (Int64Field > 10) AND NOT(Int64Field > 100 AND Int64Field < 50) + // The second part becomes AlwaysTrue, should be eliminated from AND + // Result should be just: Int64Field > 10 + expr, err := parser.ParseExpr(helper, `(Int64Field > 10) and not (Int64Field > 100 and Int64Field < 50)`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp()) + require.Equal(t, int64(10), ure.GetValue().GetInt64Val()) +} diff --git a/internal/parser/planparserv2/rewriter/term_in.go b/internal/parser/planparserv2/rewriter/term_in.go new file mode 100644 index 0000000000..c10009bc72 --- /dev/null +++ b/internal/parser/planparserv2/rewriter/term_in.go @@ -0,0 +1,666 @@ +package rewriter + +import ( + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" +) + +func (v *visitor) combineOrEqualsToIn(parts []*planpb.Expr) []*planpb.Expr { + type group struct { + col *planpb.ColumnInfo + values []*planpb.GenericValue + origIndices []int + valCase string + } + others := make([]*planpb.Expr, 0, len(parts)) + groups := make(map[string]*group) + indexToExpr := parts + for idx, e := range parts { + u := e.GetUnaryRangeExpr() + if u == nil || u.GetOp() != planpb.OpType_Equal || u.GetValue() == nil { + others = append(others, e) + continue + } + col := u.GetColumnInfo() + if col == nil { + others = append(others, e) + continue + } + key := columnKey(col) + g, ok := groups[key] + valCase := valueCase(u.GetValue()) + if !ok { + g = &group{col: col, values: []*planpb.GenericValue{}, origIndices: []int{}, valCase: valCase} + groups[key] = g + } + if g.valCase != valCase { + others = append(others, e) + continue + } + g.values = append(g.values, u.GetValue()) + g.origIndices = append(g.origIndices, idx) + } + out := make([]*planpb.Expr, 0, len(parts)) + out = append(out, others...) + for _, g := range groups { + if shouldMergeToIn(g.col.GetDataType(), len(g.values)) { + g.values = sortGenericValues(g.values) + out = append(out, newTermExpr(g.col, g.values)) + } else { + for _, i := range g.origIndices { + out = append(out, indexToExpr[i]) + } + } + } + return out +} + +func (v *visitor) combineAndNotEqualsToNotIn(parts []*planpb.Expr) []*planpb.Expr { + type group struct { + col *planpb.ColumnInfo + values []*planpb.GenericValue + origIndices []int + valCase string + } + others := make([]*planpb.Expr, 0, len(parts)) + groups := make(map[string]*group) + indexToExpr := parts + for idx, e := range parts { + u := e.GetUnaryRangeExpr() + if u == nil || u.GetOp() != planpb.OpType_NotEqual || u.GetValue() == nil { + others = append(others, e) + continue + } + col := u.GetColumnInfo() + if col == nil { + others = append(others, e) + continue + } + key := columnKey(col) + g, ok := groups[key] + valCase := valueCase(u.GetValue()) + if !ok { + g = &group{col: col, values: []*planpb.GenericValue{}, origIndices: []int{}, valCase: valCase} + groups[key] = g + } + if g.valCase != valCase { + others = append(others, e) + continue + } + g.values = append(g.values, u.GetValue()) + g.origIndices = append(g.origIndices, idx) + } + out := make([]*planpb.Expr, 0, len(parts)) + out = append(out, others...) + for _, g := range groups { + if shouldMergeToIn(g.col.GetDataType(), len(g.values)) { + g.values = sortGenericValues(g.values) + in := newTermExpr(g.col, g.values) + out = append(out, notExpr(in)) + } else { + for _, i := range g.origIndices { + out = append(out, indexToExpr[i]) + } + } + } + return out +} + +func notExpr(child *planpb.Expr) *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_UnaryExpr{ + UnaryExpr: &planpb.UnaryExpr{ + Op: planpb.UnaryExpr_Not, + Child: child, + }, + }, + } +} + +// AND: (a IN S) AND (a = v) with v in S -> a = v +func (v *visitor) combineAndInWithEqual(parts []*planpb.Expr) []*planpb.Expr { + type agg struct { + termIdxs []int + eqIdxs []int + term *planpb.TermExpr + eqValues []*planpb.GenericValue + col *planpb.ColumnInfo + } + groups := map[string]*agg{} + others := []int{} + for idx, e := range parts { + if te := e.GetTermExpr(); te != nil { + k := columnKey(te.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &agg{col: te.GetColumnInfo()} + } + g.termIdxs = append(g.termIdxs, idx) + g.term = te + groups[k] = g + continue + } + if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_Equal && ue.GetValue() != nil && ue.GetColumnInfo() != nil { + k := columnKey(ue.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &agg{col: ue.GetColumnInfo()} + } + g.eqIdxs = append(g.eqIdxs, idx) + g.eqValues = append(g.eqValues, ue.GetValue()) + groups[k] = g + continue + } + others = append(others, idx) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + for _, g := range groups { + if g.term == nil || len(g.eqIdxs) == 0 { + continue + } + // Build set of eq values and check presence in term set. + termVals := g.term.GetValues() + eqUnique := []*planpb.GenericValue{} + for _, ev := range g.eqValues { + dup := false + for _, u := range eqUnique { + if equalsGeneric(u, ev) { + dup = true + break + } + } + if !dup { + eqUnique = append(eqUnique, ev) + } + } + // If multiple different equals present, AND implies contradiction unless identical. + if len(eqUnique) > 1 { + for _, ti := range g.termIdxs { + used[ti] = true + } + for _, ei := range g.eqIdxs { + used[ei] = true + } + // emit constant false + out = append(out, newAlwaysFalseExpr()) + continue + } + // Single equal value + ev := eqUnique[0] + inSet := false + for _, tv := range termVals { + if equalsGeneric(tv, ev) { + inSet = true + break + } + } + for _, ti := range g.termIdxs { + used[ti] = true + } + for _, ei := range g.eqIdxs { + used[ei] = true + } + if inSet { + // reduce to equality + out = append(out, newUnaryRangeExpr(g.col, planpb.OpType_Equal, ev)) + } else { + // contradiction -> false + out = append(out, newAlwaysFalseExpr()) + } + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +// OR: (a IN S) OR (a = v) with v in S -> keep a IN S (drop equal) +// Optional extension (not enabled here): if v not in S, could union. +func (v *visitor) combineOrInWithEqual(parts []*planpb.Expr) []*planpb.Expr { + type agg struct { + termIdx int + term *planpb.TermExpr + col *planpb.ColumnInfo + eqIdxs []int + eqVals []*planpb.GenericValue + } + groups := map[string]*agg{} + others := []int{} + for idx, e := range parts { + if te := e.GetTermExpr(); te != nil { + k := columnKey(te.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &agg{col: te.GetColumnInfo()} + groups[k] = g + } + g.termIdx = idx + g.term = te + continue + } + if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_Equal && ue.GetValue() != nil && ue.GetColumnInfo() != nil { + k := columnKey(ue.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &agg{col: ue.GetColumnInfo()} + groups[k] = g + } + g.eqIdxs = append(g.eqIdxs, idx) + g.eqVals = append(g.eqVals, ue.GetValue()) + continue + } + others = append(others, idx) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + for _, g := range groups { + if g.term == nil || len(g.eqIdxs) == 0 { + continue + } + // union all equal values into term set + union := g.term.GetValues() + for i, ev := range g.eqVals { + union = append(union, ev) + used[g.eqIdxs[i]] = true + } + union = sortGenericValues(union) + used[g.termIdx] = true + out = append(out, newTermExpr(g.col, union)) + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +// AND: (a IN S) AND (range) -> filter S by range +func (v *visitor) combineAndInWithRange(parts []*planpb.Expr) []*planpb.Expr { + type group struct { + col *planpb.ColumnInfo + termIdx int + term *planpb.TermExpr + lower *planpb.GenericValue + lowerInc bool + upper *planpb.GenericValue + upperInc bool + rangeIdxs []int + } + groups := map[string]*group{} + others := []int{} + isRange := func(op planpb.OpType) bool { + return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual + } + for idx, e := range parts { + if te := e.GetTermExpr(); te != nil { + k := columnKey(te.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &group{col: te.GetColumnInfo()} + groups[k] = g + } + g.term = te + g.termIdx = idx + continue + } + if ue := e.GetUnaryRangeExpr(); ue != nil && isRange(ue.GetOp()) && ue.GetValue() != nil && ue.GetColumnInfo() != nil { + k := columnKey(ue.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &group{col: ue.GetColumnInfo()} + groups[k] = g + } + if ue.GetOp() == planpb.OpType_GreaterThan || ue.GetOp() == planpb.OpType_GreaterEqual { + if g.lower == nil || cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.lower) > 0 || (cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.lower) == 0 && ue.GetOp() == planpb.OpType_GreaterThan && g.lowerInc) { + g.lower = ue.GetValue() + g.lowerInc = ue.GetOp() == planpb.OpType_GreaterEqual + } + } else { + if g.upper == nil || cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.upper) < 0 || (cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.upper) == 0 && ue.GetOp() == planpb.OpType_LessThan && g.upperInc) { + g.upper = ue.GetValue() + g.upperInc = ue.GetOp() == planpb.OpType_LessEqual + } + } + g.rangeIdxs = append(g.rangeIdxs, idx) + continue + } + others = append(others, idx) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + for _, g := range groups { + if g.term == nil || (g.lower == nil && g.upper == nil) { + continue + } + // Skip optimization if any term value is not comparable with the provided bounds + termVals := g.term.GetValues() + comparable := true + for _, tv := range termVals { + if g.lower != nil { + if !(areComparableCases(valueCaseWithNil(tv), valueCaseWithNil(g.lower)) || (isNumericCase(valueCaseWithNil(tv)) && isNumericCase(valueCaseWithNil(g.lower)))) { + comparable = false + break + } + } + if comparable && g.upper != nil { + if !(areComparableCases(valueCaseWithNil(tv), valueCaseWithNil(g.upper)) || (isNumericCase(valueCaseWithNil(tv)) && isNumericCase(valueCaseWithNil(g.upper)))) { + comparable = false + break + } + } + } + if !comparable { + continue + } + filtered := filterValuesByRange(effectiveDataType(g.col), termVals, g.lower, g.lowerInc, g.upper, g.upperInc) + used[g.termIdx] = true + for _, ri := range g.rangeIdxs { + used[ri] = true + } + if len(filtered) == 0 { + // Empty IN list after filtering → AlwaysFalse + out = append(out, newAlwaysFalseExpr()) + } else { + out = append(out, newTermExpr(g.col, filtered)) + } + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +// OR: (a IN S1) OR (a IN S2) -> a IN union(S1, S2) +func (v *visitor) combineOrInWithIn(parts []*planpb.Expr) []*planpb.Expr { + type agg struct { + col *planpb.ColumnInfo + idxs []int + values [][]*planpb.GenericValue + } + groups := map[string]*agg{} + others := []int{} + for idx, e := range parts { + if te := e.GetTermExpr(); te != nil { + k := columnKey(te.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &agg{col: te.GetColumnInfo()} + groups[k] = g + } + g.idxs = append(g.idxs, idx) + g.values = append(g.values, te.GetValues()) + continue + } + others = append(others, idx) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + for _, g := range groups { + if len(g.idxs) <= 1 { + continue + } + union := []*planpb.GenericValue{} + for _, vs := range g.values { + union = append(union, vs...) + } + union = sortGenericValues(union) + for _, i := range g.idxs { + used[i] = true + } + out = append(out, newTermExpr(g.col, union)) + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +// AND: (a IN S1) AND (a IN S2) ... -> a IN intersection(S1, S2, ...) +func (v *visitor) combineAndInWithIn(parts []*planpb.Expr) []*planpb.Expr { + type agg struct { + col *planpb.ColumnInfo + idxs []int + values [][]*planpb.GenericValue + } + groups := map[string]*agg{} + others := []int{} + for idx, e := range parts { + if te := e.GetTermExpr(); te != nil { + k := columnKey(te.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &agg{col: te.GetColumnInfo()} + groups[k] = g + } + g.idxs = append(g.idxs, idx) + g.values = append(g.values, te.GetValues()) + continue + } + others = append(others, idx) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, idx := range others { + out = append(out, parts[idx]) + used[idx] = true + } + for _, g := range groups { + if len(g.idxs) <= 1 { + continue + } + // compute intersection; start from first set + inter := make([]*planpb.GenericValue, 0, len(g.values[0])) + outer: + for _, v := range g.values[0] { + // check in every other set + ok := true + for i := 1; i < len(g.values); i++ { + found := false + for _, w := range g.values[i] { + if equalsGeneric(v, w) { + found = true + break + } + } + if !found { + continue outer + } + } + if ok { + inter = append(inter, v) + } + } + for _, i := range g.idxs { + used[i] = true + } + if len(inter) == 0 { + out = append(out, newAlwaysFalseExpr()) + } else { + out = append(out, newTermExpr(g.col, inter)) + } + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +// AND: (a IN S) AND (a != d) -> remove d from S; empty -> false +func (v *visitor) combineAndInWithNotEqual(parts []*planpb.Expr) []*planpb.Expr { + type group struct { + col *planpb.ColumnInfo + termIdx int + term *planpb.TermExpr + neqIdxs []int + neqVals []*planpb.GenericValue + } + groups := map[string]*group{} + others := []int{} + for idx, e := range parts { + if te := e.GetTermExpr(); te != nil { + k := columnKey(te.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &group{col: te.GetColumnInfo()} + groups[k] = g + } + g.term = te + g.termIdx = idx + continue + } + if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_NotEqual && ue.GetValue() != nil && ue.GetColumnInfo() != nil { + k := columnKey(ue.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &group{col: ue.GetColumnInfo()} + groups[k] = g + } + g.neqIdxs = append(g.neqIdxs, idx) + g.neqVals = append(g.neqVals, ue.GetValue()) + continue + } + others = append(others, idx) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, i := range others { + out = append(out, parts[i]) + used[i] = true + } + for _, g := range groups { + if g.term == nil || len(g.neqIdxs) == 0 { + continue + } + filtered := []*planpb.GenericValue{} + for _, tv := range g.term.GetValues() { + excluded := false + for _, dv := range g.neqVals { + if equalsGeneric(tv, dv) { + excluded = true + break + } + } + if !excluded { + filtered = append(filtered, tv) + } + } + used[g.termIdx] = true + for _, ni := range g.neqIdxs { + used[ni] = true + } + if len(filtered) == 0 { + out = append(out, newAlwaysFalseExpr()) + } else { + out = append(out, newTermExpr(g.col, filtered)) + } + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} + +// OR: (a IN S) OR (a != d) -> if d ∈ S then true else (a != d) +func (v *visitor) combineOrInWithNotEqual(parts []*planpb.Expr) []*planpb.Expr { + type group struct { + col *planpb.ColumnInfo + termIdx int + term *planpb.TermExpr + neqIdxs []int + neqVals []*planpb.GenericValue + } + groups := map[string]*group{} + others := []int{} + for idx, e := range parts { + if te := e.GetTermExpr(); te != nil { + k := columnKey(te.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &group{col: te.GetColumnInfo()} + groups[k] = g + } + g.term = te + g.termIdx = idx + continue + } + if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_NotEqual && ue.GetValue() != nil && ue.GetColumnInfo() != nil { + k := columnKey(ue.GetColumnInfo()) + g := groups[k] + if g == nil { + g = &group{col: ue.GetColumnInfo()} + groups[k] = g + } + g.neqIdxs = append(g.neqIdxs, idx) + g.neqVals = append(g.neqVals, ue.GetValue()) + continue + } + others = append(others, idx) + } + used := make([]bool, len(parts)) + out := make([]*planpb.Expr, 0, len(parts)) + for _, i := range others { + out = append(out, parts[i]) + used[i] = true + } + for _, g := range groups { + if g.term == nil || len(g.neqIdxs) == 0 { + continue + } + // if any neq value is inside IN set -> true + containsAny := false + for _, dv := range g.neqVals { + for _, tv := range g.term.GetValues() { + if equalsGeneric(tv, dv) { + containsAny = true + break + } + } + if containsAny { + break + } + } + if containsAny { + used[g.termIdx] = true + for _, ni := range g.neqIdxs { + used[ni] = true + } + out = append(out, newBoolConstExpr(true)) + } else { + // drop the IN; keep != as-is + used[g.termIdx] = true + } + } + for i := range parts { + if !used[i] { + out = append(out, parts[i]) + } + } + return out +} diff --git a/internal/parser/planparserv2/rewriter/term_in_test.go b/internal/parser/planparserv2/rewriter/term_in_test.go new file mode 100644 index 0000000000..596e6aa848 --- /dev/null +++ b/internal/parser/planparserv2/rewriter/term_in_test.go @@ -0,0 +1,356 @@ +package rewriter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + parser "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +func buildSchemaHelperForRewriteT(t *testing.T) *typeutil.SchemaHelper { + fields := []*schemapb.FieldSchema{ + {FieldID: 101, Name: "Int64Field", DataType: schemapb.DataType_Int64}, + {FieldID: 102, Name: "VarCharField", DataType: schemapb.DataType_VarChar}, + {FieldID: 103, Name: "StringField", DataType: schemapb.DataType_String}, + {FieldID: 104, Name: "FloatField", DataType: schemapb.DataType_Double}, + {FieldID: 105, Name: "BoolField", DataType: schemapb.DataType_Bool}, + } + schema := &schemapb.CollectionSchema{ + Name: "rewrite_test", + AutoID: false, + Fields: fields, + } + // enable text_match on string-like fields + for _, f := range schema.Fields { + if typeutil.IsStringType(f.DataType) { + f.TypeParams = append(f.TypeParams, &commonpb.KeyValuePair{ + Key: "enable_match", + Value: "True", + }) + } + } + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + return helper +} + +func TestRewrite_OREquals_ToIN_NonNumeric(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `VarCharField == "a" or VarCharField == "b"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term, "expected OR-equals to be rewritten to TermExpr(IN ...)") + require.Equal(t, 2, len(term.GetValues())) + require.Equal(t, "a", term.GetValues()[0].GetStringVal()) + require.Equal(t, "b", term.GetValues()[1].GetStringVal()) +} + +func TestRewrite_OREquals_NotMerged_OnNumericBelowThreshold(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field == 1 or Int64Field == 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.Nil(t, expr.GetTermExpr(), "numeric OR-equals should not merge to IN under threshold") + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) + require.NotNil(t, be.GetLeft().GetUnaryRangeExpr()) + require.NotNil(t, be.GetRight().GetUnaryRangeExpr()) + require.Equal(t, planpb.OpType_Equal, be.GetLeft().GetUnaryRangeExpr().GetOp()) + require.Equal(t, planpb.OpType_Equal, be.GetRight().GetUnaryRangeExpr().GetOp()) +} + +func TestRewrite_Term_SortAndDedup_String(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `VarCharField in ["b","a","b","a"]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, 2, len(term.GetValues())) + require.Equal(t, "a", term.GetValues()[0].GetStringVal()) + require.Equal(t, "b", term.GetValues()[1].GetStringVal()) +} + +func TestRewrite_Term_SortAndDedup_Int(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [9,4,6,6,7]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, 4, len(term.GetValues())) + got := []int64{ + term.GetValues()[0].GetInt64Val(), + term.GetValues()[1].GetInt64Val(), + term.GetValues()[2].GetInt64Val(), + term.GetValues()[3].GetInt64Val(), + } + require.ElementsMatch(t, []int64{4, 6, 7, 9}, got) +} + +func TestRewrite_NotIn_SortAndDedup_Int(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field not in [4,4,3]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + un := expr.GetUnaryExpr() + require.NotNil(t, un) + term := un.GetChild().GetTermExpr() + require.NotNil(t, term) + require.Equal(t, 2, len(term.GetValues())) + require.Equal(t, int64(3), term.GetValues()[0].GetInt64Val()) + require.Equal(t, int64(4), term.GetValues()[1].GetInt64Val()) +} + +func TestRewrite_NotIn_SortAndDedup_Float(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `FloatField not in [4.0,4,3.0]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + un := expr.GetUnaryExpr() + require.NotNil(t, un) + term := un.GetChild().GetTermExpr() + require.NotNil(t, term) + require.Equal(t, 2, len(term.GetValues())) + require.Equal(t, 3.0, term.GetValues()[0].GetFloatVal()) + require.Equal(t, 4.0, term.GetValues()[1].GetFloatVal()) +} + +func TestRewrite_In_SortAndDedup_Bool(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `BoolField in [true,false,false,true]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, 2, len(term.GetValues())) + require.Equal(t, false, term.GetValues()[0].GetBoolVal()) + require.Equal(t, true, term.GetValues()[1].GetBoolVal()) +} + +func TestRewrite_Flatten_Then_OR_ToIN(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `VarCharField == "a" or (VarCharField == "b" or VarCharField == "c") or VarCharField == "d"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term, "nested OR-equals should flatten and merge to IN") + require.Equal(t, 4, len(term.GetValues())) + got := []string{ + term.GetValues()[0].GetStringVal(), + term.GetValues()[1].GetStringVal(), + term.GetValues()[2].GetStringVal(), + term.GetValues()[3].GetStringVal(), + } + require.ElementsMatch(t, []string{"a", "b", "c", "d"}, got) +} + +func TestRewrite_And_In_And_Equal_VInSet_ReducesToEqual(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field == 3`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_Equal, ure.GetOp()) + require.Equal(t, int64(3), ure.GetValue().GetInt64Val()) +} + +func TestRewrite_And_In_And_Equal_VNotInSet_False(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field == 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +func TestRewrite_Or_In_Or_Equal_Union(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,3] or Int64Field == 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, []int64{1, 2, 3}, []int64{ + term.GetValues()[0].GetInt64Val(), + term.GetValues()[1].GetInt64Val(), + term.GetValues()[2].GetInt64Val(), + }) +} + +func TestRewrite_And_In_With_Range_Filter(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field > 3`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, 1, len(term.GetValues())) + require.Equal(t, int64(5), term.GetValues()[0].GetInt64Val()) +} + +func TestRewrite_Or_In_Union(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,3] or Int64Field in [3,4]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, []int64{1, 3, 4}, []int64{ + term.GetValues()[0].GetInt64Val(), + term.GetValues()[1].GetInt64Val(), + term.GetValues()[2].GetInt64Val(), + }) +} + +func TestRewrite_And_In_Intersection(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,2,3] and Int64Field in [2,3,4]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, []int64{2, 3}, []int64{ + term.GetValues()[0].GetInt64Val(), + term.GetValues()[1].GetInt64Val(), + }) +} + +func TestRewrite_And_In_Intersection_Empty_ToFalse(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1] and Int64Field in [2]`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +func TestRewrite_And_In_And_NotEqual_Remove(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,2,3] and Int64Field != 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + term := expr.GetTermExpr() + require.NotNil(t, term) + require.Equal(t, []int64{1, 3}, []int64{ + term.GetValues()[0].GetInt64Val(), + term.GetValues()[1].GetInt64Val(), + }) +} + +func TestRewrite_And_In_And_NotEqual_AllRemoved_ToFalse(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [2] and Int64Field != 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + require.True(t, rewriter.IsAlwaysFalseExpr(expr)) +} + +func TestRewrite_Or_In_Or_NotEqual_VInSet_ToTrue(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1,2] or Int64Field != 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + val := expr.GetValueExpr() + require.NotNil(t, val) + require.Equal(t, true, val.GetValue().GetBoolVal()) +} + +func TestRewrite_Or_In_Or_NotEqual_VNotInSet_ToNotEqual(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `Int64Field in [1] or Int64Field != 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure) + require.Equal(t, planpb.OpType_NotEqual, ure.GetOp()) + require.Equal(t, int64(2), ure.GetValue().GetInt64Val()) +} + +// Test contradictory equals: (a == 1) AND (a == 2) → false +// NOTE: This is a known limitation - currently NOT optimized +func TestRewrite_And_Equal_And_Equal_Contradiction_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field == 1 AND Int64Field == 2 → false (contradiction) + // Currently NOT optimized because equals don't convert to IN in AND context + expr, err := parser.ParseExpr(helper, `Int64Field == 1 and Int64Field == 2`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: NOT optimized to false + // Remains as BinaryExpr AND with two equal predicates + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as AND (not optimized)") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) +} + +// Test contradictory equals with three values +func TestRewrite_And_Equal_ThreeWay_Contradiction_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field == 1 AND Int64Field == 2 AND Int64Field == 3 → false + // Currently NOT optimized + expr, err := parser.ParseExpr(helper, `Int64Field == 1 and Int64Field == 2 and Int64Field == 3`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: NOT optimized to false + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as AND chain (not optimized)") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) +} + +// Test range + contradictory equal: (a > 10) AND (a == 5) → false +// NOTE: Current limitation - NOT optimized +func TestRewrite_And_Range_And_Equal_Contradiction_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field > 10 AND Int64Field == 5 → false (5 is not > 10) + // This requires combining range and equality checks, which is partially implemented + expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field == 5`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current behavior: may not optimize to false + // If it does optimize via IN+range filtering, it would become false + // This test documents current limitation + // When fully optimized, should be constant false + _ = expr // Test documents that this case exists +} + +// Test non-contradictory range + equal: (a > 10) AND (a == 15) → stays as is +// NOTE: This requires Equal to be in IN form first, which happens via combineAndInWithEqual +// But Equal alone doesn't convert to IN in AND context, so this optimization doesn't happen +func TestRewrite_And_Range_And_Equal_NonContradiction_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // Int64Field > 10 AND Int64Field == 15 → should simplify to Int64Field == 15 + // But currently NOT optimized without IN involved + expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field == 15`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: remains as AND with range and equal + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as AND (not optimized)") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) +} + +// Test string contradictory equals - current limitation +func TestRewrite_And_Equal_String_Contradiction_CurrentLimitation(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + // VarCharField == "apple" AND VarCharField == "banana" → false + // Currently NOT optimized + expr, err := parser.ParseExpr(helper, `VarCharField == "apple" and VarCharField == "banana"`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + + // Current limitation: NOT optimized to false + be := expr.GetBinaryExpr() + require.NotNil(t, be, "should remain as AND (not optimized)") + require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp()) +} diff --git a/internal/parser/planparserv2/rewriter/text_match.go b/internal/parser/planparserv2/rewriter/text_match.go new file mode 100644 index 0000000000..53b51a33cd --- /dev/null +++ b/internal/parser/planparserv2/rewriter/text_match.go @@ -0,0 +1,78 @@ +package rewriter + +import ( + "strings" + + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" +) + +func (v *visitor) combineOrTextMatchToMerged(parts []*planpb.Expr) []*planpb.Expr { + type group struct { + col *planpb.ColumnInfo + origIndices []int + literals []string + } + others := make([]*planpb.Expr, 0, len(parts)) + groups := make(map[string]*group) + indexToExpr := parts + for idx, e := range parts { + u := e.GetUnaryRangeExpr() + if u == nil || u.GetOp() != planpb.OpType_TextMatch || u.GetValue() == nil { + others = append(others, e) + continue + } + if len(u.GetExtraValues()) > 0 { + others = append(others, e) + continue + } + col := u.GetColumnInfo() + if col == nil { + others = append(others, e) + continue + } + key := columnKey(col) + g, ok := groups[key] + if !ok { + g = &group{col: col} + groups[key] = g + } + literal := u.GetValue().GetStringVal() + g.literals = append(g.literals, literal) + g.origIndices = append(g.origIndices, idx) + } + out := make([]*planpb.Expr, 0, len(parts)) + out = append(out, others...) + for _, g := range groups { + if len(g.origIndices) <= 1 { + for _, i := range g.origIndices { + out = append(out, indexToExpr[i]) + } + continue + } + if len(g.literals) == 0 { + for _, i := range g.origIndices { + out = append(out, indexToExpr[i]) + } + continue + } + merged := strings.Join(g.literals, " ") + out = append(out, newTextMatchExpr(g.col, merged)) + } + return out +} + +func newTextMatchExpr(col *planpb.ColumnInfo, literal string) *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_UnaryRangeExpr{ + UnaryRangeExpr: &planpb.UnaryRangeExpr{ + ColumnInfo: col, + Op: planpb.OpType_TextMatch, + Value: &planpb.GenericValue{ + Val: &planpb.GenericValue_StringVal{ + StringVal: literal, + }, + }, + }, + }, + } +} diff --git a/internal/parser/planparserv2/rewriter/text_match_test.go b/internal/parser/planparserv2/rewriter/text_match_test.go new file mode 100644 index 0000000000..fd659d415d --- /dev/null +++ b/internal/parser/planparserv2/rewriter/text_match_test.go @@ -0,0 +1,122 @@ +package rewriter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + parser "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" +) + +// collectTextMatchLiterals walks the provided expr and collects all OpType_TextMatch literals (grouped by fieldId) into the returned map. +func collectTextMatchLiterals(expr *planpb.Expr) map[int64]string { + colToLiteral := map[int64]string{} + var collect func(e *planpb.Expr) + collect = func(e *planpb.Expr) { + if e == nil { + return + } + if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_TextMatch { + col := ue.GetColumnInfo() + colToLiteral[col.GetFieldId()] = ue.GetValue().GetStringVal() + return + } + if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalOr { + collect(be.GetLeft()) + collect(be.GetRight()) + return + } + } + collect(expr) + return colToLiteral +} + +func TestRewrite_TextMatch_OR_Merge(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A C") or text_match(VarCharField, "B D")`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + ure := expr.GetUnaryRangeExpr() + require.NotNil(t, ure, "should merge to single text_match") + require.Equal(t, planpb.OpType_TextMatch, ure.GetOp()) + require.Equal(t, "A C B D", ure.GetValue().GetStringVal()) +} + +func TestRewrite_TextMatch_OR_DifferentField_NoMerge(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or text_match(StringField, "B")`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) + require.NotNil(t, be.GetLeft().GetUnaryRangeExpr()) + require.NotNil(t, be.GetRight().GetUnaryRangeExpr()) +} + +func TestRewrite_TextMatch_OR_MultiFields_Merge(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or text_match(StringField, "A") or text_match(VarCharField, "B")`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + // should be text_match(VarCharField, "A B") or text_match(StringField, "A") + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) + + // collect all text_match literals grouped by column id + colToLiteral := collectTextMatchLiterals(expr) + require.Equal(t, 2, len(colToLiteral)) + // one of them must be "A B", and the other is "A" + hasAB := false + hasA := false + for _, lit := range colToLiteral { + if lit == "A B" { + hasAB = true + } + if lit == "A" { + hasA = true + } + } + require.True(t, hasAB) + require.True(t, hasA) +} + +func TestRewrite_TextMatch_OR_MoreMultiFields_Merge(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or (text_match(StringField, "C") or text_match(VarCharField, "B")) or text_match(StringField, "D")`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + // should be text_match(VarCharField, "A B") or text_match(StringField, "C D") + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) + + // collect all text_match literals grouped by column id + colToLiteral := collectTextMatchLiterals(expr) + require.Equal(t, 2, len(colToLiteral)) + // one of them must be "A B", and the other is "C D" + hasAB := false + hasCD := false + for _, lit := range colToLiteral { + if lit == "A B" { + hasAB = true + } + if lit == "C D" { + hasCD = true + } + } + require.True(t, hasAB) + require.True(t, hasCD) +} + +func TestRewrite_TextMatch_OR_WithOption_NoMerge(t *testing.T) { + helper := buildSchemaHelperForRewriteT(t) + expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A", minimum_should_match=1) or text_match(VarCharField, "B")`, nil) + require.NoError(t, err) + require.NotNil(t, expr) + be := expr.GetBinaryExpr() + require.NotNil(t, be) + require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp()) +} diff --git a/internal/parser/planparserv2/rewriter/util.go b/internal/parser/planparserv2/rewriter/util.go new file mode 100644 index 0000000000..ea741005c5 --- /dev/null +++ b/internal/parser/planparserv2/rewriter/util.go @@ -0,0 +1,290 @@ +package rewriter + +import ( + "fmt" + "sort" + "strings" + + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +func columnKey(c *planpb.ColumnInfo) string { + var b strings.Builder + b.WriteString(fmt.Sprintf("%d|%d|%d|%t|%t|%t|", + c.GetFieldId(), + int32(c.GetDataType()), + int32(c.GetElementType()), + c.GetIsPrimaryKey(), + c.GetIsAutoID(), + c.GetIsPartitionKey(), + )) + for _, p := range c.GetNestedPath() { + b.WriteString(p) + b.WriteByte('|') + } + return b.String() +} + +// effectiveDataType returns the real scalar type to be used for comparisons. +// For JSON/Array columns with a concrete element_type, use element_type; +// otherwise fall back to the column data_type. +func effectiveDataType(c *planpb.ColumnInfo) schemapb.DataType { + if c == nil { + return schemapb.DataType_None + } + dt := c.GetDataType() + if dt == schemapb.DataType_JSON || dt == schemapb.DataType_Array { + et := c.GetElementType() + // Treat 0 (None/Invalid) as not specified; otherwise use element type. + if et != schemapb.DataType_None { + return et + } + } + return dt +} + +func valueCase(v *planpb.GenericValue) string { + switch v.GetVal().(type) { + case *planpb.GenericValue_BoolVal: + return "bool" + case *planpb.GenericValue_Int64Val: + return "int64" + case *planpb.GenericValue_FloatVal: + return "float" + case *planpb.GenericValue_StringVal: + return "string" + case *planpb.GenericValue_ArrayVal: + return "array" + default: + return "other" + } +} + +func valueCaseWithNil(v *planpb.GenericValue) string { + if v == nil || v.GetVal() == nil { + return "nil" + } + return valueCase(v) +} + +func isNumericCase(k string) bool { + return k == "int64" || k == "float" +} + +func areComparableCases(a, b string) bool { + if a == "nil" || b == "nil" { + return false + } + if isNumericCase(a) && isNumericCase(b) { + return true + } + return a == b && (a == "bool" || a == "string") +} + +func isNumericType(dt schemapb.DataType) bool { + if typeutil.IsBoolType(dt) || typeutil.IsStringType(dt) || typeutil.IsJSONType(dt) { + return false + } + return typeutil.IsArithmetic(dt) +} + +const defaultConvertOrToInNumericLimit = 150 + +func shouldMergeToIn(dt schemapb.DataType, count int) bool { + if isNumericType(dt) { + return count > defaultConvertOrToInNumericLimit + } + return count > 1 +} + +func sortTermValues(term *planpb.TermExpr) { + if term == nil || len(term.GetValues()) <= 1 { + return + } + term.Values = sortGenericValues(term.Values) +} + +// sort and deduplicate a list of generic values. +func sortGenericValues(values []*planpb.GenericValue) []*planpb.GenericValue { + if len(values) <= 1 { + return values + } + var kind string + for _, v := range values { + if v == nil || v.GetVal() == nil { + continue + } + kind = valueCase(v) + if kind != "" && kind != "other" && kind != "array" { + break + } + } + switch kind { + case "bool": + sort.Slice(values, func(i, j int) bool { + return !values[i].GetBoolVal() && values[j].GetBoolVal() + }) + values = lo.UniqBy(values, func(v *planpb.GenericValue) bool { return v.GetBoolVal() }) + case "int64": + sort.Slice(values, func(i, j int) bool { + return values[i].GetInt64Val() < values[j].GetInt64Val() + }) + values = lo.UniqBy(values, func(v *planpb.GenericValue) int64 { return v.GetInt64Val() }) + case "float": + sort.Slice(values, func(i, j int) bool { + return values[i].GetFloatVal() < values[j].GetFloatVal() + }) + values = lo.UniqBy(values, func(v *planpb.GenericValue) float64 { return v.GetFloatVal() }) + case "string": + sort.Slice(values, func(i, j int) bool { + return values[i].GetStringVal() < values[j].GetStringVal() + }) + values = lo.UniqBy(values, func(v *planpb.GenericValue) string { return v.GetStringVal() }) + } + return values +} + +func newTermExpr(col *planpb.ColumnInfo, values []*planpb.GenericValue) *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_TermExpr{ + TermExpr: &planpb.TermExpr{ + ColumnInfo: col, + Values: values, + }, + }, + } +} + +func newUnaryRangeExpr(col *planpb.ColumnInfo, op planpb.OpType, val *planpb.GenericValue) *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_UnaryRangeExpr{ + UnaryRangeExpr: &planpb.UnaryRangeExpr{ + ColumnInfo: col, + Op: op, + Value: val, + }, + }, + } +} + +func newBoolConstExpr(v bool) *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_ValueExpr{ + ValueExpr: &planpb.ValueExpr{ + Value: &planpb.GenericValue{ + Val: &planpb.GenericValue_BoolVal{ + BoolVal: v, + }, + }, + }, + }, + } +} + +func newAlwaysTrueExpr() *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_AlwaysTrueExpr{ + AlwaysTrueExpr: &planpb.AlwaysTrueExpr{}, + }, + } +} + +func newAlwaysFalseExpr() *planpb.Expr { + return &planpb.Expr{ + Expr: &planpb.Expr_UnaryExpr{ + UnaryExpr: &planpb.UnaryExpr{ + Op: planpb.UnaryExpr_Not, + Child: newAlwaysTrueExpr(), + }, + }, + } +} + +// IsAlwaysTrueExpr checks if the expression is an AlwaysTrueExpr +func IsAlwaysTrueExpr(e *planpb.Expr) bool { + if e == nil { + return false + } + return e.GetAlwaysTrueExpr() != nil +} + +func IsAlwaysFalseExpr(e *planpb.Expr) bool { + if e == nil { + return false + } + ue := e.GetUnaryExpr() + if ue == nil || ue.GetOp() != planpb.UnaryExpr_Not { + return false + } + return IsAlwaysTrueExpr(ue.GetChild()) +} + +// equalsGeneric compares two GenericValue by content (bool/int/float/string). +func equalsGeneric(a, b *planpb.GenericValue) bool { + if a.GetVal() == nil || b.GetVal() == nil { + return false + } + + switch a.GetVal().(type) { + case *planpb.GenericValue_BoolVal: + if _, ok := b.GetVal().(*planpb.GenericValue_BoolVal); ok { + return a.GetBoolVal() == b.GetBoolVal() + } + case *planpb.GenericValue_Int64Val: + if _, ok := b.GetVal().(*planpb.GenericValue_Int64Val); ok { + return a.GetInt64Val() == b.GetInt64Val() + } + case *planpb.GenericValue_FloatVal: + if _, ok := b.GetVal().(*planpb.GenericValue_FloatVal); ok { + return a.GetFloatVal() == b.GetFloatVal() + } + case *planpb.GenericValue_StringVal: + if _, ok := b.GetVal().(*planpb.GenericValue_StringVal); ok { + return a.GetStringVal() == b.GetStringVal() + } + } + return false +} + +func satisfiesLower(dt schemapb.DataType, v, lower *planpb.GenericValue, inclusive bool) bool { + c := cmpGeneric(dt, v, lower) + if inclusive { + return c >= 0 + } + return c > 0 +} + +func satisfiesUpper(dt schemapb.DataType, v, upper *planpb.GenericValue, inclusive bool) bool { + c := cmpGeneric(dt, v, upper) + if inclusive { + return c <= 0 + } + return c < 0 +} + +func filterValuesByRange(dt schemapb.DataType, values []*planpb.GenericValue, lower *planpb.GenericValue, lowerInc bool, upper *planpb.GenericValue, upperInc bool) []*planpb.GenericValue { + out := make([]*planpb.GenericValue, 0, len(values)) + for _, v := range values { + pass := true + if lower != nil && !satisfiesLower(dt, v, lower, lowerInc) { + pass = false + } + if pass && upper != nil && !satisfiesUpper(dt, v, upper, upperInc) { + pass = false + } + if pass { + out = append(out, v) + } + } + return out +} + +func unionValues(valuesA, valuesB []*planpb.GenericValue) []*planpb.GenericValue { + all := append([]*planpb.GenericValue{}, valuesA...) + all = append(all, valuesB...) + return sortGenericValues(all) +} diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index a811b30c5d..b7ddbaf187 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -4709,7 +4709,7 @@ func (s *MaterializedViewTestSuite) TestMvEnabledPartitionKeyOnVarCharWithIsolat schema := ConstructCollectionSchemaWithPartitionKey(s.colName, s.fieldName2Types, testInt64Field, testVarCharField, false) schemaInfo := newSchemaInfo(schema) s.mockMetaCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schemaInfo, nil) - s.ErrorContains(task.PreExecute(s.ctx), "partition key isolation does not support OR") + s.ErrorContains(task.PreExecute(s.ctx), "partition key isolation does not support IN") } } diff --git a/internal/util/exprutil/expr_checker_test.go b/internal/util/exprutil/expr_checker_test.go index e64259990a..224ae5cdbc 100644 --- a/internal/util/exprutil/expr_checker_test.go +++ b/internal/util/exprutil/expr_checker_test.go @@ -80,9 +80,9 @@ func TestParsePartitionKeys(t *testing.T) { { name: "binary_expr_and with partition key in range", expr: "partition_key_field in [7, 8] && partition_key_field > 9", - expected: 2, - validPartitionKeys: []int64{7, 8}, - invalidPartitionKeys: []int64{9}, + expected: 0, + validPartitionKeys: []int64{}, + invalidPartitionKeys: []int64{}, }, { name: "binary_expr_and with partition key in range2", @@ -295,7 +295,7 @@ func TestValidatePartitionKeyIsolation(t *testing.T) { { name: "partition key isolation equal AND with same field term", expr: "key_field == 10 && key_field in [10]", - expectedErrorString: "partition key isolation does not support IN", + expectedErrorString: "", }, { name: "partition key isolation equal OR with same field equal", diff --git a/tests/go_client/testcases/delete_test.go b/tests/go_client/testcases/delete_test.go index b6f750f910..e026c5261e 100644 --- a/tests/go_client/testcases/delete_test.go +++ b/tests/go_client/testcases/delete_test.go @@ -589,33 +589,3 @@ func TestDeleteInvalidExpr(t *testing.T) { common.CheckErr(t, err, _invalidExpr.ErrNil, _invalidExpr.ErrMsg) } } - -// test delete with duplicated data ids -func TestDeleteDuplicatedPks(t *testing.T) { - ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) - mc := hp.CreateDefaultMilvusClient(ctx, t) - - // create collection and a partition - cp := hp.NewCreateCollectionParams(hp.Int64Vec) - prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption().TWithIsDynamic(true), hp.TNewSchemaOption()) - - // insert [0, 3000) into default - prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithMaxCapacity(common.TestCapacity)) - prepare.FlushData(ctx, t, mc, schema.CollectionName) - - // index and load - prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) - prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) - - // delete - deleteIDs := []int64{0, 0, 0, 0, 0} - delRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithInt64IDs(common.DefaultInt64FieldName, deleteIDs)) - common.CheckErr(t, err, true) - require.Equal(t, 5, int(delRes.DeleteCount)) - - // query, verify delete success - expr := fmt.Sprintf("%s >= 0 ", common.DefaultInt64FieldName) - resQuery, errQuery := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithConsistencyLevel(entity.ClStrong)) - common.CheckErr(t, errQuery, true) - require.Equal(t, common.DefaultNb-1, resQuery.ResultCount) -} diff --git a/tests/python_client/testcases/test_delete.py b/tests/python_client/testcases/test_delete.py index 456657ca93..1f957c5ead 100644 --- a/tests/python_client/testcases/test_delete.py +++ b/tests/python_client/testcases/test_delete.py @@ -565,22 +565,6 @@ class TestDeleteOperation(TestcaseBase): # since the search requests arrived query nodes earlier than query nodes consume the delete requests. assert len(inter) == 0 - @pytest.mark.tags(CaseLabel.L1) - def test_delete_expr_repeated_values(self): - """ - target: test delete with repeated values - method: 1.insert data with unique primary keys - 2.delete with repeated values: 'id in [0, 0]' - expected: delete one entity - """ - # init collection with nb default data - collection_w = self.init_collection_general(prefix, nb=tmp_nb, insert_data=True)[0] - expr = f'{ct.default_int64_field_name} in {[0, 0, 0]}' - del_res, _ = collection_w.delete(expr) - assert del_res.delete_count == 3 - collection_w.num_entities - collection_w.query(expr, check_task=CheckTasks.check_query_empty) - @pytest.mark.tags(CaseLabel.L1) def test_delete_duplicate_primary_keys(self): """ @@ -1433,7 +1417,7 @@ class TestDeleteString(TestcaseBase): self.init_collection_general(prefix, nb=tmp_nb, insert_data=True, primary_field=ct.default_string_field_name)[0] expr = f'{ct.default_string_field_name} in ["0", "0", "0"]' del_res, _ = collection_w.delete(expr) - assert del_res.delete_count == 3 + assert del_res.delete_count == 1 collection_w.num_entities collection_w.query(expr, check_task=CheckTasks.check_query_empty) @@ -1939,7 +1923,7 @@ class TestDeleteString(TestcaseBase): # delete string_expr = "varchar in [\"\", \"\"]" del_res, _ = collection_w.delete(string_expr) - assert del_res.delete_count == 2 + assert del_res.delete_count == 1 # load and query with id collection_w.load()