// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "Expr.h" #include "exec/expression/AlwaysTrueExpr.h" #include "exec/expression/BinaryArithOpEvalRangeExpr.h" #include "exec/expression/BinaryRangeExpr.h" #include "exec/expression/CompareExpr.h" #include "exec/expression/ConjunctExpr.h" #include "exec/expression/ExistsExpr.h" #include "exec/expression/JsonContainsExpr.h" #include "exec/expression/LogicalBinaryExpr.h" #include "exec/expression/LogicalUnaryExpr.h" #include "exec/expression/TermExpr.h" #include "exec/expression/UnaryExpr.h" namespace milvus { namespace exec { void ExprSet::Eval(int32_t begin, int32_t end, bool initialize, EvalCtx& context, std::vector& results) { results.resize(exprs_.size()); for (size_t i = begin; i < end; ++i) { exprs_[i]->Eval(context, results[i]); } } std::vector CompileExpressions(const std::vector& sources, ExecContext* context, const std::unordered_set& flatten_candidate, bool enable_constant_folding) { std::vector> exprs; exprs.reserve(sources.size()); for (auto& source : sources) { exprs.emplace_back(CompileExpression(source, context->get_query_context(), flatten_candidate, enable_constant_folding)); } return exprs; } static std::optional ShouldFlatten(const expr::TypedExprPtr& expr, const std::unordered_set& flat_candidates = {}) { if (auto call = std::dynamic_pointer_cast(expr)) { if (call->op_type_ == expr::LogicalBinaryExpr::OpType::And || call->op_type_ == expr::LogicalBinaryExpr::OpType::Or) { return call->name(); } } return std::nullopt; } static bool IsCall(const expr::TypedExprPtr& expr, const std::string& name) { if (auto call = std::dynamic_pointer_cast(expr)) { return call->name() == name; } return false; } static bool AllInputTypeEqual(const expr::TypedExprPtr& expr) { const auto& inputs = expr->inputs(); for (int i = 1; i < inputs.size(); i++) { if (inputs[0]->type() != inputs[i]->type()) { return false; } } return true; } static void FlattenInput(const expr::TypedExprPtr& input, const std::string& flatten_call, std::vector& flat) { if (IsCall(input, flatten_call) && AllInputTypeEqual(input)) { for (auto& child : input->inputs()) { FlattenInput(child, flatten_call, flat); } } else { flat.emplace_back(input); } } std::vector CompileInputs(const expr::TypedExprPtr& expr, QueryContext* context, const std::unordered_set& flatten_cadidates) { std::vector compiled_inputs; auto flatten = ShouldFlatten(expr); for (auto& input : expr->inputs()) { if (dynamic_cast(input.get())) { AssertInfo( dynamic_cast(expr.get()), "An InputReference can only occur under a FieldReference"); } else { if (flatten.has_value()) { std::vector flat_exprs; FlattenInput(input, flatten.value(), flat_exprs); for (auto& input : flat_exprs) { compiled_inputs.push_back(CompileExpression( input, context, flatten_cadidates, false)); } } else { compiled_inputs.push_back(CompileExpression( input, context, flatten_cadidates, false)); } } } return compiled_inputs; } ExprPtr CompileExpression(const expr::TypedExprPtr& expr, QueryContext* context, const std::unordered_set& flatten_candidates, bool enable_constant_folding) { ExprPtr result; auto result_type = expr->type(); auto compiled_inputs = CompileInputs(expr, context, flatten_candidates); auto GetTypes = [](const std::vector& exprs) { std::vector types; for (auto& expr : exprs) { types.push_back(expr->type()); } return types; }; auto input_types = GetTypes(compiled_inputs); if (auto call = dynamic_cast(expr.get())) { // TODO: support function register and search mode } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::UnaryRangeFilterExpr>(expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyUnaryRangeFilterExpr", context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::LogicalUnaryExpr>(expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyLogicalUnaryExpr"); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::TermFilterExpr>(expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyTermFilterExpr", context->get_segment(), context->get_active_count(), context->get_query_timestamp(), context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::LogicalBinaryExpr>(expr)) { if (casted_expr->op_type_ == milvus::expr::LogicalBinaryExpr::OpType::And || casted_expr->op_type_ == milvus::expr::LogicalBinaryExpr::OpType::Or) { result = std::make_shared( std::move(compiled_inputs), casted_expr->op_type_ == milvus::expr::LogicalBinaryExpr::OpType::And); } else { result = std::make_shared( compiled_inputs, casted_expr, "PhyLogicalBinaryExpr"); } } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::BinaryRangeFilterExpr>(expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyBinaryRangeFilterExpr", context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::AlwaysTrueExpr>(expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyAlwaysTrueExpr", context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::BinaryArithOpEvalRangeExpr>(expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyBinaryArithOpEvalRangeExpr", context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast( expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyCompareFilterExpr", context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast( expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyExistsFilterExpr", context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::JsonContainsExpr>(expr)) { result = std::make_shared( compiled_inputs, casted_expr, "PhyJsonContainsFilterExpr", context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); } return result; } } // namespace exec } // namespace milvus