enhance: moved query optimization to proxy, added various optimizations (#45526)

issue: https://github.com/milvus-io/milvus/issues/45525

see added README.md for added optimizations

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added query expression optimization feature with a new `optimizeExpr`
configuration flag to enable automatic simplification of filter
predicates, including range predicate optimization, merging of IN/NOT IN
conditions, and flattening of nested logical operators.

* **Bug Fixes**
* Adjusted delete operation behavior to correctly handle expression
evaluation.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
This commit is contained in:
Buqian Zheng 2025-12-24 00:39:19 +08:00 committed by GitHub
parent 7fca6e759f
commit e379b1f0f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 3982 additions and 978 deletions

View File

@ -18,7 +18,6 @@
#include "common/EasyAssert.h"
#include "common/Tracer.h"
#include "fmt/format.h"
#include "exec/expression/AlwaysTrueExpr.h"
#include "exec/expression/BinaryArithOpEvalRangeExpr.h"
#include "exec/expression/BinaryRangeExpr.h"
@ -84,75 +83,19 @@ CompileExpressions(const std::vector<expr::TypedExprPtr>& sources,
return exprs;
}
static std::optional<std::string>
ShouldFlatten(const expr::TypedExprPtr& expr,
const std::unordered_set<std::string>& flat_candidates = {}) {
if (auto call =
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
if (call->op_type_ == expr::LogicalBinaryExpr::OpType::And ||
call->op_type_ == expr::LogicalBinaryExpr::OpType::Or) {
return call->name();
}
}
return std::nullopt;
}
static bool
IsCall(const expr::TypedExprPtr& expr, const std::string& name) {
if (auto call =
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
return call->name() == name;
}
return false;
}
static bool
AllInputTypeEqual(const expr::TypedExprPtr& expr) {
const auto& inputs = expr->inputs();
for (int i = 1; i < inputs.size(); i++) {
if (inputs[0]->type() != inputs[i]->type()) {
return false;
}
}
return true;
}
static void
FlattenInput(const expr::TypedExprPtr& input,
const std::string& flatten_call,
std::vector<expr::TypedExprPtr>& flat) {
if (IsCall(input, flatten_call) && AllInputTypeEqual(input)) {
for (auto& child : input->inputs()) {
FlattenInput(child, flatten_call, flat);
}
} else {
flat.emplace_back(input);
}
}
std::vector<ExprPtr>
CompileInputs(const expr::TypedExprPtr& expr,
QueryContext* context,
const std::unordered_set<std::string>& flatten_cadidates) {
std::vector<ExprPtr> compiled_inputs;
auto flatten = ShouldFlatten(expr);
for (auto& input : expr->inputs()) {
if (dynamic_cast<const expr::InputTypeExpr*>(input.get())) {
AssertInfo(
dynamic_cast<const expr::FieldAccessTypeExpr*>(expr.get()),
"An InputReference can only occur under a FieldReference");
} else {
if (flatten.has_value()) {
std::vector<expr::TypedExprPtr> flat_exprs;
FlattenInput(input, flatten.value(), flat_exprs);
for (auto& input : flat_exprs) {
compiled_inputs.push_back(CompileExpression(
input, context, flatten_cadidates, false));
}
} else {
compiled_inputs.push_back(CompileExpression(
input, context, flatten_cadidates, false));
}
compiled_inputs.push_back(
CompileExpression(input, context, flatten_cadidates, false));
}
}
return compiled_inputs;
@ -507,170 +450,6 @@ ReorderConjunctExpr(std::shared_ptr<milvus::exec::PhyConjunctFilterExpr>& expr,
expr->Reorder(reorder);
}
inline std::shared_ptr<PhyLogicalUnaryExpr>
ConvertMultiNotEqualToNotInExpr(std::vector<std::shared_ptr<Expr>>& exprs,
std::vector<size_t> indices,
ExecContext* context) {
std::vector<proto::plan::GenericValue> values;
auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
for (auto& i : indices) {
auto expr = std::static_pointer_cast<PhyUnaryRangeFilterExpr>(exprs[i])
->GetLogicalExpr();
if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) {
type = expr->val_.val_case();
}
if (type != expr->val_.val_case()) {
return nullptr;
}
values.push_back(expr->val_);
}
auto logical_expr = std::make_shared<milvus::expr::TermFilterExpr>(
exprs[indices[0]]->GetColumnInfo().value(), values);
auto query_context = context->get_query_context();
auto term_expr = std::make_shared<PhyTermFilterExpr>(
std::vector<std::shared_ptr<Expr>>{},
logical_expr,
"PhyTermFilterExpr",
query_context->get_op_context(),
query_context->get_segment(),
query_context->get_active_count(),
query_context->get_query_timestamp(),
query_context->query_config()->get_expr_batch_size(),
query_context->get_consistency_level());
return std::make_shared<PhyLogicalUnaryExpr>(
std::vector<std::shared_ptr<Expr>>{term_expr},
std::make_shared<milvus::expr::LogicalUnaryExpr>(
milvus::expr::LogicalUnaryExpr::OpType::LogicalNot, logical_expr),
"PhyLogicalUnaryExpr",
query_context->get_op_context());
}
inline std::shared_ptr<PhyTermFilterExpr>
ConvertMultiOrToInExpr(std::vector<std::shared_ptr<Expr>>& exprs,
std::vector<size_t> indices,
ExecContext* context) {
std::vector<proto::plan::GenericValue> values;
auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
for (auto& i : indices) {
auto expr = std::static_pointer_cast<PhyUnaryRangeFilterExpr>(exprs[i])
->GetLogicalExpr();
if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) {
type = expr->val_.val_case();
}
if (type != expr->val_.val_case()) {
return nullptr;
}
values.push_back(expr->val_);
}
auto logical_expr = std::make_shared<milvus::expr::TermFilterExpr>(
exprs[indices[0]]->GetColumnInfo().value(), values);
auto query_context = context->get_query_context();
return std::make_shared<PhyTermFilterExpr>(
std::vector<std::shared_ptr<Expr>>{},
logical_expr,
"PhyTermFilterExpr",
query_context->get_op_context(),
query_context->get_segment(),
query_context->get_active_count(),
query_context->get_query_timestamp(),
query_context->query_config()->get_expr_batch_size(),
query_context->get_consistency_level());
}
inline void
RewriteConjunctExpr(std::shared_ptr<milvus::exec::PhyConjunctFilterExpr>& expr,
ExecContext* context) {
// covert A = .. or A = .. or A = .. to A in (.., .., ..)
if (expr->IsOr()) {
auto& inputs = expr->GetInputsRef();
std::map<expr::ColumnInfo, std::vector<size_t>> expr_indices;
for (size_t i = 0; i < inputs.size(); i++) {
auto input = inputs[i];
if (input->name() == "PhyUnaryRangeFilterExpr") {
auto phy_expr =
std::static_pointer_cast<PhyUnaryRangeFilterExpr>(input);
if (phy_expr->GetOpType() == proto::plan::OpType::Equal) {
auto column = phy_expr->GetColumnInfo().value();
if (expr_indices.find(column) != expr_indices.end()) {
expr_indices[column].push_back(i);
} else {
expr_indices[column] = {i};
}
}
}
if (input->name() == "PhyConjunctFilterExpr") {
auto expr = std::static_pointer_cast<
milvus::exec::PhyConjunctFilterExpr>(input);
RewriteConjunctExpr(expr, context);
}
}
for (auto& [column, indices] : expr_indices) {
// For numeric type, if or column greater than 150, then using in expr replace.
// For other type, all convert to in expr.
if ((IsNumericDataType(column.data_type_) &&
indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) ||
(!IsNumericDataType(column.data_type_) && indices.size() > 1)) {
auto new_expr =
ConvertMultiOrToInExpr(inputs, indices, context);
if (new_expr) {
inputs[indices[0]] = new_expr;
for (size_t j = 1; j < indices.size(); j++) {
inputs[indices[j]] = nullptr;
}
}
}
}
inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr),
inputs.end());
}
// convert A != .. and A != .. and A != .. to not A in (.., .., ..)
if (expr->IsAnd()) {
auto& inputs = expr->GetInputsRef();
std::map<expr::ColumnInfo, std::vector<size_t>> expr_indices;
for (size_t i = 0; i < inputs.size(); i++) {
auto input = inputs[i];
if (input->name() == "PhyUnaryRangeFilterExpr") {
auto phy_expr =
std::static_pointer_cast<PhyUnaryRangeFilterExpr>(input);
if (phy_expr->GetOpType() == proto::plan::OpType::NotEqual) {
auto column = phy_expr->GetColumnInfo().value();
if (expr_indices.find(column) != expr_indices.end()) {
expr_indices[column].push_back(i);
} else {
expr_indices[column] = {i};
}
}
if (input->name() == "PhyConjunctFilterExpr") {
auto expr = std::static_pointer_cast<
milvus::exec::PhyConjunctFilterExpr>(input);
RewriteConjunctExpr(expr, context);
}
}
}
for (auto& [column, indices] : expr_indices) {
if ((IsNumericDataType(column.data_type_) &&
indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) ||
(!IsNumericDataType(column.data_type_) && indices.size() > 1)) {
auto new_expr =
ConvertMultiNotEqualToNotInExpr(inputs, indices, context);
if (new_expr) {
inputs[indices[0]] = new_expr;
for (size_t j = 1; j < indices.size(); j++) {
inputs[indices[j]] = nullptr;
}
}
}
}
inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr),
inputs.end());
}
}
inline void
SetNamespaceSkipIndex(std::shared_ptr<PhyConjunctFilterExpr> conjunct_expr,
ExecContext* context) {
@ -727,7 +506,6 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector<ExprPtr>& exprs) {
LOG_DEBUG("before reoder filter expression: {}", expr->ToString());
auto conjunct_expr =
std::static_pointer_cast<PhyConjunctFilterExpr>(expr);
RewriteConjunctExpr(conjunct_expr, context);
bool has_heavy_operation = false;
ReorderConjunctExpr(conjunct_expr, context, has_heavy_operation);
LOG_DEBUG("after reorder filter expression: {}", expr->ToString());

View File

@ -256,183 +256,6 @@ TEST_P(TaskTest, LogicalExpr) {
EXPECT_EQ(num_rows, num_rows_);
}
TEST_P(TaskTest, CompileInputs_and) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec_fid =
schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
proto::plan::GenericValue val;
val.set_int64_val(10);
// expr: (int64_fid < 10 and int64_fid < 10) and (int64_fid < 10 and int64_fid < 10)
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr3, expr6);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
auto exprs = milvus::exec::CompileInputs(expr7, query_context.get(), {});
EXPECT_EQ(exprs.size(), 4);
for (int i = 0; i < exprs.size(); ++i) {
std::cout << exprs[i]->name() << std::endl;
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr");
}
}
TEST_P(TaskTest, CompileInputs_or_with_and) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec_fid =
schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
proto::plan::GenericValue val;
val.set_int64_val(10);
{
// expr: (int64_fid > 10 and int64_fid > 10) or (int64_fid > 10 and int64_fid > 10)
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr3, expr6);
auto exprs =
milvus::exec::CompileInputs(expr7, query_context.get(), {});
EXPECT_EQ(exprs.size(), 2);
for (int i = 0; i < exprs.size(); ++i) {
std::cout << exprs[i]->name() << std::endl;
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyConjunctFilterExpr");
}
}
{
// expr: (int64_fid < 10 or int64_fid < 10) or (int64_fid > 10 and int64_fid > 10)
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr3, expr6);
auto exprs =
milvus::exec::CompileInputs(expr7, query_context.get(), {});
std::cout << exprs.size() << std::endl;
EXPECT_EQ(exprs.size(), 3);
for (int i = 0; i < exprs.size() - 1; ++i) {
std::cout << exprs[i]->name() << std::endl;
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr");
}
EXPECT_STREQ(exprs[2]->name().c_str(), "PhyConjunctFilterExpr");
}
{
// expr: (int64_fid > 10 or int64_fid > 10) and (int64_fid > 10 and int64_fid > 10)
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(int64_fid, DataType::INT64),
proto::plan::OpType::GreaterThan,
val,
std::vector<proto::plan::GenericValue>{});
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr3, expr6);
auto exprs =
milvus::exec::CompileInputs(expr7, query_context.get(), {});
std::cout << exprs.size() << std::endl;
EXPECT_EQ(exprs.size(), 3);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
for (int i = 1; i < exprs.size(); ++i) {
std::cout << exprs[i]->name() << std::endl;
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr");
}
}
}
TEST_P(TaskTest, Test_reorder) {
using namespace milvus;
using namespace milvus::query;
@ -698,530 +521,6 @@ TEST_P(TaskTest, Test_reorder) {
}
}
TEST_P(TaskTest, Test_MultiNotEqualConvert) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
using namespace milvus::exec;
{
// expr: string2 != "111" and string2 != "222" and string2 != "333"
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::NotEqual,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::NotEqual,
val2,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val3;
val3.set_string_val("333");
auto expr3 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::NotEqual,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr4 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr4, expr3);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 1);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyLogicalUnaryExpr");
EXPECT_EQ(inputs[0]->GetInputsRef().size(), 1);
EXPECT_STREQ(inputs[0]->GetInputsRef()[0]->name().c_str(),
"PhyTermFilterExpr");
}
{
// expr: int64 != 111 and int64 != 222 and int64 != 333
proto::plan::GenericValue val1;
val1.set_int64_val(111);
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::NotEqual,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_int64_val(222);
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::NotEqual,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_int64_val(333);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::NotEqual,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr3, expr4);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 3);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyUnaryRangeFilterExpr");
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr");
EXPECT_STREQ(inputs[2]->name().c_str(), "PhyUnaryRangeFilterExpr");
}
{
// expr: string2 != "111" and string2 != "222" and (int64 > 10 && int64 < 100) and string2 != "333"
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::NotEqual,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::NotEqual,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_int64_val(10);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::GreaterThan,
val3,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val4;
val4.set_int64_val(100);
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::LessThan,
val4,
std::vector<proto::plan::GenericValue>{});
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr4, expr5);
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr6, expr3);
proto::plan::GenericValue val5;
val5.set_string_val("333");
auto expr8 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::NotEqual,
val5,
std::vector<proto::plan::GenericValue>{});
auto expr9 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr7, expr8);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr9}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 2);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr");
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyLogicalUnaryExpr");
auto phy_expr1 =
std::static_pointer_cast<milvus::exec::PhyLogicalUnaryExpr>(
inputs[1]);
EXPECT_EQ(phy_expr1->GetInputsRef().size(), 1);
EXPECT_STREQ(phy_expr1->GetInputsRef()[0]->name().c_str(),
"PhyTermFilterExpr");
phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
inputs[0]);
inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 2);
}
{
// expr: json['a'] != "111" and json['a'] != "222" and json['a'] != "333"
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::NotEqual,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::NotEqual,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_string_val("333");
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::NotEqual,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr3, expr4);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 1);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyLogicalUnaryExpr");
auto phy_expr1 =
std::static_pointer_cast<milvus::exec::PhyLogicalUnaryExpr>(
inputs[0]);
EXPECT_EQ(phy_expr1->GetInputsRef().size(), 1);
EXPECT_STREQ(phy_expr1->GetInputsRef()[0]->name().c_str(),
"PhyTermFilterExpr");
}
}
TEST_P(TaskTest, Test_MultiInConvert) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
using namespace milvus::exec;
{
// expr: string2 == '111' or string2 == '222' or string2 == "333"
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::Equal,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::Equal,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_string_val("333");
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::Equal,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 1);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr");
}
{
// expr: string2 == '111' or string2 == '222' or (int64 > 10 && int64 < 100) or string2 == "333"
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::Equal,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::Equal,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_int64_val(10);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::GreaterThan,
val3,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val4;
val4.set_int64_val(100);
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::LessThan,
val4,
std::vector<proto::plan::GenericValue>{});
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::And, expr4, expr5);
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr6, expr3);
proto::plan::GenericValue val5;
val5.set_string_val("333");
auto expr8 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
proto::plan::OpType::Equal,
val5,
std::vector<proto::plan::GenericValue>{});
auto expr9 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr7, expr8);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr9}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 2);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr");
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyTermFilterExpr");
}
{
// expr: json['a'] == "111" or json['a'] == "222" or json['3'] = "333"
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::Equal,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::Equal,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_string_val("333");
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::Equal,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 1);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr");
}
{
// expr: json['a'] == "111" or json['b'] == "222" or json['a'] == "333"
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::Equal,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'b'}),
proto::plan::OpType::Equal,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_string_val("333");
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::Equal,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 2);
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr");
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr");
}
{
// expr: json['a'] == "111" or json['b'] == "222" or json['a'] == 1
proto::plan::GenericValue val1;
val1.set_string_val("111");
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::Equal,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_string_val("222");
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'b'}),
proto::plan::OpType::Equal,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_int64_val(1);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["json"],
DataType::JSON,
std::vector<std::string>{'a'}),
proto::plan::OpType::Equal,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 3);
}
{
// expr: int1 == 11 or int1 == 22 or int3 == 33
proto::plan::GenericValue val1;
val1.set_int64_val(11);
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::Equal,
val1,
std::vector<proto::plan::GenericValue>{});
proto::plan::GenericValue val2;
val2.set_int64_val(222);
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::Equal,
val2,
std::vector<proto::plan::GenericValue>{});
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
proto::plan::GenericValue val3;
val3.set_int64_val(1);
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
proto::plan::OpType::Equal,
val3,
std::vector<proto::plan::GenericValue>{});
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
ExecContext context(query_context.get());
auto exprs =
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
EXPECT_EQ(exprs.size(), 1);
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
auto phy_expr =
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
exprs[0]);
auto inputs = phy_expr->GetInputsRef();
EXPECT_EQ(inputs.size(), 3);
}
}
// This test verifies the fix for https://github.com/milvus-io/milvus/issues/46053.
//
// Bug scenario:

View File

@ -15,6 +15,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated"
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
"github.com/milvus-io/milvus/internal/util/function/rerank"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -149,6 +150,7 @@ func parseExprInner(schema *typeutil.SchemaHelper, exprStr string, exprTemplateV
return nil, err
}
predicate.expr = rewriter.RewriteExpr(predicate.expr)
return predicate.expr, nil
}

View File

@ -0,0 +1,138 @@
## Expression Rewriter (planparserv2/rewriter)
This module performs rule-based logical rewrites on parsed `planpb.Expr` trees right after template value filling and before planning/execution.
### Entry
- `RewriteExpr(*planpb.Expr) *planpb.Expr` (in `entry.go`)
- Recursively visits the expression tree and applies a set of composable, side-effect-free rewrite rules.
- Uses global configuration from `paramtable.Get().CommonCfg.EnabledOptimizeExpr`
- `RewriteExprWithConfig(*planpb.Expr, bool) *planpb.Expr` (in `entry.go`)
- Same as `RewriteExpr` but allows custom configuration for testing or special cases.
### Configuration
The rewriter can be configured via the following parameter (refreshable at runtime):
| Parameter | Default | Description |
|-----------|---------|-------------|
| `common.enabledOptimizeExpr` | `true` | Enable query expression optimization including range simplification, IN/NOT IN merge, TEXT_MATCH merge, and all other optimizations |
**IMPORTANT**: IN/NOT IN value list sorting and deduplication **always** runs regardless of this configuration setting, because the execution engine depends on sorted value lists.
### Implemented Rules
1) IN / NOT IN normalization and merges (`term_in.go`)
- OR-equals to IN (same column):
- `a == v1 OR a == v2 ...``a IN (v1, v2, ...)`
- Numeric columns only merge when count > threshold (default 150); others when count > 1.
- AND-not-equals to NOT IN (same column):
- `a != v1 AND a != v2 ...``NOT (a IN (v1, v2, ...))`
- Same thresholds as above.
- IN vs Equal redundancy elimination (same column):
- AND: `(a ∈ S) AND (a = v)`:
- if `v ∈ S``a = v`
- if `v ∉ S` → contradiction → constant `false`
- OR: `(a ∈ S) OR (a = v)``a ∈ (S {v})` (always union)
- IN with IN union:
- OR: `(a ∈ S1) OR (a ∈ S2)``a ∈ (S1 S2)` with sorting/dedup
- AND: `(a ∈ S1) AND (a ∈ S2)``a ∈ (S1 ∩ S2)`; empty intersection → constant `false`
- Sort and deduplicate `IN` / `NOT IN` value lists (supported types: bool, int64, float64, string).
2) TEXT_MATCH OR merge (`text_match.go`)
- Merge ORs of `TEXT_MATCH(field, "literal")` on the same column (no options):
- Concatenate literals with a single space in the order they appear; no tokenization, deduplication, or sorting is performed.
- Example: `TEXT_MATCH(f, "A C") OR TEXT_MATCH(f, "B D")``TEXT_MATCH(f, "A C B D")`
- If any `TEXT_MATCH` in the group has options (e.g., `minimum_should_match`), this optimization is skipped for that group.
3) Range predicate simplification (`range.go`)
- AND tighten (same column):
- Lower bounds: `a > 10 AND a > 20``a > 20` (pick strongest lower)
- Upper bounds: `a < 50 AND a < 60``a < 50` (pick strongest upper)
- Mixed lower and upper: `a > 10 AND a < 50``10 < a < 50` (BinaryRangeExpr)
- Inclusion respected (>, >=, <, <=). On ties, exclusive is considered stronger than inclusive for tightening.
- OR weaken (same column, same direction):
- Lower bounds: `a > 10 OR a > 20``a > 10` (pick weakest lower)
- Upper bounds: `a < 10 OR a < 20``a < 20` (pick weakest upper)
- Inclusion respected, preferring inclusive for weakening in ties.
- Mixed-direction OR (lower vs upper) is not merged.
- Equivalent-bound collapses (same column, same value):
- AND: `a ≥ x AND a > x``a > x`; `a ≤ y AND a < y``a < y`
- OR: `a ≥ x OR a > x``a ≥ x`; `a ≤ y OR a < y``a ≤ y`
- Symmetric dedup: `a > 10 AND a ≥ 10``a > 10`; `a < 5 OR a ≤ 5``a ≤ 5`
- IN ∩ range filtering:
- AND: `(a ∈ {…}) AND (range)` → keep only values in the set that satisfy the range
- e.g., `{1,3,5} AND a > 3``{5}`
- Supported columns for range optimization:
- Scalar: Int8/Int16/Int32/Int64, Float/Double, VarChar
- Array element access: when indexing an element (e.g., `ArrayInt[0]`), the element type above applies
- JSON/dynamic fields with nested paths (e.g., `JSONField["price"]`, `$meta["age"]`) are range-optimized
- Type determined from literal value (int, float, string)
- Numeric types (int and float) are compatible and normalized to Double for merging
- Different type categories are not merged (e.g., `json["a"] > 10` and `json["a"] > "hello"` remain separate)
- Bool literals are not optimized (no meaningful ranges)
- Literal compatibility:
- Integer columns require integer literals (e.g., `Int64Field > 10`)
- Float/Double columns accept both integer and float literals (e.g., `FloatField > 10` or `> 10.5`)
- Column identity:
- Merges only happen within the same `ColumnInfo` (including nested path and element index). For example, `ArrayInt[0]` and `ArrayInt[1]` are different columns and are not merged with each other.
- BinaryRangeExpr merging:
- AND: Merge multiple `BinaryRangeExpr` nodes on the same column to compute intersection (max lower, min upper)
- `(10 < x < 50) AND (20 < x < 40)``(20 < x < 40)`
- Empty intersection → constant `false`
- AND with UnaryRangeExpr: Update appropriate bound of `BinaryRangeExpr`
- `(10 < x < 50) AND (x > 30)``(30 < x < 50)`
- OR: Merge overlapping or adjacent `BinaryRangeExpr` nodes into wider interval
- `(10 < x < 25) OR (20 < x < 40)``(10 < x < 40)` (overlapping)
- `(10 < x <= 20) OR (20 <= x < 30)``(10 < x < 30)` (adjacent with inclusive)
- Disjoint intervals remain separate: `(10 < x < 20) OR (30 < x < 40)` → remains as OR
- Inclusivity handling: AND prefers exclusive on equal bounds (stronger), OR prefers inclusive (weaker)
### General Notes
- All merges require operands to target the same column (same `ColumnInfo`, including nested path/element type).
- Rewrite runs after template value filling; template placeholders do not appear here.
- Sorting/dedup for IN/NOT IN is deterministic; duplicates are removed post-sort.
- Numeric-threshold for OR→IN / AND≠→NOT IN is defined in `util.go` (`defaultConvertOrToInNumericLimit`, default 150).
### Pass Ordering (current)
- OR branch:
1. Flatten
2. OR `==` → IN
3. TEXT_MATCH merge (no options)
4. Range weaken (same-direction bounds)
5. BinaryRangeExpr merge (overlapping/adjacent intervals)
6. IN with `!=` short-circuiting
7. IN IN union
8. IN vs Equal redundancy elimination
9. Fold back to BinaryExpr
- AND branch:
1. Flatten
2. Range tighten / interval construction
3. BinaryRangeExpr merge (intersection, also with UnaryRangeExpr)
4. IN IN intersection (if any)
5. IN with `!=` filtering
6. IN ∩ range filtering
7. IN vs Equal redundancy elimination
8. AND `!=` → NOT IN
9. Fold back to BinaryExpr
Each construction of IN will be normalized (sorted and deduplicated). TEXT_MATCH OR merge concatenates literals with a single space; no tokenization, deduplication, or sorting is performed.
### File Structure
- `entry.go` — rewrite entry and visitor orchestration
- `util.go` — shared helpers (column keying, value classification, sorting/dedup, constructors)
- `term_in.go` — IN/NOT IN normalization and conversions
- `text_match.go` — TEXT_MATCH OR merge (no options)
- `range.go` — range tightening/weakening and interval construction
### Future Extensions
- More IN-range algebra (e.g., `IN` vs exact equality propagation across subtrees).
- Merging phrase_match or other string ops with clearly-defined token rules.
- More algebraic simplifications around equality and null checks:
- Contradiction detection: `(a == 1) AND (a == 2)``false`; `(a > 10) AND (a == 5)``false`
- Tautology detection: `(a > 10) OR (a <= 10)``true` (for non-NULL values)
- Absorption laws: `(a > 10) OR ((a > 10) AND (b > 20))``a > 10`
- Advanced BinaryRangeExpr merging:
- OR with 3+ intervals: Currently limited to 2 intervals. Full interval merging algorithm needed for `(10 < x < 20) OR (15 < x < 25) OR (22 < x < 30)``(10 < x < 30)`.
- OR with unbounded + bounded: Currently skipped. Could optimize `(x > 10) OR (5 < x < 15)``x > 5`.

View File

@ -0,0 +1,200 @@
package rewriter
import (
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func RewriteExpr(e *planpb.Expr) *planpb.Expr {
optimizeEnabled := paramtable.Get().CommonCfg.EnabledOptimizeExpr.GetAsBool()
return RewriteExprWithConfig(e, optimizeEnabled)
}
func RewriteExprWithConfig(e *planpb.Expr, optimizeEnabled bool) *planpb.Expr {
if e == nil {
return nil
}
v := &visitor{optimizeEnabled: optimizeEnabled}
res := v.visitExpr(e)
if out, ok := res.(*planpb.Expr); ok && out != nil {
return out
}
return e
}
type visitor struct {
optimizeEnabled bool
}
func (v *visitor) visitExpr(expr *planpb.Expr) interface{} {
switch real := expr.GetExpr().(type) {
case *planpb.Expr_BinaryExpr:
return v.visitBinaryExpr(real.BinaryExpr)
case *planpb.Expr_UnaryExpr:
return v.visitUnaryExpr(real.UnaryExpr)
case *planpb.Expr_TermExpr:
return v.visitTermExpr(real.TermExpr)
// no optimization for other types
default:
return expr
}
}
func (v *visitor) visitBinaryExpr(expr *planpb.BinaryExpr) interface{} {
left := v.visitExpr(expr.GetLeft()).(*planpb.Expr)
right := v.visitExpr(expr.GetRight()).(*planpb.Expr)
switch expr.GetOp() {
case planpb.BinaryExpr_LogicalOr:
parts := flattenOr(left, right)
if v.optimizeEnabled {
parts = v.combineOrEqualsToIn(parts)
parts = v.combineOrTextMatchToMerged(parts)
parts = v.combineOrRangePredicates(parts)
parts = v.combineOrBinaryRanges(parts)
parts = v.combineOrInWithNotEqual(parts)
parts = v.combineOrInWithIn(parts)
parts = v.combineOrInWithEqual(parts)
}
return foldBinary(planpb.BinaryExpr_LogicalOr, parts)
case planpb.BinaryExpr_LogicalAnd:
parts := flattenAnd(left, right)
if v.optimizeEnabled {
parts = v.combineAndRangePredicates(parts)
parts = v.combineAndBinaryRanges(parts)
parts = v.combineAndInWithIn(parts)
parts = v.combineAndInWithNotEqual(parts)
parts = v.combineAndInWithRange(parts)
parts = v.combineAndInWithEqual(parts)
parts = v.combineAndNotEqualsToNotIn(parts)
}
return foldBinary(planpb.BinaryExpr_LogicalAnd, parts)
default:
return &planpb.Expr{
Expr: &planpb.Expr_BinaryExpr{
BinaryExpr: &planpb.BinaryExpr{
Left: left,
Right: right,
Op: expr.GetOp(),
},
},
}
}
}
func (v *visitor) visitUnaryExpr(expr *planpb.UnaryExpr) interface{} {
child := v.visitExpr(expr.GetChild()).(*planpb.Expr)
// Optimize double negation: NOT (NOT AlwaysTrue) → AlwaysTrue
if expr.GetOp() == planpb.UnaryExpr_Not {
if IsAlwaysFalseExpr(child) {
return newAlwaysTrueExpr()
}
}
return &planpb.Expr{
Expr: &planpb.Expr_UnaryExpr{
UnaryExpr: &planpb.UnaryExpr{
Op: expr.GetOp(),
Child: child,
},
},
}
}
func (v *visitor) visitTermExpr(expr *planpb.TermExpr) interface{} {
sortTermValues(expr)
return &planpb.Expr{Expr: &planpb.Expr_TermExpr{TermExpr: expr}}
}
func flattenOr(a, b *planpb.Expr) []*planpb.Expr {
out := make([]*planpb.Expr, 0, 4)
collectOr(a, &out)
collectOr(b, &out)
return out
}
func collectOr(e *planpb.Expr, out *[]*planpb.Expr) {
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalOr {
collectOr(be.GetLeft(), out)
collectOr(be.GetRight(), out)
return
}
*out = append(*out, e)
}
func flattenAnd(a, b *planpb.Expr) []*planpb.Expr {
out := make([]*planpb.Expr, 0, 4)
collectAnd(a, &out)
collectAnd(b, &out)
return out
}
func collectAnd(e *planpb.Expr, out *[]*planpb.Expr) {
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd {
collectAnd(be.GetLeft(), out)
collectAnd(be.GetRight(), out)
return
}
*out = append(*out, e)
}
func foldBinary(op planpb.BinaryExpr_BinaryOp, exprs []*planpb.Expr) *planpb.Expr {
if len(exprs) == 0 {
return nil
}
// Handle AlwaysTrue and AlwaysFalse optimizations (single-pass)
switch op {
case planpb.BinaryExpr_LogicalAnd:
filtered := make([]*planpb.Expr, 0, len(exprs))
for _, e := range exprs {
if IsAlwaysFalseExpr(e) {
// AND: any AlwaysFalse → entire expression is AlwaysFalse
return newAlwaysFalseExpr()
}
if !IsAlwaysTrueExpr(e) {
// Filter out AlwaysTrue (since AlwaysTrue AND X = X)
filtered = append(filtered, e)
}
}
exprs = filtered
// If all were AlwaysTrue, return AlwaysTrue
if len(exprs) == 0 {
return newAlwaysTrueExpr()
}
case planpb.BinaryExpr_LogicalOr:
filtered := make([]*planpb.Expr, 0, len(exprs))
for _, e := range exprs {
if IsAlwaysTrueExpr(e) {
// OR: any AlwaysTrue → entire expression is AlwaysTrue
return newAlwaysTrueExpr()
}
if !IsAlwaysFalseExpr(e) {
// Filter out AlwaysFalse (since AlwaysFalse OR X = X)
filtered = append(filtered, e)
}
}
exprs = filtered
// If all were AlwaysFalse, return AlwaysFalse
if len(exprs) == 0 {
return newAlwaysFalseExpr()
}
}
if len(exprs) == 1 {
return exprs[0]
}
cur := exprs[0]
for i := 1; i < len(exprs); i++ {
cur = &planpb.Expr{
Expr: &planpb.Expr_BinaryExpr{
BinaryExpr: &planpb.BinaryExpr{
Left: cur,
Right: exprs[i],
Op: op,
},
},
}
}
return cur
}

View File

@ -0,0 +1,923 @@
package rewriter
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
)
type bound struct {
value *planpb.GenericValue
inclusive bool
isLower bool
exprIndex int
}
func isSupportedScalarForRange(dt schemapb.DataType) bool {
switch dt {
case schemapb.DataType_Int8,
schemapb.DataType_Int16,
schemapb.DataType_Int32,
schemapb.DataType_Int64,
schemapb.DataType_Float,
schemapb.DataType_Double,
schemapb.DataType_VarChar:
return true
default:
return false
}
}
// resolveEffectiveType returns (dt, ok) where ok indicates this column is eligible for range optimization.
// Eligible when the column is a supported scalar, or an array whose element type is a supported scalar,
// or a JSON field with a nested path (type will be determined from literal values).
func resolveEffectiveType(col *planpb.ColumnInfo) (schemapb.DataType, bool) {
if col == nil {
return schemapb.DataType_None, false
}
dt := col.GetDataType()
if isSupportedScalarForRange(dt) {
return dt, true
}
if dt == schemapb.DataType_Array {
et := col.GetElementType()
if isSupportedScalarForRange(et) {
return et, true
}
}
// JSON fields with nested paths are eligible; the effective type will be
// determined from the literal value in the comparison.
if dt == schemapb.DataType_JSON && len(col.GetNestedPath()) > 0 {
// Return a placeholder type; actual type checking happens in resolveJSONEffectiveType
return schemapb.DataType_JSON, true
}
return schemapb.DataType_None, false
}
// resolveJSONEffectiveType returns the effective type for a JSON field based on the literal value.
// Returns (type, ok) where ok is false if the value is not suitable for range optimization.
// For numeric types (int and float), we normalize to Double to allow mixing, similar to scalar Float/Double columns.
func resolveJSONEffectiveType(v *planpb.GenericValue) (schemapb.DataType, bool) {
if v == nil || v.GetVal() == nil {
return schemapb.DataType_None, false
}
switch v.GetVal().(type) {
case *planpb.GenericValue_Int64Val:
// Normalize int to Double for JSON to allow mixing with float literals
return schemapb.DataType_Double, true
case *planpb.GenericValue_FloatVal:
return schemapb.DataType_Double, true
case *planpb.GenericValue_StringVal:
return schemapb.DataType_VarChar, true
case *planpb.GenericValue_BoolVal:
// Boolean comparisons don't have meaningful ranges
return schemapb.DataType_None, false
default:
return schemapb.DataType_None, false
}
}
func valueMatchesType(dt schemapb.DataType, v *planpb.GenericValue) bool {
if v == nil || v.GetVal() == nil {
return false
}
switch dt {
case schemapb.DataType_Int8,
schemapb.DataType_Int16,
schemapb.DataType_Int32,
schemapb.DataType_Int64:
_, ok := v.GetVal().(*planpb.GenericValue_Int64Val)
return ok
case schemapb.DataType_Float, schemapb.DataType_Double:
// For float columns, accept both float and int literal values
switch v.GetVal().(type) {
case *planpb.GenericValue_FloatVal, *planpb.GenericValue_Int64Val:
return true
default:
return false
}
case schemapb.DataType_VarChar:
_, ok := v.GetVal().(*planpb.GenericValue_StringVal)
return ok
case schemapb.DataType_JSON:
// For JSON, check if we can determine a valid type from the value
_, ok := resolveJSONEffectiveType(v)
return ok
default:
return false
}
}
func (v *visitor) combineAndRangePredicates(parts []*planpb.Expr) []*planpb.Expr {
type group struct {
col *planpb.ColumnInfo
effDt schemapb.DataType // effective type for comparison
lowers []bound
uppers []bound
}
groups := map[string]*group{}
// exprs not eligible for range optimization
others := []int{}
isRangeOp := func(op planpb.OpType) bool {
return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual
}
for idx, e := range parts {
u := e.GetUnaryRangeExpr()
if u == nil || !isRangeOp(u.GetOp()) || u.GetValue() == nil {
others = append(others, idx)
continue
}
col := u.GetColumnInfo()
if col == nil {
others = append(others, idx)
continue
}
// Only optimize for supported types and matching value type
effDt, ok := resolveEffectiveType(col)
if !ok || !valueMatchesType(effDt, u.GetValue()) {
others = append(others, idx)
continue
}
// For JSON fields, determine the actual effective type from the literal value
if effDt == schemapb.DataType_JSON {
var typeOk bool
effDt, typeOk = resolveJSONEffectiveType(u.GetValue())
if !typeOk {
others = append(others, idx)
continue
}
}
// Group by column + effective type (for JSON, type depends on literal)
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
g, ok := groups[key]
if !ok {
g = &group{col: col, effDt: effDt}
groups[key] = g
}
b := bound{
value: u.GetValue(),
inclusive: u.GetOp() == planpb.OpType_GreaterEqual || u.GetOp() == planpb.OpType_LessEqual,
isLower: u.GetOp() == planpb.OpType_GreaterThan || u.GetOp() == planpb.OpType_GreaterEqual,
exprIndex: idx,
}
if b.isLower {
g.lowers = append(g.lowers, b)
} else {
g.uppers = append(g.uppers, b)
}
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
// Use the effective type stored in the group
var bestLower *bound
for i := range g.lowers {
if bestLower == nil || cmpGeneric(g.effDt, g.lowers[i].value, bestLower.value) > 0 ||
(cmpGeneric(g.effDt, g.lowers[i].value, bestLower.value) == 0 && !g.lowers[i].inclusive && bestLower.inclusive) {
b := g.lowers[i]
bestLower = &b
}
}
var bestUpper *bound
for i := range g.uppers {
if bestUpper == nil || cmpGeneric(g.effDt, g.uppers[i].value, bestUpper.value) < 0 ||
(cmpGeneric(g.effDt, g.uppers[i].value, bestUpper.value) == 0 && !g.uppers[i].inclusive && bestUpper.inclusive) {
b := g.uppers[i]
bestUpper = &b
}
}
if bestLower != nil && bestUpper != nil {
// Check if the interval is valid (non-empty)
c := cmpGeneric(g.effDt, bestLower.value, bestUpper.value)
isEmpty := false
if c > 0 {
// lower > upper: always empty
isEmpty = true
} else if c == 0 {
// lower == upper: only valid if both bounds are inclusive
if !bestLower.inclusive || !bestUpper.inclusive {
isEmpty = true
}
}
for _, b := range g.lowers {
used[b.exprIndex] = true
}
for _, b := range g.uppers {
used[b.exprIndex] = true
}
if isEmpty {
// Empty interval → constant false
out = append(out, newAlwaysFalseExpr())
} else {
out = append(out, newBinaryRangeExpr(g.col, bestLower.inclusive, bestUpper.inclusive, bestLower.value, bestUpper.value))
}
} else if bestLower != nil {
for _, b := range g.lowers {
used[b.exprIndex] = true
}
op := planpb.OpType_GreaterThan
if bestLower.inclusive {
op = planpb.OpType_GreaterEqual
}
out = append(out, newUnaryRangeExpr(g.col, op, bestLower.value))
} else if bestUpper != nil {
for _, b := range g.uppers {
used[b.exprIndex] = true
}
op := planpb.OpType_LessThan
if bestUpper.inclusive {
op = planpb.OpType_LessEqual
}
out = append(out, newUnaryRangeExpr(g.col, op, bestUpper.value))
}
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
func (v *visitor) combineOrRangePredicates(parts []*planpb.Expr) []*planpb.Expr {
type key struct {
colKey string
isLower bool
effDt schemapb.DataType // effective type for JSON fields
}
type group struct {
col *planpb.ColumnInfo
effDt schemapb.DataType
dirLower bool
bounds []bound
}
groups := map[key]*group{}
others := []int{}
isRangeOp := func(op planpb.OpType) bool {
return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual
}
for idx, e := range parts {
u := e.GetUnaryRangeExpr()
if u == nil || !isRangeOp(u.GetOp()) || u.GetValue() == nil {
others = append(others, idx)
continue
}
col := u.GetColumnInfo()
if col == nil {
others = append(others, idx)
continue
}
effDt, ok := resolveEffectiveType(col)
if !ok || !valueMatchesType(effDt, u.GetValue()) {
others = append(others, idx)
continue
}
// For JSON fields, determine the actual effective type from the literal value
if effDt == schemapb.DataType_JSON {
var typeOk bool
effDt, typeOk = resolveJSONEffectiveType(u.GetValue())
if !typeOk {
others = append(others, idx)
continue
}
}
isLower := u.GetOp() == planpb.OpType_GreaterThan || u.GetOp() == planpb.OpType_GreaterEqual
k := key{colKey: columnKey(col), isLower: isLower, effDt: effDt}
g, ok := groups[k]
if !ok {
g = &group{col: col, effDt: effDt, dirLower: isLower}
groups[k] = g
}
g.bounds = append(g.bounds, bound{
value: u.GetValue(),
inclusive: u.GetOp() == planpb.OpType_GreaterEqual || u.GetOp() == planpb.OpType_LessEqual,
isLower: isLower,
exprIndex: idx,
})
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if len(g.bounds) <= 1 {
continue
}
if g.dirLower {
var best *bound
for i := range g.bounds {
// Use the effective type stored in the group
if best == nil || cmpGeneric(g.effDt, g.bounds[i].value, best.value) < 0 ||
(cmpGeneric(g.effDt, g.bounds[i].value, best.value) == 0 && g.bounds[i].inclusive && !best.inclusive) {
b := g.bounds[i]
best = &b
}
}
for _, b := range g.bounds {
used[b.exprIndex] = true
}
op := planpb.OpType_GreaterThan
if best.inclusive {
op = planpb.OpType_GreaterEqual
}
out = append(out, newUnaryRangeExpr(g.col, op, best.value))
} else {
var best *bound
for i := range g.bounds {
// Use the effective type stored in the group
if best == nil || cmpGeneric(g.effDt, g.bounds[i].value, best.value) > 0 ||
(cmpGeneric(g.effDt, g.bounds[i].value, best.value) == 0 && g.bounds[i].inclusive && !best.inclusive) {
b := g.bounds[i]
best = &b
}
}
for _, b := range g.bounds {
used[b.exprIndex] = true
}
op := planpb.OpType_LessThan
if best.inclusive {
op = planpb.OpType_LessEqual
}
out = append(out, newUnaryRangeExpr(g.col, op, best.value))
}
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
func newBinaryRangeExpr(col *planpb.ColumnInfo, lowerInclusive bool, upperInclusive bool, lower *planpb.GenericValue, upper *planpb.GenericValue) *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_BinaryRangeExpr{
BinaryRangeExpr: &planpb.BinaryRangeExpr{
ColumnInfo: col,
LowerInclusive: lowerInclusive,
UpperInclusive: upperInclusive,
LowerValue: lower,
UpperValue: upper,
},
},
}
}
// -1 means a < b, 0 means a == b, 1 means a > b
func cmpGeneric(dt schemapb.DataType, a, b *planpb.GenericValue) int {
switch dt {
case schemapb.DataType_Int8,
schemapb.DataType_Int16,
schemapb.DataType_Int32,
schemapb.DataType_Int64:
ai, bi := a.GetInt64Val(), b.GetInt64Val()
if ai < bi {
return -1
}
if ai > bi {
return 1
}
return 0
case schemapb.DataType_Float, schemapb.DataType_Double:
// Allow comparing int and float by promoting to float
toFloat := func(g *planpb.GenericValue) float64 {
switch g.GetVal().(type) {
case *planpb.GenericValue_FloatVal:
return g.GetFloatVal()
case *planpb.GenericValue_Int64Val:
return float64(g.GetInt64Val())
default:
// Should not happen due to gate; treat as 0 deterministically
return 0
}
}
af, bf := toFloat(a), toFloat(b)
if af < bf {
return -1
}
if af > bf {
return 1
}
return 0
case schemapb.DataType_String,
schemapb.DataType_VarChar:
as, bs := a.GetStringVal(), b.GetStringVal()
if as < bs {
return -1
}
if as > bs {
return 1
}
return 0
default:
// Unsupported types are not optimized; callers gate with resolveEffectiveType.
return 0
}
}
// combineAndBinaryRanges merges BinaryRangeExpr nodes with AND semantics (intersection).
// Also handles mixing BinaryRangeExpr with UnaryRangeExpr.
func (v *visitor) combineAndBinaryRanges(parts []*planpb.Expr) []*planpb.Expr {
type interval struct {
lower *planpb.GenericValue
lowerInc bool
upper *planpb.GenericValue
upperInc bool
exprIndex int
isBinaryRange bool
}
type group struct {
col *planpb.ColumnInfo
effDt schemapb.DataType
intervals []interval
}
groups := map[string]*group{}
others := []int{}
for idx, e := range parts {
// Try BinaryRangeExpr
if bre := e.GetBinaryRangeExpr(); bre != nil {
col := bre.GetColumnInfo()
if col == nil {
others = append(others, idx)
continue
}
effDt, ok := resolveEffectiveType(col)
if !ok {
others = append(others, idx)
continue
}
// For JSON, determine actual type from lower value
if effDt == schemapb.DataType_JSON {
var typeOk bool
effDt, typeOk = resolveJSONEffectiveType(bre.GetLowerValue())
if !typeOk {
others = append(others, idx)
continue
}
}
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
g, exists := groups[key]
if !exists {
g = &group{col: col, effDt: effDt}
groups[key] = g
}
g.intervals = append(g.intervals, interval{
lower: bre.GetLowerValue(),
lowerInc: bre.GetLowerInclusive(),
upper: bre.GetUpperValue(),
upperInc: bre.GetUpperInclusive(),
exprIndex: idx,
isBinaryRange: true,
})
continue
}
// Try UnaryRangeExpr (range ops only)
if ure := e.GetUnaryRangeExpr(); ure != nil {
op := ure.GetOp()
if op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual {
col := ure.GetColumnInfo()
if col == nil {
others = append(others, idx)
continue
}
effDt, ok := resolveEffectiveType(col)
if !ok || !valueMatchesType(effDt, ure.GetValue()) {
others = append(others, idx)
continue
}
if effDt == schemapb.DataType_JSON {
var typeOk bool
effDt, typeOk = resolveJSONEffectiveType(ure.GetValue())
if !typeOk {
others = append(others, idx)
continue
}
}
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
g, exists := groups[key]
if !exists {
g = &group{col: col, effDt: effDt}
groups[key] = g
}
isLower := op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual
inc := op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessEqual
if isLower {
g.intervals = append(g.intervals, interval{
lower: ure.GetValue(),
lowerInc: inc,
upper: nil,
upperInc: false,
exprIndex: idx,
isBinaryRange: false,
})
} else {
g.intervals = append(g.intervals, interval{
lower: nil,
lowerInc: false,
upper: ure.GetValue(),
upperInc: inc,
exprIndex: idx,
isBinaryRange: false,
})
}
continue
}
}
// Not a range expr we can optimize
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if len(g.intervals) == 0 {
continue
}
if len(g.intervals) == 1 {
// Single interval, keep as is
continue
}
// Compute intersection: max lower, min upper
var finalLower *planpb.GenericValue
var finalLowerInc bool
var finalUpper *planpb.GenericValue
var finalUpperInc bool
for _, iv := range g.intervals {
if iv.lower != nil {
if finalLower == nil {
finalLower = iv.lower
finalLowerInc = iv.lowerInc
} else {
c := cmpGeneric(g.effDt, iv.lower, finalLower)
if c > 0 || (c == 0 && !iv.lowerInc) {
finalLower = iv.lower
finalLowerInc = iv.lowerInc
}
}
}
if iv.upper != nil {
if finalUpper == nil {
finalUpper = iv.upper
finalUpperInc = iv.upperInc
} else {
c := cmpGeneric(g.effDt, iv.upper, finalUpper)
if c < 0 || (c == 0 && !iv.upperInc) {
finalUpper = iv.upper
finalUpperInc = iv.upperInc
}
}
}
}
// Check if intersection is empty
if finalLower != nil && finalUpper != nil {
c := cmpGeneric(g.effDt, finalLower, finalUpper)
isEmpty := false
if c > 0 {
isEmpty = true
} else if c == 0 {
// Equal bounds: only valid if both inclusive
if !finalLowerInc || !finalUpperInc {
isEmpty = true
}
}
if isEmpty {
// Empty intersection → constant false
for _, iv := range g.intervals {
used[iv.exprIndex] = true
}
out = append(out, newAlwaysFalseExpr())
continue
}
}
// Mark all intervals as used
for _, iv := range g.intervals {
used[iv.exprIndex] = true
}
// Emit the merged interval
if finalLower != nil && finalUpper != nil {
out = append(out, newBinaryRangeExpr(g.col, finalLowerInc, finalUpperInc, finalLower, finalUpper))
} else if finalLower != nil {
op := planpb.OpType_GreaterThan
if finalLowerInc {
op = planpb.OpType_GreaterEqual
}
out = append(out, newUnaryRangeExpr(g.col, op, finalLower))
} else if finalUpper != nil {
op := planpb.OpType_LessThan
if finalUpperInc {
op = planpb.OpType_LessEqual
}
out = append(out, newUnaryRangeExpr(g.col, op, finalUpper))
}
}
// Add unused parts
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
// combineOrBinaryRanges merges BinaryRangeExpr nodes with OR semantics (union if overlapping/adjacent).
// Also handles mixing BinaryRangeExpr with UnaryRangeExpr.
func (v *visitor) combineOrBinaryRanges(parts []*planpb.Expr) []*planpb.Expr {
type interval struct {
lower *planpb.GenericValue
lowerInc bool
upper *planpb.GenericValue
upperInc bool
exprIndex int
isBinaryRange bool
}
type group struct {
col *planpb.ColumnInfo
effDt schemapb.DataType
intervals []interval
}
groups := map[string]*group{}
others := []int{}
for idx, e := range parts {
// Try BinaryRangeExpr
if bre := e.GetBinaryRangeExpr(); bre != nil {
col := bre.GetColumnInfo()
if col == nil {
others = append(others, idx)
continue
}
effDt, ok := resolveEffectiveType(col)
if !ok {
others = append(others, idx)
continue
}
if effDt == schemapb.DataType_JSON {
var typeOk bool
effDt, typeOk = resolveJSONEffectiveType(bre.GetLowerValue())
if !typeOk {
others = append(others, idx)
continue
}
}
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
g, exists := groups[key]
if !exists {
g = &group{col: col, effDt: effDt}
groups[key] = g
}
g.intervals = append(g.intervals, interval{
lower: bre.GetLowerValue(),
lowerInc: bre.GetLowerInclusive(),
upper: bre.GetUpperValue(),
upperInc: bre.GetUpperInclusive(),
exprIndex: idx,
isBinaryRange: true,
})
continue
}
// Try UnaryRangeExpr
if ure := e.GetUnaryRangeExpr(); ure != nil {
op := ure.GetOp()
if op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual {
col := ure.GetColumnInfo()
if col == nil {
others = append(others, idx)
continue
}
effDt, ok := resolveEffectiveType(col)
if !ok || !valueMatchesType(effDt, ure.GetValue()) {
others = append(others, idx)
continue
}
if effDt == schemapb.DataType_JSON {
var typeOk bool
effDt, typeOk = resolveJSONEffectiveType(ure.GetValue())
if !typeOk {
others = append(others, idx)
continue
}
}
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
g, exists := groups[key]
if !exists {
g = &group{col: col, effDt: effDt}
groups[key] = g
}
isLower := op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual
inc := op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessEqual
if isLower {
g.intervals = append(g.intervals, interval{
lower: ure.GetValue(),
lowerInc: inc,
upper: nil,
upperInc: false,
exprIndex: idx,
isBinaryRange: false,
})
} else {
g.intervals = append(g.intervals, interval{
lower: nil,
lowerInc: false,
upper: ure.GetValue(),
upperInc: inc,
exprIndex: idx,
isBinaryRange: false,
})
}
continue
}
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if len(g.intervals) == 0 {
continue
}
if len(g.intervals) == 1 {
// Single interval, keep as is
continue
}
// For OR, try to merge overlapping/adjacent intervals
// If any interval is unbounded on one side, check if it subsumes others
// For simplicity, we'll handle the common cases:
// 1. All bounded intervals: try to merge if overlapping/adjacent
// 2. Mix of bounded/unbounded: merge unbounded with compatible bounds
var hasUnboundedLower, hasUnboundedUpper bool
var unboundedLowerVal *planpb.GenericValue
var unboundedLowerInc bool
var unboundedUpperVal *planpb.GenericValue
var unboundedUpperInc bool
// Check for unbounded intervals
for _, iv := range g.intervals {
if iv.lower != nil && iv.upper == nil {
// Lower bound only (x > a)
if !hasUnboundedLower {
hasUnboundedLower = true
unboundedLowerVal = iv.lower
unboundedLowerInc = iv.lowerInc
} else {
// Multiple lower-only bounds: take weakest (minimum)
c := cmpGeneric(g.effDt, iv.lower, unboundedLowerVal)
if c < 0 || (c == 0 && iv.lowerInc && !unboundedLowerInc) {
unboundedLowerVal = iv.lower
unboundedLowerInc = iv.lowerInc
}
}
}
if iv.lower == nil && iv.upper != nil {
// Upper bound only (x < b)
if !hasUnboundedUpper {
hasUnboundedUpper = true
unboundedUpperVal = iv.upper
unboundedUpperInc = iv.upperInc
} else {
// Multiple upper-only bounds: take weakest (maximum)
c := cmpGeneric(g.effDt, iv.upper, unboundedUpperVal)
if c > 0 || (c == 0 && iv.upperInc && !unboundedUpperInc) {
unboundedUpperVal = iv.upper
unboundedUpperInc = iv.upperInc
}
}
}
}
// Case 1: Have both unbounded lower and upper → entire domain (always true, but we can't express that simply)
// For now, keep them separate
// Case 2: Have one unbounded side → merge with compatible bounded intervals
// Case 3: All bounded → try to merge overlapping/adjacent
if hasUnboundedLower && hasUnboundedUpper {
// Both unbounded sides - this likely covers most values
// Keep as separate predicates for now (more advanced merging could be done)
continue
}
if hasUnboundedLower || hasUnboundedUpper {
// Merge unbounded with bounded intervals where applicable
// For unbounded lower (x > a): can merge with binary ranges that have compatible upper bounds
// For unbounded upper (x < b): can merge with binary ranges that have compatible lower bounds
// This is complex, so for now we'll keep it simple and just skip merging
// In practice, unbounded intervals often dominate
continue
}
// All bounded intervals: try to merge overlapping/adjacent ones
// This requires sorting and checking overlap
// For simplicity in this initial implementation, we'll check if there are exactly 2 intervals
// and try to merge them if they overlap or are adjacent
if len(g.intervals) == 2 {
iv1, iv2 := g.intervals[0], g.intervals[1]
if iv1.lower == nil || iv1.upper == nil || iv2.lower == nil || iv2.upper == nil {
// One is not fully bounded, skip
continue
}
// Check if they overlap or are adjacent
// They overlap if: iv1.lower <= iv2.upper AND iv2.lower <= iv1.upper
// They are adjacent if: iv1.upper == iv2.lower (or vice versa) with at least one inclusive
// Determine order: which has smaller lower bound
var first, second interval
c := cmpGeneric(g.effDt, iv1.lower, iv2.lower)
if c <= 0 {
first, second = iv1, iv2
} else {
first, second = iv2, iv1
}
// Check if they can be merged
// Overlap: first.upper >= second.lower
cmpUpperLower := cmpGeneric(g.effDt, first.upper, second.lower)
canMerge := false
if cmpUpperLower > 0 {
// Overlap
canMerge = true
} else if cmpUpperLower == 0 {
// Adjacent: at least one bound must be inclusive
if first.upperInc || second.lowerInc {
canMerge = true
}
}
if canMerge {
// Merge: take min lower and max upper
mergedLower := first.lower
mergedLowerInc := first.lowerInc
mergedUpper := second.upper
mergedUpperInc := second.upperInc
// Upper bound: take maximum
cmpUppers := cmpGeneric(g.effDt, first.upper, second.upper)
if cmpUppers > 0 {
mergedUpper = first.upper
mergedUpperInc = first.upperInc
} else if cmpUppers == 0 {
// Same value: prefer inclusive
if first.upperInc {
mergedUpperInc = true
}
}
// Mark both as used
used[first.exprIndex] = true
used[second.exprIndex] = true
// Emit merged interval
out = append(out, newBinaryRangeExpr(g.col, mergedLowerInc, mergedUpperInc, mergedLower, mergedUpper))
}
}
// For more than 2 intervals, we'd need more sophisticated merging logic
// For now, we'll leave them separate
}
// Add unused parts
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}

View File

@ -0,0 +1,412 @@
package rewriter_test
import (
"testing"
"github.com/stretchr/testify/require"
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
)
// Test BinaryRangeExpr AND BinaryRangeExpr - intersection
func TestRewrite_BinaryRange_AND_BinaryRange_Intersection(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 50) AND (20 < x < 40) → (20 < x < 40)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 20 and Int64Field < 40)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre, "should merge to single binary range")
require.Equal(t, false, bre.GetLowerInclusive())
require.Equal(t, false, bre.GetUpperInclusive())
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr AND BinaryRangeExpr - tighter lower
func TestRewrite_BinaryRange_AND_BinaryRange_TighterLower(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 50) AND (5 < x < 40) → (10 < x < 40)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 5 and Int64Field < 40)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr AND BinaryRangeExpr - tighter upper
func TestRewrite_BinaryRange_AND_BinaryRange_TighterUpper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 50) AND (15 < x < 60) → (15 < x < 50)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 15 and Int64Field < 60)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(15), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr AND BinaryRangeExpr - empty intersection
func TestRewrite_BinaryRange_AND_BinaryRange_EmptyIntersection(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 20) AND (30 < x < 40) → false
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) and (Int64Field > 30 and Int64Field < 40)`, nil)
require.NoError(t, err)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
// Test BinaryRangeExpr AND BinaryRangeExpr - equal bounds, both inclusive
func TestRewrite_BinaryRange_AND_BinaryRange_EqualBounds_BothInclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 <= x <= 10) AND (10 <= x <= 20) → (x == 10) which is (10 <= x <= 10)
expr, err := parser.ParseExpr(helper, `(Int64Field >= 10 and Int64Field <= 10) and (Int64Field >= 10 and Int64Field <= 20)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, true, bre.GetLowerInclusive())
require.Equal(t, true, bre.GetUpperInclusive())
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(10), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr AND BinaryRangeExpr - equal bounds, one exclusive → false
func TestRewrite_BinaryRange_AND_BinaryRange_EqualBounds_Exclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x <= 20) AND (5 <= x < 10) → false (bounds meet at 10 but exclusive)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) and (Int64Field >= 5 and Int64Field < 10)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
// Test BinaryRangeExpr AND UnaryRangeExpr - tighten lower bound
func TestRewrite_BinaryRange_AND_UnaryRange_TightenLower(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 50) AND (x > 30) → (30 < x < 50)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and Int64Field > 30`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(30), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr AND UnaryRangeExpr - tighten upper bound
func TestRewrite_BinaryRange_AND_UnaryRange_TightenUpper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 50) AND (x < 25) → (10 < x < 25)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and Int64Field < 25`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(25), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr AND UnaryRangeExpr - weaker bound (no change)
func TestRewrite_BinaryRange_AND_UnaryRange_WeakerBound(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (20 < x < 30) AND (x > 10) → (20 < x < 30) (10 is weaker than 20)
expr, err := parser.ParseExpr(helper, `(Int64Field > 20 and Int64Field < 30) and Int64Field > 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr OR BinaryRangeExpr - overlapping → union
func TestRewrite_BinaryRange_OR_BinaryRange_Overlapping(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 25) OR (20 < x < 40) → (10 < x < 40)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 25) or (Int64Field > 20 and Int64Field < 40)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre, "overlapping intervals should merge")
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr OR BinaryRangeExpr - adjacent (inclusive) → union
func TestRewrite_BinaryRange_OR_BinaryRange_Adjacent(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x <= 20) OR (20 <= x < 30) → (10 < x < 30)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) or (Int64Field >= 20 and Int64Field < 30)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre, "adjacent intervals should merge")
require.Equal(t, false, bre.GetLowerInclusive())
require.Equal(t, false, bre.GetUpperInclusive())
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRangeExpr OR BinaryRangeExpr - disjoint (no merge)
func TestRewrite_BinaryRange_OR_BinaryRange_Disjoint(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 20) OR (30 < x < 40) → remains as OR
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 30 and Int64Field < 40)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Should remain as BinaryExpr OR since intervals are disjoint
be := expr.GetBinaryExpr()
require.NotNil(t, be, "disjoint intervals should not merge")
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
}
// Test BinaryRangeExpr OR BinaryRangeExpr - adjacent but both exclusive (no merge)
func TestRewrite_BinaryRange_OR_BinaryRange_AdjacentExclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 20) OR (20 < x < 30) → remains as OR (gap at 20)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 20 and Int64Field < 30)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be, "exclusive adjacent intervals should not merge")
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
}
// Test BinaryRangeExpr OR BinaryRangeExpr - prefer inclusive when merging
func TestRewrite_BinaryRange_OR_BinaryRange_PreferInclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 <= x < 25) OR (15 <= x <= 30) → (10 <= x <= 30)
expr, err := parser.ParseExpr(helper, `(Int64Field >= 10 and Int64Field < 25) or (Int64Field >= 15 and Int64Field <= 30)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
// Lower should be 10 (from first), upper should be 30 (from second)
// Both should prefer inclusive where available
require.Equal(t, true, bre.GetLowerInclusive())
require.Equal(t, true, bre.GetUpperInclusive())
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val())
}
// Test with Float fields
func TestRewrite_BinaryRange_AND_Float(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (1.0 < x < 5.0) AND (2.0 < x < 4.0) → (2.0 < x < 4.0)
expr, err := parser.ParseExpr(helper, `(FloatField > 1.0 and FloatField < 5.0) and (FloatField > 2.0 and FloatField < 4.0)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.InDelta(t, 2.0, bre.GetLowerValue().GetFloatVal(), 1e-9)
require.InDelta(t, 4.0, bre.GetUpperValue().GetFloatVal(), 1e-9)
}
// Test with VarChar fields
func TestRewrite_BinaryRange_OR_VarChar_Overlapping(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// ("b" < x < "m") OR ("j" < x < "z") → ("b" < x < "z")
expr, err := parser.ParseExpr(helper, `(VarCharField > "b" and VarCharField < "m") or (VarCharField > "j" and VarCharField < "z")`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, "b", bre.GetLowerValue().GetStringVal())
require.Equal(t, "z", bre.GetUpperValue().GetStringVal())
}
// Test with JSON fields
func TestRewrite_BinaryRange_AND_JSON(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// (10 < json["price"] < 100) AND (20 < json["price"] < 80) → (20 < json["price"] < 80)
expr, err := parser.ParseExpr(helper, `(JSONField["price"] > 10 and JSONField["price"] < 100) and (JSONField["price"] > 20 and JSONField["price"] < 80)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val())
}
// Test with JSON fields - OR overlapping
func TestRewrite_BinaryRange_OR_JSON_Overlapping(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// (1.0 < json["score"] < 3.0) OR (2.5 < json["score"] < 5.0) → (1.0 < json["score"] < 5.0)
expr, err := parser.ParseExpr(helper, `(JSONField["score"] > 1.0 and JSONField["score"] < 3.0) or (JSONField["score"] > 2.5 and JSONField["score"] < 5.0)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.InDelta(t, 1.0, bre.GetLowerValue().GetFloatVal(), 1e-9)
require.InDelta(t, 5.0, bre.GetUpperValue().GetFloatVal(), 1e-9)
}
// Test mixing BinaryRange with multiple UnaryRanges
func TestRewrite_BinaryRange_AND_MultipleUnary(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 100) AND (x > 20) AND (x < 80) → (20 < x < 80)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 100) and Int64Field > 20 and Int64Field < 80`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val())
}
// Test three BinaryRanges with AND
func TestRewrite_BinaryRange_AND_ThreeBinaryRanges(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (5 < x < 100) AND (10 < x < 90) AND (15 < x < 80) → (15 < x < 80)
expr, err := parser.ParseExpr(helper, `(Int64Field > 5 and Int64Field < 100) and (Int64Field > 10 and Int64Field < 90) and (Int64Field > 15 and Int64Field < 80)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, int64(15), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val())
}
// Test BinaryRange on different columns should not merge
func TestRewrite_BinaryRange_AND_DifferentColumns(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < Int64Field < 50) AND (20 < FloatField < 40) → both remain
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (FloatField > 20 and FloatField < 40)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be, "different columns should not merge")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
}
// Test BinaryRange with JSON different paths should not merge
func TestRewrite_BinaryRange_AND_JSON_DifferentPaths(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// (10 < json["a"] < 50) AND (20 < json["b"] < 40) → both remain
expr, err := parser.ParseExpr(helper, `(JSONField["a"] > 10 and JSONField["a"] < 50) and (JSONField["b"] > 20 and JSONField["b"] < 40)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be, "different JSON paths should not merge")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
}
// Test BinaryRangeExpr OR with 3 overlapping intervals
// NOTE: Current implementation limitation - only merges 2 intervals at a time
func TestRewrite_BinaryRange_OR_ThreeOverlapping_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 20) OR (15 < x < 25) OR (22 < x < 30)
// Ideally should merge to (10 < x < 30)
// Currently: may only partially merge due to limitation
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 15 and Int64Field < 25) or (Int64Field > 22 and Int64Field < 30)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Due to current limitation, result may vary depending on tree structure
// This test documents the current behavior rather than ideal behavior
// When enhancement is implemented, this test should be updated to verify (10 < x < 30)
// For now, just verify it doesn't crash and produces valid output
require.NotNil(t, expr)
// Could be BinaryRangeExpr (if some merged) or BinaryExpr OR (if not merged)
// We document that 3+ intervals are NOT fully optimized yet
}
// Test BinaryRangeExpr OR with 3 fully overlapping intervals
func TestRewrite_BinaryRange_OR_ThreeFullyOverlapping(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x < 30) OR (12 < x < 28) OR (15 < x < 25)
// The second and third are fully contained in the first
// Ideally should merge to (10 < x < 30)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 30) or (Int64Field > 12 and Int64Field < 28) or (Int64Field > 15 and Int64Field < 25)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: may not fully optimize
// This test documents that 3+ interval merging is not yet complete
require.NotNil(t, expr)
}
// Test BinaryRangeExpr OR with 4 adjacent intervals
func TestRewrite_BinaryRange_OR_FourAdjacent_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (10 < x <= 20) OR (20 < x <= 30) OR (30 < x <= 40) OR (40 < x <= 50)
// Ideally should merge to (10 < x <= 50)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) or (Int64Field > 20 and Int64Field <= 30) or (Int64Field > 30 and Int64Field <= 40) or (Int64Field > 40 and Int64Field <= 50)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: only pairs of adjacent intervals may merge
// Full chain merging not implemented
require.NotNil(t, expr)
}
// Test OR with unbounded lower + bounded interval
// NOTE: Current implementation limitation - unbounded intervals not merged with bounded
func TestRewrite_BinaryRange_OR_UnboundedLower_Bounded_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (x > 10) OR (5 < x < 15)
// Ideally should merge to (x > 5)
// Currently: both predicates remain separate
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or (Int64Field > 5 and Int64Field < 15)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: unbounded + bounded intervals not optimized
// Result should be a BinaryExpr OR with both predicates
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as OR (not merged)")
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
}
// Test OR with unbounded upper + bounded interval
func TestRewrite_BinaryRange_OR_UnboundedUpper_Bounded_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (x < 20) OR (10 < x < 30)
// Ideally should merge to (x < 30)
// Currently: both predicates remain separate
expr, err := parser.ParseExpr(helper, `Int64Field < 20 or (Int64Field > 10 and Int64Field < 30)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: unbounded + bounded intervals not optimized
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as OR (not merged)")
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
}
// Test OR with unbounded lower + unbounded upper
func TestRewrite_BinaryRange_OR_UnboundedBoth_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (x > 10) OR (x < 20)
// This covers most values (gap only between 10 and 20 if both exclusive)
// Currently: both predicates remain separate
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field < 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: unbounded intervals in OR not merged
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as OR")
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
}
// Test OR with multiple unbounded lower bounds - these DO get optimized (weakening)
func TestRewrite_BinaryRange_OR_MultipleUnboundedLower(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (x > 10) OR (x > 20) → (x > 10) [weaker bound]
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// This SHOULD be optimized (weakening works for same direction)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure, "should merge to single unary range")
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}

View File

@ -0,0 +1,294 @@
package rewriter_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func buildSchemaHelperWithJSON(t *testing.T) *typeutil.SchemaHelper {
fields := []*schemapb.FieldSchema{
{FieldID: 101, Name: "Int64Field", DataType: schemapb.DataType_Int64},
{FieldID: 102, Name: "JSONField", DataType: schemapb.DataType_JSON},
{FieldID: 103, Name: "$meta", DataType: schemapb.DataType_JSON, IsDynamic: true},
}
schema := &schemapb.CollectionSchema{
Name: "rewrite_json_test",
AutoID: false,
Fields: fields,
EnableDynamicField: true,
}
helper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
return helper
}
// Test JSON field with int comparison - AND tightening
func TestRewrite_JSON_Int_AND_Strengthen(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["price"] > 10 and JSONField["price"] > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure, "should merge to single unary range")
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
// Verify column info
require.Equal(t, schemapb.DataType_JSON, ure.GetColumnInfo().GetDataType())
require.Equal(t, []string{"price"}, ure.GetColumnInfo().GetNestedPath())
}
// Test JSON field with int comparison - AND to BinaryRange
func TestRewrite_JSON_Int_AND_ToBinaryRange(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["price"] > 10 and JSONField["price"] < 50`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre, "should create binary range")
require.Equal(t, false, bre.GetLowerInclusive())
require.Equal(t, false, bre.GetUpperInclusive())
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
// Verify column info
require.Equal(t, schemapb.DataType_JSON, bre.GetColumnInfo().GetDataType())
require.Equal(t, []string{"price"}, bre.GetColumnInfo().GetNestedPath())
}
// Test JSON field with float comparison - AND tightening with mixed int/float
func TestRewrite_JSON_Float_AND_Strengthen_Mixed(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["score"] > 10 and JSONField["score"] > 15.5`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.InDelta(t, 15.5, ure.GetValue().GetFloatVal(), 1e-9)
}
// Test JSON field with string comparison - AND tightening
func TestRewrite_JSON_String_AND_Strengthen(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["name"] > "alice" and JSONField["name"] > "bob"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, "bob", ure.GetValue().GetStringVal())
}
// Test JSON field with string comparison - AND to BinaryRange
func TestRewrite_JSON_String_AND_ToBinaryRange(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["name"] >= "alice" and JSONField["name"] <= "zebra"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, true, bre.GetLowerInclusive())
require.Equal(t, true, bre.GetUpperInclusive())
require.Equal(t, "alice", bre.GetLowerValue().GetStringVal())
require.Equal(t, "zebra", bre.GetUpperValue().GetStringVal())
}
// Test JSON field with OR - weakening lower bounds
func TestRewrite_JSON_Int_OR_Weaken(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["age"] > 10 or JSONField["age"] > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}
// Test JSON field with OR - weakening upper bounds
func TestRewrite_JSON_String_OR_Weaken_Upper(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["category"] < "electronics" or JSONField["category"] < "sports"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
require.Equal(t, "sports", ure.GetValue().GetStringVal())
}
// Test JSON field with numeric types (int and float) - SHOULD merge
func TestRewrite_JSON_NumericTypes_Merge(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// JSONField["value"] > 10.5 (float) and JSONField["value"] > 20 (int)
// Numeric types should merge - both treated as Double
expr, err := parser.ParseExpr(helper, `JSONField["value"] > 10.5 and JSONField["value"] > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure, "numeric types should merge to single predicate")
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
// Should pick the stronger bound (20)
switch ure.GetValue().GetVal().(type) {
case *planpb.GenericValue_Int64Val:
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
case *planpb.GenericValue_FloatVal:
require.InDelta(t, 20.0, ure.GetValue().GetFloatVal(), 1e-9)
default:
t.Fatalf("unexpected value type")
}
}
// Test JSON field with mixed type categories - should NOT merge
// Note: Numeric types (int and float) ARE compatible, but numeric and string are not
func TestRewrite_JSON_MixedTypes_NoMerge(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// JSONField["value"] > 10 (numeric) and JSONField["value"] > "hello" (string)
// These should remain separate as they have different type categories
expr, err := parser.ParseExpr(helper, `JSONField["value"] > 10 and JSONField["value"] > "hello"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as binary expr (AND)")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
// Both sides should be UnaryRangeExpr
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
}
// Test JSON field with mixed type categories in OR - should NOT merge
func TestRewrite_JSON_MixedTypes_OR_NoMerge(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["data"] > 100 or JSONField["data"] > "text"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as binary expr (OR)")
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
}
// Test dynamic field ($meta) - same as JSON field
func TestRewrite_DynamicField_Int_AND_Strengthen(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// $meta is the dynamic field
expr, err := parser.ParseExpr(helper, `$meta["count"] > 5 and $meta["count"] > 15`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(15), ure.GetValue().GetInt64Val())
}
// Test dynamic field - AND to BinaryRange
func TestRewrite_DynamicField_Float_AND_ToBinaryRange(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `$meta["rating"] >= 1.0 and $meta["rating"] <= 5.0`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, true, bre.GetLowerInclusive())
require.Equal(t, true, bre.GetUpperInclusive())
require.InDelta(t, 1.0, bre.GetLowerValue().GetFloatVal(), 1e-9)
require.InDelta(t, 5.0, bre.GetUpperValue().GetFloatVal(), 1e-9)
}
// Test different nested paths - should NOT merge
func TestRewrite_JSON_DifferentPaths_NoMerge(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["a"] > 10 and JSONField["b"] > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be, "different paths should not merge")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
}
// Test nested JSON path
func TestRewrite_JSON_NestedPath_AND_Strengthen(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `JSONField["user"]["age"] > 18 and JSONField["user"]["age"] > 21`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(21), ure.GetValue().GetInt64Val())
require.Equal(t, []string{"user", "age"}, ure.GetColumnInfo().GetNestedPath())
}
// Test inclusive vs exclusive bounds
func TestRewrite_JSON_AND_EquivalentBounds(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// JSONField["x"] >= 10 AND JSONField["x"] > 10 → JSONField["x"] > 10 (exclusive is stronger)
expr, err := parser.ParseExpr(helper, `JSONField["x"] >= 10 and JSONField["x"] > 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}
// Test OR with inclusive bounds preference
func TestRewrite_JSON_OR_EquivalentBounds(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
// JSONField["x"] >= 10 OR JSONField["x"] > 10 → JSONField["x"] >= 10 (inclusive is weaker)
expr, err := parser.ParseExpr(helper, `JSONField["x"] >= 10 or JSONField["x"] > 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}
// Test that scalar field and JSON field don't interfere
func TestRewrite_JSON_And_Scalar_Independent(t *testing.T) {
helper := buildSchemaHelperWithJSON(t)
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 20 and JSONField["price"] > 5 and JSONField["price"] > 15`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Should result in AND of two optimized predicates
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
// Collect all predicates
var predicates []*planpb.Expr
var collect func(*planpb.Expr)
collect = func(e *planpb.Expr) {
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd {
collect(be.GetLeft())
collect(be.GetRight())
} else {
predicates = append(predicates, e)
}
}
collect(expr)
require.Equal(t, 2, len(predicates), "should have two optimized predicates")
// Check that both are optimized to the stronger bounds
var hasInt64, hasJSON bool
for _, p := range predicates {
if ure := p.GetUnaryRangeExpr(); ure != nil {
col := ure.GetColumnInfo()
if col.GetDataType() == schemapb.DataType_Int64 {
hasInt64 = true
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
} else if col.GetDataType() == schemapb.DataType_JSON {
hasJSON = true
require.Equal(t, int64(15), ure.GetValue().GetInt64Val())
}
}
}
require.True(t, hasInt64, "should have optimized Int64Field predicate")
require.True(t, hasJSON, "should have optimized JSONField predicate")
}

View File

@ -0,0 +1,492 @@
package rewriter_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestRewrite_Range_AND_Strengthen(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
}
func TestRewrite_Range_AND_Strengthen_Upper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field < 50 and Int64Field < 60`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
require.Equal(t, int64(50), ure.GetValue().GetInt64Val())
}
func TestRewrite_Range_AND_EquivalentBounds(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// a ≥ x AND a > x → a > x
expr, err := parser.ParseExpr(helper, `Int64Field >= 10 and Int64Field > 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
// a ≤ y AND a < y → a < y
expr, err = parser.ParseExpr(helper, `Int64Field <= 10 and Int64Field < 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure = expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}
func TestRewrite_Range_OR_Weaken(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}
func TestRewrite_Range_OR_Weaken_Upper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field < 10 or Int64Field < 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
}
func TestRewrite_Range_OR_EquivalentBounds(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// a ≥ x OR a > x → a ≥ x
expr, err := parser.ParseExpr(helper, `Int64Field >= 10 or Int64Field > 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
// a ≤ y OR a < y → a ≤ y
expr, err = parser.ParseExpr(helper, `Int64Field <= 10 or Int64Field < 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure = expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_LessEqual, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}
func TestRewrite_Range_AND_ToBinaryRange(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field < 50`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, false, bre.GetLowerInclusive())
require.Equal(t, false, bre.GetUpperInclusive())
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
}
func TestRewrite_Range_AND_ToBinaryRange_Inclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field >= 10 and Int64Field <= 50`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.Equal(t, true, bre.GetLowerInclusive())
require.Equal(t, true, bre.GetUpperInclusive())
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
}
func TestRewrite_Range_OR_MixedDirection_NoMerge(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field < 5`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
}
// Edge cases for Float/Double columns: allow mixing int and float literals.
func TestRewrite_Range_AND_Strengthen_Float_Mixed(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `FloatField > 10 and FloatField > 15.0`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.InDelta(t, 15.0, ure.GetValue().GetFloatVal(), 1e-9)
}
func TestRewrite_Range_AND_ToBinaryRange_Float_Mixed(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `FloatField > 10 and FloatField < 20.5`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.False(t, bre.GetLowerInclusive())
require.False(t, bre.GetUpperInclusive())
// lower may be encoded as int or float; assert either encoding equals 10
lv := bre.GetLowerValue()
switch lv.GetVal().(type) {
case *planpb.GenericValue_Int64Val:
require.Equal(t, int64(10), lv.GetInt64Val())
case *planpb.GenericValue_FloatVal:
require.InDelta(t, 10.0, lv.GetFloatVal(), 1e-9)
default:
t.Fatalf("unexpected lower value type")
}
// upper is float literal 20.5
require.InDelta(t, 20.5, bre.GetUpperValue().GetFloatVal(), 1e-9)
}
func TestRewrite_Range_OR_Weaken_Float_Mixed(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `FloatField > 10 or FloatField > 20.5`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
// weakest lower is 10; value may be encoded as int or float
switch ure.GetValue().GetVal().(type) {
case *planpb.GenericValue_Int64Val:
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
case *planpb.GenericValue_FloatVal:
require.InDelta(t, 10.0, ure.GetValue().GetFloatVal(), 1e-9)
default:
t.Fatalf("unexpected value type")
}
}
func TestRewrite_Range_AND_NoMerge_Int_WithFloatLiteral(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
_, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 15.0`, nil)
// keeping this test so that we know the parser will not accept this expression, so we
// don't need to optimize it.
require.Error(t, err)
}
func TestRewrite_Range_Tie_Inclusive_Float(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// OR: >=10 or >10 -> >=10 (weaken prefers inclusive)
expr, err := parser.ParseExpr(helper, `FloatField >= 10 or FloatField > 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp())
// AND: >=10 and >10 -> >10 (tighten prefers strict)
expr, err = parser.ParseExpr(helper, `FloatField >= 10 and FloatField > 10`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure = expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
}
// VarChar range optimization tests
func TestRewrite_Range_VarChar_AND_Strengthen(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `VarCharField > "a" and VarCharField > "b"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, "b", ure.GetValue().GetStringVal())
}
func TestRewrite_Range_VarChar_OR_Weaken_Upper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `VarCharField < "m" or VarCharField < "z"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
require.Equal(t, "z", ure.GetValue().GetStringVal())
}
// Array fields: ensure parser rejects direct range comparison on arrays for different element types.
// If in the future parser supports range on arrays, these tests can be updated accordingly.
func buildSchemaHelperWithArraysT(t *testing.T) *typeutil.SchemaHelper {
fields := []*schemapb.FieldSchema{
{FieldID: 201, Name: "ArrayInt", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
{FieldID: 202, Name: "ArrayFloat", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double},
{FieldID: 203, Name: "ArrayVarchar", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_VarChar},
}
schema := &schemapb.CollectionSchema{
Name: "rewrite_array_test",
AutoID: false,
Fields: fields,
}
helper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
return helper
}
func TestRewrite_Range_Array_Int_NotSupported(t *testing.T) {
helper := buildSchemaHelperWithArraysT(t)
_, err := parser.ParseExpr(helper, `ArrayInt > 10`, nil)
require.Error(t, err)
}
func TestRewrite_Range_Array_Float_NotSupported(t *testing.T) {
helper := buildSchemaHelperWithArraysT(t)
_, err := parser.ParseExpr(helper, `ArrayFloat > 10.5`, nil)
require.Error(t, err)
}
func TestRewrite_Range_Array_VarChar_NotSupported(t *testing.T) {
helper := buildSchemaHelperWithArraysT(t)
_, err := parser.ParseExpr(helper, `ArrayVarchar > "a"`, nil)
require.Error(t, err)
}
// Array index access optimizations
func TestRewrite_Range_ArrayInt_Index_AND_Strengthen(t *testing.T) {
helper := buildSchemaHelperWithArraysT(t)
expr, err := parser.ParseExpr(helper, `ArrayInt[0] > 10 and ArrayInt[0] > 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
}
func TestRewrite_Range_ArrayFloat_Index_OR_Weaken_Mixed(t *testing.T) {
helper := buildSchemaHelperWithArraysT(t)
expr, err := parser.ParseExpr(helper, `ArrayFloat[0] > 10 or ArrayFloat[0] > 20.5`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
switch ure.GetValue().GetVal().(type) {
case *planpb.GenericValue_Int64Val:
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
case *planpb.GenericValue_FloatVal:
require.InDelta(t, 10.0, ure.GetValue().GetFloatVal(), 1e-9)
default:
t.Fatalf("unexpected value type")
}
}
func TestRewrite_Range_ArrayVarChar_Index_ToBinaryRange(t *testing.T) {
helper := buildSchemaHelperWithArraysT(t)
expr, err := parser.ParseExpr(helper, `ArrayVarchar[0] > "a" and ArrayVarchar[0] < "m"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre)
require.False(t, bre.GetLowerInclusive())
require.False(t, bre.GetUpperInclusive())
require.Equal(t, "a", bre.GetLowerValue().GetStringVal())
require.Equal(t, "m", bre.GetUpperValue().GetStringVal())
}
// helper to flatten AND tree into list of exprs
func collectAndExprs(e *planpb.Expr, out *[]*planpb.Expr) {
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd {
collectAndExprs(be.GetLeft(), out)
collectAndExprs(be.GetRight(), out)
return
}
*out = append(*out, e)
}
func TestRewrite_Range_Array_Index_Different_NoMerge(t *testing.T) {
helper := buildSchemaHelperWithArraysT(t)
expr, err := parser.ParseExpr(helper, `ArrayInt[0] > 10 and ArrayInt[1] > 20 and ArrayInt[0] < 20`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
parts := []*planpb.Expr{}
collectAndExprs(expr, &parts)
// expect exactly two parts after rewrite: interval on index 0, and lower bound on index 1
require.Equal(t, 2, len(parts))
var seenInterval, seenLower bool
for _, p := range parts {
if bre := p.GetBinaryRangeExpr(); bre != nil {
seenInterval = true
// 10 < ArrayInt[0] < 20
require.False(t, bre.GetLowerInclusive())
require.False(t, bre.GetUpperInclusive())
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(20), bre.GetUpperValue().GetInt64Val())
continue
}
if ure := p.GetUnaryRangeExpr(); ure != nil {
seenLower = true
require.True(t, ure.GetOp() == planpb.OpType_GreaterThan || ure.GetOp() == planpb.OpType_GreaterEqual)
// bound value 20 on index 1 lower side
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
continue
}
// should not reach here: only BinaryRangeExpr and UnaryRangeExpr expected
t.Fatalf("unexpected expr kind in AND parts")
}
require.True(t, seenInterval)
require.True(t, seenLower)
}
// Test invalid BinaryRangeExpr: lower > upper → false
func TestRewrite_Range_AND_InvalidRange_LowerGreaterThanUpper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field > 100 AND Int64Field < 50 → false (impossible range)
expr, err := parser.ParseExpr(helper, `Int64Field > 100 and Int64Field < 50`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
// Test invalid BinaryRangeExpr: lower == upper with exclusive bounds → false
func TestRewrite_Range_AND_InvalidRange_EqualBoundsExclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field > 50 AND Int64Field < 50 → false (exclusive on equal bounds)
expr, err := parser.ParseExpr(helper, `Int64Field > 50 and Int64Field < 50`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
// Test invalid BinaryRangeExpr: lower == upper with one exclusive → false
func TestRewrite_Range_AND_InvalidRange_EqualBoundsOneExclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field >= 50 AND Int64Field < 50 → false (one exclusive on equal bounds)
expr, err := parser.ParseExpr(helper, `Int64Field >= 50 and Int64Field < 50`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
// Test valid BinaryRangeExpr: lower == upper with both inclusive → valid
func TestRewrite_Range_AND_ValidRange_EqualBoundsBothInclusive(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field >= 50 AND Int64Field <= 50 → (50 <= x <= 50), which is valid (x == 50)
expr, err := parser.ParseExpr(helper, `Int64Field >= 50 and Int64Field <= 50`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
bre := expr.GetBinaryRangeExpr()
require.NotNil(t, bre, "should create valid binary range for x == 50")
require.Equal(t, true, bre.GetLowerInclusive())
require.Equal(t, true, bre.GetUpperInclusive())
require.Equal(t, int64(50), bre.GetLowerValue().GetInt64Val())
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
}
// Test invalid BinaryRangeExpr with float: lower > upper → false
func TestRewrite_Range_AND_InvalidRange_Float_LowerGreaterThanUpper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// FloatField > 99.9 AND FloatField < 10.5 → false
expr, err := parser.ParseExpr(helper, `FloatField > 99.9 and FloatField < 10.5`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
// Test invalid BinaryRangeExpr with string: lower > upper → false
func TestRewrite_Range_AND_InvalidRange_String_LowerGreaterThanUpper(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// VarCharField > "zebra" AND VarCharField < "apple" → false
expr, err := parser.ParseExpr(helper, `VarCharField > "zebra" and VarCharField < "apple"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
// Test AlwaysFalse propagation through nested AND expressions
func TestRewrite_AlwaysFalse_Propagation_DeepNesting(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Deep nesting: (Int64Field > 10) AND ((Int64Field > 20) AND (Int64Field > 100 AND Int64Field < 50))
// The innermost (Int64Field > 100 AND Int64Field < 50) should become AlwaysFalse
// This AlwaysFalse should propagate up through all ANDs, making the entire expression AlwaysFalse
expr, err := parser.ParseExpr(helper, `(Int64Field > 10) and ((Int64Field > 20) and (Int64Field > 100 and Int64Field < 50))`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Should propagate to become AlwaysFalse at top level
require.True(t, rewriter.IsAlwaysFalseExpr(expr), "AlwaysFalse should propagate to top level")
}
// Test AlwaysFalse elimination in OR expressions
func TestRewrite_AlwaysFalse_Elimination_InOR(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (Int64Field > 10) OR ((Int64Field > 20) OR (Int64Field > 100 AND Int64Field < 50))
// The innermost becomes AlwaysFalse, should be eliminated from OR
// Result should be: Int64Field > 10 OR Int64Field > 20 → Int64Field > 10 (weaker bound)
expr, err := parser.ParseExpr(helper, `(Int64Field > 10) or ((Int64Field > 20) or (Int64Field > 100 and Int64Field < 50))`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Should simplify to single range condition: Int64Field > 10
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure, "AlwaysFalse should be eliminated, leaving simplified range")
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}
// Test complex double negation: NOT NOT AlwaysTrue → AlwaysTrue
func TestRewrite_DoubleNegation_ToAlwaysTrue(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `not (Int64Field > 100 and Int64Field < 50)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysTrueExpr(expr))
}
// Test complex nested double negation with multiple layers
func TestRewrite_ComplexDoubleNegation_MultiLayer(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `not ((Int64Field > 100 and Int64Field < 50) or (FloatField > 99.9 and FloatField < 10.5))`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysTrueExpr(expr))
}
// Test AlwaysTrue in AND with normal conditions gets eliminated
func TestRewrite_AlwaysTrue_Elimination_InAND(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// (Int64Field > 10) AND NOT(Int64Field > 100 AND Int64Field < 50)
// The second part becomes AlwaysTrue, should be eliminated from AND
// Result should be just: Int64Field > 10
expr, err := parser.ParseExpr(helper, `(Int64Field > 10) and not (Int64Field > 100 and Int64Field < 50)`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
}

View File

@ -0,0 +1,666 @@
package rewriter
import (
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
)
func (v *visitor) combineOrEqualsToIn(parts []*planpb.Expr) []*planpb.Expr {
type group struct {
col *planpb.ColumnInfo
values []*planpb.GenericValue
origIndices []int
valCase string
}
others := make([]*planpb.Expr, 0, len(parts))
groups := make(map[string]*group)
indexToExpr := parts
for idx, e := range parts {
u := e.GetUnaryRangeExpr()
if u == nil || u.GetOp() != planpb.OpType_Equal || u.GetValue() == nil {
others = append(others, e)
continue
}
col := u.GetColumnInfo()
if col == nil {
others = append(others, e)
continue
}
key := columnKey(col)
g, ok := groups[key]
valCase := valueCase(u.GetValue())
if !ok {
g = &group{col: col, values: []*planpb.GenericValue{}, origIndices: []int{}, valCase: valCase}
groups[key] = g
}
if g.valCase != valCase {
others = append(others, e)
continue
}
g.values = append(g.values, u.GetValue())
g.origIndices = append(g.origIndices, idx)
}
out := make([]*planpb.Expr, 0, len(parts))
out = append(out, others...)
for _, g := range groups {
if shouldMergeToIn(g.col.GetDataType(), len(g.values)) {
g.values = sortGenericValues(g.values)
out = append(out, newTermExpr(g.col, g.values))
} else {
for _, i := range g.origIndices {
out = append(out, indexToExpr[i])
}
}
}
return out
}
func (v *visitor) combineAndNotEqualsToNotIn(parts []*planpb.Expr) []*planpb.Expr {
type group struct {
col *planpb.ColumnInfo
values []*planpb.GenericValue
origIndices []int
valCase string
}
others := make([]*planpb.Expr, 0, len(parts))
groups := make(map[string]*group)
indexToExpr := parts
for idx, e := range parts {
u := e.GetUnaryRangeExpr()
if u == nil || u.GetOp() != planpb.OpType_NotEqual || u.GetValue() == nil {
others = append(others, e)
continue
}
col := u.GetColumnInfo()
if col == nil {
others = append(others, e)
continue
}
key := columnKey(col)
g, ok := groups[key]
valCase := valueCase(u.GetValue())
if !ok {
g = &group{col: col, values: []*planpb.GenericValue{}, origIndices: []int{}, valCase: valCase}
groups[key] = g
}
if g.valCase != valCase {
others = append(others, e)
continue
}
g.values = append(g.values, u.GetValue())
g.origIndices = append(g.origIndices, idx)
}
out := make([]*planpb.Expr, 0, len(parts))
out = append(out, others...)
for _, g := range groups {
if shouldMergeToIn(g.col.GetDataType(), len(g.values)) {
g.values = sortGenericValues(g.values)
in := newTermExpr(g.col, g.values)
out = append(out, notExpr(in))
} else {
for _, i := range g.origIndices {
out = append(out, indexToExpr[i])
}
}
}
return out
}
func notExpr(child *planpb.Expr) *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_UnaryExpr{
UnaryExpr: &planpb.UnaryExpr{
Op: planpb.UnaryExpr_Not,
Child: child,
},
},
}
}
// AND: (a IN S) AND (a = v) with v in S -> a = v
func (v *visitor) combineAndInWithEqual(parts []*planpb.Expr) []*planpb.Expr {
type agg struct {
termIdxs []int
eqIdxs []int
term *planpb.TermExpr
eqValues []*planpb.GenericValue
col *planpb.ColumnInfo
}
groups := map[string]*agg{}
others := []int{}
for idx, e := range parts {
if te := e.GetTermExpr(); te != nil {
k := columnKey(te.GetColumnInfo())
g := groups[k]
if g == nil {
g = &agg{col: te.GetColumnInfo()}
}
g.termIdxs = append(g.termIdxs, idx)
g.term = te
groups[k] = g
continue
}
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_Equal && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
k := columnKey(ue.GetColumnInfo())
g := groups[k]
if g == nil {
g = &agg{col: ue.GetColumnInfo()}
}
g.eqIdxs = append(g.eqIdxs, idx)
g.eqValues = append(g.eqValues, ue.GetValue())
groups[k] = g
continue
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if g.term == nil || len(g.eqIdxs) == 0 {
continue
}
// Build set of eq values and check presence in term set.
termVals := g.term.GetValues()
eqUnique := []*planpb.GenericValue{}
for _, ev := range g.eqValues {
dup := false
for _, u := range eqUnique {
if equalsGeneric(u, ev) {
dup = true
break
}
}
if !dup {
eqUnique = append(eqUnique, ev)
}
}
// If multiple different equals present, AND implies contradiction unless identical.
if len(eqUnique) > 1 {
for _, ti := range g.termIdxs {
used[ti] = true
}
for _, ei := range g.eqIdxs {
used[ei] = true
}
// emit constant false
out = append(out, newAlwaysFalseExpr())
continue
}
// Single equal value
ev := eqUnique[0]
inSet := false
for _, tv := range termVals {
if equalsGeneric(tv, ev) {
inSet = true
break
}
}
for _, ti := range g.termIdxs {
used[ti] = true
}
for _, ei := range g.eqIdxs {
used[ei] = true
}
if inSet {
// reduce to equality
out = append(out, newUnaryRangeExpr(g.col, planpb.OpType_Equal, ev))
} else {
// contradiction -> false
out = append(out, newAlwaysFalseExpr())
}
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
// OR: (a IN S) OR (a = v) with v in S -> keep a IN S (drop equal)
// Optional extension (not enabled here): if v not in S, could union.
func (v *visitor) combineOrInWithEqual(parts []*planpb.Expr) []*planpb.Expr {
type agg struct {
termIdx int
term *planpb.TermExpr
col *planpb.ColumnInfo
eqIdxs []int
eqVals []*planpb.GenericValue
}
groups := map[string]*agg{}
others := []int{}
for idx, e := range parts {
if te := e.GetTermExpr(); te != nil {
k := columnKey(te.GetColumnInfo())
g := groups[k]
if g == nil {
g = &agg{col: te.GetColumnInfo()}
groups[k] = g
}
g.termIdx = idx
g.term = te
continue
}
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_Equal && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
k := columnKey(ue.GetColumnInfo())
g := groups[k]
if g == nil {
g = &agg{col: ue.GetColumnInfo()}
groups[k] = g
}
g.eqIdxs = append(g.eqIdxs, idx)
g.eqVals = append(g.eqVals, ue.GetValue())
continue
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if g.term == nil || len(g.eqIdxs) == 0 {
continue
}
// union all equal values into term set
union := g.term.GetValues()
for i, ev := range g.eqVals {
union = append(union, ev)
used[g.eqIdxs[i]] = true
}
union = sortGenericValues(union)
used[g.termIdx] = true
out = append(out, newTermExpr(g.col, union))
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
// AND: (a IN S) AND (range) -> filter S by range
func (v *visitor) combineAndInWithRange(parts []*planpb.Expr) []*planpb.Expr {
type group struct {
col *planpb.ColumnInfo
termIdx int
term *planpb.TermExpr
lower *planpb.GenericValue
lowerInc bool
upper *planpb.GenericValue
upperInc bool
rangeIdxs []int
}
groups := map[string]*group{}
others := []int{}
isRange := func(op planpb.OpType) bool {
return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual
}
for idx, e := range parts {
if te := e.GetTermExpr(); te != nil {
k := columnKey(te.GetColumnInfo())
g := groups[k]
if g == nil {
g = &group{col: te.GetColumnInfo()}
groups[k] = g
}
g.term = te
g.termIdx = idx
continue
}
if ue := e.GetUnaryRangeExpr(); ue != nil && isRange(ue.GetOp()) && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
k := columnKey(ue.GetColumnInfo())
g := groups[k]
if g == nil {
g = &group{col: ue.GetColumnInfo()}
groups[k] = g
}
if ue.GetOp() == planpb.OpType_GreaterThan || ue.GetOp() == planpb.OpType_GreaterEqual {
if g.lower == nil || cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.lower) > 0 || (cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.lower) == 0 && ue.GetOp() == planpb.OpType_GreaterThan && g.lowerInc) {
g.lower = ue.GetValue()
g.lowerInc = ue.GetOp() == planpb.OpType_GreaterEqual
}
} else {
if g.upper == nil || cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.upper) < 0 || (cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.upper) == 0 && ue.GetOp() == planpb.OpType_LessThan && g.upperInc) {
g.upper = ue.GetValue()
g.upperInc = ue.GetOp() == planpb.OpType_LessEqual
}
}
g.rangeIdxs = append(g.rangeIdxs, idx)
continue
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if g.term == nil || (g.lower == nil && g.upper == nil) {
continue
}
// Skip optimization if any term value is not comparable with the provided bounds
termVals := g.term.GetValues()
comparable := true
for _, tv := range termVals {
if g.lower != nil {
if !(areComparableCases(valueCaseWithNil(tv), valueCaseWithNil(g.lower)) || (isNumericCase(valueCaseWithNil(tv)) && isNumericCase(valueCaseWithNil(g.lower)))) {
comparable = false
break
}
}
if comparable && g.upper != nil {
if !(areComparableCases(valueCaseWithNil(tv), valueCaseWithNil(g.upper)) || (isNumericCase(valueCaseWithNil(tv)) && isNumericCase(valueCaseWithNil(g.upper)))) {
comparable = false
break
}
}
}
if !comparable {
continue
}
filtered := filterValuesByRange(effectiveDataType(g.col), termVals, g.lower, g.lowerInc, g.upper, g.upperInc)
used[g.termIdx] = true
for _, ri := range g.rangeIdxs {
used[ri] = true
}
if len(filtered) == 0 {
// Empty IN list after filtering → AlwaysFalse
out = append(out, newAlwaysFalseExpr())
} else {
out = append(out, newTermExpr(g.col, filtered))
}
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
// OR: (a IN S1) OR (a IN S2) -> a IN union(S1, S2)
func (v *visitor) combineOrInWithIn(parts []*planpb.Expr) []*planpb.Expr {
type agg struct {
col *planpb.ColumnInfo
idxs []int
values [][]*planpb.GenericValue
}
groups := map[string]*agg{}
others := []int{}
for idx, e := range parts {
if te := e.GetTermExpr(); te != nil {
k := columnKey(te.GetColumnInfo())
g := groups[k]
if g == nil {
g = &agg{col: te.GetColumnInfo()}
groups[k] = g
}
g.idxs = append(g.idxs, idx)
g.values = append(g.values, te.GetValues())
continue
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if len(g.idxs) <= 1 {
continue
}
union := []*planpb.GenericValue{}
for _, vs := range g.values {
union = append(union, vs...)
}
union = sortGenericValues(union)
for _, i := range g.idxs {
used[i] = true
}
out = append(out, newTermExpr(g.col, union))
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
// AND: (a IN S1) AND (a IN S2) ... -> a IN intersection(S1, S2, ...)
func (v *visitor) combineAndInWithIn(parts []*planpb.Expr) []*planpb.Expr {
type agg struct {
col *planpb.ColumnInfo
idxs []int
values [][]*planpb.GenericValue
}
groups := map[string]*agg{}
others := []int{}
for idx, e := range parts {
if te := e.GetTermExpr(); te != nil {
k := columnKey(te.GetColumnInfo())
g := groups[k]
if g == nil {
g = &agg{col: te.GetColumnInfo()}
groups[k] = g
}
g.idxs = append(g.idxs, idx)
g.values = append(g.values, te.GetValues())
continue
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, idx := range others {
out = append(out, parts[idx])
used[idx] = true
}
for _, g := range groups {
if len(g.idxs) <= 1 {
continue
}
// compute intersection; start from first set
inter := make([]*planpb.GenericValue, 0, len(g.values[0]))
outer:
for _, v := range g.values[0] {
// check in every other set
ok := true
for i := 1; i < len(g.values); i++ {
found := false
for _, w := range g.values[i] {
if equalsGeneric(v, w) {
found = true
break
}
}
if !found {
continue outer
}
}
if ok {
inter = append(inter, v)
}
}
for _, i := range g.idxs {
used[i] = true
}
if len(inter) == 0 {
out = append(out, newAlwaysFalseExpr())
} else {
out = append(out, newTermExpr(g.col, inter))
}
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
// AND: (a IN S) AND (a != d) -> remove d from S; empty -> false
func (v *visitor) combineAndInWithNotEqual(parts []*planpb.Expr) []*planpb.Expr {
type group struct {
col *planpb.ColumnInfo
termIdx int
term *planpb.TermExpr
neqIdxs []int
neqVals []*planpb.GenericValue
}
groups := map[string]*group{}
others := []int{}
for idx, e := range parts {
if te := e.GetTermExpr(); te != nil {
k := columnKey(te.GetColumnInfo())
g := groups[k]
if g == nil {
g = &group{col: te.GetColumnInfo()}
groups[k] = g
}
g.term = te
g.termIdx = idx
continue
}
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_NotEqual && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
k := columnKey(ue.GetColumnInfo())
g := groups[k]
if g == nil {
g = &group{col: ue.GetColumnInfo()}
groups[k] = g
}
g.neqIdxs = append(g.neqIdxs, idx)
g.neqVals = append(g.neqVals, ue.GetValue())
continue
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, i := range others {
out = append(out, parts[i])
used[i] = true
}
for _, g := range groups {
if g.term == nil || len(g.neqIdxs) == 0 {
continue
}
filtered := []*planpb.GenericValue{}
for _, tv := range g.term.GetValues() {
excluded := false
for _, dv := range g.neqVals {
if equalsGeneric(tv, dv) {
excluded = true
break
}
}
if !excluded {
filtered = append(filtered, tv)
}
}
used[g.termIdx] = true
for _, ni := range g.neqIdxs {
used[ni] = true
}
if len(filtered) == 0 {
out = append(out, newAlwaysFalseExpr())
} else {
out = append(out, newTermExpr(g.col, filtered))
}
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}
// OR: (a IN S) OR (a != d) -> if d ∈ S then true else (a != d)
func (v *visitor) combineOrInWithNotEqual(parts []*planpb.Expr) []*planpb.Expr {
type group struct {
col *planpb.ColumnInfo
termIdx int
term *planpb.TermExpr
neqIdxs []int
neqVals []*planpb.GenericValue
}
groups := map[string]*group{}
others := []int{}
for idx, e := range parts {
if te := e.GetTermExpr(); te != nil {
k := columnKey(te.GetColumnInfo())
g := groups[k]
if g == nil {
g = &group{col: te.GetColumnInfo()}
groups[k] = g
}
g.term = te
g.termIdx = idx
continue
}
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_NotEqual && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
k := columnKey(ue.GetColumnInfo())
g := groups[k]
if g == nil {
g = &group{col: ue.GetColumnInfo()}
groups[k] = g
}
g.neqIdxs = append(g.neqIdxs, idx)
g.neqVals = append(g.neqVals, ue.GetValue())
continue
}
others = append(others, idx)
}
used := make([]bool, len(parts))
out := make([]*planpb.Expr, 0, len(parts))
for _, i := range others {
out = append(out, parts[i])
used[i] = true
}
for _, g := range groups {
if g.term == nil || len(g.neqIdxs) == 0 {
continue
}
// if any neq value is inside IN set -> true
containsAny := false
for _, dv := range g.neqVals {
for _, tv := range g.term.GetValues() {
if equalsGeneric(tv, dv) {
containsAny = true
break
}
}
if containsAny {
break
}
}
if containsAny {
used[g.termIdx] = true
for _, ni := range g.neqIdxs {
used[ni] = true
}
out = append(out, newBoolConstExpr(true))
} else {
// drop the IN; keep != as-is
used[g.termIdx] = true
}
}
for i := range parts {
if !used[i] {
out = append(out, parts[i])
}
}
return out
}

View File

@ -0,0 +1,356 @@
package rewriter_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func buildSchemaHelperForRewriteT(t *testing.T) *typeutil.SchemaHelper {
fields := []*schemapb.FieldSchema{
{FieldID: 101, Name: "Int64Field", DataType: schemapb.DataType_Int64},
{FieldID: 102, Name: "VarCharField", DataType: schemapb.DataType_VarChar},
{FieldID: 103, Name: "StringField", DataType: schemapb.DataType_String},
{FieldID: 104, Name: "FloatField", DataType: schemapb.DataType_Double},
{FieldID: 105, Name: "BoolField", DataType: schemapb.DataType_Bool},
}
schema := &schemapb.CollectionSchema{
Name: "rewrite_test",
AutoID: false,
Fields: fields,
}
// enable text_match on string-like fields
for _, f := range schema.Fields {
if typeutil.IsStringType(f.DataType) {
f.TypeParams = append(f.TypeParams, &commonpb.KeyValuePair{
Key: "enable_match",
Value: "True",
})
}
}
helper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
return helper
}
func TestRewrite_OREquals_ToIN_NonNumeric(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `VarCharField == "a" or VarCharField == "b"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term, "expected OR-equals to be rewritten to TermExpr(IN ...)")
require.Equal(t, 2, len(term.GetValues()))
require.Equal(t, "a", term.GetValues()[0].GetStringVal())
require.Equal(t, "b", term.GetValues()[1].GetStringVal())
}
func TestRewrite_OREquals_NotMerged_OnNumericBelowThreshold(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field == 1 or Int64Field == 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.Nil(t, expr.GetTermExpr(), "numeric OR-equals should not merge to IN under threshold")
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
require.Equal(t, planpb.OpType_Equal, be.GetLeft().GetUnaryRangeExpr().GetOp())
require.Equal(t, planpb.OpType_Equal, be.GetRight().GetUnaryRangeExpr().GetOp())
}
func TestRewrite_Term_SortAndDedup_String(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `VarCharField in ["b","a","b","a"]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, 2, len(term.GetValues()))
require.Equal(t, "a", term.GetValues()[0].GetStringVal())
require.Equal(t, "b", term.GetValues()[1].GetStringVal())
}
func TestRewrite_Term_SortAndDedup_Int(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [9,4,6,6,7]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, 4, len(term.GetValues()))
got := []int64{
term.GetValues()[0].GetInt64Val(),
term.GetValues()[1].GetInt64Val(),
term.GetValues()[2].GetInt64Val(),
term.GetValues()[3].GetInt64Val(),
}
require.ElementsMatch(t, []int64{4, 6, 7, 9}, got)
}
func TestRewrite_NotIn_SortAndDedup_Int(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field not in [4,4,3]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
un := expr.GetUnaryExpr()
require.NotNil(t, un)
term := un.GetChild().GetTermExpr()
require.NotNil(t, term)
require.Equal(t, 2, len(term.GetValues()))
require.Equal(t, int64(3), term.GetValues()[0].GetInt64Val())
require.Equal(t, int64(4), term.GetValues()[1].GetInt64Val())
}
func TestRewrite_NotIn_SortAndDedup_Float(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `FloatField not in [4.0,4,3.0]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
un := expr.GetUnaryExpr()
require.NotNil(t, un)
term := un.GetChild().GetTermExpr()
require.NotNil(t, term)
require.Equal(t, 2, len(term.GetValues()))
require.Equal(t, 3.0, term.GetValues()[0].GetFloatVal())
require.Equal(t, 4.0, term.GetValues()[1].GetFloatVal())
}
func TestRewrite_In_SortAndDedup_Bool(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `BoolField in [true,false,false,true]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, 2, len(term.GetValues()))
require.Equal(t, false, term.GetValues()[0].GetBoolVal())
require.Equal(t, true, term.GetValues()[1].GetBoolVal())
}
func TestRewrite_Flatten_Then_OR_ToIN(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `VarCharField == "a" or (VarCharField == "b" or VarCharField == "c") or VarCharField == "d"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term, "nested OR-equals should flatten and merge to IN")
require.Equal(t, 4, len(term.GetValues()))
got := []string{
term.GetValues()[0].GetStringVal(),
term.GetValues()[1].GetStringVal(),
term.GetValues()[2].GetStringVal(),
term.GetValues()[3].GetStringVal(),
}
require.ElementsMatch(t, []string{"a", "b", "c", "d"}, got)
}
func TestRewrite_And_In_And_Equal_VInSet_ReducesToEqual(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field == 3`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_Equal, ure.GetOp())
require.Equal(t, int64(3), ure.GetValue().GetInt64Val())
}
func TestRewrite_And_In_And_Equal_VNotInSet_False(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field == 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
func TestRewrite_Or_In_Or_Equal_Union(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3] or Int64Field == 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, []int64{1, 2, 3}, []int64{
term.GetValues()[0].GetInt64Val(),
term.GetValues()[1].GetInt64Val(),
term.GetValues()[2].GetInt64Val(),
})
}
func TestRewrite_And_In_With_Range_Filter(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field > 3`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, 1, len(term.GetValues()))
require.Equal(t, int64(5), term.GetValues()[0].GetInt64Val())
}
func TestRewrite_Or_In_Union(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3] or Int64Field in [3,4]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, []int64{1, 3, 4}, []int64{
term.GetValues()[0].GetInt64Val(),
term.GetValues()[1].GetInt64Val(),
term.GetValues()[2].GetInt64Val(),
})
}
func TestRewrite_And_In_Intersection(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,2,3] and Int64Field in [2,3,4]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, []int64{2, 3}, []int64{
term.GetValues()[0].GetInt64Val(),
term.GetValues()[1].GetInt64Val(),
})
}
func TestRewrite_And_In_Intersection_Empty_ToFalse(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1] and Int64Field in [2]`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
func TestRewrite_And_In_And_NotEqual_Remove(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,2,3] and Int64Field != 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
term := expr.GetTermExpr()
require.NotNil(t, term)
require.Equal(t, []int64{1, 3}, []int64{
term.GetValues()[0].GetInt64Val(),
term.GetValues()[1].GetInt64Val(),
})
}
func TestRewrite_And_In_And_NotEqual_AllRemoved_ToFalse(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [2] and Int64Field != 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
}
func TestRewrite_Or_In_Or_NotEqual_VInSet_ToTrue(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1,2] or Int64Field != 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
val := expr.GetValueExpr()
require.NotNil(t, val)
require.Equal(t, true, val.GetValue().GetBoolVal())
}
func TestRewrite_Or_In_Or_NotEqual_VNotInSet_ToNotEqual(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `Int64Field in [1] or Int64Field != 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure)
require.Equal(t, planpb.OpType_NotEqual, ure.GetOp())
require.Equal(t, int64(2), ure.GetValue().GetInt64Val())
}
// Test contradictory equals: (a == 1) AND (a == 2) → false
// NOTE: This is a known limitation - currently NOT optimized
func TestRewrite_And_Equal_And_Equal_Contradiction_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field == 1 AND Int64Field == 2 → false (contradiction)
// Currently NOT optimized because equals don't convert to IN in AND context
expr, err := parser.ParseExpr(helper, `Int64Field == 1 and Int64Field == 2`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: NOT optimized to false
// Remains as BinaryExpr AND with two equal predicates
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as AND (not optimized)")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
}
// Test contradictory equals with three values
func TestRewrite_And_Equal_ThreeWay_Contradiction_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field == 1 AND Int64Field == 2 AND Int64Field == 3 → false
// Currently NOT optimized
expr, err := parser.ParseExpr(helper, `Int64Field == 1 and Int64Field == 2 and Int64Field == 3`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: NOT optimized to false
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as AND chain (not optimized)")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
}
// Test range + contradictory equal: (a > 10) AND (a == 5) → false
// NOTE: Current limitation - NOT optimized
func TestRewrite_And_Range_And_Equal_Contradiction_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field > 10 AND Int64Field == 5 → false (5 is not > 10)
// This requires combining range and equality checks, which is partially implemented
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field == 5`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current behavior: may not optimize to false
// If it does optimize via IN+range filtering, it would become false
// This test documents current limitation
// When fully optimized, should be constant false
_ = expr // Test documents that this case exists
}
// Test non-contradictory range + equal: (a > 10) AND (a == 15) → stays as is
// NOTE: This requires Equal to be in IN form first, which happens via combineAndInWithEqual
// But Equal alone doesn't convert to IN in AND context, so this optimization doesn't happen
func TestRewrite_And_Range_And_Equal_NonContradiction_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// Int64Field > 10 AND Int64Field == 15 → should simplify to Int64Field == 15
// But currently NOT optimized without IN involved
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field == 15`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: remains as AND with range and equal
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as AND (not optimized)")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
}
// Test string contradictory equals - current limitation
func TestRewrite_And_Equal_String_Contradiction_CurrentLimitation(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
// VarCharField == "apple" AND VarCharField == "banana" → false
// Currently NOT optimized
expr, err := parser.ParseExpr(helper, `VarCharField == "apple" and VarCharField == "banana"`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// Current limitation: NOT optimized to false
be := expr.GetBinaryExpr()
require.NotNil(t, be, "should remain as AND (not optimized)")
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
}

View File

@ -0,0 +1,78 @@
package rewriter
import (
"strings"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
)
func (v *visitor) combineOrTextMatchToMerged(parts []*planpb.Expr) []*planpb.Expr {
type group struct {
col *planpb.ColumnInfo
origIndices []int
literals []string
}
others := make([]*planpb.Expr, 0, len(parts))
groups := make(map[string]*group)
indexToExpr := parts
for idx, e := range parts {
u := e.GetUnaryRangeExpr()
if u == nil || u.GetOp() != planpb.OpType_TextMatch || u.GetValue() == nil {
others = append(others, e)
continue
}
if len(u.GetExtraValues()) > 0 {
others = append(others, e)
continue
}
col := u.GetColumnInfo()
if col == nil {
others = append(others, e)
continue
}
key := columnKey(col)
g, ok := groups[key]
if !ok {
g = &group{col: col}
groups[key] = g
}
literal := u.GetValue().GetStringVal()
g.literals = append(g.literals, literal)
g.origIndices = append(g.origIndices, idx)
}
out := make([]*planpb.Expr, 0, len(parts))
out = append(out, others...)
for _, g := range groups {
if len(g.origIndices) <= 1 {
for _, i := range g.origIndices {
out = append(out, indexToExpr[i])
}
continue
}
if len(g.literals) == 0 {
for _, i := range g.origIndices {
out = append(out, indexToExpr[i])
}
continue
}
merged := strings.Join(g.literals, " ")
out = append(out, newTextMatchExpr(g.col, merged))
}
return out
}
func newTextMatchExpr(col *planpb.ColumnInfo, literal string) *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_UnaryRangeExpr{
UnaryRangeExpr: &planpb.UnaryRangeExpr{
ColumnInfo: col,
Op: planpb.OpType_TextMatch,
Value: &planpb.GenericValue{
Val: &planpb.GenericValue_StringVal{
StringVal: literal,
},
},
},
},
}
}

View File

@ -0,0 +1,122 @@
package rewriter_test
import (
"testing"
"github.com/stretchr/testify/require"
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
)
// collectTextMatchLiterals walks the provided expr and collects all OpType_TextMatch literals (grouped by fieldId) into the returned map.
func collectTextMatchLiterals(expr *planpb.Expr) map[int64]string {
colToLiteral := map[int64]string{}
var collect func(e *planpb.Expr)
collect = func(e *planpb.Expr) {
if e == nil {
return
}
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_TextMatch {
col := ue.GetColumnInfo()
colToLiteral[col.GetFieldId()] = ue.GetValue().GetStringVal()
return
}
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalOr {
collect(be.GetLeft())
collect(be.GetRight())
return
}
}
collect(expr)
return colToLiteral
}
func TestRewrite_TextMatch_OR_Merge(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A C") or text_match(VarCharField, "B D")`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
ure := expr.GetUnaryRangeExpr()
require.NotNil(t, ure, "should merge to single text_match")
require.Equal(t, planpb.OpType_TextMatch, ure.GetOp())
require.Equal(t, "A C B D", ure.GetValue().GetStringVal())
}
func TestRewrite_TextMatch_OR_DifferentField_NoMerge(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or text_match(StringField, "B")`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
}
func TestRewrite_TextMatch_OR_MultiFields_Merge(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or text_match(StringField, "A") or text_match(VarCharField, "B")`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// should be text_match(VarCharField, "A B") or text_match(StringField, "A")
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
// collect all text_match literals grouped by column id
colToLiteral := collectTextMatchLiterals(expr)
require.Equal(t, 2, len(colToLiteral))
// one of them must be "A B", and the other is "A"
hasAB := false
hasA := false
for _, lit := range colToLiteral {
if lit == "A B" {
hasAB = true
}
if lit == "A" {
hasA = true
}
}
require.True(t, hasAB)
require.True(t, hasA)
}
func TestRewrite_TextMatch_OR_MoreMultiFields_Merge(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or (text_match(StringField, "C") or text_match(VarCharField, "B")) or text_match(StringField, "D")`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
// should be text_match(VarCharField, "A B") or text_match(StringField, "C D")
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
// collect all text_match literals grouped by column id
colToLiteral := collectTextMatchLiterals(expr)
require.Equal(t, 2, len(colToLiteral))
// one of them must be "A B", and the other is "C D"
hasAB := false
hasCD := false
for _, lit := range colToLiteral {
if lit == "A B" {
hasAB = true
}
if lit == "C D" {
hasCD = true
}
}
require.True(t, hasAB)
require.True(t, hasCD)
}
func TestRewrite_TextMatch_OR_WithOption_NoMerge(t *testing.T) {
helper := buildSchemaHelperForRewriteT(t)
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A", minimum_should_match=1) or text_match(VarCharField, "B")`, nil)
require.NoError(t, err)
require.NotNil(t, expr)
be := expr.GetBinaryExpr()
require.NotNil(t, be)
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
}

View File

@ -0,0 +1,290 @@
package rewriter
import (
"fmt"
"sort"
"strings"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func columnKey(c *planpb.ColumnInfo) string {
var b strings.Builder
b.WriteString(fmt.Sprintf("%d|%d|%d|%t|%t|%t|",
c.GetFieldId(),
int32(c.GetDataType()),
int32(c.GetElementType()),
c.GetIsPrimaryKey(),
c.GetIsAutoID(),
c.GetIsPartitionKey(),
))
for _, p := range c.GetNestedPath() {
b.WriteString(p)
b.WriteByte('|')
}
return b.String()
}
// effectiveDataType returns the real scalar type to be used for comparisons.
// For JSON/Array columns with a concrete element_type, use element_type;
// otherwise fall back to the column data_type.
func effectiveDataType(c *planpb.ColumnInfo) schemapb.DataType {
if c == nil {
return schemapb.DataType_None
}
dt := c.GetDataType()
if dt == schemapb.DataType_JSON || dt == schemapb.DataType_Array {
et := c.GetElementType()
// Treat 0 (None/Invalid) as not specified; otherwise use element type.
if et != schemapb.DataType_None {
return et
}
}
return dt
}
func valueCase(v *planpb.GenericValue) string {
switch v.GetVal().(type) {
case *planpb.GenericValue_BoolVal:
return "bool"
case *planpb.GenericValue_Int64Val:
return "int64"
case *planpb.GenericValue_FloatVal:
return "float"
case *planpb.GenericValue_StringVal:
return "string"
case *planpb.GenericValue_ArrayVal:
return "array"
default:
return "other"
}
}
func valueCaseWithNil(v *planpb.GenericValue) string {
if v == nil || v.GetVal() == nil {
return "nil"
}
return valueCase(v)
}
func isNumericCase(k string) bool {
return k == "int64" || k == "float"
}
func areComparableCases(a, b string) bool {
if a == "nil" || b == "nil" {
return false
}
if isNumericCase(a) && isNumericCase(b) {
return true
}
return a == b && (a == "bool" || a == "string")
}
func isNumericType(dt schemapb.DataType) bool {
if typeutil.IsBoolType(dt) || typeutil.IsStringType(dt) || typeutil.IsJSONType(dt) {
return false
}
return typeutil.IsArithmetic(dt)
}
const defaultConvertOrToInNumericLimit = 150
func shouldMergeToIn(dt schemapb.DataType, count int) bool {
if isNumericType(dt) {
return count > defaultConvertOrToInNumericLimit
}
return count > 1
}
func sortTermValues(term *planpb.TermExpr) {
if term == nil || len(term.GetValues()) <= 1 {
return
}
term.Values = sortGenericValues(term.Values)
}
// sort and deduplicate a list of generic values.
func sortGenericValues(values []*planpb.GenericValue) []*planpb.GenericValue {
if len(values) <= 1 {
return values
}
var kind string
for _, v := range values {
if v == nil || v.GetVal() == nil {
continue
}
kind = valueCase(v)
if kind != "" && kind != "other" && kind != "array" {
break
}
}
switch kind {
case "bool":
sort.Slice(values, func(i, j int) bool {
return !values[i].GetBoolVal() && values[j].GetBoolVal()
})
values = lo.UniqBy(values, func(v *planpb.GenericValue) bool { return v.GetBoolVal() })
case "int64":
sort.Slice(values, func(i, j int) bool {
return values[i].GetInt64Val() < values[j].GetInt64Val()
})
values = lo.UniqBy(values, func(v *planpb.GenericValue) int64 { return v.GetInt64Val() })
case "float":
sort.Slice(values, func(i, j int) bool {
return values[i].GetFloatVal() < values[j].GetFloatVal()
})
values = lo.UniqBy(values, func(v *planpb.GenericValue) float64 { return v.GetFloatVal() })
case "string":
sort.Slice(values, func(i, j int) bool {
return values[i].GetStringVal() < values[j].GetStringVal()
})
values = lo.UniqBy(values, func(v *planpb.GenericValue) string { return v.GetStringVal() })
}
return values
}
func newTermExpr(col *planpb.ColumnInfo, values []*planpb.GenericValue) *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_TermExpr{
TermExpr: &planpb.TermExpr{
ColumnInfo: col,
Values: values,
},
},
}
}
func newUnaryRangeExpr(col *planpb.ColumnInfo, op planpb.OpType, val *planpb.GenericValue) *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_UnaryRangeExpr{
UnaryRangeExpr: &planpb.UnaryRangeExpr{
ColumnInfo: col,
Op: op,
Value: val,
},
},
}
}
func newBoolConstExpr(v bool) *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_ValueExpr{
ValueExpr: &planpb.ValueExpr{
Value: &planpb.GenericValue{
Val: &planpb.GenericValue_BoolVal{
BoolVal: v,
},
},
},
},
}
}
func newAlwaysTrueExpr() *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_AlwaysTrueExpr{
AlwaysTrueExpr: &planpb.AlwaysTrueExpr{},
},
}
}
func newAlwaysFalseExpr() *planpb.Expr {
return &planpb.Expr{
Expr: &planpb.Expr_UnaryExpr{
UnaryExpr: &planpb.UnaryExpr{
Op: planpb.UnaryExpr_Not,
Child: newAlwaysTrueExpr(),
},
},
}
}
// IsAlwaysTrueExpr checks if the expression is an AlwaysTrueExpr
func IsAlwaysTrueExpr(e *planpb.Expr) bool {
if e == nil {
return false
}
return e.GetAlwaysTrueExpr() != nil
}
func IsAlwaysFalseExpr(e *planpb.Expr) bool {
if e == nil {
return false
}
ue := e.GetUnaryExpr()
if ue == nil || ue.GetOp() != planpb.UnaryExpr_Not {
return false
}
return IsAlwaysTrueExpr(ue.GetChild())
}
// equalsGeneric compares two GenericValue by content (bool/int/float/string).
func equalsGeneric(a, b *planpb.GenericValue) bool {
if a.GetVal() == nil || b.GetVal() == nil {
return false
}
switch a.GetVal().(type) {
case *planpb.GenericValue_BoolVal:
if _, ok := b.GetVal().(*planpb.GenericValue_BoolVal); ok {
return a.GetBoolVal() == b.GetBoolVal()
}
case *planpb.GenericValue_Int64Val:
if _, ok := b.GetVal().(*planpb.GenericValue_Int64Val); ok {
return a.GetInt64Val() == b.GetInt64Val()
}
case *planpb.GenericValue_FloatVal:
if _, ok := b.GetVal().(*planpb.GenericValue_FloatVal); ok {
return a.GetFloatVal() == b.GetFloatVal()
}
case *planpb.GenericValue_StringVal:
if _, ok := b.GetVal().(*planpb.GenericValue_StringVal); ok {
return a.GetStringVal() == b.GetStringVal()
}
}
return false
}
func satisfiesLower(dt schemapb.DataType, v, lower *planpb.GenericValue, inclusive bool) bool {
c := cmpGeneric(dt, v, lower)
if inclusive {
return c >= 0
}
return c > 0
}
func satisfiesUpper(dt schemapb.DataType, v, upper *planpb.GenericValue, inclusive bool) bool {
c := cmpGeneric(dt, v, upper)
if inclusive {
return c <= 0
}
return c < 0
}
func filterValuesByRange(dt schemapb.DataType, values []*planpb.GenericValue, lower *planpb.GenericValue, lowerInc bool, upper *planpb.GenericValue, upperInc bool) []*planpb.GenericValue {
out := make([]*planpb.GenericValue, 0, len(values))
for _, v := range values {
pass := true
if lower != nil && !satisfiesLower(dt, v, lower, lowerInc) {
pass = false
}
if pass && upper != nil && !satisfiesUpper(dt, v, upper, upperInc) {
pass = false
}
if pass {
out = append(out, v)
}
}
return out
}
func unionValues(valuesA, valuesB []*planpb.GenericValue) []*planpb.GenericValue {
all := append([]*planpb.GenericValue{}, valuesA...)
all = append(all, valuesB...)
return sortGenericValues(all)
}

View File

@ -4709,7 +4709,7 @@ func (s *MaterializedViewTestSuite) TestMvEnabledPartitionKeyOnVarCharWithIsolat
schema := ConstructCollectionSchemaWithPartitionKey(s.colName, s.fieldName2Types, testInt64Field, testVarCharField, false)
schemaInfo := newSchemaInfo(schema)
s.mockMetaCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schemaInfo, nil)
s.ErrorContains(task.PreExecute(s.ctx), "partition key isolation does not support OR")
s.ErrorContains(task.PreExecute(s.ctx), "partition key isolation does not support IN")
}
}

View File

@ -80,9 +80,9 @@ func TestParsePartitionKeys(t *testing.T) {
{
name: "binary_expr_and with partition key in range",
expr: "partition_key_field in [7, 8] && partition_key_field > 9",
expected: 2,
validPartitionKeys: []int64{7, 8},
invalidPartitionKeys: []int64{9},
expected: 0,
validPartitionKeys: []int64{},
invalidPartitionKeys: []int64{},
},
{
name: "binary_expr_and with partition key in range2",
@ -295,7 +295,7 @@ func TestValidatePartitionKeyIsolation(t *testing.T) {
{
name: "partition key isolation equal AND with same field term",
expr: "key_field == 10 && key_field in [10]",
expectedErrorString: "partition key isolation does not support IN",
expectedErrorString: "",
},
{
name: "partition key isolation equal OR with same field equal",

View File

@ -589,33 +589,3 @@ func TestDeleteInvalidExpr(t *testing.T) {
common.CheckErr(t, err, _invalidExpr.ErrNil, _invalidExpr.ErrMsg)
}
}
// test delete with duplicated data ids
func TestDeleteDuplicatedPks(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create collection and a partition
cp := hp.NewCreateCollectionParams(hp.Int64Vec)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption().TWithIsDynamic(true), hp.TNewSchemaOption())
// insert [0, 3000) into default
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithMaxCapacity(common.TestCapacity))
prepare.FlushData(ctx, t, mc, schema.CollectionName)
// index and load
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// delete
deleteIDs := []int64{0, 0, 0, 0, 0}
delRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithInt64IDs(common.DefaultInt64FieldName, deleteIDs))
common.CheckErr(t, err, true)
require.Equal(t, 5, int(delRes.DeleteCount))
// query, verify delete success
expr := fmt.Sprintf("%s >= 0 ", common.DefaultInt64FieldName)
resQuery, errQuery := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, errQuery, true)
require.Equal(t, common.DefaultNb-1, resQuery.ResultCount)
}

View File

@ -565,22 +565,6 @@ class TestDeleteOperation(TestcaseBase):
# since the search requests arrived query nodes earlier than query nodes consume the delete requests.
assert len(inter) == 0
@pytest.mark.tags(CaseLabel.L1)
def test_delete_expr_repeated_values(self):
"""
target: test delete with repeated values
method: 1.insert data with unique primary keys
2.delete with repeated values: 'id in [0, 0]'
expected: delete one entity
"""
# init collection with nb default data
collection_w = self.init_collection_general(prefix, nb=tmp_nb, insert_data=True)[0]
expr = f'{ct.default_int64_field_name} in {[0, 0, 0]}'
del_res, _ = collection_w.delete(expr)
assert del_res.delete_count == 3
collection_w.num_entities
collection_w.query(expr, check_task=CheckTasks.check_query_empty)
@pytest.mark.tags(CaseLabel.L1)
def test_delete_duplicate_primary_keys(self):
"""
@ -1433,7 +1417,7 @@ class TestDeleteString(TestcaseBase):
self.init_collection_general(prefix, nb=tmp_nb, insert_data=True, primary_field=ct.default_string_field_name)[0]
expr = f'{ct.default_string_field_name} in ["0", "0", "0"]'
del_res, _ = collection_w.delete(expr)
assert del_res.delete_count == 3
assert del_res.delete_count == 1
collection_w.num_entities
collection_w.query(expr, check_task=CheckTasks.check_query_empty)
@ -1939,7 +1923,7 @@ class TestDeleteString(TestcaseBase):
# delete
string_expr = "varchar in [\"\", \"\"]"
del_res, _ = collection_w.delete(string_expr)
assert del_res.delete_count == 2
assert del_res.delete_count == 1
# load and query with id
collection_w.load()