mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
enhance: moved query optimization to proxy, added various optimizations (#45526)
issue: https://github.com/milvus-io/milvus/issues/45525 see added README.md for added optimizations <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added query expression optimization feature with a new `optimizeExpr` configuration flag to enable automatic simplification of filter predicates, including range predicate optimization, merging of IN/NOT IN conditions, and flattening of nested logical operators. * **Bug Fixes** * Adjusted delete operation behavior to correctly handle expression evaluation. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
This commit is contained in:
parent
7fca6e759f
commit
e379b1f0f4
@ -18,7 +18,6 @@
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "fmt/format.h"
|
||||
#include "exec/expression/AlwaysTrueExpr.h"
|
||||
#include "exec/expression/BinaryArithOpEvalRangeExpr.h"
|
||||
#include "exec/expression/BinaryRangeExpr.h"
|
||||
@ -84,75 +83,19 @@ CompileExpressions(const std::vector<expr::TypedExprPtr>& sources,
|
||||
return exprs;
|
||||
}
|
||||
|
||||
static std::optional<std::string>
|
||||
ShouldFlatten(const expr::TypedExprPtr& expr,
|
||||
const std::unordered_set<std::string>& flat_candidates = {}) {
|
||||
if (auto call =
|
||||
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
|
||||
if (call->op_type_ == expr::LogicalBinaryExpr::OpType::And ||
|
||||
call->op_type_ == expr::LogicalBinaryExpr::OpType::Or) {
|
||||
return call->name();
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static bool
|
||||
IsCall(const expr::TypedExprPtr& expr, const std::string& name) {
|
||||
if (auto call =
|
||||
std::dynamic_pointer_cast<const expr::LogicalBinaryExpr>(expr)) {
|
||||
return call->name() == name;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool
|
||||
AllInputTypeEqual(const expr::TypedExprPtr& expr) {
|
||||
const auto& inputs = expr->inputs();
|
||||
for (int i = 1; i < inputs.size(); i++) {
|
||||
if (inputs[0]->type() != inputs[i]->type()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void
|
||||
FlattenInput(const expr::TypedExprPtr& input,
|
||||
const std::string& flatten_call,
|
||||
std::vector<expr::TypedExprPtr>& flat) {
|
||||
if (IsCall(input, flatten_call) && AllInputTypeEqual(input)) {
|
||||
for (auto& child : input->inputs()) {
|
||||
FlattenInput(child, flatten_call, flat);
|
||||
}
|
||||
} else {
|
||||
flat.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ExprPtr>
|
||||
CompileInputs(const expr::TypedExprPtr& expr,
|
||||
QueryContext* context,
|
||||
const std::unordered_set<std::string>& flatten_cadidates) {
|
||||
std::vector<ExprPtr> compiled_inputs;
|
||||
auto flatten = ShouldFlatten(expr);
|
||||
for (auto& input : expr->inputs()) {
|
||||
if (dynamic_cast<const expr::InputTypeExpr*>(input.get())) {
|
||||
AssertInfo(
|
||||
dynamic_cast<const expr::FieldAccessTypeExpr*>(expr.get()),
|
||||
"An InputReference can only occur under a FieldReference");
|
||||
} else {
|
||||
if (flatten.has_value()) {
|
||||
std::vector<expr::TypedExprPtr> flat_exprs;
|
||||
FlattenInput(input, flatten.value(), flat_exprs);
|
||||
for (auto& input : flat_exprs) {
|
||||
compiled_inputs.push_back(CompileExpression(
|
||||
input, context, flatten_cadidates, false));
|
||||
}
|
||||
} else {
|
||||
compiled_inputs.push_back(CompileExpression(
|
||||
input, context, flatten_cadidates, false));
|
||||
}
|
||||
compiled_inputs.push_back(
|
||||
CompileExpression(input, context, flatten_cadidates, false));
|
||||
}
|
||||
}
|
||||
return compiled_inputs;
|
||||
@ -507,170 +450,6 @@ ReorderConjunctExpr(std::shared_ptr<milvus::exec::PhyConjunctFilterExpr>& expr,
|
||||
expr->Reorder(reorder);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<PhyLogicalUnaryExpr>
|
||||
ConvertMultiNotEqualToNotInExpr(std::vector<std::shared_ptr<Expr>>& exprs,
|
||||
std::vector<size_t> indices,
|
||||
ExecContext* context) {
|
||||
std::vector<proto::plan::GenericValue> values;
|
||||
auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
|
||||
for (auto& i : indices) {
|
||||
auto expr = std::static_pointer_cast<PhyUnaryRangeFilterExpr>(exprs[i])
|
||||
->GetLogicalExpr();
|
||||
if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) {
|
||||
type = expr->val_.val_case();
|
||||
}
|
||||
if (type != expr->val_.val_case()) {
|
||||
return nullptr;
|
||||
}
|
||||
values.push_back(expr->val_);
|
||||
}
|
||||
auto logical_expr = std::make_shared<milvus::expr::TermFilterExpr>(
|
||||
exprs[indices[0]]->GetColumnInfo().value(), values);
|
||||
auto query_context = context->get_query_context();
|
||||
auto term_expr = std::make_shared<PhyTermFilterExpr>(
|
||||
std::vector<std::shared_ptr<Expr>>{},
|
||||
logical_expr,
|
||||
"PhyTermFilterExpr",
|
||||
query_context->get_op_context(),
|
||||
query_context->get_segment(),
|
||||
query_context->get_active_count(),
|
||||
query_context->get_query_timestamp(),
|
||||
query_context->query_config()->get_expr_batch_size(),
|
||||
query_context->get_consistency_level());
|
||||
return std::make_shared<PhyLogicalUnaryExpr>(
|
||||
std::vector<std::shared_ptr<Expr>>{term_expr},
|
||||
std::make_shared<milvus::expr::LogicalUnaryExpr>(
|
||||
milvus::expr::LogicalUnaryExpr::OpType::LogicalNot, logical_expr),
|
||||
"PhyLogicalUnaryExpr",
|
||||
query_context->get_op_context());
|
||||
}
|
||||
|
||||
inline std::shared_ptr<PhyTermFilterExpr>
|
||||
ConvertMultiOrToInExpr(std::vector<std::shared_ptr<Expr>>& exprs,
|
||||
std::vector<size_t> indices,
|
||||
ExecContext* context) {
|
||||
std::vector<proto::plan::GenericValue> values;
|
||||
auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
|
||||
for (auto& i : indices) {
|
||||
auto expr = std::static_pointer_cast<PhyUnaryRangeFilterExpr>(exprs[i])
|
||||
->GetLogicalExpr();
|
||||
if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) {
|
||||
type = expr->val_.val_case();
|
||||
}
|
||||
if (type != expr->val_.val_case()) {
|
||||
return nullptr;
|
||||
}
|
||||
values.push_back(expr->val_);
|
||||
}
|
||||
auto logical_expr = std::make_shared<milvus::expr::TermFilterExpr>(
|
||||
exprs[indices[0]]->GetColumnInfo().value(), values);
|
||||
auto query_context = context->get_query_context();
|
||||
return std::make_shared<PhyTermFilterExpr>(
|
||||
std::vector<std::shared_ptr<Expr>>{},
|
||||
logical_expr,
|
||||
"PhyTermFilterExpr",
|
||||
query_context->get_op_context(),
|
||||
query_context->get_segment(),
|
||||
query_context->get_active_count(),
|
||||
query_context->get_query_timestamp(),
|
||||
query_context->query_config()->get_expr_batch_size(),
|
||||
query_context->get_consistency_level());
|
||||
}
|
||||
|
||||
inline void
|
||||
RewriteConjunctExpr(std::shared_ptr<milvus::exec::PhyConjunctFilterExpr>& expr,
|
||||
ExecContext* context) {
|
||||
// covert A = .. or A = .. or A = .. to A in (.., .., ..)
|
||||
if (expr->IsOr()) {
|
||||
auto& inputs = expr->GetInputsRef();
|
||||
std::map<expr::ColumnInfo, std::vector<size_t>> expr_indices;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto input = inputs[i];
|
||||
if (input->name() == "PhyUnaryRangeFilterExpr") {
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<PhyUnaryRangeFilterExpr>(input);
|
||||
if (phy_expr->GetOpType() == proto::plan::OpType::Equal) {
|
||||
auto column = phy_expr->GetColumnInfo().value();
|
||||
if (expr_indices.find(column) != expr_indices.end()) {
|
||||
expr_indices[column].push_back(i);
|
||||
} else {
|
||||
expr_indices[column] = {i};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (input->name() == "PhyConjunctFilterExpr") {
|
||||
auto expr = std::static_pointer_cast<
|
||||
milvus::exec::PhyConjunctFilterExpr>(input);
|
||||
RewriteConjunctExpr(expr, context);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& [column, indices] : expr_indices) {
|
||||
// For numeric type, if or column greater than 150, then using in expr replace.
|
||||
// For other type, all convert to in expr.
|
||||
if ((IsNumericDataType(column.data_type_) &&
|
||||
indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) ||
|
||||
(!IsNumericDataType(column.data_type_) && indices.size() > 1)) {
|
||||
auto new_expr =
|
||||
ConvertMultiOrToInExpr(inputs, indices, context);
|
||||
if (new_expr) {
|
||||
inputs[indices[0]] = new_expr;
|
||||
for (size_t j = 1; j < indices.size(); j++) {
|
||||
inputs[indices[j]] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr),
|
||||
inputs.end());
|
||||
}
|
||||
|
||||
// convert A != .. and A != .. and A != .. to not A in (.., .., ..)
|
||||
if (expr->IsAnd()) {
|
||||
auto& inputs = expr->GetInputsRef();
|
||||
std::map<expr::ColumnInfo, std::vector<size_t>> expr_indices;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto input = inputs[i];
|
||||
if (input->name() == "PhyUnaryRangeFilterExpr") {
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<PhyUnaryRangeFilterExpr>(input);
|
||||
if (phy_expr->GetOpType() == proto::plan::OpType::NotEqual) {
|
||||
auto column = phy_expr->GetColumnInfo().value();
|
||||
if (expr_indices.find(column) != expr_indices.end()) {
|
||||
expr_indices[column].push_back(i);
|
||||
} else {
|
||||
expr_indices[column] = {i};
|
||||
}
|
||||
}
|
||||
|
||||
if (input->name() == "PhyConjunctFilterExpr") {
|
||||
auto expr = std::static_pointer_cast<
|
||||
milvus::exec::PhyConjunctFilterExpr>(input);
|
||||
RewriteConjunctExpr(expr, context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& [column, indices] : expr_indices) {
|
||||
if ((IsNumericDataType(column.data_type_) &&
|
||||
indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) ||
|
||||
(!IsNumericDataType(column.data_type_) && indices.size() > 1)) {
|
||||
auto new_expr =
|
||||
ConvertMultiNotEqualToNotInExpr(inputs, indices, context);
|
||||
if (new_expr) {
|
||||
inputs[indices[0]] = new_expr;
|
||||
for (size_t j = 1; j < indices.size(); j++) {
|
||||
inputs[indices[j]] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr),
|
||||
inputs.end());
|
||||
}
|
||||
}
|
||||
|
||||
inline void
|
||||
SetNamespaceSkipIndex(std::shared_ptr<PhyConjunctFilterExpr> conjunct_expr,
|
||||
ExecContext* context) {
|
||||
@ -727,7 +506,6 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector<ExprPtr>& exprs) {
|
||||
LOG_DEBUG("before reoder filter expression: {}", expr->ToString());
|
||||
auto conjunct_expr =
|
||||
std::static_pointer_cast<PhyConjunctFilterExpr>(expr);
|
||||
RewriteConjunctExpr(conjunct_expr, context);
|
||||
bool has_heavy_operation = false;
|
||||
ReorderConjunctExpr(conjunct_expr, context, has_heavy_operation);
|
||||
LOG_DEBUG("after reorder filter expression: {}", expr->ToString());
|
||||
|
||||
@ -256,183 +256,6 @@ TEST_P(TaskTest, LogicalExpr) {
|
||||
EXPECT_EQ(num_rows, num_rows_);
|
||||
}
|
||||
|
||||
TEST_P(TaskTest, CompileInputs_and) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid =
|
||||
schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
|
||||
proto::plan::GenericValue val;
|
||||
val.set_int64_val(10);
|
||||
// expr: (int64_fid < 10 and int64_fid < 10) and (int64_fid < 10 and int64_fid < 10)
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr3, expr6);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
auto exprs = milvus::exec::CompileInputs(expr7, query_context.get(), {});
|
||||
EXPECT_EQ(exprs.size(), 4);
|
||||
for (int i = 0; i < exprs.size(); ++i) {
|
||||
std::cout << exprs[i]->name() << std::endl;
|
||||
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(TaskTest, CompileInputs_or_with_and) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid =
|
||||
schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
|
||||
proto::plan::GenericValue val;
|
||||
val.set_int64_val(10);
|
||||
{
|
||||
// expr: (int64_fid > 10 and int64_fid > 10) or (int64_fid > 10 and int64_fid > 10)
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr3, expr6);
|
||||
auto exprs =
|
||||
milvus::exec::CompileInputs(expr7, query_context.get(), {});
|
||||
EXPECT_EQ(exprs.size(), 2);
|
||||
for (int i = 0; i < exprs.size(); ++i) {
|
||||
std::cout << exprs[i]->name() << std::endl;
|
||||
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
}
|
||||
}
|
||||
{
|
||||
// expr: (int64_fid < 10 or int64_fid < 10) or (int64_fid > 10 and int64_fid > 10)
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr3, expr6);
|
||||
auto exprs =
|
||||
milvus::exec::CompileInputs(expr7, query_context.get(), {});
|
||||
std::cout << exprs.size() << std::endl;
|
||||
EXPECT_EQ(exprs.size(), 3);
|
||||
for (int i = 0; i < exprs.size() - 1; ++i) {
|
||||
std::cout << exprs[i]->name() << std::endl;
|
||||
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr");
|
||||
}
|
||||
EXPECT_STREQ(exprs[2]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
}
|
||||
{
|
||||
// expr: (int64_fid > 10 or int64_fid > 10) and (int64_fid > 10 and int64_fid > 10)
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(int64_fid, DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr3, expr6);
|
||||
auto exprs =
|
||||
milvus::exec::CompileInputs(expr7, query_context.get(), {});
|
||||
std::cout << exprs.size() << std::endl;
|
||||
EXPECT_EQ(exprs.size(), 3);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
for (int i = 1; i < exprs.size(); ++i) {
|
||||
std::cout << exprs[i]->name() << std::endl;
|
||||
EXPECT_STREQ(exprs[i]->name().c_str(), "PhyUnaryRangeFilterExpr");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(TaskTest, Test_reorder) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
@ -698,530 +521,6 @@ TEST_P(TaskTest, Test_reorder) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(TaskTest, Test_MultiNotEqualConvert) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
using namespace milvus::exec;
|
||||
|
||||
{
|
||||
// expr: string2 != "111" and string2 != "222" and string2 != "333"
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_string_val("333");
|
||||
auto expr3 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr4 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr4, expr3);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 1);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyLogicalUnaryExpr");
|
||||
EXPECT_EQ(inputs[0]->GetInputsRef().size(), 1);
|
||||
EXPECT_STREQ(inputs[0]->GetInputsRef()[0]->name().c_str(),
|
||||
"PhyTermFilterExpr");
|
||||
}
|
||||
|
||||
{
|
||||
// expr: int64 != 111 and int64 != 222 and int64 != 333
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_int64_val(111);
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_int64_val(222);
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_int64_val(333);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr3, expr4);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 3);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyUnaryRangeFilterExpr");
|
||||
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr");
|
||||
EXPECT_STREQ(inputs[2]->name().c_str(), "PhyUnaryRangeFilterExpr");
|
||||
}
|
||||
|
||||
{
|
||||
// expr: string2 != "111" and string2 != "222" and (int64 > 10 && int64 < 100) and string2 != "333"
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_int64_val(10);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val4;
|
||||
val4.set_int64_val(100);
|
||||
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::LessThan,
|
||||
val4,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr4, expr5);
|
||||
|
||||
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr6, expr3);
|
||||
|
||||
proto::plan::GenericValue val5;
|
||||
val5.set_string_val("333");
|
||||
auto expr8 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val5,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr9 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr7, expr8);
|
||||
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr9}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 2);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyLogicalUnaryExpr");
|
||||
auto phy_expr1 =
|
||||
std::static_pointer_cast<milvus::exec::PhyLogicalUnaryExpr>(
|
||||
inputs[1]);
|
||||
EXPECT_EQ(phy_expr1->GetInputsRef().size(), 1);
|
||||
EXPECT_STREQ(phy_expr1->GetInputsRef()[0]->name().c_str(),
|
||||
"PhyTermFilterExpr");
|
||||
phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
inputs[0]);
|
||||
inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 2);
|
||||
}
|
||||
|
||||
{
|
||||
// expr: json['a'] != "111" and json['a'] != "222" and json['a'] != "333"
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr1, expr2);
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_string_val("333");
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::NotEqual,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr3, expr4);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 1);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyLogicalUnaryExpr");
|
||||
auto phy_expr1 =
|
||||
std::static_pointer_cast<milvus::exec::PhyLogicalUnaryExpr>(
|
||||
inputs[0]);
|
||||
EXPECT_EQ(phy_expr1->GetInputsRef().size(), 1);
|
||||
EXPECT_STREQ(phy_expr1->GetInputsRef()[0]->name().c_str(),
|
||||
"PhyTermFilterExpr");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(TaskTest, Test_MultiInConvert) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
using namespace milvus::exec;
|
||||
|
||||
{
|
||||
// expr: string2 == '111' or string2 == '222' or string2 == "333"
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::Equal,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::Equal,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_string_val("333");
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::Equal,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 1);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr");
|
||||
}
|
||||
|
||||
{
|
||||
// expr: string2 == '111' or string2 == '222' or (int64 > 10 && int64 < 100) or string2 == "333"
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::Equal,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::Equal,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_int64_val(10);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::GreaterThan,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val4;
|
||||
val4.set_int64_val(100);
|
||||
auto expr5 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::LessThan,
|
||||
val4,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr6 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::And, expr4, expr5);
|
||||
|
||||
auto expr7 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr6, expr3);
|
||||
|
||||
proto::plan::GenericValue val5;
|
||||
val5.set_string_val("333");
|
||||
auto expr8 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR),
|
||||
proto::plan::OpType::Equal,
|
||||
val5,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr9 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr7, expr8);
|
||||
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr9}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 2);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyTermFilterExpr");
|
||||
}
|
||||
{
|
||||
// expr: json['a'] == "111" or json['a'] == "222" or json['3'] = "333"
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_string_val("333");
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 1);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr");
|
||||
}
|
||||
|
||||
{
|
||||
// expr: json['a'] == "111" or json['b'] == "222" or json['a'] == "333"
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'b'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_string_val("333");
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 2);
|
||||
EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr");
|
||||
EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr");
|
||||
}
|
||||
|
||||
{
|
||||
// expr: json['a'] == "111" or json['b'] == "222" or json['a'] == 1
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_string_val("111");
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_string_val("222");
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'b'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_int64_val(1);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["json"],
|
||||
DataType::JSON,
|
||||
std::vector<std::string>{'a'}),
|
||||
proto::plan::OpType::Equal,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 3);
|
||||
}
|
||||
|
||||
{
|
||||
// expr: int1 == 11 or int1 == 22 or int3 == 33
|
||||
proto::plan::GenericValue val1;
|
||||
val1.set_int64_val(11);
|
||||
auto expr1 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::Equal,
|
||||
val1,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
proto::plan::GenericValue val2;
|
||||
val2.set_int64_val(222);
|
||||
auto expr2 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::Equal,
|
||||
val2,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr3 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr1, expr2);
|
||||
proto::plan::GenericValue val3;
|
||||
val3.set_int64_val(1);
|
||||
auto expr4 = std::make_shared<expr::UnaryRangeFilterExpr>(
|
||||
expr::ColumnInfo(field_map_["int64"], DataType::INT64),
|
||||
proto::plan::OpType::Equal,
|
||||
val3,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto expr5 = std::make_shared<expr::LogicalBinaryExpr>(
|
||||
expr::LogicalBinaryExpr::OpType::Or, expr3, expr4);
|
||||
auto query_context = std::make_shared<milvus::exec::QueryContext>(
|
||||
DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP);
|
||||
ExecContext context(query_context.get());
|
||||
auto exprs =
|
||||
milvus::exec::CompileExpressions({expr5}, &context, {}, false);
|
||||
EXPECT_EQ(exprs.size(), 1);
|
||||
EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr");
|
||||
auto phy_expr =
|
||||
std::static_pointer_cast<milvus::exec::PhyConjunctFilterExpr>(
|
||||
exprs[0]);
|
||||
auto inputs = phy_expr->GetInputsRef();
|
||||
EXPECT_EQ(inputs.size(), 3);
|
||||
}
|
||||
}
|
||||
|
||||
// This test verifies the fix for https://github.com/milvus-io/milvus/issues/46053.
|
||||
//
|
||||
// Bug scenario:
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
@ -149,6 +150,7 @@ func parseExprInner(schema *typeutil.SchemaHelper, exprStr string, exprTemplateV
|
||||
return nil, err
|
||||
}
|
||||
|
||||
predicate.expr = rewriter.RewriteExpr(predicate.expr)
|
||||
return predicate.expr, nil
|
||||
}
|
||||
|
||||
|
||||
138
internal/parser/planparserv2/rewriter/README.md
Normal file
138
internal/parser/planparserv2/rewriter/README.md
Normal file
@ -0,0 +1,138 @@
|
||||
## Expression Rewriter (planparserv2/rewriter)
|
||||
|
||||
This module performs rule-based logical rewrites on parsed `planpb.Expr` trees right after template value filling and before planning/execution.
|
||||
|
||||
### Entry
|
||||
- `RewriteExpr(*planpb.Expr) *planpb.Expr` (in `entry.go`)
|
||||
- Recursively visits the expression tree and applies a set of composable, side-effect-free rewrite rules.
|
||||
- Uses global configuration from `paramtable.Get().CommonCfg.EnabledOptimizeExpr`
|
||||
- `RewriteExprWithConfig(*planpb.Expr, bool) *planpb.Expr` (in `entry.go`)
|
||||
- Same as `RewriteExpr` but allows custom configuration for testing or special cases.
|
||||
|
||||
### Configuration
|
||||
|
||||
The rewriter can be configured via the following parameter (refreshable at runtime):
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `common.enabledOptimizeExpr` | `true` | Enable query expression optimization including range simplification, IN/NOT IN merge, TEXT_MATCH merge, and all other optimizations |
|
||||
|
||||
**IMPORTANT**: IN/NOT IN value list sorting and deduplication **always** runs regardless of this configuration setting, because the execution engine depends on sorted value lists.
|
||||
|
||||
### Implemented Rules
|
||||
|
||||
1) IN / NOT IN normalization and merges (`term_in.go`)
|
||||
- OR-equals to IN (same column):
|
||||
- `a == v1 OR a == v2 ...` → `a IN (v1, v2, ...)`
|
||||
- Numeric columns only merge when count > threshold (default 150); others when count > 1.
|
||||
- AND-not-equals to NOT IN (same column):
|
||||
- `a != v1 AND a != v2 ...` → `NOT (a IN (v1, v2, ...))`
|
||||
- Same thresholds as above.
|
||||
- IN vs Equal redundancy elimination (same column):
|
||||
- AND: `(a ∈ S) AND (a = v)`:
|
||||
- if `v ∈ S` → `a = v`
|
||||
- if `v ∉ S` → contradiction → constant `false`
|
||||
- OR: `(a ∈ S) OR (a = v)` → `a ∈ (S ∪ {v})` (always union)
|
||||
- IN with IN union:
|
||||
- OR: `(a ∈ S1) OR (a ∈ S2)` → `a ∈ (S1 ∪ S2)` with sorting/dedup
|
||||
- AND: `(a ∈ S1) AND (a ∈ S2)` → `a ∈ (S1 ∩ S2)`; empty intersection → constant `false`
|
||||
- Sort and deduplicate `IN` / `NOT IN` value lists (supported types: bool, int64, float64, string).
|
||||
|
||||
2) TEXT_MATCH OR merge (`text_match.go`)
|
||||
- Merge ORs of `TEXT_MATCH(field, "literal")` on the same column (no options):
|
||||
- Concatenate literals with a single space in the order they appear; no tokenization, deduplication, or sorting is performed.
|
||||
- Example: `TEXT_MATCH(f, "A C") OR TEXT_MATCH(f, "B D")` → `TEXT_MATCH(f, "A C B D")`
|
||||
- If any `TEXT_MATCH` in the group has options (e.g., `minimum_should_match`), this optimization is skipped for that group.
|
||||
|
||||
3) Range predicate simplification (`range.go`)
|
||||
- AND tighten (same column):
|
||||
- Lower bounds: `a > 10 AND a > 20` → `a > 20` (pick strongest lower)
|
||||
- Upper bounds: `a < 50 AND a < 60` → `a < 50` (pick strongest upper)
|
||||
- Mixed lower and upper: `a > 10 AND a < 50` → `10 < a < 50` (BinaryRangeExpr)
|
||||
- Inclusion respected (>, >=, <, <=). On ties, exclusive is considered stronger than inclusive for tightening.
|
||||
- OR weaken (same column, same direction):
|
||||
- Lower bounds: `a > 10 OR a > 20` → `a > 10` (pick weakest lower)
|
||||
- Upper bounds: `a < 10 OR a < 20` → `a < 20` (pick weakest upper)
|
||||
- Inclusion respected, preferring inclusive for weakening in ties.
|
||||
- Mixed-direction OR (lower vs upper) is not merged.
|
||||
- Equivalent-bound collapses (same column, same value):
|
||||
- AND: `a ≥ x AND a > x` → `a > x`; `a ≤ y AND a < y` → `a < y`
|
||||
- OR: `a ≥ x OR a > x` → `a ≥ x`; `a ≤ y OR a < y` → `a ≤ y`
|
||||
- Symmetric dedup: `a > 10 AND a ≥ 10` → `a > 10`; `a < 5 OR a ≤ 5` → `a ≤ 5`
|
||||
- IN ∩ range filtering:
|
||||
- AND: `(a ∈ {…}) AND (range)` → keep only values in the set that satisfy the range
|
||||
- e.g., `{1,3,5} AND a > 3` → `{5}`
|
||||
- Supported columns for range optimization:
|
||||
- Scalar: Int8/Int16/Int32/Int64, Float/Double, VarChar
|
||||
- Array element access: when indexing an element (e.g., `ArrayInt[0]`), the element type above applies
|
||||
- JSON/dynamic fields with nested paths (e.g., `JSONField["price"]`, `$meta["age"]`) are range-optimized
|
||||
- Type determined from literal value (int, float, string)
|
||||
- Numeric types (int and float) are compatible and normalized to Double for merging
|
||||
- Different type categories are not merged (e.g., `json["a"] > 10` and `json["a"] > "hello"` remain separate)
|
||||
- Bool literals are not optimized (no meaningful ranges)
|
||||
- Literal compatibility:
|
||||
- Integer columns require integer literals (e.g., `Int64Field > 10`)
|
||||
- Float/Double columns accept both integer and float literals (e.g., `FloatField > 10` or `> 10.5`)
|
||||
- Column identity:
|
||||
- Merges only happen within the same `ColumnInfo` (including nested path and element index). For example, `ArrayInt[0]` and `ArrayInt[1]` are different columns and are not merged with each other.
|
||||
- BinaryRangeExpr merging:
|
||||
- AND: Merge multiple `BinaryRangeExpr` nodes on the same column to compute intersection (max lower, min upper)
|
||||
- `(10 < x < 50) AND (20 < x < 40)` → `(20 < x < 40)`
|
||||
- Empty intersection → constant `false`
|
||||
- AND with UnaryRangeExpr: Update appropriate bound of `BinaryRangeExpr`
|
||||
- `(10 < x < 50) AND (x > 30)` → `(30 < x < 50)`
|
||||
- OR: Merge overlapping or adjacent `BinaryRangeExpr` nodes into wider interval
|
||||
- `(10 < x < 25) OR (20 < x < 40)` → `(10 < x < 40)` (overlapping)
|
||||
- `(10 < x <= 20) OR (20 <= x < 30)` → `(10 < x < 30)` (adjacent with inclusive)
|
||||
- Disjoint intervals remain separate: `(10 < x < 20) OR (30 < x < 40)` → remains as OR
|
||||
- Inclusivity handling: AND prefers exclusive on equal bounds (stronger), OR prefers inclusive (weaker)
|
||||
|
||||
### General Notes
|
||||
- All merges require operands to target the same column (same `ColumnInfo`, including nested path/element type).
|
||||
- Rewrite runs after template value filling; template placeholders do not appear here.
|
||||
- Sorting/dedup for IN/NOT IN is deterministic; duplicates are removed post-sort.
|
||||
- Numeric-threshold for OR→IN / AND≠→NOT IN is defined in `util.go` (`defaultConvertOrToInNumericLimit`, default 150).
|
||||
|
||||
### Pass Ordering (current)
|
||||
- OR branch:
|
||||
1. Flatten
|
||||
2. OR `==` → IN
|
||||
3. TEXT_MATCH merge (no options)
|
||||
4. Range weaken (same-direction bounds)
|
||||
5. BinaryRangeExpr merge (overlapping/adjacent intervals)
|
||||
6. IN with `!=` short-circuiting
|
||||
7. IN ∪ IN union
|
||||
8. IN vs Equal redundancy elimination
|
||||
9. Fold back to BinaryExpr
|
||||
- AND branch:
|
||||
1. Flatten
|
||||
2. Range tighten / interval construction
|
||||
3. BinaryRangeExpr merge (intersection, also with UnaryRangeExpr)
|
||||
4. IN ∪ IN intersection (if any)
|
||||
5. IN with `!=` filtering
|
||||
6. IN ∩ range filtering
|
||||
7. IN vs Equal redundancy elimination
|
||||
8. AND `!=` → NOT IN
|
||||
9. Fold back to BinaryExpr
|
||||
|
||||
Each construction of IN will be normalized (sorted and deduplicated). TEXT_MATCH OR merge concatenates literals with a single space; no tokenization, deduplication, or sorting is performed.
|
||||
|
||||
### File Structure
|
||||
- `entry.go` — rewrite entry and visitor orchestration
|
||||
- `util.go` — shared helpers (column keying, value classification, sorting/dedup, constructors)
|
||||
- `term_in.go` — IN/NOT IN normalization and conversions
|
||||
- `text_match.go` — TEXT_MATCH OR merge (no options)
|
||||
- `range.go` — range tightening/weakening and interval construction
|
||||
|
||||
### Future Extensions
|
||||
- More IN-range algebra (e.g., `IN` vs exact equality propagation across subtrees).
|
||||
- Merging phrase_match or other string ops with clearly-defined token rules.
|
||||
- More algebraic simplifications around equality and null checks:
|
||||
- Contradiction detection: `(a == 1) AND (a == 2)` → `false`; `(a > 10) AND (a == 5)` → `false`
|
||||
- Tautology detection: `(a > 10) OR (a <= 10)` → `true` (for non-NULL values)
|
||||
- Absorption laws: `(a > 10) OR ((a > 10) AND (b > 20))` → `a > 10`
|
||||
- Advanced BinaryRangeExpr merging:
|
||||
- OR with 3+ intervals: Currently limited to 2 intervals. Full interval merging algorithm needed for `(10 < x < 20) OR (15 < x < 25) OR (22 < x < 30)` → `(10 < x < 30)`.
|
||||
- OR with unbounded + bounded: Currently skipped. Could optimize `(x > 10) OR (5 < x < 15)` → `x > 5`.
|
||||
|
||||
|
||||
200
internal/parser/planparserv2/rewriter/entry.go
Normal file
200
internal/parser/planparserv2/rewriter/entry.go
Normal file
@ -0,0 +1,200 @@
|
||||
package rewriter
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
func RewriteExpr(e *planpb.Expr) *planpb.Expr {
|
||||
optimizeEnabled := paramtable.Get().CommonCfg.EnabledOptimizeExpr.GetAsBool()
|
||||
return RewriteExprWithConfig(e, optimizeEnabled)
|
||||
}
|
||||
|
||||
func RewriteExprWithConfig(e *planpb.Expr, optimizeEnabled bool) *planpb.Expr {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
v := &visitor{optimizeEnabled: optimizeEnabled}
|
||||
res := v.visitExpr(e)
|
||||
if out, ok := res.(*planpb.Expr); ok && out != nil {
|
||||
return out
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
type visitor struct {
|
||||
optimizeEnabled bool
|
||||
}
|
||||
|
||||
func (v *visitor) visitExpr(expr *planpb.Expr) interface{} {
|
||||
switch real := expr.GetExpr().(type) {
|
||||
case *planpb.Expr_BinaryExpr:
|
||||
return v.visitBinaryExpr(real.BinaryExpr)
|
||||
case *planpb.Expr_UnaryExpr:
|
||||
return v.visitUnaryExpr(real.UnaryExpr)
|
||||
case *planpb.Expr_TermExpr:
|
||||
return v.visitTermExpr(real.TermExpr)
|
||||
// no optimization for other types
|
||||
default:
|
||||
return expr
|
||||
}
|
||||
}
|
||||
|
||||
func (v *visitor) visitBinaryExpr(expr *planpb.BinaryExpr) interface{} {
|
||||
left := v.visitExpr(expr.GetLeft()).(*planpb.Expr)
|
||||
right := v.visitExpr(expr.GetRight()).(*planpb.Expr)
|
||||
switch expr.GetOp() {
|
||||
case planpb.BinaryExpr_LogicalOr:
|
||||
parts := flattenOr(left, right)
|
||||
if v.optimizeEnabled {
|
||||
parts = v.combineOrEqualsToIn(parts)
|
||||
parts = v.combineOrTextMatchToMerged(parts)
|
||||
parts = v.combineOrRangePredicates(parts)
|
||||
parts = v.combineOrBinaryRanges(parts)
|
||||
parts = v.combineOrInWithNotEqual(parts)
|
||||
parts = v.combineOrInWithIn(parts)
|
||||
parts = v.combineOrInWithEqual(parts)
|
||||
}
|
||||
return foldBinary(planpb.BinaryExpr_LogicalOr, parts)
|
||||
case planpb.BinaryExpr_LogicalAnd:
|
||||
parts := flattenAnd(left, right)
|
||||
if v.optimizeEnabled {
|
||||
parts = v.combineAndRangePredicates(parts)
|
||||
parts = v.combineAndBinaryRanges(parts)
|
||||
parts = v.combineAndInWithIn(parts)
|
||||
parts = v.combineAndInWithNotEqual(parts)
|
||||
parts = v.combineAndInWithRange(parts)
|
||||
parts = v.combineAndInWithEqual(parts)
|
||||
parts = v.combineAndNotEqualsToNotIn(parts)
|
||||
}
|
||||
return foldBinary(planpb.BinaryExpr_LogicalAnd, parts)
|
||||
default:
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_BinaryExpr{
|
||||
BinaryExpr: &planpb.BinaryExpr{
|
||||
Left: left,
|
||||
Right: right,
|
||||
Op: expr.GetOp(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (v *visitor) visitUnaryExpr(expr *planpb.UnaryExpr) interface{} {
|
||||
child := v.visitExpr(expr.GetChild()).(*planpb.Expr)
|
||||
|
||||
// Optimize double negation: NOT (NOT AlwaysTrue) → AlwaysTrue
|
||||
if expr.GetOp() == planpb.UnaryExpr_Not {
|
||||
if IsAlwaysFalseExpr(child) {
|
||||
return newAlwaysTrueExpr()
|
||||
}
|
||||
}
|
||||
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_UnaryExpr{
|
||||
UnaryExpr: &planpb.UnaryExpr{
|
||||
Op: expr.GetOp(),
|
||||
Child: child,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *visitor) visitTermExpr(expr *planpb.TermExpr) interface{} {
|
||||
sortTermValues(expr)
|
||||
return &planpb.Expr{Expr: &planpb.Expr_TermExpr{TermExpr: expr}}
|
||||
}
|
||||
|
||||
func flattenOr(a, b *planpb.Expr) []*planpb.Expr {
|
||||
out := make([]*planpb.Expr, 0, 4)
|
||||
collectOr(a, &out)
|
||||
collectOr(b, &out)
|
||||
return out
|
||||
}
|
||||
|
||||
func collectOr(e *planpb.Expr, out *[]*planpb.Expr) {
|
||||
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalOr {
|
||||
collectOr(be.GetLeft(), out)
|
||||
collectOr(be.GetRight(), out)
|
||||
return
|
||||
}
|
||||
*out = append(*out, e)
|
||||
}
|
||||
|
||||
func flattenAnd(a, b *planpb.Expr) []*planpb.Expr {
|
||||
out := make([]*planpb.Expr, 0, 4)
|
||||
collectAnd(a, &out)
|
||||
collectAnd(b, &out)
|
||||
return out
|
||||
}
|
||||
|
||||
func collectAnd(e *planpb.Expr, out *[]*planpb.Expr) {
|
||||
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd {
|
||||
collectAnd(be.GetLeft(), out)
|
||||
collectAnd(be.GetRight(), out)
|
||||
return
|
||||
}
|
||||
*out = append(*out, e)
|
||||
}
|
||||
|
||||
func foldBinary(op planpb.BinaryExpr_BinaryOp, exprs []*planpb.Expr) *planpb.Expr {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle AlwaysTrue and AlwaysFalse optimizations (single-pass)
|
||||
switch op {
|
||||
case planpb.BinaryExpr_LogicalAnd:
|
||||
filtered := make([]*planpb.Expr, 0, len(exprs))
|
||||
for _, e := range exprs {
|
||||
if IsAlwaysFalseExpr(e) {
|
||||
// AND: any AlwaysFalse → entire expression is AlwaysFalse
|
||||
return newAlwaysFalseExpr()
|
||||
}
|
||||
if !IsAlwaysTrueExpr(e) {
|
||||
// Filter out AlwaysTrue (since AlwaysTrue AND X = X)
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
}
|
||||
exprs = filtered
|
||||
// If all were AlwaysTrue, return AlwaysTrue
|
||||
if len(exprs) == 0 {
|
||||
return newAlwaysTrueExpr()
|
||||
}
|
||||
case planpb.BinaryExpr_LogicalOr:
|
||||
filtered := make([]*planpb.Expr, 0, len(exprs))
|
||||
for _, e := range exprs {
|
||||
if IsAlwaysTrueExpr(e) {
|
||||
// OR: any AlwaysTrue → entire expression is AlwaysTrue
|
||||
return newAlwaysTrueExpr()
|
||||
}
|
||||
if !IsAlwaysFalseExpr(e) {
|
||||
// Filter out AlwaysFalse (since AlwaysFalse OR X = X)
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
}
|
||||
exprs = filtered
|
||||
// If all were AlwaysFalse, return AlwaysFalse
|
||||
if len(exprs) == 0 {
|
||||
return newAlwaysFalseExpr()
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs) == 1 {
|
||||
return exprs[0]
|
||||
}
|
||||
cur := exprs[0]
|
||||
for i := 1; i < len(exprs); i++ {
|
||||
cur = &planpb.Expr{
|
||||
Expr: &planpb.Expr_BinaryExpr{
|
||||
BinaryExpr: &planpb.BinaryExpr{
|
||||
Left: cur,
|
||||
Right: exprs[i],
|
||||
Op: op,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return cur
|
||||
}
|
||||
923
internal/parser/planparserv2/rewriter/range.go
Normal file
923
internal/parser/planparserv2/rewriter/range.go
Normal file
@ -0,0 +1,923 @@
|
||||
package rewriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
)
|
||||
|
||||
type bound struct {
|
||||
value *planpb.GenericValue
|
||||
inclusive bool
|
||||
isLower bool
|
||||
exprIndex int
|
||||
}
|
||||
|
||||
func isSupportedScalarForRange(dt schemapb.DataType) bool {
|
||||
switch dt {
|
||||
case schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_VarChar:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// resolveEffectiveType returns (dt, ok) where ok indicates this column is eligible for range optimization.
|
||||
// Eligible when the column is a supported scalar, or an array whose element type is a supported scalar,
|
||||
// or a JSON field with a nested path (type will be determined from literal values).
|
||||
func resolveEffectiveType(col *planpb.ColumnInfo) (schemapb.DataType, bool) {
|
||||
if col == nil {
|
||||
return schemapb.DataType_None, false
|
||||
}
|
||||
dt := col.GetDataType()
|
||||
if isSupportedScalarForRange(dt) {
|
||||
return dt, true
|
||||
}
|
||||
if dt == schemapb.DataType_Array {
|
||||
et := col.GetElementType()
|
||||
if isSupportedScalarForRange(et) {
|
||||
return et, true
|
||||
}
|
||||
}
|
||||
// JSON fields with nested paths are eligible; the effective type will be
|
||||
// determined from the literal value in the comparison.
|
||||
if dt == schemapb.DataType_JSON && len(col.GetNestedPath()) > 0 {
|
||||
// Return a placeholder type; actual type checking happens in resolveJSONEffectiveType
|
||||
return schemapb.DataType_JSON, true
|
||||
}
|
||||
return schemapb.DataType_None, false
|
||||
}
|
||||
|
||||
// resolveJSONEffectiveType returns the effective type for a JSON field based on the literal value.
|
||||
// Returns (type, ok) where ok is false if the value is not suitable for range optimization.
|
||||
// For numeric types (int and float), we normalize to Double to allow mixing, similar to scalar Float/Double columns.
|
||||
func resolveJSONEffectiveType(v *planpb.GenericValue) (schemapb.DataType, bool) {
|
||||
if v == nil || v.GetVal() == nil {
|
||||
return schemapb.DataType_None, false
|
||||
}
|
||||
switch v.GetVal().(type) {
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
// Normalize int to Double for JSON to allow mixing with float literals
|
||||
return schemapb.DataType_Double, true
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
return schemapb.DataType_Double, true
|
||||
case *planpb.GenericValue_StringVal:
|
||||
return schemapb.DataType_VarChar, true
|
||||
case *planpb.GenericValue_BoolVal:
|
||||
// Boolean comparisons don't have meaningful ranges
|
||||
return schemapb.DataType_None, false
|
||||
default:
|
||||
return schemapb.DataType_None, false
|
||||
}
|
||||
}
|
||||
|
||||
func valueMatchesType(dt schemapb.DataType, v *planpb.GenericValue) bool {
|
||||
if v == nil || v.GetVal() == nil {
|
||||
return false
|
||||
}
|
||||
switch dt {
|
||||
case schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64:
|
||||
_, ok := v.GetVal().(*planpb.GenericValue_Int64Val)
|
||||
return ok
|
||||
case schemapb.DataType_Float, schemapb.DataType_Double:
|
||||
// For float columns, accept both float and int literal values
|
||||
switch v.GetVal().(type) {
|
||||
case *planpb.GenericValue_FloatVal, *planpb.GenericValue_Int64Val:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
_, ok := v.GetVal().(*planpb.GenericValue_StringVal)
|
||||
return ok
|
||||
case schemapb.DataType_JSON:
|
||||
// For JSON, check if we can determine a valid type from the value
|
||||
_, ok := resolveJSONEffectiveType(v)
|
||||
return ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (v *visitor) combineAndRangePredicates(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
effDt schemapb.DataType // effective type for comparison
|
||||
lowers []bound
|
||||
uppers []bound
|
||||
}
|
||||
groups := map[string]*group{}
|
||||
// exprs not eligible for range optimization
|
||||
others := []int{}
|
||||
isRangeOp := func(op planpb.OpType) bool {
|
||||
return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
|
||||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual
|
||||
}
|
||||
for idx, e := range parts {
|
||||
u := e.GetUnaryRangeExpr()
|
||||
if u == nil || !isRangeOp(u.GetOp()) || u.GetValue() == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
col := u.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
// Only optimize for supported types and matching value type
|
||||
effDt, ok := resolveEffectiveType(col)
|
||||
if !ok || !valueMatchesType(effDt, u.GetValue()) {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
// For JSON fields, determine the actual effective type from the literal value
|
||||
if effDt == schemapb.DataType_JSON {
|
||||
var typeOk bool
|
||||
effDt, typeOk = resolveJSONEffectiveType(u.GetValue())
|
||||
if !typeOk {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Group by column + effective type (for JSON, type depends on literal)
|
||||
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
|
||||
g, ok := groups[key]
|
||||
if !ok {
|
||||
g = &group{col: col, effDt: effDt}
|
||||
groups[key] = g
|
||||
}
|
||||
b := bound{
|
||||
value: u.GetValue(),
|
||||
inclusive: u.GetOp() == planpb.OpType_GreaterEqual || u.GetOp() == planpb.OpType_LessEqual,
|
||||
isLower: u.GetOp() == planpb.OpType_GreaterThan || u.GetOp() == planpb.OpType_GreaterEqual,
|
||||
exprIndex: idx,
|
||||
}
|
||||
if b.isLower {
|
||||
g.lowers = append(g.lowers, b)
|
||||
} else {
|
||||
g.uppers = append(g.uppers, b)
|
||||
}
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
// Use the effective type stored in the group
|
||||
var bestLower *bound
|
||||
for i := range g.lowers {
|
||||
if bestLower == nil || cmpGeneric(g.effDt, g.lowers[i].value, bestLower.value) > 0 ||
|
||||
(cmpGeneric(g.effDt, g.lowers[i].value, bestLower.value) == 0 && !g.lowers[i].inclusive && bestLower.inclusive) {
|
||||
b := g.lowers[i]
|
||||
bestLower = &b
|
||||
}
|
||||
}
|
||||
var bestUpper *bound
|
||||
for i := range g.uppers {
|
||||
if bestUpper == nil || cmpGeneric(g.effDt, g.uppers[i].value, bestUpper.value) < 0 ||
|
||||
(cmpGeneric(g.effDt, g.uppers[i].value, bestUpper.value) == 0 && !g.uppers[i].inclusive && bestUpper.inclusive) {
|
||||
b := g.uppers[i]
|
||||
bestUpper = &b
|
||||
}
|
||||
}
|
||||
if bestLower != nil && bestUpper != nil {
|
||||
// Check if the interval is valid (non-empty)
|
||||
c := cmpGeneric(g.effDt, bestLower.value, bestUpper.value)
|
||||
isEmpty := false
|
||||
if c > 0 {
|
||||
// lower > upper: always empty
|
||||
isEmpty = true
|
||||
} else if c == 0 {
|
||||
// lower == upper: only valid if both bounds are inclusive
|
||||
if !bestLower.inclusive || !bestUpper.inclusive {
|
||||
isEmpty = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, b := range g.lowers {
|
||||
used[b.exprIndex] = true
|
||||
}
|
||||
for _, b := range g.uppers {
|
||||
used[b.exprIndex] = true
|
||||
}
|
||||
|
||||
if isEmpty {
|
||||
// Empty interval → constant false
|
||||
out = append(out, newAlwaysFalseExpr())
|
||||
} else {
|
||||
out = append(out, newBinaryRangeExpr(g.col, bestLower.inclusive, bestUpper.inclusive, bestLower.value, bestUpper.value))
|
||||
}
|
||||
} else if bestLower != nil {
|
||||
for _, b := range g.lowers {
|
||||
used[b.exprIndex] = true
|
||||
}
|
||||
op := planpb.OpType_GreaterThan
|
||||
if bestLower.inclusive {
|
||||
op = planpb.OpType_GreaterEqual
|
||||
}
|
||||
out = append(out, newUnaryRangeExpr(g.col, op, bestLower.value))
|
||||
} else if bestUpper != nil {
|
||||
for _, b := range g.uppers {
|
||||
used[b.exprIndex] = true
|
||||
}
|
||||
op := planpb.OpType_LessThan
|
||||
if bestUpper.inclusive {
|
||||
op = planpb.OpType_LessEqual
|
||||
}
|
||||
out = append(out, newUnaryRangeExpr(g.col, op, bestUpper.value))
|
||||
}
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *visitor) combineOrRangePredicates(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type key struct {
|
||||
colKey string
|
||||
isLower bool
|
||||
effDt schemapb.DataType // effective type for JSON fields
|
||||
}
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
effDt schemapb.DataType
|
||||
dirLower bool
|
||||
bounds []bound
|
||||
}
|
||||
groups := map[key]*group{}
|
||||
others := []int{}
|
||||
isRangeOp := func(op planpb.OpType) bool {
|
||||
return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
|
||||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual
|
||||
}
|
||||
for idx, e := range parts {
|
||||
u := e.GetUnaryRangeExpr()
|
||||
if u == nil || !isRangeOp(u.GetOp()) || u.GetValue() == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
col := u.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
effDt, ok := resolveEffectiveType(col)
|
||||
if !ok || !valueMatchesType(effDt, u.GetValue()) {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
// For JSON fields, determine the actual effective type from the literal value
|
||||
if effDt == schemapb.DataType_JSON {
|
||||
var typeOk bool
|
||||
effDt, typeOk = resolveJSONEffectiveType(u.GetValue())
|
||||
if !typeOk {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
isLower := u.GetOp() == planpb.OpType_GreaterThan || u.GetOp() == planpb.OpType_GreaterEqual
|
||||
k := key{colKey: columnKey(col), isLower: isLower, effDt: effDt}
|
||||
g, ok := groups[k]
|
||||
if !ok {
|
||||
g = &group{col: col, effDt: effDt, dirLower: isLower}
|
||||
groups[k] = g
|
||||
}
|
||||
g.bounds = append(g.bounds, bound{
|
||||
value: u.GetValue(),
|
||||
inclusive: u.GetOp() == planpb.OpType_GreaterEqual || u.GetOp() == planpb.OpType_LessEqual,
|
||||
isLower: isLower,
|
||||
exprIndex: idx,
|
||||
})
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if len(g.bounds) <= 1 {
|
||||
continue
|
||||
}
|
||||
if g.dirLower {
|
||||
var best *bound
|
||||
for i := range g.bounds {
|
||||
// Use the effective type stored in the group
|
||||
if best == nil || cmpGeneric(g.effDt, g.bounds[i].value, best.value) < 0 ||
|
||||
(cmpGeneric(g.effDt, g.bounds[i].value, best.value) == 0 && g.bounds[i].inclusive && !best.inclusive) {
|
||||
b := g.bounds[i]
|
||||
best = &b
|
||||
}
|
||||
}
|
||||
for _, b := range g.bounds {
|
||||
used[b.exprIndex] = true
|
||||
}
|
||||
op := planpb.OpType_GreaterThan
|
||||
if best.inclusive {
|
||||
op = planpb.OpType_GreaterEqual
|
||||
}
|
||||
out = append(out, newUnaryRangeExpr(g.col, op, best.value))
|
||||
} else {
|
||||
var best *bound
|
||||
for i := range g.bounds {
|
||||
// Use the effective type stored in the group
|
||||
if best == nil || cmpGeneric(g.effDt, g.bounds[i].value, best.value) > 0 ||
|
||||
(cmpGeneric(g.effDt, g.bounds[i].value, best.value) == 0 && g.bounds[i].inclusive && !best.inclusive) {
|
||||
b := g.bounds[i]
|
||||
best = &b
|
||||
}
|
||||
}
|
||||
for _, b := range g.bounds {
|
||||
used[b.exprIndex] = true
|
||||
}
|
||||
op := planpb.OpType_LessThan
|
||||
if best.inclusive {
|
||||
op = planpb.OpType_LessEqual
|
||||
}
|
||||
out = append(out, newUnaryRangeExpr(g.col, op, best.value))
|
||||
}
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func newBinaryRangeExpr(col *planpb.ColumnInfo, lowerInclusive bool, upperInclusive bool, lower *planpb.GenericValue, upper *planpb.GenericValue) *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_BinaryRangeExpr{
|
||||
BinaryRangeExpr: &planpb.BinaryRangeExpr{
|
||||
ColumnInfo: col,
|
||||
LowerInclusive: lowerInclusive,
|
||||
UpperInclusive: upperInclusive,
|
||||
LowerValue: lower,
|
||||
UpperValue: upper,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// -1 means a < b, 0 means a == b, 1 means a > b
|
||||
func cmpGeneric(dt schemapb.DataType, a, b *planpb.GenericValue) int {
|
||||
switch dt {
|
||||
case schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64:
|
||||
ai, bi := a.GetInt64Val(), b.GetInt64Val()
|
||||
if ai < bi {
|
||||
return -1
|
||||
}
|
||||
if ai > bi {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case schemapb.DataType_Float, schemapb.DataType_Double:
|
||||
// Allow comparing int and float by promoting to float
|
||||
toFloat := func(g *planpb.GenericValue) float64 {
|
||||
switch g.GetVal().(type) {
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
return g.GetFloatVal()
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
return float64(g.GetInt64Val())
|
||||
default:
|
||||
// Should not happen due to gate; treat as 0 deterministically
|
||||
return 0
|
||||
}
|
||||
}
|
||||
af, bf := toFloat(a), toFloat(b)
|
||||
if af < bf {
|
||||
return -1
|
||||
}
|
||||
if af > bf {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case schemapb.DataType_String,
|
||||
schemapb.DataType_VarChar:
|
||||
as, bs := a.GetStringVal(), b.GetStringVal()
|
||||
if as < bs {
|
||||
return -1
|
||||
}
|
||||
if as > bs {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
// Unsupported types are not optimized; callers gate with resolveEffectiveType.
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// combineAndBinaryRanges merges BinaryRangeExpr nodes with AND semantics (intersection).
|
||||
// Also handles mixing BinaryRangeExpr with UnaryRangeExpr.
|
||||
func (v *visitor) combineAndBinaryRanges(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type interval struct {
|
||||
lower *planpb.GenericValue
|
||||
lowerInc bool
|
||||
upper *planpb.GenericValue
|
||||
upperInc bool
|
||||
exprIndex int
|
||||
isBinaryRange bool
|
||||
}
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
effDt schemapb.DataType
|
||||
intervals []interval
|
||||
}
|
||||
groups := map[string]*group{}
|
||||
others := []int{}
|
||||
|
||||
for idx, e := range parts {
|
||||
// Try BinaryRangeExpr
|
||||
if bre := e.GetBinaryRangeExpr(); bre != nil {
|
||||
col := bre.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
effDt, ok := resolveEffectiveType(col)
|
||||
if !ok {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
// For JSON, determine actual type from lower value
|
||||
if effDt == schemapb.DataType_JSON {
|
||||
var typeOk bool
|
||||
effDt, typeOk = resolveJSONEffectiveType(bre.GetLowerValue())
|
||||
if !typeOk {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
|
||||
g, exists := groups[key]
|
||||
if !exists {
|
||||
g = &group{col: col, effDt: effDt}
|
||||
groups[key] = g
|
||||
}
|
||||
g.intervals = append(g.intervals, interval{
|
||||
lower: bre.GetLowerValue(),
|
||||
lowerInc: bre.GetLowerInclusive(),
|
||||
upper: bre.GetUpperValue(),
|
||||
upperInc: bre.GetUpperInclusive(),
|
||||
exprIndex: idx,
|
||||
isBinaryRange: true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Try UnaryRangeExpr (range ops only)
|
||||
if ure := e.GetUnaryRangeExpr(); ure != nil {
|
||||
op := ure.GetOp()
|
||||
if op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
|
||||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual {
|
||||
col := ure.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
effDt, ok := resolveEffectiveType(col)
|
||||
if !ok || !valueMatchesType(effDt, ure.GetValue()) {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
if effDt == schemapb.DataType_JSON {
|
||||
var typeOk bool
|
||||
effDt, typeOk = resolveJSONEffectiveType(ure.GetValue())
|
||||
if !typeOk {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
|
||||
g, exists := groups[key]
|
||||
if !exists {
|
||||
g = &group{col: col, effDt: effDt}
|
||||
groups[key] = g
|
||||
}
|
||||
isLower := op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual
|
||||
inc := op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessEqual
|
||||
if isLower {
|
||||
g.intervals = append(g.intervals, interval{
|
||||
lower: ure.GetValue(),
|
||||
lowerInc: inc,
|
||||
upper: nil,
|
||||
upperInc: false,
|
||||
exprIndex: idx,
|
||||
isBinaryRange: false,
|
||||
})
|
||||
} else {
|
||||
g.intervals = append(g.intervals, interval{
|
||||
lower: nil,
|
||||
lowerInc: false,
|
||||
upper: ure.GetValue(),
|
||||
upperInc: inc,
|
||||
exprIndex: idx,
|
||||
isBinaryRange: false,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Not a range expr we can optimize
|
||||
others = append(others, idx)
|
||||
}
|
||||
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
if len(g.intervals) == 0 {
|
||||
continue
|
||||
}
|
||||
if len(g.intervals) == 1 {
|
||||
// Single interval, keep as is
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute intersection: max lower, min upper
|
||||
var finalLower *planpb.GenericValue
|
||||
var finalLowerInc bool
|
||||
var finalUpper *planpb.GenericValue
|
||||
var finalUpperInc bool
|
||||
|
||||
for _, iv := range g.intervals {
|
||||
if iv.lower != nil {
|
||||
if finalLower == nil {
|
||||
finalLower = iv.lower
|
||||
finalLowerInc = iv.lowerInc
|
||||
} else {
|
||||
c := cmpGeneric(g.effDt, iv.lower, finalLower)
|
||||
if c > 0 || (c == 0 && !iv.lowerInc) {
|
||||
finalLower = iv.lower
|
||||
finalLowerInc = iv.lowerInc
|
||||
}
|
||||
}
|
||||
}
|
||||
if iv.upper != nil {
|
||||
if finalUpper == nil {
|
||||
finalUpper = iv.upper
|
||||
finalUpperInc = iv.upperInc
|
||||
} else {
|
||||
c := cmpGeneric(g.effDt, iv.upper, finalUpper)
|
||||
if c < 0 || (c == 0 && !iv.upperInc) {
|
||||
finalUpper = iv.upper
|
||||
finalUpperInc = iv.upperInc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if intersection is empty
|
||||
if finalLower != nil && finalUpper != nil {
|
||||
c := cmpGeneric(g.effDt, finalLower, finalUpper)
|
||||
isEmpty := false
|
||||
if c > 0 {
|
||||
isEmpty = true
|
||||
} else if c == 0 {
|
||||
// Equal bounds: only valid if both inclusive
|
||||
if !finalLowerInc || !finalUpperInc {
|
||||
isEmpty = true
|
||||
}
|
||||
}
|
||||
if isEmpty {
|
||||
// Empty intersection → constant false
|
||||
for _, iv := range g.intervals {
|
||||
used[iv.exprIndex] = true
|
||||
}
|
||||
out = append(out, newAlwaysFalseExpr())
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Mark all intervals as used
|
||||
for _, iv := range g.intervals {
|
||||
used[iv.exprIndex] = true
|
||||
}
|
||||
|
||||
// Emit the merged interval
|
||||
if finalLower != nil && finalUpper != nil {
|
||||
out = append(out, newBinaryRangeExpr(g.col, finalLowerInc, finalUpperInc, finalLower, finalUpper))
|
||||
} else if finalLower != nil {
|
||||
op := planpb.OpType_GreaterThan
|
||||
if finalLowerInc {
|
||||
op = planpb.OpType_GreaterEqual
|
||||
}
|
||||
out = append(out, newUnaryRangeExpr(g.col, op, finalLower))
|
||||
} else if finalUpper != nil {
|
||||
op := planpb.OpType_LessThan
|
||||
if finalUpperInc {
|
||||
op = planpb.OpType_LessEqual
|
||||
}
|
||||
out = append(out, newUnaryRangeExpr(g.col, op, finalUpper))
|
||||
}
|
||||
}
|
||||
|
||||
// Add unused parts
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// combineOrBinaryRanges merges BinaryRangeExpr nodes with OR semantics (union if overlapping/adjacent).
|
||||
// Also handles mixing BinaryRangeExpr with UnaryRangeExpr.
|
||||
func (v *visitor) combineOrBinaryRanges(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type interval struct {
|
||||
lower *planpb.GenericValue
|
||||
lowerInc bool
|
||||
upper *planpb.GenericValue
|
||||
upperInc bool
|
||||
exprIndex int
|
||||
isBinaryRange bool
|
||||
}
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
effDt schemapb.DataType
|
||||
intervals []interval
|
||||
}
|
||||
groups := map[string]*group{}
|
||||
others := []int{}
|
||||
|
||||
for idx, e := range parts {
|
||||
// Try BinaryRangeExpr
|
||||
if bre := e.GetBinaryRangeExpr(); bre != nil {
|
||||
col := bre.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
effDt, ok := resolveEffectiveType(col)
|
||||
if !ok {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
if effDt == schemapb.DataType_JSON {
|
||||
var typeOk bool
|
||||
effDt, typeOk = resolveJSONEffectiveType(bre.GetLowerValue())
|
||||
if !typeOk {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
|
||||
g, exists := groups[key]
|
||||
if !exists {
|
||||
g = &group{col: col, effDt: effDt}
|
||||
groups[key] = g
|
||||
}
|
||||
g.intervals = append(g.intervals, interval{
|
||||
lower: bre.GetLowerValue(),
|
||||
lowerInc: bre.GetLowerInclusive(),
|
||||
upper: bre.GetUpperValue(),
|
||||
upperInc: bre.GetUpperInclusive(),
|
||||
exprIndex: idx,
|
||||
isBinaryRange: true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Try UnaryRangeExpr
|
||||
if ure := e.GetUnaryRangeExpr(); ure != nil {
|
||||
op := ure.GetOp()
|
||||
if op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual ||
|
||||
op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual {
|
||||
col := ure.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
effDt, ok := resolveEffectiveType(col)
|
||||
if !ok || !valueMatchesType(effDt, ure.GetValue()) {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
if effDt == schemapb.DataType_JSON {
|
||||
var typeOk bool
|
||||
effDt, typeOk = resolveJSONEffectiveType(ure.GetValue())
|
||||
if !typeOk {
|
||||
others = append(others, idx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
key := columnKey(col) + fmt.Sprintf("|%d", effDt)
|
||||
g, exists := groups[key]
|
||||
if !exists {
|
||||
g = &group{col: col, effDt: effDt}
|
||||
groups[key] = g
|
||||
}
|
||||
isLower := op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual
|
||||
inc := op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessEqual
|
||||
if isLower {
|
||||
g.intervals = append(g.intervals, interval{
|
||||
lower: ure.GetValue(),
|
||||
lowerInc: inc,
|
||||
upper: nil,
|
||||
upperInc: false,
|
||||
exprIndex: idx,
|
||||
isBinaryRange: false,
|
||||
})
|
||||
} else {
|
||||
g.intervals = append(g.intervals, interval{
|
||||
lower: nil,
|
||||
lowerInc: false,
|
||||
upper: ure.GetValue(),
|
||||
upperInc: inc,
|
||||
exprIndex: idx,
|
||||
isBinaryRange: false,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
others = append(others, idx)
|
||||
}
|
||||
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
if len(g.intervals) == 0 {
|
||||
continue
|
||||
}
|
||||
if len(g.intervals) == 1 {
|
||||
// Single interval, keep as is
|
||||
continue
|
||||
}
|
||||
|
||||
// For OR, try to merge overlapping/adjacent intervals
|
||||
// If any interval is unbounded on one side, check if it subsumes others
|
||||
// For simplicity, we'll handle the common cases:
|
||||
// 1. All bounded intervals: try to merge if overlapping/adjacent
|
||||
// 2. Mix of bounded/unbounded: merge unbounded with compatible bounds
|
||||
|
||||
var hasUnboundedLower, hasUnboundedUpper bool
|
||||
var unboundedLowerVal *planpb.GenericValue
|
||||
var unboundedLowerInc bool
|
||||
var unboundedUpperVal *planpb.GenericValue
|
||||
var unboundedUpperInc bool
|
||||
|
||||
// Check for unbounded intervals
|
||||
for _, iv := range g.intervals {
|
||||
if iv.lower != nil && iv.upper == nil {
|
||||
// Lower bound only (x > a)
|
||||
if !hasUnboundedLower {
|
||||
hasUnboundedLower = true
|
||||
unboundedLowerVal = iv.lower
|
||||
unboundedLowerInc = iv.lowerInc
|
||||
} else {
|
||||
// Multiple lower-only bounds: take weakest (minimum)
|
||||
c := cmpGeneric(g.effDt, iv.lower, unboundedLowerVal)
|
||||
if c < 0 || (c == 0 && iv.lowerInc && !unboundedLowerInc) {
|
||||
unboundedLowerVal = iv.lower
|
||||
unboundedLowerInc = iv.lowerInc
|
||||
}
|
||||
}
|
||||
}
|
||||
if iv.lower == nil && iv.upper != nil {
|
||||
// Upper bound only (x < b)
|
||||
if !hasUnboundedUpper {
|
||||
hasUnboundedUpper = true
|
||||
unboundedUpperVal = iv.upper
|
||||
unboundedUpperInc = iv.upperInc
|
||||
} else {
|
||||
// Multiple upper-only bounds: take weakest (maximum)
|
||||
c := cmpGeneric(g.effDt, iv.upper, unboundedUpperVal)
|
||||
if c > 0 || (c == 0 && iv.upperInc && !unboundedUpperInc) {
|
||||
unboundedUpperVal = iv.upper
|
||||
unboundedUpperInc = iv.upperInc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 1: Have both unbounded lower and upper → entire domain (always true, but we can't express that simply)
|
||||
// For now, keep them separate
|
||||
// Case 2: Have one unbounded side → merge with compatible bounded intervals
|
||||
// Case 3: All bounded → try to merge overlapping/adjacent
|
||||
|
||||
if hasUnboundedLower && hasUnboundedUpper {
|
||||
// Both unbounded sides - this likely covers most values
|
||||
// Keep as separate predicates for now (more advanced merging could be done)
|
||||
continue
|
||||
}
|
||||
|
||||
if hasUnboundedLower || hasUnboundedUpper {
|
||||
// Merge unbounded with bounded intervals where applicable
|
||||
// For unbounded lower (x > a): can merge with binary ranges that have compatible upper bounds
|
||||
// For unbounded upper (x < b): can merge with binary ranges that have compatible lower bounds
|
||||
// This is complex, so for now we'll keep it simple and just skip merging
|
||||
// In practice, unbounded intervals often dominate
|
||||
continue
|
||||
}
|
||||
|
||||
// All bounded intervals: try to merge overlapping/adjacent ones
|
||||
// This requires sorting and checking overlap
|
||||
// For simplicity in this initial implementation, we'll check if there are exactly 2 intervals
|
||||
// and try to merge them if they overlap or are adjacent
|
||||
|
||||
if len(g.intervals) == 2 {
|
||||
iv1, iv2 := g.intervals[0], g.intervals[1]
|
||||
if iv1.lower == nil || iv1.upper == nil || iv2.lower == nil || iv2.upper == nil {
|
||||
// One is not fully bounded, skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if they overlap or are adjacent
|
||||
// They overlap if: iv1.lower <= iv2.upper AND iv2.lower <= iv1.upper
|
||||
// They are adjacent if: iv1.upper == iv2.lower (or vice versa) with at least one inclusive
|
||||
|
||||
// Determine order: which has smaller lower bound
|
||||
var first, second interval
|
||||
c := cmpGeneric(g.effDt, iv1.lower, iv2.lower)
|
||||
if c <= 0 {
|
||||
first, second = iv1, iv2
|
||||
} else {
|
||||
first, second = iv2, iv1
|
||||
}
|
||||
|
||||
// Check if they can be merged
|
||||
// Overlap: first.upper >= second.lower
|
||||
cmpUpperLower := cmpGeneric(g.effDt, first.upper, second.lower)
|
||||
canMerge := false
|
||||
if cmpUpperLower > 0 {
|
||||
// Overlap
|
||||
canMerge = true
|
||||
} else if cmpUpperLower == 0 {
|
||||
// Adjacent: at least one bound must be inclusive
|
||||
if first.upperInc || second.lowerInc {
|
||||
canMerge = true
|
||||
}
|
||||
}
|
||||
|
||||
if canMerge {
|
||||
// Merge: take min lower and max upper
|
||||
mergedLower := first.lower
|
||||
mergedLowerInc := first.lowerInc
|
||||
mergedUpper := second.upper
|
||||
mergedUpperInc := second.upperInc
|
||||
|
||||
// Upper bound: take maximum
|
||||
cmpUppers := cmpGeneric(g.effDt, first.upper, second.upper)
|
||||
if cmpUppers > 0 {
|
||||
mergedUpper = first.upper
|
||||
mergedUpperInc = first.upperInc
|
||||
} else if cmpUppers == 0 {
|
||||
// Same value: prefer inclusive
|
||||
if first.upperInc {
|
||||
mergedUpperInc = true
|
||||
}
|
||||
}
|
||||
|
||||
// Mark both as used
|
||||
used[first.exprIndex] = true
|
||||
used[second.exprIndex] = true
|
||||
|
||||
// Emit merged interval
|
||||
out = append(out, newBinaryRangeExpr(g.col, mergedLowerInc, mergedUpperInc, mergedLower, mergedUpper))
|
||||
}
|
||||
}
|
||||
// For more than 2 intervals, we'd need more sophisticated merging logic
|
||||
// For now, we'll leave them separate
|
||||
}
|
||||
|
||||
// Add unused parts
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
412
internal/parser/planparserv2/rewriter/range_binary_test.go
Normal file
412
internal/parser/planparserv2/rewriter/range_binary_test.go
Normal file
@ -0,0 +1,412 @@
|
||||
package rewriter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
)
|
||||
|
||||
// Test BinaryRangeExpr AND BinaryRangeExpr - intersection
|
||||
func TestRewrite_BinaryRange_AND_BinaryRange_Intersection(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 50) AND (20 < x < 40) → (20 < x < 40)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 20 and Int64Field < 40)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre, "should merge to single binary range")
|
||||
require.Equal(t, false, bre.GetLowerInclusive())
|
||||
require.Equal(t, false, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND BinaryRangeExpr - tighter lower
|
||||
func TestRewrite_BinaryRange_AND_BinaryRange_TighterLower(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 50) AND (5 < x < 40) → (10 < x < 40)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 5 and Int64Field < 40)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND BinaryRangeExpr - tighter upper
|
||||
func TestRewrite_BinaryRange_AND_BinaryRange_TighterUpper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 50) AND (15 < x < 60) → (15 < x < 50)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (Int64Field > 15 and Int64Field < 60)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(15), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND BinaryRangeExpr - empty intersection
|
||||
func TestRewrite_BinaryRange_AND_BinaryRange_EmptyIntersection(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 20) AND (30 < x < 40) → false
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) and (Int64Field > 30 and Int64Field < 40)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND BinaryRangeExpr - equal bounds, both inclusive
|
||||
func TestRewrite_BinaryRange_AND_BinaryRange_EqualBounds_BothInclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 <= x <= 10) AND (10 <= x <= 20) → (x == 10) which is (10 <= x <= 10)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field >= 10 and Int64Field <= 10) and (Int64Field >= 10 and Int64Field <= 20)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, true, bre.GetLowerInclusive())
|
||||
require.Equal(t, true, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(10), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND BinaryRangeExpr - equal bounds, one exclusive → false
|
||||
func TestRewrite_BinaryRange_AND_BinaryRange_EqualBounds_Exclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x <= 20) AND (5 <= x < 10) → false (bounds meet at 10 but exclusive)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) and (Int64Field >= 5 and Int64Field < 10)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND UnaryRangeExpr - tighten lower bound
|
||||
func TestRewrite_BinaryRange_AND_UnaryRange_TightenLower(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 50) AND (x > 30) → (30 < x < 50)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and Int64Field > 30`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(30), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND UnaryRangeExpr - tighten upper bound
|
||||
func TestRewrite_BinaryRange_AND_UnaryRange_TightenUpper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 50) AND (x < 25) → (10 < x < 25)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and Int64Field < 25`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(25), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr AND UnaryRangeExpr - weaker bound (no change)
|
||||
func TestRewrite_BinaryRange_AND_UnaryRange_WeakerBound(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (20 < x < 30) AND (x > 10) → (20 < x < 30) (10 is weaker than 20)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 20 and Int64Field < 30) and Int64Field > 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR BinaryRangeExpr - overlapping → union
|
||||
func TestRewrite_BinaryRange_OR_BinaryRange_Overlapping(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 25) OR (20 < x < 40) → (10 < x < 40)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 25) or (Int64Field > 20 and Int64Field < 40)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre, "overlapping intervals should merge")
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(40), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR BinaryRangeExpr - adjacent (inclusive) → union
|
||||
func TestRewrite_BinaryRange_OR_BinaryRange_Adjacent(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x <= 20) OR (20 <= x < 30) → (10 < x < 30)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) or (Int64Field >= 20 and Int64Field < 30)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre, "adjacent intervals should merge")
|
||||
require.Equal(t, false, bre.GetLowerInclusive())
|
||||
require.Equal(t, false, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR BinaryRangeExpr - disjoint (no merge)
|
||||
func TestRewrite_BinaryRange_OR_BinaryRange_Disjoint(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 20) OR (30 < x < 40) → remains as OR
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 30 and Int64Field < 40)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
// Should remain as BinaryExpr OR since intervals are disjoint
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "disjoint intervals should not merge")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR BinaryRangeExpr - adjacent but both exclusive (no merge)
|
||||
func TestRewrite_BinaryRange_OR_BinaryRange_AdjacentExclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 20) OR (20 < x < 30) → remains as OR (gap at 20)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 20 and Int64Field < 30)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "exclusive adjacent intervals should not merge")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR BinaryRangeExpr - prefer inclusive when merging
|
||||
func TestRewrite_BinaryRange_OR_BinaryRange_PreferInclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 <= x < 25) OR (15 <= x <= 30) → (10 <= x <= 30)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field >= 10 and Int64Field < 25) or (Int64Field >= 15 and Int64Field <= 30)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
// Lower should be 10 (from first), upper should be 30 (from second)
|
||||
// Both should prefer inclusive where available
|
||||
require.Equal(t, true, bre.GetLowerInclusive())
|
||||
require.Equal(t, true, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(30), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test with Float fields
|
||||
func TestRewrite_BinaryRange_AND_Float(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (1.0 < x < 5.0) AND (2.0 < x < 4.0) → (2.0 < x < 4.0)
|
||||
expr, err := parser.ParseExpr(helper, `(FloatField > 1.0 and FloatField < 5.0) and (FloatField > 2.0 and FloatField < 4.0)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.InDelta(t, 2.0, bre.GetLowerValue().GetFloatVal(), 1e-9)
|
||||
require.InDelta(t, 4.0, bre.GetUpperValue().GetFloatVal(), 1e-9)
|
||||
}
|
||||
|
||||
// Test with VarChar fields
|
||||
func TestRewrite_BinaryRange_OR_VarChar_Overlapping(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// ("b" < x < "m") OR ("j" < x < "z") → ("b" < x < "z")
|
||||
expr, err := parser.ParseExpr(helper, `(VarCharField > "b" and VarCharField < "m") or (VarCharField > "j" and VarCharField < "z")`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, "b", bre.GetLowerValue().GetStringVal())
|
||||
require.Equal(t, "z", bre.GetUpperValue().GetStringVal())
|
||||
}
|
||||
|
||||
// Test with JSON fields
|
||||
func TestRewrite_BinaryRange_AND_JSON(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// (10 < json["price"] < 100) AND (20 < json["price"] < 80) → (20 < json["price"] < 80)
|
||||
expr, err := parser.ParseExpr(helper, `(JSONField["price"] > 10 and JSONField["price"] < 100) and (JSONField["price"] > 20 and JSONField["price"] < 80)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test with JSON fields - OR overlapping
|
||||
func TestRewrite_BinaryRange_OR_JSON_Overlapping(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// (1.0 < json["score"] < 3.0) OR (2.5 < json["score"] < 5.0) → (1.0 < json["score"] < 5.0)
|
||||
expr, err := parser.ParseExpr(helper, `(JSONField["score"] > 1.0 and JSONField["score"] < 3.0) or (JSONField["score"] > 2.5 and JSONField["score"] < 5.0)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.InDelta(t, 1.0, bre.GetLowerValue().GetFloatVal(), 1e-9)
|
||||
require.InDelta(t, 5.0, bre.GetUpperValue().GetFloatVal(), 1e-9)
|
||||
}
|
||||
|
||||
// Test mixing BinaryRange with multiple UnaryRanges
|
||||
func TestRewrite_BinaryRange_AND_MultipleUnary(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 100) AND (x > 20) AND (x < 80) → (20 < x < 80)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 100) and Int64Field > 20 and Int64Field < 80`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(20), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test three BinaryRanges with AND
|
||||
func TestRewrite_BinaryRange_AND_ThreeBinaryRanges(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (5 < x < 100) AND (10 < x < 90) AND (15 < x < 80) → (15 < x < 80)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 5 and Int64Field < 100) and (Int64Field > 10 and Int64Field < 90) and (Int64Field > 15 and Int64Field < 80)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, int64(15), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(80), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test BinaryRange on different columns should not merge
|
||||
func TestRewrite_BinaryRange_AND_DifferentColumns(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < Int64Field < 50) AND (20 < FloatField < 40) → both remain
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 50) and (FloatField > 20 and FloatField < 40)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "different columns should not merge")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
}
|
||||
|
||||
// Test BinaryRange with JSON different paths should not merge
|
||||
func TestRewrite_BinaryRange_AND_JSON_DifferentPaths(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// (10 < json["a"] < 50) AND (20 < json["b"] < 40) → both remain
|
||||
expr, err := parser.ParseExpr(helper, `(JSONField["a"] > 10 and JSONField["a"] < 50) and (JSONField["b"] > 20 and JSONField["b"] < 40)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "different JSON paths should not merge")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR with 3 overlapping intervals
|
||||
// NOTE: Current implementation limitation - only merges 2 intervals at a time
|
||||
func TestRewrite_BinaryRange_OR_ThreeOverlapping_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 20) OR (15 < x < 25) OR (22 < x < 30)
|
||||
// Ideally should merge to (10 < x < 30)
|
||||
// Currently: may only partially merge due to limitation
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 20) or (Int64Field > 15 and Int64Field < 25) or (Int64Field > 22 and Int64Field < 30)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Due to current limitation, result may vary depending on tree structure
|
||||
// This test documents the current behavior rather than ideal behavior
|
||||
// When enhancement is implemented, this test should be updated to verify (10 < x < 30)
|
||||
|
||||
// For now, just verify it doesn't crash and produces valid output
|
||||
require.NotNil(t, expr)
|
||||
// Could be BinaryRangeExpr (if some merged) or BinaryExpr OR (if not merged)
|
||||
// We document that 3+ intervals are NOT fully optimized yet
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR with 3 fully overlapping intervals
|
||||
func TestRewrite_BinaryRange_OR_ThreeFullyOverlapping(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x < 30) OR (12 < x < 28) OR (15 < x < 25)
|
||||
// The second and third are fully contained in the first
|
||||
// Ideally should merge to (10 < x < 30)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field < 30) or (Int64Field > 12 and Int64Field < 28) or (Int64Field > 15 and Int64Field < 25)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: may not fully optimize
|
||||
// This test documents that 3+ interval merging is not yet complete
|
||||
require.NotNil(t, expr)
|
||||
}
|
||||
|
||||
// Test BinaryRangeExpr OR with 4 adjacent intervals
|
||||
func TestRewrite_BinaryRange_OR_FourAdjacent_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (10 < x <= 20) OR (20 < x <= 30) OR (30 < x <= 40) OR (40 < x <= 50)
|
||||
// Ideally should merge to (10 < x <= 50)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10 and Int64Field <= 20) or (Int64Field > 20 and Int64Field <= 30) or (Int64Field > 30 and Int64Field <= 40) or (Int64Field > 40 and Int64Field <= 50)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: only pairs of adjacent intervals may merge
|
||||
// Full chain merging not implemented
|
||||
require.NotNil(t, expr)
|
||||
}
|
||||
|
||||
// Test OR with unbounded lower + bounded interval
|
||||
// NOTE: Current implementation limitation - unbounded intervals not merged with bounded
|
||||
func TestRewrite_BinaryRange_OR_UnboundedLower_Bounded_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (x > 10) OR (5 < x < 15)
|
||||
// Ideally should merge to (x > 5)
|
||||
// Currently: both predicates remain separate
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or (Int64Field > 5 and Int64Field < 15)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: unbounded + bounded intervals not optimized
|
||||
// Result should be a BinaryExpr OR with both predicates
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as OR (not merged)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
}
|
||||
|
||||
// Test OR with unbounded upper + bounded interval
|
||||
func TestRewrite_BinaryRange_OR_UnboundedUpper_Bounded_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (x < 20) OR (10 < x < 30)
|
||||
// Ideally should merge to (x < 30)
|
||||
// Currently: both predicates remain separate
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field < 20 or (Int64Field > 10 and Int64Field < 30)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: unbounded + bounded intervals not optimized
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as OR (not merged)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
}
|
||||
|
||||
// Test OR with unbounded lower + unbounded upper
|
||||
func TestRewrite_BinaryRange_OR_UnboundedBoth_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (x > 10) OR (x < 20)
|
||||
// This covers most values (gap only between 10 and 20 if both exclusive)
|
||||
// Currently: both predicates remain separate
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field < 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: unbounded intervals in OR not merged
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as OR")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
}
|
||||
|
||||
// Test OR with multiple unbounded lower bounds - these DO get optimized (weakening)
|
||||
func TestRewrite_BinaryRange_OR_MultipleUnboundedLower(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (x > 10) OR (x > 20) → (x > 10) [weaker bound]
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// This SHOULD be optimized (weakening works for same direction)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure, "should merge to single unary range")
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
294
internal/parser/planparserv2/rewriter/range_json_test.go
Normal file
294
internal/parser/planparserv2/rewriter/range_json_test.go
Normal file
@ -0,0 +1,294 @@
|
||||
package rewriter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func buildSchemaHelperWithJSON(t *testing.T) *typeutil.SchemaHelper {
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{FieldID: 101, Name: "Int64Field", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 102, Name: "JSONField", DataType: schemapb.DataType_JSON},
|
||||
{FieldID: 103, Name: "$meta", DataType: schemapb.DataType_JSON, IsDynamic: true},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "rewrite_json_test",
|
||||
AutoID: false,
|
||||
Fields: fields,
|
||||
EnableDynamicField: true,
|
||||
}
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
require.NoError(t, err)
|
||||
return helper
|
||||
}
|
||||
|
||||
// Test JSON field with int comparison - AND tightening
|
||||
func TestRewrite_JSON_Int_AND_Strengthen(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["price"] > 10 and JSONField["price"] > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure, "should merge to single unary range")
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
|
||||
// Verify column info
|
||||
require.Equal(t, schemapb.DataType_JSON, ure.GetColumnInfo().GetDataType())
|
||||
require.Equal(t, []string{"price"}, ure.GetColumnInfo().GetNestedPath())
|
||||
}
|
||||
|
||||
// Test JSON field with int comparison - AND to BinaryRange
|
||||
func TestRewrite_JSON_Int_AND_ToBinaryRange(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["price"] > 10 and JSONField["price"] < 50`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre, "should create binary range")
|
||||
require.Equal(t, false, bre.GetLowerInclusive())
|
||||
require.Equal(t, false, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
|
||||
// Verify column info
|
||||
require.Equal(t, schemapb.DataType_JSON, bre.GetColumnInfo().GetDataType())
|
||||
require.Equal(t, []string{"price"}, bre.GetColumnInfo().GetNestedPath())
|
||||
}
|
||||
|
||||
// Test JSON field with float comparison - AND tightening with mixed int/float
|
||||
func TestRewrite_JSON_Float_AND_Strengthen_Mixed(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["score"] > 10 and JSONField["score"] > 15.5`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.InDelta(t, 15.5, ure.GetValue().GetFloatVal(), 1e-9)
|
||||
}
|
||||
|
||||
// Test JSON field with string comparison - AND tightening
|
||||
func TestRewrite_JSON_String_AND_Strengthen(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["name"] > "alice" and JSONField["name"] > "bob"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, "bob", ure.GetValue().GetStringVal())
|
||||
}
|
||||
|
||||
// Test JSON field with string comparison - AND to BinaryRange
|
||||
func TestRewrite_JSON_String_AND_ToBinaryRange(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["name"] >= "alice" and JSONField["name"] <= "zebra"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, true, bre.GetLowerInclusive())
|
||||
require.Equal(t, true, bre.GetUpperInclusive())
|
||||
require.Equal(t, "alice", bre.GetLowerValue().GetStringVal())
|
||||
require.Equal(t, "zebra", bre.GetUpperValue().GetStringVal())
|
||||
}
|
||||
|
||||
// Test JSON field with OR - weakening lower bounds
|
||||
func TestRewrite_JSON_Int_OR_Weaken(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["age"] > 10 or JSONField["age"] > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test JSON field with OR - weakening upper bounds
|
||||
func TestRewrite_JSON_String_OR_Weaken_Upper(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["category"] < "electronics" or JSONField["category"] < "sports"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
|
||||
require.Equal(t, "sports", ure.GetValue().GetStringVal())
|
||||
}
|
||||
|
||||
// Test JSON field with numeric types (int and float) - SHOULD merge
|
||||
func TestRewrite_JSON_NumericTypes_Merge(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// JSONField["value"] > 10.5 (float) and JSONField["value"] > 20 (int)
|
||||
// Numeric types should merge - both treated as Double
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["value"] > 10.5 and JSONField["value"] > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure, "numeric types should merge to single predicate")
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
// Should pick the stronger bound (20)
|
||||
switch ure.GetValue().GetVal().(type) {
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
require.InDelta(t, 20.0, ure.GetValue().GetFloatVal(), 1e-9)
|
||||
default:
|
||||
t.Fatalf("unexpected value type")
|
||||
}
|
||||
}
|
||||
|
||||
// Test JSON field with mixed type categories - should NOT merge
|
||||
// Note: Numeric types (int and float) ARE compatible, but numeric and string are not
|
||||
func TestRewrite_JSON_MixedTypes_NoMerge(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// JSONField["value"] > 10 (numeric) and JSONField["value"] > "hello" (string)
|
||||
// These should remain separate as they have different type categories
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["value"] > 10 and JSONField["value"] > "hello"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as binary expr (AND)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
// Both sides should be UnaryRangeExpr
|
||||
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
|
||||
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
|
||||
}
|
||||
|
||||
// Test JSON field with mixed type categories in OR - should NOT merge
|
||||
func TestRewrite_JSON_MixedTypes_OR_NoMerge(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["data"] > 100 or JSONField["data"] > "text"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as binary expr (OR)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
}
|
||||
|
||||
// Test dynamic field ($meta) - same as JSON field
|
||||
func TestRewrite_DynamicField_Int_AND_Strengthen(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// $meta is the dynamic field
|
||||
expr, err := parser.ParseExpr(helper, `$meta["count"] > 5 and $meta["count"] > 15`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(15), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test dynamic field - AND to BinaryRange
|
||||
func TestRewrite_DynamicField_Float_AND_ToBinaryRange(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `$meta["rating"] >= 1.0 and $meta["rating"] <= 5.0`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, true, bre.GetLowerInclusive())
|
||||
require.Equal(t, true, bre.GetUpperInclusive())
|
||||
require.InDelta(t, 1.0, bre.GetLowerValue().GetFloatVal(), 1e-9)
|
||||
require.InDelta(t, 5.0, bre.GetUpperValue().GetFloatVal(), 1e-9)
|
||||
}
|
||||
|
||||
// Test different nested paths - should NOT merge
|
||||
func TestRewrite_JSON_DifferentPaths_NoMerge(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["a"] > 10 and JSONField["b"] > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "different paths should not merge")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
}
|
||||
|
||||
// Test nested JSON path
|
||||
func TestRewrite_JSON_NestedPath_AND_Strengthen(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["user"]["age"] > 18 and JSONField["user"]["age"] > 21`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(21), ure.GetValue().GetInt64Val())
|
||||
require.Equal(t, []string{"user", "age"}, ure.GetColumnInfo().GetNestedPath())
|
||||
}
|
||||
|
||||
// Test inclusive vs exclusive bounds
|
||||
func TestRewrite_JSON_AND_EquivalentBounds(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// JSONField["x"] >= 10 AND JSONField["x"] > 10 → JSONField["x"] > 10 (exclusive is stronger)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["x"] >= 10 and JSONField["x"] > 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test OR with inclusive bounds preference
|
||||
func TestRewrite_JSON_OR_EquivalentBounds(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
// JSONField["x"] >= 10 OR JSONField["x"] > 10 → JSONField["x"] >= 10 (inclusive is weaker)
|
||||
expr, err := parser.ParseExpr(helper, `JSONField["x"] >= 10 or JSONField["x"] > 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test that scalar field and JSON field don't interfere
|
||||
func TestRewrite_JSON_And_Scalar_Independent(t *testing.T) {
|
||||
helper := buildSchemaHelperWithJSON(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 20 and JSONField["price"] > 5 and JSONField["price"] > 15`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
// Should result in AND of two optimized predicates
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
|
||||
// Collect all predicates
|
||||
var predicates []*planpb.Expr
|
||||
var collect func(*planpb.Expr)
|
||||
collect = func(e *planpb.Expr) {
|
||||
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd {
|
||||
collect(be.GetLeft())
|
||||
collect(be.GetRight())
|
||||
} else {
|
||||
predicates = append(predicates, e)
|
||||
}
|
||||
}
|
||||
collect(expr)
|
||||
|
||||
require.Equal(t, 2, len(predicates), "should have two optimized predicates")
|
||||
|
||||
// Check that both are optimized to the stronger bounds
|
||||
var hasInt64, hasJSON bool
|
||||
for _, p := range predicates {
|
||||
if ure := p.GetUnaryRangeExpr(); ure != nil {
|
||||
col := ure.GetColumnInfo()
|
||||
if col.GetDataType() == schemapb.DataType_Int64 {
|
||||
hasInt64 = true
|
||||
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
|
||||
} else if col.GetDataType() == schemapb.DataType_JSON {
|
||||
hasJSON = true
|
||||
require.Equal(t, int64(15), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, hasInt64, "should have optimized Int64Field predicate")
|
||||
require.True(t, hasJSON, "should have optimized JSONField predicate")
|
||||
}
|
||||
492
internal/parser/planparserv2/rewriter/range_test.go
Normal file
492
internal/parser/planparserv2/rewriter/range_test.go
Normal file
@ -0,0 +1,492 @@
|
||||
package rewriter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func TestRewrite_Range_AND_Strengthen(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_AND_Strengthen_Upper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field < 50 and Int64Field < 60`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
|
||||
require.Equal(t, int64(50), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_AND_EquivalentBounds(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// a ≥ x AND a > x → a > x
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field >= 10 and Int64Field > 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
// a ≤ y AND a < y → a < y
|
||||
expr, err = parser.ParseExpr(helper, `Int64Field <= 10 and Int64Field < 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure = expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_OR_Weaken(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_OR_Weaken_Upper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field < 10 or Int64Field < 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
|
||||
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_OR_EquivalentBounds(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// a ≥ x OR a > x → a ≥ x
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field >= 10 or Int64Field > 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
// a ≤ y OR a < y → a ≤ y
|
||||
expr, err = parser.ParseExpr(helper, `Int64Field <= 10 or Int64Field < 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure = expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_LessEqual, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_AND_ToBinaryRange(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field < 50`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, false, bre.GetLowerInclusive())
|
||||
require.Equal(t, false, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_AND_ToBinaryRange_Inclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field >= 10 and Int64Field <= 50`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.Equal(t, true, bre.GetLowerInclusive())
|
||||
require.Equal(t, true, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_OR_MixedDirection_NoMerge(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 or Int64Field < 5`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
|
||||
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
|
||||
}
|
||||
|
||||
// Edge cases for Float/Double columns: allow mixing int and float literals.
|
||||
func TestRewrite_Range_AND_Strengthen_Float_Mixed(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `FloatField > 10 and FloatField > 15.0`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.InDelta(t, 15.0, ure.GetValue().GetFloatVal(), 1e-9)
|
||||
}
|
||||
|
||||
func TestRewrite_Range_AND_ToBinaryRange_Float_Mixed(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `FloatField > 10 and FloatField < 20.5`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.False(t, bre.GetLowerInclusive())
|
||||
require.False(t, bre.GetUpperInclusive())
|
||||
// lower may be encoded as int or float; assert either encoding equals 10
|
||||
lv := bre.GetLowerValue()
|
||||
switch lv.GetVal().(type) {
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
require.Equal(t, int64(10), lv.GetInt64Val())
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
require.InDelta(t, 10.0, lv.GetFloatVal(), 1e-9)
|
||||
default:
|
||||
t.Fatalf("unexpected lower value type")
|
||||
}
|
||||
// upper is float literal 20.5
|
||||
require.InDelta(t, 20.5, bre.GetUpperValue().GetFloatVal(), 1e-9)
|
||||
}
|
||||
|
||||
func TestRewrite_Range_OR_Weaken_Float_Mixed(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `FloatField > 10 or FloatField > 20.5`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
// weakest lower is 10; value may be encoded as int or float
|
||||
switch ure.GetValue().GetVal().(type) {
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
require.InDelta(t, 10.0, ure.GetValue().GetFloatVal(), 1e-9)
|
||||
default:
|
||||
t.Fatalf("unexpected value type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewrite_Range_AND_NoMerge_Int_WithFloatLiteral(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
_, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field > 15.0`, nil)
|
||||
// keeping this test so that we know the parser will not accept this expression, so we
|
||||
// don't need to optimize it.
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRewrite_Range_Tie_Inclusive_Float(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// OR: >=10 or >10 -> >=10 (weaken prefers inclusive)
|
||||
expr, err := parser.ParseExpr(helper, `FloatField >= 10 or FloatField > 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterEqual, ure.GetOp())
|
||||
// AND: >=10 and >10 -> >10 (tighten prefers strict)
|
||||
expr, err = parser.ParseExpr(helper, `FloatField >= 10 and FloatField > 10`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure = expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
}
|
||||
|
||||
// VarChar range optimization tests
|
||||
func TestRewrite_Range_VarChar_AND_Strengthen(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `VarCharField > "a" and VarCharField > "b"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, "b", ure.GetValue().GetStringVal())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_VarChar_OR_Weaken_Upper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `VarCharField < "m" or VarCharField < "z"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_LessThan, ure.GetOp())
|
||||
require.Equal(t, "z", ure.GetValue().GetStringVal())
|
||||
}
|
||||
|
||||
// Array fields: ensure parser rejects direct range comparison on arrays for different element types.
|
||||
// If in the future parser supports range on arrays, these tests can be updated accordingly.
|
||||
func buildSchemaHelperWithArraysT(t *testing.T) *typeutil.SchemaHelper {
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{FieldID: 201, Name: "ArrayInt", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
|
||||
{FieldID: 202, Name: "ArrayFloat", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double},
|
||||
{FieldID: 203, Name: "ArrayVarchar", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_VarChar},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "rewrite_array_test",
|
||||
AutoID: false,
|
||||
Fields: fields,
|
||||
}
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
require.NoError(t, err)
|
||||
return helper
|
||||
}
|
||||
|
||||
func TestRewrite_Range_Array_Int_NotSupported(t *testing.T) {
|
||||
helper := buildSchemaHelperWithArraysT(t)
|
||||
_, err := parser.ParseExpr(helper, `ArrayInt > 10`, nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRewrite_Range_Array_Float_NotSupported(t *testing.T) {
|
||||
helper := buildSchemaHelperWithArraysT(t)
|
||||
_, err := parser.ParseExpr(helper, `ArrayFloat > 10.5`, nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRewrite_Range_Array_VarChar_NotSupported(t *testing.T) {
|
||||
helper := buildSchemaHelperWithArraysT(t)
|
||||
_, err := parser.ParseExpr(helper, `ArrayVarchar > "a"`, nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// Array index access optimizations
|
||||
func TestRewrite_Range_ArrayInt_Index_AND_Strengthen(t *testing.T) {
|
||||
helper := buildSchemaHelperWithArraysT(t)
|
||||
expr, err := parser.ParseExpr(helper, `ArrayInt[0] > 10 and ArrayInt[0] > 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Range_ArrayFloat_Index_OR_Weaken_Mixed(t *testing.T) {
|
||||
helper := buildSchemaHelperWithArraysT(t)
|
||||
expr, err := parser.ParseExpr(helper, `ArrayFloat[0] > 10 or ArrayFloat[0] > 20.5`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
switch ure.GetValue().GetVal().(type) {
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
require.InDelta(t, 10.0, ure.GetValue().GetFloatVal(), 1e-9)
|
||||
default:
|
||||
t.Fatalf("unexpected value type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewrite_Range_ArrayVarChar_Index_ToBinaryRange(t *testing.T) {
|
||||
helper := buildSchemaHelperWithArraysT(t)
|
||||
expr, err := parser.ParseExpr(helper, `ArrayVarchar[0] > "a" and ArrayVarchar[0] < "m"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre)
|
||||
require.False(t, bre.GetLowerInclusive())
|
||||
require.False(t, bre.GetUpperInclusive())
|
||||
require.Equal(t, "a", bre.GetLowerValue().GetStringVal())
|
||||
require.Equal(t, "m", bre.GetUpperValue().GetStringVal())
|
||||
}
|
||||
|
||||
// helper to flatten AND tree into list of exprs
|
||||
func collectAndExprs(e *planpb.Expr, out *[]*planpb.Expr) {
|
||||
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalAnd {
|
||||
collectAndExprs(be.GetLeft(), out)
|
||||
collectAndExprs(be.GetRight(), out)
|
||||
return
|
||||
}
|
||||
*out = append(*out, e)
|
||||
}
|
||||
|
||||
func TestRewrite_Range_Array_Index_Different_NoMerge(t *testing.T) {
|
||||
helper := buildSchemaHelperWithArraysT(t)
|
||||
expr, err := parser.ParseExpr(helper, `ArrayInt[0] > 10 and ArrayInt[1] > 20 and ArrayInt[0] < 20`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
parts := []*planpb.Expr{}
|
||||
collectAndExprs(expr, &parts)
|
||||
// expect exactly two parts after rewrite: interval on index 0, and lower bound on index 1
|
||||
require.Equal(t, 2, len(parts))
|
||||
var seenInterval, seenLower bool
|
||||
for _, p := range parts {
|
||||
if bre := p.GetBinaryRangeExpr(); bre != nil {
|
||||
seenInterval = true
|
||||
// 10 < ArrayInt[0] < 20
|
||||
require.False(t, bre.GetLowerInclusive())
|
||||
require.False(t, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(10), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(20), bre.GetUpperValue().GetInt64Val())
|
||||
continue
|
||||
}
|
||||
if ure := p.GetUnaryRangeExpr(); ure != nil {
|
||||
seenLower = true
|
||||
require.True(t, ure.GetOp() == planpb.OpType_GreaterThan || ure.GetOp() == planpb.OpType_GreaterEqual)
|
||||
// bound value 20 on index 1 lower side
|
||||
require.Equal(t, int64(20), ure.GetValue().GetInt64Val())
|
||||
continue
|
||||
}
|
||||
// should not reach here: only BinaryRangeExpr and UnaryRangeExpr expected
|
||||
t.Fatalf("unexpected expr kind in AND parts")
|
||||
}
|
||||
require.True(t, seenInterval)
|
||||
require.True(t, seenLower)
|
||||
}
|
||||
|
||||
// Test invalid BinaryRangeExpr: lower > upper → false
|
||||
func TestRewrite_Range_AND_InvalidRange_LowerGreaterThanUpper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field > 100 AND Int64Field < 50 → false (impossible range)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 100 and Int64Field < 50`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
// Test invalid BinaryRangeExpr: lower == upper with exclusive bounds → false
|
||||
func TestRewrite_Range_AND_InvalidRange_EqualBoundsExclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field > 50 AND Int64Field < 50 → false (exclusive on equal bounds)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 50 and Int64Field < 50`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
// Test invalid BinaryRangeExpr: lower == upper with one exclusive → false
|
||||
func TestRewrite_Range_AND_InvalidRange_EqualBoundsOneExclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field >= 50 AND Int64Field < 50 → false (one exclusive on equal bounds)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field >= 50 and Int64Field < 50`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
// Test valid BinaryRangeExpr: lower == upper with both inclusive → valid
|
||||
func TestRewrite_Range_AND_ValidRange_EqualBoundsBothInclusive(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field >= 50 AND Int64Field <= 50 → (50 <= x <= 50), which is valid (x == 50)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field >= 50 and Int64Field <= 50`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
bre := expr.GetBinaryRangeExpr()
|
||||
require.NotNil(t, bre, "should create valid binary range for x == 50")
|
||||
require.Equal(t, true, bre.GetLowerInclusive())
|
||||
require.Equal(t, true, bre.GetUpperInclusive())
|
||||
require.Equal(t, int64(50), bre.GetLowerValue().GetInt64Val())
|
||||
require.Equal(t, int64(50), bre.GetUpperValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test invalid BinaryRangeExpr with float: lower > upper → false
|
||||
func TestRewrite_Range_AND_InvalidRange_Float_LowerGreaterThanUpper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// FloatField > 99.9 AND FloatField < 10.5 → false
|
||||
expr, err := parser.ParseExpr(helper, `FloatField > 99.9 and FloatField < 10.5`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
// Test invalid BinaryRangeExpr with string: lower > upper → false
|
||||
func TestRewrite_Range_AND_InvalidRange_String_LowerGreaterThanUpper(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// VarCharField > "zebra" AND VarCharField < "apple" → false
|
||||
expr, err := parser.ParseExpr(helper, `VarCharField > "zebra" and VarCharField < "apple"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
// Test AlwaysFalse propagation through nested AND expressions
|
||||
func TestRewrite_AlwaysFalse_Propagation_DeepNesting(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Deep nesting: (Int64Field > 10) AND ((Int64Field > 20) AND (Int64Field > 100 AND Int64Field < 50))
|
||||
// The innermost (Int64Field > 100 AND Int64Field < 50) should become AlwaysFalse
|
||||
// This AlwaysFalse should propagate up through all ANDs, making the entire expression AlwaysFalse
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10) and ((Int64Field > 20) and (Int64Field > 100 and Int64Field < 50))`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
// Should propagate to become AlwaysFalse at top level
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr), "AlwaysFalse should propagate to top level")
|
||||
}
|
||||
|
||||
// Test AlwaysFalse elimination in OR expressions
|
||||
func TestRewrite_AlwaysFalse_Elimination_InOR(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (Int64Field > 10) OR ((Int64Field > 20) OR (Int64Field > 100 AND Int64Field < 50))
|
||||
// The innermost becomes AlwaysFalse, should be eliminated from OR
|
||||
// Result should be: Int64Field > 10 OR Int64Field > 20 → Int64Field > 10 (weaker bound)
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10) or ((Int64Field > 20) or (Int64Field > 100 and Int64Field < 50))`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
// Should simplify to single range condition: Int64Field > 10
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure, "AlwaysFalse should be eliminated, leaving simplified range")
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test complex double negation: NOT NOT AlwaysTrue → AlwaysTrue
|
||||
func TestRewrite_DoubleNegation_ToAlwaysTrue(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `not (Int64Field > 100 and Int64Field < 50)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysTrueExpr(expr))
|
||||
}
|
||||
|
||||
// Test complex nested double negation with multiple layers
|
||||
func TestRewrite_ComplexDoubleNegation_MultiLayer(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `not ((Int64Field > 100 and Int64Field < 50) or (FloatField > 99.9 and FloatField < 10.5))`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysTrueExpr(expr))
|
||||
}
|
||||
|
||||
// Test AlwaysTrue in AND with normal conditions gets eliminated
|
||||
func TestRewrite_AlwaysTrue_Elimination_InAND(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// (Int64Field > 10) AND NOT(Int64Field > 100 AND Int64Field < 50)
|
||||
// The second part becomes AlwaysTrue, should be eliminated from AND
|
||||
// Result should be just: Int64Field > 10
|
||||
expr, err := parser.ParseExpr(helper, `(Int64Field > 10) and not (Int64Field > 100 and Int64Field < 50)`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_GreaterThan, ure.GetOp())
|
||||
require.Equal(t, int64(10), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
666
internal/parser/planparserv2/rewriter/term_in.go
Normal file
666
internal/parser/planparserv2/rewriter/term_in.go
Normal file
@ -0,0 +1,666 @@
|
||||
package rewriter
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
)
|
||||
|
||||
func (v *visitor) combineOrEqualsToIn(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
values []*planpb.GenericValue
|
||||
origIndices []int
|
||||
valCase string
|
||||
}
|
||||
others := make([]*planpb.Expr, 0, len(parts))
|
||||
groups := make(map[string]*group)
|
||||
indexToExpr := parts
|
||||
for idx, e := range parts {
|
||||
u := e.GetUnaryRangeExpr()
|
||||
if u == nil || u.GetOp() != planpb.OpType_Equal || u.GetValue() == nil {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
col := u.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
key := columnKey(col)
|
||||
g, ok := groups[key]
|
||||
valCase := valueCase(u.GetValue())
|
||||
if !ok {
|
||||
g = &group{col: col, values: []*planpb.GenericValue{}, origIndices: []int{}, valCase: valCase}
|
||||
groups[key] = g
|
||||
}
|
||||
if g.valCase != valCase {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
g.values = append(g.values, u.GetValue())
|
||||
g.origIndices = append(g.origIndices, idx)
|
||||
}
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
out = append(out, others...)
|
||||
for _, g := range groups {
|
||||
if shouldMergeToIn(g.col.GetDataType(), len(g.values)) {
|
||||
g.values = sortGenericValues(g.values)
|
||||
out = append(out, newTermExpr(g.col, g.values))
|
||||
} else {
|
||||
for _, i := range g.origIndices {
|
||||
out = append(out, indexToExpr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *visitor) combineAndNotEqualsToNotIn(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
values []*planpb.GenericValue
|
||||
origIndices []int
|
||||
valCase string
|
||||
}
|
||||
others := make([]*planpb.Expr, 0, len(parts))
|
||||
groups := make(map[string]*group)
|
||||
indexToExpr := parts
|
||||
for idx, e := range parts {
|
||||
u := e.GetUnaryRangeExpr()
|
||||
if u == nil || u.GetOp() != planpb.OpType_NotEqual || u.GetValue() == nil {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
col := u.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
key := columnKey(col)
|
||||
g, ok := groups[key]
|
||||
valCase := valueCase(u.GetValue())
|
||||
if !ok {
|
||||
g = &group{col: col, values: []*planpb.GenericValue{}, origIndices: []int{}, valCase: valCase}
|
||||
groups[key] = g
|
||||
}
|
||||
if g.valCase != valCase {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
g.values = append(g.values, u.GetValue())
|
||||
g.origIndices = append(g.origIndices, idx)
|
||||
}
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
out = append(out, others...)
|
||||
for _, g := range groups {
|
||||
if shouldMergeToIn(g.col.GetDataType(), len(g.values)) {
|
||||
g.values = sortGenericValues(g.values)
|
||||
in := newTermExpr(g.col, g.values)
|
||||
out = append(out, notExpr(in))
|
||||
} else {
|
||||
for _, i := range g.origIndices {
|
||||
out = append(out, indexToExpr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func notExpr(child *planpb.Expr) *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_UnaryExpr{
|
||||
UnaryExpr: &planpb.UnaryExpr{
|
||||
Op: planpb.UnaryExpr_Not,
|
||||
Child: child,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AND: (a IN S) AND (a = v) with v in S -> a = v
|
||||
func (v *visitor) combineAndInWithEqual(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type agg struct {
|
||||
termIdxs []int
|
||||
eqIdxs []int
|
||||
term *planpb.TermExpr
|
||||
eqValues []*planpb.GenericValue
|
||||
col *planpb.ColumnInfo
|
||||
}
|
||||
groups := map[string]*agg{}
|
||||
others := []int{}
|
||||
for idx, e := range parts {
|
||||
if te := e.GetTermExpr(); te != nil {
|
||||
k := columnKey(te.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &agg{col: te.GetColumnInfo()}
|
||||
}
|
||||
g.termIdxs = append(g.termIdxs, idx)
|
||||
g.term = te
|
||||
groups[k] = g
|
||||
continue
|
||||
}
|
||||
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_Equal && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
|
||||
k := columnKey(ue.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &agg{col: ue.GetColumnInfo()}
|
||||
}
|
||||
g.eqIdxs = append(g.eqIdxs, idx)
|
||||
g.eqValues = append(g.eqValues, ue.GetValue())
|
||||
groups[k] = g
|
||||
continue
|
||||
}
|
||||
others = append(others, idx)
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if g.term == nil || len(g.eqIdxs) == 0 {
|
||||
continue
|
||||
}
|
||||
// Build set of eq values and check presence in term set.
|
||||
termVals := g.term.GetValues()
|
||||
eqUnique := []*planpb.GenericValue{}
|
||||
for _, ev := range g.eqValues {
|
||||
dup := false
|
||||
for _, u := range eqUnique {
|
||||
if equalsGeneric(u, ev) {
|
||||
dup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !dup {
|
||||
eqUnique = append(eqUnique, ev)
|
||||
}
|
||||
}
|
||||
// If multiple different equals present, AND implies contradiction unless identical.
|
||||
if len(eqUnique) > 1 {
|
||||
for _, ti := range g.termIdxs {
|
||||
used[ti] = true
|
||||
}
|
||||
for _, ei := range g.eqIdxs {
|
||||
used[ei] = true
|
||||
}
|
||||
// emit constant false
|
||||
out = append(out, newAlwaysFalseExpr())
|
||||
continue
|
||||
}
|
||||
// Single equal value
|
||||
ev := eqUnique[0]
|
||||
inSet := false
|
||||
for _, tv := range termVals {
|
||||
if equalsGeneric(tv, ev) {
|
||||
inSet = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, ti := range g.termIdxs {
|
||||
used[ti] = true
|
||||
}
|
||||
for _, ei := range g.eqIdxs {
|
||||
used[ei] = true
|
||||
}
|
||||
if inSet {
|
||||
// reduce to equality
|
||||
out = append(out, newUnaryRangeExpr(g.col, planpb.OpType_Equal, ev))
|
||||
} else {
|
||||
// contradiction -> false
|
||||
out = append(out, newAlwaysFalseExpr())
|
||||
}
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// OR: (a IN S) OR (a = v) with v in S -> keep a IN S (drop equal)
|
||||
// Optional extension (not enabled here): if v not in S, could union.
|
||||
func (v *visitor) combineOrInWithEqual(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type agg struct {
|
||||
termIdx int
|
||||
term *planpb.TermExpr
|
||||
col *planpb.ColumnInfo
|
||||
eqIdxs []int
|
||||
eqVals []*planpb.GenericValue
|
||||
}
|
||||
groups := map[string]*agg{}
|
||||
others := []int{}
|
||||
for idx, e := range parts {
|
||||
if te := e.GetTermExpr(); te != nil {
|
||||
k := columnKey(te.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &agg{col: te.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.termIdx = idx
|
||||
g.term = te
|
||||
continue
|
||||
}
|
||||
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_Equal && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
|
||||
k := columnKey(ue.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &agg{col: ue.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.eqIdxs = append(g.eqIdxs, idx)
|
||||
g.eqVals = append(g.eqVals, ue.GetValue())
|
||||
continue
|
||||
}
|
||||
others = append(others, idx)
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if g.term == nil || len(g.eqIdxs) == 0 {
|
||||
continue
|
||||
}
|
||||
// union all equal values into term set
|
||||
union := g.term.GetValues()
|
||||
for i, ev := range g.eqVals {
|
||||
union = append(union, ev)
|
||||
used[g.eqIdxs[i]] = true
|
||||
}
|
||||
union = sortGenericValues(union)
|
||||
used[g.termIdx] = true
|
||||
out = append(out, newTermExpr(g.col, union))
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AND: (a IN S) AND (range) -> filter S by range
|
||||
func (v *visitor) combineAndInWithRange(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
termIdx int
|
||||
term *planpb.TermExpr
|
||||
lower *planpb.GenericValue
|
||||
lowerInc bool
|
||||
upper *planpb.GenericValue
|
||||
upperInc bool
|
||||
rangeIdxs []int
|
||||
}
|
||||
groups := map[string]*group{}
|
||||
others := []int{}
|
||||
isRange := func(op planpb.OpType) bool {
|
||||
return op == planpb.OpType_GreaterThan || op == planpb.OpType_GreaterEqual || op == planpb.OpType_LessThan || op == planpb.OpType_LessEqual
|
||||
}
|
||||
for idx, e := range parts {
|
||||
if te := e.GetTermExpr(); te != nil {
|
||||
k := columnKey(te.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &group{col: te.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.term = te
|
||||
g.termIdx = idx
|
||||
continue
|
||||
}
|
||||
if ue := e.GetUnaryRangeExpr(); ue != nil && isRange(ue.GetOp()) && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
|
||||
k := columnKey(ue.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &group{col: ue.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
if ue.GetOp() == planpb.OpType_GreaterThan || ue.GetOp() == planpb.OpType_GreaterEqual {
|
||||
if g.lower == nil || cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.lower) > 0 || (cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.lower) == 0 && ue.GetOp() == planpb.OpType_GreaterThan && g.lowerInc) {
|
||||
g.lower = ue.GetValue()
|
||||
g.lowerInc = ue.GetOp() == planpb.OpType_GreaterEqual
|
||||
}
|
||||
} else {
|
||||
if g.upper == nil || cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.upper) < 0 || (cmpGeneric(effectiveDataType(g.col), ue.GetValue(), g.upper) == 0 && ue.GetOp() == planpb.OpType_LessThan && g.upperInc) {
|
||||
g.upper = ue.GetValue()
|
||||
g.upperInc = ue.GetOp() == planpb.OpType_LessEqual
|
||||
}
|
||||
}
|
||||
g.rangeIdxs = append(g.rangeIdxs, idx)
|
||||
continue
|
||||
}
|
||||
others = append(others, idx)
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if g.term == nil || (g.lower == nil && g.upper == nil) {
|
||||
continue
|
||||
}
|
||||
// Skip optimization if any term value is not comparable with the provided bounds
|
||||
termVals := g.term.GetValues()
|
||||
comparable := true
|
||||
for _, tv := range termVals {
|
||||
if g.lower != nil {
|
||||
if !(areComparableCases(valueCaseWithNil(tv), valueCaseWithNil(g.lower)) || (isNumericCase(valueCaseWithNil(tv)) && isNumericCase(valueCaseWithNil(g.lower)))) {
|
||||
comparable = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if comparable && g.upper != nil {
|
||||
if !(areComparableCases(valueCaseWithNil(tv), valueCaseWithNil(g.upper)) || (isNumericCase(valueCaseWithNil(tv)) && isNumericCase(valueCaseWithNil(g.upper)))) {
|
||||
comparable = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !comparable {
|
||||
continue
|
||||
}
|
||||
filtered := filterValuesByRange(effectiveDataType(g.col), termVals, g.lower, g.lowerInc, g.upper, g.upperInc)
|
||||
used[g.termIdx] = true
|
||||
for _, ri := range g.rangeIdxs {
|
||||
used[ri] = true
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
// Empty IN list after filtering → AlwaysFalse
|
||||
out = append(out, newAlwaysFalseExpr())
|
||||
} else {
|
||||
out = append(out, newTermExpr(g.col, filtered))
|
||||
}
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// OR: (a IN S1) OR (a IN S2) -> a IN union(S1, S2)
|
||||
func (v *visitor) combineOrInWithIn(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type agg struct {
|
||||
col *planpb.ColumnInfo
|
||||
idxs []int
|
||||
values [][]*planpb.GenericValue
|
||||
}
|
||||
groups := map[string]*agg{}
|
||||
others := []int{}
|
||||
for idx, e := range parts {
|
||||
if te := e.GetTermExpr(); te != nil {
|
||||
k := columnKey(te.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &agg{col: te.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.idxs = append(g.idxs, idx)
|
||||
g.values = append(g.values, te.GetValues())
|
||||
continue
|
||||
}
|
||||
others = append(others, idx)
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if len(g.idxs) <= 1 {
|
||||
continue
|
||||
}
|
||||
union := []*planpb.GenericValue{}
|
||||
for _, vs := range g.values {
|
||||
union = append(union, vs...)
|
||||
}
|
||||
union = sortGenericValues(union)
|
||||
for _, i := range g.idxs {
|
||||
used[i] = true
|
||||
}
|
||||
out = append(out, newTermExpr(g.col, union))
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AND: (a IN S1) AND (a IN S2) ... -> a IN intersection(S1, S2, ...)
|
||||
func (v *visitor) combineAndInWithIn(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type agg struct {
|
||||
col *planpb.ColumnInfo
|
||||
idxs []int
|
||||
values [][]*planpb.GenericValue
|
||||
}
|
||||
groups := map[string]*agg{}
|
||||
others := []int{}
|
||||
for idx, e := range parts {
|
||||
if te := e.GetTermExpr(); te != nil {
|
||||
k := columnKey(te.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &agg{col: te.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.idxs = append(g.idxs, idx)
|
||||
g.values = append(g.values, te.GetValues())
|
||||
continue
|
||||
}
|
||||
others = append(others, idx)
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, idx := range others {
|
||||
out = append(out, parts[idx])
|
||||
used[idx] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if len(g.idxs) <= 1 {
|
||||
continue
|
||||
}
|
||||
// compute intersection; start from first set
|
||||
inter := make([]*planpb.GenericValue, 0, len(g.values[0]))
|
||||
outer:
|
||||
for _, v := range g.values[0] {
|
||||
// check in every other set
|
||||
ok := true
|
||||
for i := 1; i < len(g.values); i++ {
|
||||
found := false
|
||||
for _, w := range g.values[i] {
|
||||
if equalsGeneric(v, w) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
if ok {
|
||||
inter = append(inter, v)
|
||||
}
|
||||
}
|
||||
for _, i := range g.idxs {
|
||||
used[i] = true
|
||||
}
|
||||
if len(inter) == 0 {
|
||||
out = append(out, newAlwaysFalseExpr())
|
||||
} else {
|
||||
out = append(out, newTermExpr(g.col, inter))
|
||||
}
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AND: (a IN S) AND (a != d) -> remove d from S; empty -> false
|
||||
func (v *visitor) combineAndInWithNotEqual(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
termIdx int
|
||||
term *planpb.TermExpr
|
||||
neqIdxs []int
|
||||
neqVals []*planpb.GenericValue
|
||||
}
|
||||
groups := map[string]*group{}
|
||||
others := []int{}
|
||||
for idx, e := range parts {
|
||||
if te := e.GetTermExpr(); te != nil {
|
||||
k := columnKey(te.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &group{col: te.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.term = te
|
||||
g.termIdx = idx
|
||||
continue
|
||||
}
|
||||
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_NotEqual && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
|
||||
k := columnKey(ue.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &group{col: ue.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.neqIdxs = append(g.neqIdxs, idx)
|
||||
g.neqVals = append(g.neqVals, ue.GetValue())
|
||||
continue
|
||||
}
|
||||
others = append(others, idx)
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, i := range others {
|
||||
out = append(out, parts[i])
|
||||
used[i] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if g.term == nil || len(g.neqIdxs) == 0 {
|
||||
continue
|
||||
}
|
||||
filtered := []*planpb.GenericValue{}
|
||||
for _, tv := range g.term.GetValues() {
|
||||
excluded := false
|
||||
for _, dv := range g.neqVals {
|
||||
if equalsGeneric(tv, dv) {
|
||||
excluded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !excluded {
|
||||
filtered = append(filtered, tv)
|
||||
}
|
||||
}
|
||||
used[g.termIdx] = true
|
||||
for _, ni := range g.neqIdxs {
|
||||
used[ni] = true
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
out = append(out, newAlwaysFalseExpr())
|
||||
} else {
|
||||
out = append(out, newTermExpr(g.col, filtered))
|
||||
}
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// OR: (a IN S) OR (a != d) -> if d ∈ S then true else (a != d)
|
||||
func (v *visitor) combineOrInWithNotEqual(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
termIdx int
|
||||
term *planpb.TermExpr
|
||||
neqIdxs []int
|
||||
neqVals []*planpb.GenericValue
|
||||
}
|
||||
groups := map[string]*group{}
|
||||
others := []int{}
|
||||
for idx, e := range parts {
|
||||
if te := e.GetTermExpr(); te != nil {
|
||||
k := columnKey(te.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &group{col: te.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.term = te
|
||||
g.termIdx = idx
|
||||
continue
|
||||
}
|
||||
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_NotEqual && ue.GetValue() != nil && ue.GetColumnInfo() != nil {
|
||||
k := columnKey(ue.GetColumnInfo())
|
||||
g := groups[k]
|
||||
if g == nil {
|
||||
g = &group{col: ue.GetColumnInfo()}
|
||||
groups[k] = g
|
||||
}
|
||||
g.neqIdxs = append(g.neqIdxs, idx)
|
||||
g.neqVals = append(g.neqVals, ue.GetValue())
|
||||
continue
|
||||
}
|
||||
others = append(others, idx)
|
||||
}
|
||||
used := make([]bool, len(parts))
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
for _, i := range others {
|
||||
out = append(out, parts[i])
|
||||
used[i] = true
|
||||
}
|
||||
for _, g := range groups {
|
||||
if g.term == nil || len(g.neqIdxs) == 0 {
|
||||
continue
|
||||
}
|
||||
// if any neq value is inside IN set -> true
|
||||
containsAny := false
|
||||
for _, dv := range g.neqVals {
|
||||
for _, tv := range g.term.GetValues() {
|
||||
if equalsGeneric(tv, dv) {
|
||||
containsAny = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if containsAny {
|
||||
break
|
||||
}
|
||||
}
|
||||
if containsAny {
|
||||
used[g.termIdx] = true
|
||||
for _, ni := range g.neqIdxs {
|
||||
used[ni] = true
|
||||
}
|
||||
out = append(out, newBoolConstExpr(true))
|
||||
} else {
|
||||
// drop the IN; keep != as-is
|
||||
used[g.termIdx] = true
|
||||
}
|
||||
}
|
||||
for i := range parts {
|
||||
if !used[i] {
|
||||
out = append(out, parts[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
356
internal/parser/planparserv2/rewriter/term_in_test.go
Normal file
356
internal/parser/planparserv2/rewriter/term_in_test.go
Normal file
@ -0,0 +1,356 @@
|
||||
package rewriter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2/rewriter"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func buildSchemaHelperForRewriteT(t *testing.T) *typeutil.SchemaHelper {
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{FieldID: 101, Name: "Int64Field", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 102, Name: "VarCharField", DataType: schemapb.DataType_VarChar},
|
||||
{FieldID: 103, Name: "StringField", DataType: schemapb.DataType_String},
|
||||
{FieldID: 104, Name: "FloatField", DataType: schemapb.DataType_Double},
|
||||
{FieldID: 105, Name: "BoolField", DataType: schemapb.DataType_Bool},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "rewrite_test",
|
||||
AutoID: false,
|
||||
Fields: fields,
|
||||
}
|
||||
// enable text_match on string-like fields
|
||||
for _, f := range schema.Fields {
|
||||
if typeutil.IsStringType(f.DataType) {
|
||||
f.TypeParams = append(f.TypeParams, &commonpb.KeyValuePair{
|
||||
Key: "enable_match",
|
||||
Value: "True",
|
||||
})
|
||||
}
|
||||
}
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
require.NoError(t, err)
|
||||
return helper
|
||||
}
|
||||
|
||||
func TestRewrite_OREquals_ToIN_NonNumeric(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `VarCharField == "a" or VarCharField == "b"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term, "expected OR-equals to be rewritten to TermExpr(IN ...)")
|
||||
require.Equal(t, 2, len(term.GetValues()))
|
||||
require.Equal(t, "a", term.GetValues()[0].GetStringVal())
|
||||
require.Equal(t, "b", term.GetValues()[1].GetStringVal())
|
||||
}
|
||||
|
||||
func TestRewrite_OREquals_NotMerged_OnNumericBelowThreshold(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field == 1 or Int64Field == 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.Nil(t, expr.GetTermExpr(), "numeric OR-equals should not merge to IN under threshold")
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
|
||||
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
|
||||
require.Equal(t, planpb.OpType_Equal, be.GetLeft().GetUnaryRangeExpr().GetOp())
|
||||
require.Equal(t, planpb.OpType_Equal, be.GetRight().GetUnaryRangeExpr().GetOp())
|
||||
}
|
||||
|
||||
func TestRewrite_Term_SortAndDedup_String(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `VarCharField in ["b","a","b","a"]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, 2, len(term.GetValues()))
|
||||
require.Equal(t, "a", term.GetValues()[0].GetStringVal())
|
||||
require.Equal(t, "b", term.GetValues()[1].GetStringVal())
|
||||
}
|
||||
|
||||
func TestRewrite_Term_SortAndDedup_Int(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [9,4,6,6,7]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, 4, len(term.GetValues()))
|
||||
got := []int64{
|
||||
term.GetValues()[0].GetInt64Val(),
|
||||
term.GetValues()[1].GetInt64Val(),
|
||||
term.GetValues()[2].GetInt64Val(),
|
||||
term.GetValues()[3].GetInt64Val(),
|
||||
}
|
||||
require.ElementsMatch(t, []int64{4, 6, 7, 9}, got)
|
||||
}
|
||||
|
||||
func TestRewrite_NotIn_SortAndDedup_Int(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field not in [4,4,3]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
un := expr.GetUnaryExpr()
|
||||
require.NotNil(t, un)
|
||||
term := un.GetChild().GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, 2, len(term.GetValues()))
|
||||
require.Equal(t, int64(3), term.GetValues()[0].GetInt64Val())
|
||||
require.Equal(t, int64(4), term.GetValues()[1].GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_NotIn_SortAndDedup_Float(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `FloatField not in [4.0,4,3.0]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
un := expr.GetUnaryExpr()
|
||||
require.NotNil(t, un)
|
||||
term := un.GetChild().GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, 2, len(term.GetValues()))
|
||||
require.Equal(t, 3.0, term.GetValues()[0].GetFloatVal())
|
||||
require.Equal(t, 4.0, term.GetValues()[1].GetFloatVal())
|
||||
}
|
||||
|
||||
func TestRewrite_In_SortAndDedup_Bool(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `BoolField in [true,false,false,true]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, 2, len(term.GetValues()))
|
||||
require.Equal(t, false, term.GetValues()[0].GetBoolVal())
|
||||
require.Equal(t, true, term.GetValues()[1].GetBoolVal())
|
||||
}
|
||||
|
||||
func TestRewrite_Flatten_Then_OR_ToIN(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `VarCharField == "a" or (VarCharField == "b" or VarCharField == "c") or VarCharField == "d"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term, "nested OR-equals should flatten and merge to IN")
|
||||
require.Equal(t, 4, len(term.GetValues()))
|
||||
got := []string{
|
||||
term.GetValues()[0].GetStringVal(),
|
||||
term.GetValues()[1].GetStringVal(),
|
||||
term.GetValues()[2].GetStringVal(),
|
||||
term.GetValues()[3].GetStringVal(),
|
||||
}
|
||||
require.ElementsMatch(t, []string{"a", "b", "c", "d"}, got)
|
||||
}
|
||||
|
||||
func TestRewrite_And_In_And_Equal_VInSet_ReducesToEqual(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field == 3`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_Equal, ure.GetOp())
|
||||
require.Equal(t, int64(3), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_And_In_And_Equal_VNotInSet_False(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field == 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
func TestRewrite_Or_In_Or_Equal_Union(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3] or Int64Field == 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, []int64{1, 2, 3}, []int64{
|
||||
term.GetValues()[0].GetInt64Val(),
|
||||
term.GetValues()[1].GetInt64Val(),
|
||||
term.GetValues()[2].GetInt64Val(),
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewrite_And_In_With_Range_Filter(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3,5] and Int64Field > 3`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, 1, len(term.GetValues()))
|
||||
require.Equal(t, int64(5), term.GetValues()[0].GetInt64Val())
|
||||
}
|
||||
|
||||
func TestRewrite_Or_In_Union(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,3] or Int64Field in [3,4]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, []int64{1, 3, 4}, []int64{
|
||||
term.GetValues()[0].GetInt64Val(),
|
||||
term.GetValues()[1].GetInt64Val(),
|
||||
term.GetValues()[2].GetInt64Val(),
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewrite_And_In_Intersection(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,2,3] and Int64Field in [2,3,4]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, []int64{2, 3}, []int64{
|
||||
term.GetValues()[0].GetInt64Val(),
|
||||
term.GetValues()[1].GetInt64Val(),
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewrite_And_In_Intersection_Empty_ToFalse(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1] and Int64Field in [2]`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
func TestRewrite_And_In_And_NotEqual_Remove(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,2,3] and Int64Field != 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
term := expr.GetTermExpr()
|
||||
require.NotNil(t, term)
|
||||
require.Equal(t, []int64{1, 3}, []int64{
|
||||
term.GetValues()[0].GetInt64Val(),
|
||||
term.GetValues()[1].GetInt64Val(),
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewrite_And_In_And_NotEqual_AllRemoved_ToFalse(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [2] and Int64Field != 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
require.True(t, rewriter.IsAlwaysFalseExpr(expr))
|
||||
}
|
||||
|
||||
func TestRewrite_Or_In_Or_NotEqual_VInSet_ToTrue(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1,2] or Int64Field != 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
val := expr.GetValueExpr()
|
||||
require.NotNil(t, val)
|
||||
require.Equal(t, true, val.GetValue().GetBoolVal())
|
||||
}
|
||||
|
||||
func TestRewrite_Or_In_Or_NotEqual_VNotInSet_ToNotEqual(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field in [1] or Int64Field != 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure)
|
||||
require.Equal(t, planpb.OpType_NotEqual, ure.GetOp())
|
||||
require.Equal(t, int64(2), ure.GetValue().GetInt64Val())
|
||||
}
|
||||
|
||||
// Test contradictory equals: (a == 1) AND (a == 2) → false
|
||||
// NOTE: This is a known limitation - currently NOT optimized
|
||||
func TestRewrite_And_Equal_And_Equal_Contradiction_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field == 1 AND Int64Field == 2 → false (contradiction)
|
||||
// Currently NOT optimized because equals don't convert to IN in AND context
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field == 1 and Int64Field == 2`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: NOT optimized to false
|
||||
// Remains as BinaryExpr AND with two equal predicates
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as AND (not optimized)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
}
|
||||
|
||||
// Test contradictory equals with three values
|
||||
func TestRewrite_And_Equal_ThreeWay_Contradiction_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field == 1 AND Int64Field == 2 AND Int64Field == 3 → false
|
||||
// Currently NOT optimized
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field == 1 and Int64Field == 2 and Int64Field == 3`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: NOT optimized to false
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as AND chain (not optimized)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
}
|
||||
|
||||
// Test range + contradictory equal: (a > 10) AND (a == 5) → false
|
||||
// NOTE: Current limitation - NOT optimized
|
||||
func TestRewrite_And_Range_And_Equal_Contradiction_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field > 10 AND Int64Field == 5 → false (5 is not > 10)
|
||||
// This requires combining range and equality checks, which is partially implemented
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field == 5`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current behavior: may not optimize to false
|
||||
// If it does optimize via IN+range filtering, it would become false
|
||||
// This test documents current limitation
|
||||
// When fully optimized, should be constant false
|
||||
_ = expr // Test documents that this case exists
|
||||
}
|
||||
|
||||
// Test non-contradictory range + equal: (a > 10) AND (a == 15) → stays as is
|
||||
// NOTE: This requires Equal to be in IN form first, which happens via combineAndInWithEqual
|
||||
// But Equal alone doesn't convert to IN in AND context, so this optimization doesn't happen
|
||||
func TestRewrite_And_Range_And_Equal_NonContradiction_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// Int64Field > 10 AND Int64Field == 15 → should simplify to Int64Field == 15
|
||||
// But currently NOT optimized without IN involved
|
||||
expr, err := parser.ParseExpr(helper, `Int64Field > 10 and Int64Field == 15`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: remains as AND with range and equal
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as AND (not optimized)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
}
|
||||
|
||||
// Test string contradictory equals - current limitation
|
||||
func TestRewrite_And_Equal_String_Contradiction_CurrentLimitation(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
// VarCharField == "apple" AND VarCharField == "banana" → false
|
||||
// Currently NOT optimized
|
||||
expr, err := parser.ParseExpr(helper, `VarCharField == "apple" and VarCharField == "banana"`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
|
||||
// Current limitation: NOT optimized to false
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be, "should remain as AND (not optimized)")
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalAnd, be.GetOp())
|
||||
}
|
||||
78
internal/parser/planparserv2/rewriter/text_match.go
Normal file
78
internal/parser/planparserv2/rewriter/text_match.go
Normal file
@ -0,0 +1,78 @@
|
||||
package rewriter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
)
|
||||
|
||||
func (v *visitor) combineOrTextMatchToMerged(parts []*planpb.Expr) []*planpb.Expr {
|
||||
type group struct {
|
||||
col *planpb.ColumnInfo
|
||||
origIndices []int
|
||||
literals []string
|
||||
}
|
||||
others := make([]*planpb.Expr, 0, len(parts))
|
||||
groups := make(map[string]*group)
|
||||
indexToExpr := parts
|
||||
for idx, e := range parts {
|
||||
u := e.GetUnaryRangeExpr()
|
||||
if u == nil || u.GetOp() != planpb.OpType_TextMatch || u.GetValue() == nil {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
if len(u.GetExtraValues()) > 0 {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
col := u.GetColumnInfo()
|
||||
if col == nil {
|
||||
others = append(others, e)
|
||||
continue
|
||||
}
|
||||
key := columnKey(col)
|
||||
g, ok := groups[key]
|
||||
if !ok {
|
||||
g = &group{col: col}
|
||||
groups[key] = g
|
||||
}
|
||||
literal := u.GetValue().GetStringVal()
|
||||
g.literals = append(g.literals, literal)
|
||||
g.origIndices = append(g.origIndices, idx)
|
||||
}
|
||||
out := make([]*planpb.Expr, 0, len(parts))
|
||||
out = append(out, others...)
|
||||
for _, g := range groups {
|
||||
if len(g.origIndices) <= 1 {
|
||||
for _, i := range g.origIndices {
|
||||
out = append(out, indexToExpr[i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(g.literals) == 0 {
|
||||
for _, i := range g.origIndices {
|
||||
out = append(out, indexToExpr[i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
merged := strings.Join(g.literals, " ")
|
||||
out = append(out, newTextMatchExpr(g.col, merged))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func newTextMatchExpr(col *planpb.ColumnInfo, literal string) *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_UnaryRangeExpr{
|
||||
UnaryRangeExpr: &planpb.UnaryRangeExpr{
|
||||
ColumnInfo: col,
|
||||
Op: planpb.OpType_TextMatch,
|
||||
Value: &planpb.GenericValue{
|
||||
Val: &planpb.GenericValue_StringVal{
|
||||
StringVal: literal,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
122
internal/parser/planparserv2/rewriter/text_match_test.go
Normal file
122
internal/parser/planparserv2/rewriter/text_match_test.go
Normal file
@ -0,0 +1,122 @@
|
||||
package rewriter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
parser "github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
)
|
||||
|
||||
// collectTextMatchLiterals walks the provided expr and collects all OpType_TextMatch literals (grouped by fieldId) into the returned map.
|
||||
func collectTextMatchLiterals(expr *planpb.Expr) map[int64]string {
|
||||
colToLiteral := map[int64]string{}
|
||||
var collect func(e *planpb.Expr)
|
||||
collect = func(e *planpb.Expr) {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
if ue := e.GetUnaryRangeExpr(); ue != nil && ue.GetOp() == planpb.OpType_TextMatch {
|
||||
col := ue.GetColumnInfo()
|
||||
colToLiteral[col.GetFieldId()] = ue.GetValue().GetStringVal()
|
||||
return
|
||||
}
|
||||
if be := e.GetBinaryExpr(); be != nil && be.GetOp() == planpb.BinaryExpr_LogicalOr {
|
||||
collect(be.GetLeft())
|
||||
collect(be.GetRight())
|
||||
return
|
||||
}
|
||||
}
|
||||
collect(expr)
|
||||
return colToLiteral
|
||||
}
|
||||
|
||||
func TestRewrite_TextMatch_OR_Merge(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A C") or text_match(VarCharField, "B D")`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
ure := expr.GetUnaryRangeExpr()
|
||||
require.NotNil(t, ure, "should merge to single text_match")
|
||||
require.Equal(t, planpb.OpType_TextMatch, ure.GetOp())
|
||||
require.Equal(t, "A C B D", ure.GetValue().GetStringVal())
|
||||
}
|
||||
|
||||
func TestRewrite_TextMatch_OR_DifferentField_NoMerge(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or text_match(StringField, "B")`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
require.NotNil(t, be.GetLeft().GetUnaryRangeExpr())
|
||||
require.NotNil(t, be.GetRight().GetUnaryRangeExpr())
|
||||
}
|
||||
|
||||
func TestRewrite_TextMatch_OR_MultiFields_Merge(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or text_match(StringField, "A") or text_match(VarCharField, "B")`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
// should be text_match(VarCharField, "A B") or text_match(StringField, "A")
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
|
||||
// collect all text_match literals grouped by column id
|
||||
colToLiteral := collectTextMatchLiterals(expr)
|
||||
require.Equal(t, 2, len(colToLiteral))
|
||||
// one of them must be "A B", and the other is "A"
|
||||
hasAB := false
|
||||
hasA := false
|
||||
for _, lit := range colToLiteral {
|
||||
if lit == "A B" {
|
||||
hasAB = true
|
||||
}
|
||||
if lit == "A" {
|
||||
hasA = true
|
||||
}
|
||||
}
|
||||
require.True(t, hasAB)
|
||||
require.True(t, hasA)
|
||||
}
|
||||
|
||||
func TestRewrite_TextMatch_OR_MoreMultiFields_Merge(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A") or (text_match(StringField, "C") or text_match(VarCharField, "B")) or text_match(StringField, "D")`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
// should be text_match(VarCharField, "A B") or text_match(StringField, "C D")
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
|
||||
// collect all text_match literals grouped by column id
|
||||
colToLiteral := collectTextMatchLiterals(expr)
|
||||
require.Equal(t, 2, len(colToLiteral))
|
||||
// one of them must be "A B", and the other is "C D"
|
||||
hasAB := false
|
||||
hasCD := false
|
||||
for _, lit := range colToLiteral {
|
||||
if lit == "A B" {
|
||||
hasAB = true
|
||||
}
|
||||
if lit == "C D" {
|
||||
hasCD = true
|
||||
}
|
||||
}
|
||||
require.True(t, hasAB)
|
||||
require.True(t, hasCD)
|
||||
}
|
||||
|
||||
func TestRewrite_TextMatch_OR_WithOption_NoMerge(t *testing.T) {
|
||||
helper := buildSchemaHelperForRewriteT(t)
|
||||
expr, err := parser.ParseExpr(helper, `text_match(VarCharField, "A", minimum_should_match=1) or text_match(VarCharField, "B")`, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expr)
|
||||
be := expr.GetBinaryExpr()
|
||||
require.NotNil(t, be)
|
||||
require.Equal(t, planpb.BinaryExpr_LogicalOr, be.GetOp())
|
||||
}
|
||||
290
internal/parser/planparserv2/rewriter/util.go
Normal file
290
internal/parser/planparserv2/rewriter/util.go
Normal file
@ -0,0 +1,290 @@
|
||||
package rewriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func columnKey(c *planpb.ColumnInfo) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("%d|%d|%d|%t|%t|%t|",
|
||||
c.GetFieldId(),
|
||||
int32(c.GetDataType()),
|
||||
int32(c.GetElementType()),
|
||||
c.GetIsPrimaryKey(),
|
||||
c.GetIsAutoID(),
|
||||
c.GetIsPartitionKey(),
|
||||
))
|
||||
for _, p := range c.GetNestedPath() {
|
||||
b.WriteString(p)
|
||||
b.WriteByte('|')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// effectiveDataType returns the real scalar type to be used for comparisons.
|
||||
// For JSON/Array columns with a concrete element_type, use element_type;
|
||||
// otherwise fall back to the column data_type.
|
||||
func effectiveDataType(c *planpb.ColumnInfo) schemapb.DataType {
|
||||
if c == nil {
|
||||
return schemapb.DataType_None
|
||||
}
|
||||
dt := c.GetDataType()
|
||||
if dt == schemapb.DataType_JSON || dt == schemapb.DataType_Array {
|
||||
et := c.GetElementType()
|
||||
// Treat 0 (None/Invalid) as not specified; otherwise use element type.
|
||||
if et != schemapb.DataType_None {
|
||||
return et
|
||||
}
|
||||
}
|
||||
return dt
|
||||
}
|
||||
|
||||
func valueCase(v *planpb.GenericValue) string {
|
||||
switch v.GetVal().(type) {
|
||||
case *planpb.GenericValue_BoolVal:
|
||||
return "bool"
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
return "int64"
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
return "float"
|
||||
case *planpb.GenericValue_StringVal:
|
||||
return "string"
|
||||
case *planpb.GenericValue_ArrayVal:
|
||||
return "array"
|
||||
default:
|
||||
return "other"
|
||||
}
|
||||
}
|
||||
|
||||
func valueCaseWithNil(v *planpb.GenericValue) string {
|
||||
if v == nil || v.GetVal() == nil {
|
||||
return "nil"
|
||||
}
|
||||
return valueCase(v)
|
||||
}
|
||||
|
||||
func isNumericCase(k string) bool {
|
||||
return k == "int64" || k == "float"
|
||||
}
|
||||
|
||||
func areComparableCases(a, b string) bool {
|
||||
if a == "nil" || b == "nil" {
|
||||
return false
|
||||
}
|
||||
if isNumericCase(a) && isNumericCase(b) {
|
||||
return true
|
||||
}
|
||||
return a == b && (a == "bool" || a == "string")
|
||||
}
|
||||
|
||||
func isNumericType(dt schemapb.DataType) bool {
|
||||
if typeutil.IsBoolType(dt) || typeutil.IsStringType(dt) || typeutil.IsJSONType(dt) {
|
||||
return false
|
||||
}
|
||||
return typeutil.IsArithmetic(dt)
|
||||
}
|
||||
|
||||
const defaultConvertOrToInNumericLimit = 150
|
||||
|
||||
func shouldMergeToIn(dt schemapb.DataType, count int) bool {
|
||||
if isNumericType(dt) {
|
||||
return count > defaultConvertOrToInNumericLimit
|
||||
}
|
||||
return count > 1
|
||||
}
|
||||
|
||||
func sortTermValues(term *planpb.TermExpr) {
|
||||
if term == nil || len(term.GetValues()) <= 1 {
|
||||
return
|
||||
}
|
||||
term.Values = sortGenericValues(term.Values)
|
||||
}
|
||||
|
||||
// sort and deduplicate a list of generic values.
|
||||
func sortGenericValues(values []*planpb.GenericValue) []*planpb.GenericValue {
|
||||
if len(values) <= 1 {
|
||||
return values
|
||||
}
|
||||
var kind string
|
||||
for _, v := range values {
|
||||
if v == nil || v.GetVal() == nil {
|
||||
continue
|
||||
}
|
||||
kind = valueCase(v)
|
||||
if kind != "" && kind != "other" && kind != "array" {
|
||||
break
|
||||
}
|
||||
}
|
||||
switch kind {
|
||||
case "bool":
|
||||
sort.Slice(values, func(i, j int) bool {
|
||||
return !values[i].GetBoolVal() && values[j].GetBoolVal()
|
||||
})
|
||||
values = lo.UniqBy(values, func(v *planpb.GenericValue) bool { return v.GetBoolVal() })
|
||||
case "int64":
|
||||
sort.Slice(values, func(i, j int) bool {
|
||||
return values[i].GetInt64Val() < values[j].GetInt64Val()
|
||||
})
|
||||
values = lo.UniqBy(values, func(v *planpb.GenericValue) int64 { return v.GetInt64Val() })
|
||||
case "float":
|
||||
sort.Slice(values, func(i, j int) bool {
|
||||
return values[i].GetFloatVal() < values[j].GetFloatVal()
|
||||
})
|
||||
values = lo.UniqBy(values, func(v *planpb.GenericValue) float64 { return v.GetFloatVal() })
|
||||
case "string":
|
||||
sort.Slice(values, func(i, j int) bool {
|
||||
return values[i].GetStringVal() < values[j].GetStringVal()
|
||||
})
|
||||
values = lo.UniqBy(values, func(v *planpb.GenericValue) string { return v.GetStringVal() })
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func newTermExpr(col *planpb.ColumnInfo, values []*planpb.GenericValue) *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_TermExpr{
|
||||
TermExpr: &planpb.TermExpr{
|
||||
ColumnInfo: col,
|
||||
Values: values,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newUnaryRangeExpr(col *planpb.ColumnInfo, op planpb.OpType, val *planpb.GenericValue) *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_UnaryRangeExpr{
|
||||
UnaryRangeExpr: &planpb.UnaryRangeExpr{
|
||||
ColumnInfo: col,
|
||||
Op: op,
|
||||
Value: val,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newBoolConstExpr(v bool) *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_ValueExpr{
|
||||
ValueExpr: &planpb.ValueExpr{
|
||||
Value: &planpb.GenericValue{
|
||||
Val: &planpb.GenericValue_BoolVal{
|
||||
BoolVal: v,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newAlwaysTrueExpr() *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_AlwaysTrueExpr{
|
||||
AlwaysTrueExpr: &planpb.AlwaysTrueExpr{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newAlwaysFalseExpr() *planpb.Expr {
|
||||
return &planpb.Expr{
|
||||
Expr: &planpb.Expr_UnaryExpr{
|
||||
UnaryExpr: &planpb.UnaryExpr{
|
||||
Op: planpb.UnaryExpr_Not,
|
||||
Child: newAlwaysTrueExpr(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// IsAlwaysTrueExpr checks if the expression is an AlwaysTrueExpr
|
||||
func IsAlwaysTrueExpr(e *planpb.Expr) bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
return e.GetAlwaysTrueExpr() != nil
|
||||
}
|
||||
|
||||
func IsAlwaysFalseExpr(e *planpb.Expr) bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
ue := e.GetUnaryExpr()
|
||||
if ue == nil || ue.GetOp() != planpb.UnaryExpr_Not {
|
||||
return false
|
||||
}
|
||||
return IsAlwaysTrueExpr(ue.GetChild())
|
||||
}
|
||||
|
||||
// equalsGeneric compares two GenericValue by content (bool/int/float/string).
|
||||
func equalsGeneric(a, b *planpb.GenericValue) bool {
|
||||
if a.GetVal() == nil || b.GetVal() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch a.GetVal().(type) {
|
||||
case *planpb.GenericValue_BoolVal:
|
||||
if _, ok := b.GetVal().(*planpb.GenericValue_BoolVal); ok {
|
||||
return a.GetBoolVal() == b.GetBoolVal()
|
||||
}
|
||||
case *planpb.GenericValue_Int64Val:
|
||||
if _, ok := b.GetVal().(*planpb.GenericValue_Int64Val); ok {
|
||||
return a.GetInt64Val() == b.GetInt64Val()
|
||||
}
|
||||
case *planpb.GenericValue_FloatVal:
|
||||
if _, ok := b.GetVal().(*planpb.GenericValue_FloatVal); ok {
|
||||
return a.GetFloatVal() == b.GetFloatVal()
|
||||
}
|
||||
case *planpb.GenericValue_StringVal:
|
||||
if _, ok := b.GetVal().(*planpb.GenericValue_StringVal); ok {
|
||||
return a.GetStringVal() == b.GetStringVal()
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func satisfiesLower(dt schemapb.DataType, v, lower *planpb.GenericValue, inclusive bool) bool {
|
||||
c := cmpGeneric(dt, v, lower)
|
||||
if inclusive {
|
||||
return c >= 0
|
||||
}
|
||||
return c > 0
|
||||
}
|
||||
|
||||
func satisfiesUpper(dt schemapb.DataType, v, upper *planpb.GenericValue, inclusive bool) bool {
|
||||
c := cmpGeneric(dt, v, upper)
|
||||
if inclusive {
|
||||
return c <= 0
|
||||
}
|
||||
return c < 0
|
||||
}
|
||||
|
||||
func filterValuesByRange(dt schemapb.DataType, values []*planpb.GenericValue, lower *planpb.GenericValue, lowerInc bool, upper *planpb.GenericValue, upperInc bool) []*planpb.GenericValue {
|
||||
out := make([]*planpb.GenericValue, 0, len(values))
|
||||
for _, v := range values {
|
||||
pass := true
|
||||
if lower != nil && !satisfiesLower(dt, v, lower, lowerInc) {
|
||||
pass = false
|
||||
}
|
||||
if pass && upper != nil && !satisfiesUpper(dt, v, upper, upperInc) {
|
||||
pass = false
|
||||
}
|
||||
if pass {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func unionValues(valuesA, valuesB []*planpb.GenericValue) []*planpb.GenericValue {
|
||||
all := append([]*planpb.GenericValue{}, valuesA...)
|
||||
all = append(all, valuesB...)
|
||||
return sortGenericValues(all)
|
||||
}
|
||||
@ -4709,7 +4709,7 @@ func (s *MaterializedViewTestSuite) TestMvEnabledPartitionKeyOnVarCharWithIsolat
|
||||
schema := ConstructCollectionSchemaWithPartitionKey(s.colName, s.fieldName2Types, testInt64Field, testVarCharField, false)
|
||||
schemaInfo := newSchemaInfo(schema)
|
||||
s.mockMetaCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schemaInfo, nil)
|
||||
s.ErrorContains(task.PreExecute(s.ctx), "partition key isolation does not support OR")
|
||||
s.ErrorContains(task.PreExecute(s.ctx), "partition key isolation does not support IN")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -80,9 +80,9 @@ func TestParsePartitionKeys(t *testing.T) {
|
||||
{
|
||||
name: "binary_expr_and with partition key in range",
|
||||
expr: "partition_key_field in [7, 8] && partition_key_field > 9",
|
||||
expected: 2,
|
||||
validPartitionKeys: []int64{7, 8},
|
||||
invalidPartitionKeys: []int64{9},
|
||||
expected: 0,
|
||||
validPartitionKeys: []int64{},
|
||||
invalidPartitionKeys: []int64{},
|
||||
},
|
||||
{
|
||||
name: "binary_expr_and with partition key in range2",
|
||||
@ -295,7 +295,7 @@ func TestValidatePartitionKeyIsolation(t *testing.T) {
|
||||
{
|
||||
name: "partition key isolation equal AND with same field term",
|
||||
expr: "key_field == 10 && key_field in [10]",
|
||||
expectedErrorString: "partition key isolation does not support IN",
|
||||
expectedErrorString: "",
|
||||
},
|
||||
{
|
||||
name: "partition key isolation equal OR with same field equal",
|
||||
|
||||
@ -589,33 +589,3 @@ func TestDeleteInvalidExpr(t *testing.T) {
|
||||
common.CheckErr(t, err, _invalidExpr.ErrNil, _invalidExpr.ErrMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// test delete with duplicated data ids
|
||||
func TestDeleteDuplicatedPks(t *testing.T) {
|
||||
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
|
||||
mc := hp.CreateDefaultMilvusClient(ctx, t)
|
||||
|
||||
// create collection and a partition
|
||||
cp := hp.NewCreateCollectionParams(hp.Int64Vec)
|
||||
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption().TWithIsDynamic(true), hp.TNewSchemaOption())
|
||||
|
||||
// insert [0, 3000) into default
|
||||
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithMaxCapacity(common.TestCapacity))
|
||||
prepare.FlushData(ctx, t, mc, schema.CollectionName)
|
||||
|
||||
// index and load
|
||||
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
|
||||
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
|
||||
|
||||
// delete
|
||||
deleteIDs := []int64{0, 0, 0, 0, 0}
|
||||
delRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithInt64IDs(common.DefaultInt64FieldName, deleteIDs))
|
||||
common.CheckErr(t, err, true)
|
||||
require.Equal(t, 5, int(delRes.DeleteCount))
|
||||
|
||||
// query, verify delete success
|
||||
expr := fmt.Sprintf("%s >= 0 ", common.DefaultInt64FieldName)
|
||||
resQuery, errQuery := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, errQuery, true)
|
||||
require.Equal(t, common.DefaultNb-1, resQuery.ResultCount)
|
||||
}
|
||||
|
||||
@ -565,22 +565,6 @@ class TestDeleteOperation(TestcaseBase):
|
||||
# since the search requests arrived query nodes earlier than query nodes consume the delete requests.
|
||||
assert len(inter) == 0
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_delete_expr_repeated_values(self):
|
||||
"""
|
||||
target: test delete with repeated values
|
||||
method: 1.insert data with unique primary keys
|
||||
2.delete with repeated values: 'id in [0, 0]'
|
||||
expected: delete one entity
|
||||
"""
|
||||
# init collection with nb default data
|
||||
collection_w = self.init_collection_general(prefix, nb=tmp_nb, insert_data=True)[0]
|
||||
expr = f'{ct.default_int64_field_name} in {[0, 0, 0]}'
|
||||
del_res, _ = collection_w.delete(expr)
|
||||
assert del_res.delete_count == 3
|
||||
collection_w.num_entities
|
||||
collection_w.query(expr, check_task=CheckTasks.check_query_empty)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_delete_duplicate_primary_keys(self):
|
||||
"""
|
||||
@ -1433,7 +1417,7 @@ class TestDeleteString(TestcaseBase):
|
||||
self.init_collection_general(prefix, nb=tmp_nb, insert_data=True, primary_field=ct.default_string_field_name)[0]
|
||||
expr = f'{ct.default_string_field_name} in ["0", "0", "0"]'
|
||||
del_res, _ = collection_w.delete(expr)
|
||||
assert del_res.delete_count == 3
|
||||
assert del_res.delete_count == 1
|
||||
collection_w.num_entities
|
||||
collection_w.query(expr, check_task=CheckTasks.check_query_empty)
|
||||
|
||||
@ -1939,7 +1923,7 @@ class TestDeleteString(TestcaseBase):
|
||||
# delete
|
||||
string_expr = "varchar in [\"\", \"\"]"
|
||||
del_res, _ = collection_w.delete(string_expr)
|
||||
assert del_res.delete_count == 2
|
||||
assert del_res.delete_count == 1
|
||||
|
||||
# load and query with id
|
||||
collection_w.load()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user