From deed5b5df480597573ce08fd3766bfe56ca4a800 Mon Sep 17 00:00:00 2001 From: zhagnlu <1542303831@qq.com> Date: Thu, 27 Mar 2025 11:08:20 +0800 Subject: [PATCH] enhance:change multi or expr to in expr (#40751) pr: #40757 Signed-off-by: luzhang Co-authored-by: luzhang --- internal/core/src/common/Consts.h | 1 + internal/core/src/common/Types.h | 1 + .../core/src/exec/expression/ConjunctExpr.h | 14 +- internal/core/src/exec/expression/Expr.cpp | 82 ++++++ internal/core/src/exec/expression/Expr.h | 2 +- internal/core/src/exec/expression/UnaryExpr.h | 15 + internal/core/src/expr/ITypeExpr.h | 13 + internal/core/unittest/test_exec.cpp | 278 ++++++++++++++++++ internal/core/unittest/test_expr.cpp | 85 ++++++ 9 files changed, 487 insertions(+), 4 deletions(-) diff --git a/internal/core/src/common/Consts.h b/internal/core/src/common/Consts.h index 83864cece4..3548e2d3ce 100644 --- a/internal/core/src/common/Consts.h +++ b/internal/core/src/common/Consts.h @@ -86,3 +86,4 @@ const std::string JSON_PATH = "json_path"; const bool DEFAULT_OPTIMIZE_EXPR_ENABLED = true; const bool DEFAULT_GROWING_JSON_KEY_STATS_ENABLED = false; const int64_t DEFAULT_JSON_KEY_STATS_COMMIT_INTERVAL = 200; +const int64_t DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT = 150; diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 12fd9f0b95..7c7bad44cb 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -225,6 +225,7 @@ IsPrimaryKeyDataType(DataType data_type) { inline bool IsIntegerDataType(DataType data_type) { switch (data_type) { + case DataType::BOOL: case DataType::INT8: case DataType::INT16: case DataType::INT32: diff --git a/internal/core/src/exec/expression/ConjunctExpr.h b/internal/core/src/exec/expression/ConjunctExpr.h index 3c655c3cdf..0d266bf17f 100644 --- a/internal/core/src/exec/expression/ConjunctExpr.h +++ b/internal/core/src/exec/expression/ConjunctExpr.h @@ -66,9 +66,7 @@ struct ConjunctElementFunc { class PhyConjunctFilterExpr : public Expr { public: PhyConjunctFilterExpr(std::vector&& inputs, bool is_and) - : Expr(DataType::BOOL, - std::move(inputs), - is_and ? "PhyConjunctFilterExpr" : "PhyConjunctFilterExpr"), + : Expr(DataType::BOOL, std::move(inputs), "PhyConjunctFilterExpr"), is_and_(is_and) { std::vector input_types; input_types.reserve(inputs_.size()); @@ -160,6 +158,16 @@ class PhyConjunctFilterExpr : public Expr { context.clear_bitmap_input(); } + bool + IsAnd() { + return is_and_; + } + + bool + IsOr() { + return !is_and_; + } + private: int64_t UpdateResult(ColumnVectorPtr& input_result, diff --git a/internal/core/src/exec/expression/Expr.cpp b/internal/core/src/exec/expression/Expr.cpp index fb7ef8d1eb..4e78952f83 100644 --- a/internal/core/src/exec/expression/Expr.cpp +++ b/internal/core/src/exec/expression/Expr.cpp @@ -460,6 +460,87 @@ ReorderConjunctExpr(std::shared_ptr& expr, expr->Reorder(reorder); } +inline std::shared_ptr +ConvertMultiOrToInExpr(std::vector>& exprs, + std::vector indices, + ExecContext* context) { + std::vector values; + bool get_value_type = false; + auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET; + for (auto& i : indices) { + auto expr = std::static_pointer_cast(exprs[i]) + ->GetLogicalExpr(); + if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) { + type = expr->val_.val_case(); + } + if (type != expr->val_.val_case()) { + return nullptr; + } + values.push_back(expr->val_); + } + auto logical_expr = std::make_shared( + exprs[indices[0]]->GetColumnInfo().value(), values); + auto query_context = context->get_query_context(); + return std::make_shared( + std::vector>{}, + logical_expr, + "PhyTermFilterExpr", + query_context->get_segment(), + query_context->get_active_count(), + query_context->get_query_timestamp(), + query_context->query_config()->get_expr_batch_size(), + query_context->get_consistency_level()); +} + +inline void +RewriteConjunctExpr(std::shared_ptr& expr, + ExecContext* context) { + if (expr->IsOr()) { + auto& inputs = expr->GetInputsRef(); + std::map> expr_indices; + for (size_t i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; + if (input->name() == "PhyUnaryRangeFilterExpr") { + auto phy_expr = + std::static_pointer_cast(input); + if (phy_expr->GetOpType() == proto::plan::OpType::Equal) { + auto column = phy_expr->GetColumnInfo().value(); + if (expr_indices.find(column) != expr_indices.end()) { + expr_indices[column].push_back(i); + } else { + expr_indices[column] = {i}; + } + } + } + + if (input->name() == "PhyConjunctFilterExpr") { + auto expr = std::static_pointer_cast< + milvus::exec::PhyConjunctFilterExpr>(input); + RewriteConjunctExpr(expr, context); + } + } + + for (auto& [column, indices] : expr_indices) { + // For numeric type, if or column greater than 150, then using in expr replace. + // For other type, all convert to in expr. + if ((IsNumericDataType(column.data_type_) && + indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) || + (!IsNumericDataType(column.data_type_) && indices.size() > 1)) { + auto new_expr = + ConvertMultiOrToInExpr(inputs, indices, context); + if (new_expr) { + inputs[indices[0]] = new_expr; + for (size_t j = 1; j < indices.size(); j++) { + inputs[indices[j]] = nullptr; + } + } + } + } + inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr), + inputs.end()); + } +} + inline void OptimizeCompiledExprs(ExecContext* context, const std::vector& exprs) { std::chrono::high_resolution_clock::time_point start = @@ -469,6 +550,7 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector& exprs) { LOG_DEBUG("before reoder filter expression: {}", expr->ToString()); auto conjunct_expr = std::static_pointer_cast(expr); + RewriteConjunctExpr(conjunct_expr, context); bool has_heavy_operation = false; ReorderConjunctExpr(conjunct_expr, context, has_heavy_operation); LOG_DEBUG("after reorder filter expression: {}", expr->ToString()); diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 3d8c848a5a..29d430cdec 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -105,7 +105,7 @@ class Expr { PanicInfo(ErrorCode::NotImplemented, "not implemented"); } - const std::vector>& + std::vector>& GetInputsRef() { return inputs_; } diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index 7b61807301..4c524e77cc 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -380,6 +380,21 @@ class PhyUnaryRangeFilterExpr : public SegmentExpr { return expr_; } + proto::plan::OpType + GetOpType() { + return expr_->op_type_; + } + + FieldId + GetFieldId() { + return expr_->column_.field_id_; + } + + DataType + GetFieldType() { + return expr_->column_.data_type_; + } + private: template VectorPtr diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h index 254975e3ff..9ac06a8955 100644 --- a/internal/core/src/expr/ITypeExpr.h +++ b/internal/core/src/expr/ITypeExpr.h @@ -161,6 +161,19 @@ struct ColumnInfo { return true; } + bool + operator<(const ColumnInfo& other) const { + return std::tie(field_id_, + data_type_, + element_type_, + nested_path_, + nullable_) < std::tie(other.field_id_, + other.data_type_, + other.element_type_, + other.nested_path_, + other.nullable_); + } + std::string ToString() const { return fmt::format( diff --git a/internal/core/unittest/test_exec.cpp b/internal/core/unittest/test_exec.cpp index a7a198d057..5eb587712f 100644 --- a/internal/core/unittest/test_exec.cpp +++ b/internal/core/unittest/test_exec.cpp @@ -675,3 +675,281 @@ TEST_P(TaskTest, Test_reorder) { OPTIMIZE_EXPR_ENABLED = true; } } + +TEST_P(TaskTest, Test_MultiInConvert) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + using namespace milvus::exec; + + { + // expr: string2 == '111' or string2 == '222' or string2 == "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_string_val("333"); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val3); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 1); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); + } + + { + // expr: string2 == '111' or string2 == '222' or (int64 > 10 && int64 < 100) or string2 == "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + + proto::plan::GenericValue val3; + val3.set_int64_val(10); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::GreaterThan, + val3); + proto::plan::GenericValue val4; + val4.set_int64_val(100); + auto expr5 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::LessThan, + val4); + auto expr6 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr4, expr5); + + auto expr7 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr6, expr3); + + proto::plan::GenericValue val5; + val5.set_string_val("333"); + auto expr8 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val5); + auto expr9 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr7, expr8); + + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr9}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 2); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr"); + EXPECT_STREQ(inputs[1]->name().c_str(), "PhyTermFilterExpr"); + } + { + // expr: json['a'] == "111" or json['a'] == "222" or json['3'] = "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_string_val("333"); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val3); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 1); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); + } + + { + // expr: json['a'] == "111" or json['b'] == "222" or json['a'] == "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'b'}), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_string_val("333"); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val3); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 2); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); + EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr"); + } + + { + // expr: json['a'] == "111" or json['b'] == "222" or json['a'] == 1 + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'b'}), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_int64_val(1); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val3); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 3); + } + + { + // expr: int1 == 11 or int1 == 22 or int3 == 33 + proto::plan::GenericValue val1; + val1.set_int64_val(11); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_int64_val(222); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_int64_val(1); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::Equal, + val3); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 3); + } +} diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 11a58b8577..a6de6da8a2 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -4122,6 +4122,91 @@ TEST_P(ExprTest, TestCompareExprNullable2) { std::cout << "end compare test" << std::endl; } +TEST_P(ExprTest, TestMutiInConvert) { + auto schema = std::make_shared(); + auto pk = schema->AddDebugField("id", DataType::INT64); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto json_fid = schema->AddDebugField("json", DataType::JSON, false); + auto str_array_fid = + schema->AddDebugField("str_array", DataType::ARRAY, DataType::VARCHAR); + schema->set_primary_field_id(pk); + + auto seg = CreateSealedSegment(schema); + size_t N = 1000; + auto raw_data = DataGen(schema, N); + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + auto build_expr = [&](int index) -> expr::TypedExprPtr { + switch (index) { + case 0: { + proto::plan::GenericValue val1; + val1.set_int64_val(100); + auto expr1 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_int64_val(200); + auto expr2 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_int64_val(300); + auto expr4 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::Equal, + val3); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + }; + default: + PanicInfo(ErrorCode::UnexpectedError, "not implement"); + } + }; + + auto expr = build_expr(0); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto final1 = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + OPTIMIZE_EXPR_ENABLED = false; + auto final2 = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + EXPECT_EQ(final1.size(), final2.size()); + for (auto i = 0; i < final1.size(); i++) { + EXPECT_EQ(final1[i], final2[i]); + } +} + TEST(Expr, TestExprPerformance) { GTEST_SKIP() << "Skip performance test, open it when test performance"; auto schema = std::make_shared();