diff --git a/internal/core/src/common/Consts.h b/internal/core/src/common/Consts.h index f51bd01c93..794fb616cd 100644 --- a/internal/core/src/common/Consts.h +++ b/internal/core/src/common/Consts.h @@ -81,3 +81,4 @@ const size_t MARISA_NULL_KEY_ID = -1; const std::string JSON_CAST_TYPE = "json_cast_type"; const std::string JSON_PATH = "json_path"; const bool DEFAULT_OPTIMIZE_EXPR_ENABLED = true; +const int64_t DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT = 150; diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 9ee8ddc953..baa5b9a138 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -233,6 +233,7 @@ IsPrimaryKeyDataType(DataType data_type) { inline bool IsIntegerDataType(DataType data_type) { switch (data_type) { + case DataType::BOOL: case DataType::INT8: case DataType::INT16: case DataType::INT32: diff --git a/internal/core/src/exec/expression/ConjunctExpr.h b/internal/core/src/exec/expression/ConjunctExpr.h index 31afa80483..0d266bf17f 100644 --- a/internal/core/src/exec/expression/ConjunctExpr.h +++ b/internal/core/src/exec/expression/ConjunctExpr.h @@ -158,6 +158,16 @@ class PhyConjunctFilterExpr : public Expr { context.clear_bitmap_input(); } + bool + IsAnd() { + return is_and_; + } + + bool + IsOr() { + return !is_and_; + } + private: int64_t UpdateResult(ColumnVectorPtr& input_result, diff --git a/internal/core/src/exec/expression/Expr.cpp b/internal/core/src/exec/expression/Expr.cpp index 28fa0839ba..4cd2095b72 100644 --- a/internal/core/src/exec/expression/Expr.cpp +++ b/internal/core/src/exec/expression/Expr.cpp @@ -454,6 +454,85 @@ ReorderConjunctExpr(std::shared_ptr& expr, expr->Reorder(reorder); } +inline std::shared_ptr +ConvertMultiOrToInExpr(std::vector>& exprs, + std::vector indices, + ExecContext* context) { + std::vector values; + auto type = proto::plan::GenericValue::ValCase::VAL_NOT_SET; + for (auto& i : indices) { + auto expr = std::static_pointer_cast(exprs[i]) + ->GetLogicalExpr(); + if (type == proto::plan::GenericValue::ValCase::VAL_NOT_SET) { + type = expr->val_.val_case(); + } + if (type != expr->val_.val_case()) { + return nullptr; + } + values.push_back(expr->val_); + } + auto logical_expr = std::make_shared( + exprs[indices[0]]->GetColumnInfo().value(), values); + auto query_context = context->get_query_context(); + return std::make_shared( + std::vector>{}, + logical_expr, + "PhyTermFilterExpr", + query_context->get_segment(), + query_context->get_active_count(), + query_context->get_query_timestamp(), + query_context->query_config()->get_expr_batch_size()); +} + +inline void +RewriteConjunctExpr(std::shared_ptr& expr, + ExecContext* context) { + if (expr->IsOr()) { + auto& inputs = expr->GetInputsRef(); + std::map> expr_indices; + for (size_t i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; + if (input->name() == "PhyUnaryRangeFilterExpr") { + auto phy_expr = + std::static_pointer_cast(input); + if (phy_expr->GetOpType() == proto::plan::OpType::Equal) { + auto column = phy_expr->GetColumnInfo().value(); + if (expr_indices.find(column) != expr_indices.end()) { + expr_indices[column].push_back(i); + } else { + expr_indices[column] = {i}; + } + } + } + + if (input->name() == "PhyConjunctFilterExpr") { + auto expr = std::static_pointer_cast< + milvus::exec::PhyConjunctFilterExpr>(input); + RewriteConjunctExpr(expr, context); + } + } + + for (auto& [column, indices] : expr_indices) { + // For numeric type, if or column greater than 150, then using in expr replace. + // For other type, all convert to in expr. + if ((IsNumericDataType(column.data_type_) && + indices.size() > DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT) || + (!IsNumericDataType(column.data_type_) && indices.size() > 1)) { + auto new_expr = + ConvertMultiOrToInExpr(inputs, indices, context); + if (new_expr) { + inputs[indices[0]] = new_expr; + for (size_t j = 1; j < indices.size(); j++) { + inputs[indices[j]] = nullptr; + } + } + } + } + inputs.erase(std::remove(inputs.begin(), inputs.end(), nullptr), + inputs.end()); + } +} + inline void OptimizeCompiledExprs(ExecContext* context, const std::vector& exprs) { std::chrono::high_resolution_clock::time_point start = @@ -463,6 +542,7 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector& exprs) { LOG_DEBUG("before reoder filter expression: {}", expr->ToString()); auto conjunct_expr = std::static_pointer_cast(expr); + RewriteConjunctExpr(conjunct_expr, context); bool has_heavy_operation = false; ReorderConjunctExpr(conjunct_expr, context, has_heavy_operation); LOG_DEBUG("after reorder filter expression: {}", expr->ToString()); diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 0068d17385..3c138d2902 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -103,7 +103,7 @@ class Expr { PanicInfo(ErrorCode::NotImplemented, "not implemented"); } - const std::vector>& + std::vector>& GetInputsRef() { return inputs_; } diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index 318724864d..803cd84c48 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -379,6 +379,21 @@ class PhyUnaryRangeFilterExpr : public SegmentExpr { return expr_; } + proto::plan::OpType + GetOpType() { + return expr_->op_type_; + } + + FieldId + GetFieldId() { + return expr_->column_.field_id_; + } + + DataType + GetFieldType() { + return expr_->column_.data_type_; + } + private: template VectorPtr diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h index fed5a20b3b..3c78fe4be7 100644 --- a/internal/core/src/expr/ITypeExpr.h +++ b/internal/core/src/expr/ITypeExpr.h @@ -161,6 +161,19 @@ struct ColumnInfo { return true; } + bool + operator<(const ColumnInfo& other) const { + return std::tie(field_id_, + data_type_, + element_type_, + nested_path_, + nullable_) < std::tie(other.field_id_, + other.data_type_, + other.element_type_, + other.nested_path_, + other.nullable_); + } + std::string ToString() const { return fmt::format( @@ -353,7 +366,8 @@ class UnaryRangeFilterExpr : public ITypeFilterExpr { const ColumnInfo& column, proto::plan::OpType op_type, const proto::plan::GenericValue& val, - const std::vector& extra_values) + const std::vector& extra_values = + std::vector{}) : ITypeFilterExpr(), column_(column), op_type_(op_type), diff --git a/internal/core/unittest/test_exec.cpp b/internal/core/unittest/test_exec.cpp index 6d8539e863..c373669a8f 100644 --- a/internal/core/unittest/test_exec.cpp +++ b/internal/core/unittest/test_exec.cpp @@ -702,3 +702,301 @@ TEST_P(TaskTest, Test_reorder) { OPTIMIZE_EXPR_ENABLED = true; } } + +TEST_P(TaskTest, Test_MultiInConvert) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + using namespace milvus::exec; + + { + // expr: string2 == '111' or string2 == '222' or string2 == "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val1, + std::vector{}); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val2, + std::vector{}); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_string_val("333"); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val3, + std::vector{}); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 1); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); + } + + { + // expr: string2 == '111' or string2 == '222' or (int64 > 10 && int64 < 100) or string2 == "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val1, + std::vector{}); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val2, + std::vector{}); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + + proto::plan::GenericValue val3; + val3.set_int64_val(10); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::GreaterThan, + val3, + std::vector{}); + proto::plan::GenericValue val4; + val4.set_int64_val(100); + auto expr5 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::LessThan, + val4, + std::vector{}); + auto expr6 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr4, expr5); + + auto expr7 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr6, expr3); + + proto::plan::GenericValue val5; + val5.set_string_val("333"); + auto expr8 = std::make_shared( + expr::ColumnInfo(field_map_["string2"], DataType::VARCHAR), + proto::plan::OpType::Equal, + val5, + std::vector{}); + auto expr9 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr7, expr8); + + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr9}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 2); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyConjunctFilterExpr"); + EXPECT_STREQ(inputs[1]->name().c_str(), "PhyTermFilterExpr"); + } + { + // expr: json['a'] == "111" or json['a'] == "222" or json['3'] = "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val1, + std::vector{}); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val2, + std::vector{}); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_string_val("333"); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val3, + std::vector{}); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 1); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); + } + + { + // expr: json['a'] == "111" or json['b'] == "222" or json['a'] == "333" + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val1, + std::vector{}); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'b'}), + proto::plan::OpType::Equal, + val2, + std::vector{}); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_string_val("333"); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val3, + std::vector{}); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 2); + EXPECT_STREQ(inputs[0]->name().c_str(), "PhyTermFilterExpr"); + EXPECT_STREQ(inputs[1]->name().c_str(), "PhyUnaryRangeFilterExpr"); + } + + { + // expr: json['a'] == "111" or json['b'] == "222" or json['a'] == 1 + proto::plan::GenericValue val1; + val1.set_string_val("111"); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val1, + std::vector{}); + proto::plan::GenericValue val2; + val2.set_string_val("222"); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'b'}), + proto::plan::OpType::Equal, + val2, + std::vector{}); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_int64_val(1); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["json"], + DataType::JSON, + std::vector{'a'}), + proto::plan::OpType::Equal, + val3, + std::vector{}); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 3); + } + + { + // expr: int1 == 11 or int1 == 22 or int3 == 33 + proto::plan::GenericValue val1; + val1.set_int64_val(11); + auto expr1 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::Equal, + val1, + std::vector{}); + proto::plan::GenericValue val2; + val2.set_int64_val(222); + auto expr2 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::Equal, + val2, + std::vector{}); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_int64_val(1); + auto expr4 = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::Equal, + val3, + std::vector{}); + auto expr5 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment_.get(), 100000, MAX_TIMESTAMP); + ExecContext context(query_context.get()); + auto exprs = + milvus::exec::CompileExpressions({expr5}, &context, {}, false); + EXPECT_EQ(exprs.size(), 1); + EXPECT_STREQ(exprs[0]->name().c_str(), "PhyConjunctFilterExpr"); + auto phy_expr = + std::static_pointer_cast( + exprs[0]); + auto inputs = phy_expr->GetInputsRef(); + EXPECT_EQ(inputs.size(), 3); + } +} diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index fa05a34616..e2c794c416 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -3852,6 +3852,91 @@ TEST_P(ExprTest, TestCompareExprNullable2) { std::cout << "end compare test" << std::endl; } +TEST_P(ExprTest, TestMutiInConvert) { + auto schema = std::make_shared(); + auto pk = schema->AddDebugField("id", DataType::INT64); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto json_fid = schema->AddDebugField("json", DataType::JSON, false); + auto str_array_fid = + schema->AddDebugField("str_array", DataType::ARRAY, DataType::VARCHAR); + schema->set_primary_field_id(pk); + + auto seg = CreateSealedSegment(schema); + size_t N = 1000; + auto raw_data = DataGen(schema, N); + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + auto build_expr = [&](int index) -> expr::TypedExprPtr { + switch (index) { + case 0: { + proto::plan::GenericValue val1; + val1.set_int64_val(100); + auto expr1 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::Equal, + val1); + proto::plan::GenericValue val2; + val2.set_int64_val(200); + auto expr2 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::Equal, + val2); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + proto::plan::GenericValue val3; + val3.set_int64_val(300); + auto expr4 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::Equal, + val3); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr4); + }; + default: + PanicInfo(ErrorCode::UnexpectedError, "not implement"); + } + }; + + auto expr = build_expr(0); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto final1 = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + OPTIMIZE_EXPR_ENABLED = false; + auto final2 = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + EXPECT_EQ(final1.size(), final2.size()); + for (auto i = 0; i < final1.size(); i++) { + EXPECT_EQ(final1[i], final2[i]); + } +} + TEST(Expr, TestExprPerformance) { GTEST_SKIP() << "Skip performance test, open it when test performance"; auto schema = std::make_shared();