mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
relate: https://github.com/milvus-io/milvus/issues/43867 Support boost function score, multiply by the weight if match filter. Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
610 lines
24 KiB
C++
610 lines
24 KiB
C++
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
|
|
|
#include "PlanProto.h"
|
|
|
|
#include <google/protobuf/text_format.h>
|
|
|
|
#include <cstdint>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "common/VectorTrait.h"
|
|
#include "common/EasyAssert.h"
|
|
#include "exec/expression/function/FunctionFactory.h"
|
|
#include "pb/plan.pb.h"
|
|
#include "query/Utils.h"
|
|
#include "knowhere/comp/materialized_view.h"
|
|
#include "plan/PlanNode.h"
|
|
#include "rescores/Scorer.h"
|
|
|
|
namespace milvus::query {
|
|
namespace planpb = milvus::proto::plan;
|
|
|
|
std::unique_ptr<VectorPlanNode>
|
|
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|
// TODO: add more buffs
|
|
Assert(plan_node_proto.has_vector_anns());
|
|
auto& anns_proto = plan_node_proto.vector_anns();
|
|
|
|
auto expr_parser = [&]() -> plan::PlanNodePtr {
|
|
auto expr = ParseExprs(anns_proto.predicates());
|
|
return std::make_shared<plan::FilterBitsNode>(
|
|
milvus::plan::GetNextPlanNodeId(), expr);
|
|
};
|
|
|
|
auto search_info_parser = [&]() -> SearchInfo {
|
|
SearchInfo search_info;
|
|
auto& query_info_proto = anns_proto.query_info();
|
|
auto field_id = FieldId(anns_proto.field_id());
|
|
search_info.field_id_ = field_id;
|
|
|
|
search_info.metric_type_ = query_info_proto.metric_type();
|
|
search_info.topk_ = query_info_proto.topk();
|
|
search_info.round_decimal_ = query_info_proto.round_decimal();
|
|
search_info.search_params_ =
|
|
nlohmann::json::parse(query_info_proto.search_params());
|
|
search_info.materialized_view_involved =
|
|
query_info_proto.materialized_view_involved();
|
|
// currently, iterative filter does not support range search
|
|
if (!search_info.search_params_.contains(RADIUS)) {
|
|
if (query_info_proto.hints() != "") {
|
|
if (query_info_proto.hints() == "disable") {
|
|
search_info.iterative_filter_execution = false;
|
|
} else if (query_info_proto.hints() == ITERATIVE_FILTER) {
|
|
search_info.iterative_filter_execution = true;
|
|
} else {
|
|
// check if hints is valid
|
|
ThrowInfo(ConfigInvalid,
|
|
"hints: {} not supported",
|
|
query_info_proto.hints());
|
|
}
|
|
} else if (search_info.search_params_.contains(HINTS)) {
|
|
if (search_info.search_params_[HINTS] == ITERATIVE_FILTER) {
|
|
search_info.iterative_filter_execution = true;
|
|
} else {
|
|
// check if hints is valid
|
|
ThrowInfo(ConfigInvalid,
|
|
"hints: {} not supported",
|
|
search_info.search_params_[HINTS]);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (query_info_proto.bm25_avgdl() > 0) {
|
|
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
|
|
query_info_proto.bm25_avgdl();
|
|
}
|
|
|
|
if (query_info_proto.group_by_field_id() > 0) {
|
|
auto group_by_field_id =
|
|
FieldId(query_info_proto.group_by_field_id());
|
|
search_info.group_by_field_id_ = group_by_field_id;
|
|
search_info.group_size_ = query_info_proto.group_size() > 0
|
|
? query_info_proto.group_size()
|
|
: 1;
|
|
search_info.strict_group_size_ =
|
|
query_info_proto.strict_group_size();
|
|
}
|
|
|
|
if (query_info_proto.has_search_iterator_v2_info()) {
|
|
auto& iterator_v2_info_proto =
|
|
query_info_proto.search_iterator_v2_info();
|
|
search_info.iterator_v2_info_ = SearchIteratorV2Info{
|
|
.token = iterator_v2_info_proto.token(),
|
|
.batch_size = iterator_v2_info_proto.batch_size(),
|
|
};
|
|
if (iterator_v2_info_proto.has_last_bound()) {
|
|
search_info.iterator_v2_info_->last_bound =
|
|
iterator_v2_info_proto.last_bound();
|
|
}
|
|
}
|
|
|
|
return search_info;
|
|
};
|
|
|
|
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
|
|
if (anns_proto.vector_type() ==
|
|
milvus::proto::plan::VectorType::BinaryVector) {
|
|
return std::make_unique<BinaryVectorANNS>();
|
|
} else if (anns_proto.vector_type() ==
|
|
milvus::proto::plan::VectorType::Float16Vector) {
|
|
return std::make_unique<Float16VectorANNS>();
|
|
} else if (anns_proto.vector_type() ==
|
|
milvus::proto::plan::VectorType::BFloat16Vector) {
|
|
return std::make_unique<BFloat16VectorANNS>();
|
|
} else if (anns_proto.vector_type() ==
|
|
milvus::proto::plan::VectorType::SparseFloatVector) {
|
|
return std::make_unique<SparseFloatVectorANNS>();
|
|
} else if (anns_proto.vector_type() ==
|
|
milvus::proto::plan::VectorType::Int8Vector) {
|
|
return std::make_unique<Int8VectorANNS>();
|
|
} else {
|
|
return std::make_unique<FloatVectorANNS>();
|
|
}
|
|
}();
|
|
plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
|
|
plan_node->search_info_ = std::move(search_info_parser());
|
|
|
|
milvus::plan::PlanNodePtr plannode;
|
|
std::vector<milvus::plan::PlanNodePtr> sources;
|
|
|
|
// mvcc node -> vector search node -> iterative filter node
|
|
auto iterative_filter_plan = [&]() {
|
|
plannode = std::make_shared<milvus::plan::MvccNode>(
|
|
milvus::plan::GetNextPlanNodeId());
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
|
|
auto expr = ParseExprs(anns_proto.predicates());
|
|
plannode = std::make_shared<plan::FilterNode>(
|
|
milvus::plan::GetNextPlanNodeId(), expr, sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
};
|
|
|
|
// pre filter node -> mvcc node -> vector search node
|
|
auto pre_filter_plan = [&]() {
|
|
plannode = std::move(expr_parser());
|
|
if (plan_node->search_info_.materialized_view_involved) {
|
|
const auto expr_info = plannode->GatherInfo();
|
|
knowhere::MaterializedViewSearchInfo materialized_view_search_info;
|
|
for (const auto& [expr_field_id, vals] :
|
|
expr_info.field_id_to_values) {
|
|
materialized_view_search_info
|
|
.field_id_to_touched_categories_cnt[expr_field_id] =
|
|
vals.size();
|
|
}
|
|
materialized_view_search_info.is_pure_and = expr_info.is_pure_and;
|
|
materialized_view_search_info.has_not = expr_info.has_not;
|
|
|
|
plan_node->search_info_
|
|
.search_params_[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] =
|
|
materialized_view_search_info;
|
|
}
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
plannode = std::make_shared<milvus::plan::MvccNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
|
|
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
};
|
|
|
|
if (anns_proto.has_predicates()) {
|
|
// currently limit iterative filter scope to search only
|
|
if (plan_node->search_info_.iterative_filter_execution &&
|
|
plan_node->search_info_.group_by_field_id_ == std::nullopt) {
|
|
iterative_filter_plan();
|
|
} else {
|
|
pre_filter_plan();
|
|
}
|
|
} else {
|
|
// no filter, force set iterative filter hint to false, go with normal vector search path
|
|
plan_node->search_info_.iterative_filter_execution = false;
|
|
plannode = std::make_shared<milvus::plan::MvccNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
|
|
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
}
|
|
|
|
if (plan_node->search_info_.group_by_field_id_ != std::nullopt) {
|
|
plannode = std::make_shared<milvus::plan::GroupByNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
}
|
|
|
|
// if has score function, run filter and scorer at last
|
|
if (plan_node_proto.scorers_size() > 0){
|
|
std::vector<std::shared_ptr<rescores::Scorer>> scorers;
|
|
for (const auto& function: plan_node_proto.scorers()){
|
|
scorers.push_back(ParseScorer(function));
|
|
}
|
|
|
|
plannode = std::make_shared<milvus::plan::RescoresNode>(
|
|
milvus::plan::GetNextPlanNodeId(), std::move(scorers), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
}
|
|
|
|
plan_node->plannodes_ = plannode;
|
|
|
|
return plan_node;
|
|
}
|
|
|
|
std::unique_ptr<RetrievePlanNode>
|
|
ProtoParser::RetrievePlanNodeFromProto(
|
|
const planpb::PlanNode& plan_node_proto) {
|
|
Assert(plan_node_proto.has_predicates() || plan_node_proto.has_query());
|
|
|
|
milvus::plan::PlanNodePtr plannode;
|
|
std::vector<milvus::plan::PlanNodePtr> sources;
|
|
|
|
auto plan_node = [&]() -> std::unique_ptr<RetrievePlanNode> {
|
|
auto node = std::make_unique<RetrievePlanNode>();
|
|
if (plan_node_proto.has_predicates()) { // version before 2023.03.30.
|
|
node->is_count_ = false;
|
|
auto& predicate_proto = plan_node_proto.predicates();
|
|
auto expr_parser = [&]() -> plan::PlanNodePtr {
|
|
auto expr = ParseExprs(predicate_proto);
|
|
return std::make_shared<plan::FilterBitsNode>(
|
|
milvus::plan::GetNextPlanNodeId(), expr);
|
|
}();
|
|
plannode = std::move(expr_parser);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
plannode = std::make_shared<milvus::plan::MvccNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
node->plannodes_ = std::move(plannode);
|
|
} else {
|
|
auto& query = plan_node_proto.query();
|
|
if (query.has_predicates()) {
|
|
auto parse_expr_to_filter_node =
|
|
[&](const proto::plan::Expr& predicate_proto)
|
|
-> plan::PlanNodePtr {
|
|
auto expr = ParseExprs(predicate_proto);
|
|
return std::make_shared<plan::FilterBitsNode>(
|
|
milvus::plan::GetNextPlanNodeId(), expr, sources);
|
|
};
|
|
|
|
auto* predicate_proto = &query.predicates();
|
|
if (predicate_proto->expr_case() ==
|
|
proto::plan::Expr::kRandomSampleExpr) {
|
|
// Predicate exists in random_sample_expr means we encounter expression
|
|
// like "`predicate expression` && random_sample(...)". Extract it to construct
|
|
// FilterBitsNode and make it be executed before RandomSampleNode.
|
|
auto& sample_expr = predicate_proto->random_sample_expr();
|
|
if (sample_expr.has_predicate()) {
|
|
auto expr_parser =
|
|
parse_expr_to_filter_node(sample_expr.predicate());
|
|
plannode = std::move(expr_parser);
|
|
sources =
|
|
std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
}
|
|
|
|
plannode = std::move(
|
|
std::make_shared<milvus::plan::RandomSampleNode>(
|
|
milvus::plan::GetNextPlanNodeId(),
|
|
sample_expr.sample_factor(),
|
|
sources));
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
} else {
|
|
auto expr_parser =
|
|
parse_expr_to_filter_node(query.predicates());
|
|
plannode = std::move(expr_parser);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
}
|
|
}
|
|
|
|
plannode = std::make_shared<milvus::plan::MvccNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
|
|
node->is_count_ = query.is_count();
|
|
node->limit_ = query.limit();
|
|
if (node->is_count_) {
|
|
plannode = std::make_shared<milvus::plan::CountNode>(
|
|
milvus::plan::GetNextPlanNodeId(), sources);
|
|
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
|
}
|
|
node->plannodes_ = plannode;
|
|
}
|
|
return node;
|
|
}();
|
|
|
|
return plan_node;
|
|
}
|
|
|
|
std::unique_ptr<Plan>
|
|
ProtoParser::CreatePlan(const proto::plan::PlanNode& plan_node_proto) {
|
|
LOG_DEBUG("create search plan from proto: {}",
|
|
plan_node_proto.DebugString());
|
|
auto plan = std::make_unique<Plan>(schema);
|
|
|
|
auto plan_node = PlanNodeFromProto(plan_node_proto);
|
|
plan->tag2field_["$0"] = plan_node->search_info_.field_id_;
|
|
plan->plan_node_ = std::move(plan_node);
|
|
ExtractedPlanInfo extra_info(schema->size());
|
|
extra_info.add_involved_field(plan->plan_node_->search_info_.field_id_);
|
|
plan->extra_info_opt_ = std::move(extra_info);
|
|
|
|
for (auto field_id_raw : plan_node_proto.output_field_ids()) {
|
|
auto field_id = FieldId(field_id_raw);
|
|
plan->target_entries_.push_back(field_id);
|
|
}
|
|
for (auto dynamic_field : plan_node_proto.dynamic_fields()) {
|
|
plan->target_dynamic_fields_.push_back(dynamic_field);
|
|
}
|
|
|
|
return plan;
|
|
}
|
|
|
|
std::unique_ptr<RetrievePlan>
|
|
ProtoParser::CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto) {
|
|
LOG_DEBUG("create retrieve plan from proto: {}",
|
|
plan_node_proto.DebugString());
|
|
auto retrieve_plan = std::make_unique<RetrievePlan>(schema);
|
|
|
|
auto plan_node = RetrievePlanNodeFromProto(plan_node_proto);
|
|
|
|
retrieve_plan->plan_node_ = std::move(plan_node);
|
|
for (auto field_id_raw : plan_node_proto.output_field_ids()) {
|
|
auto field_id = FieldId(field_id_raw);
|
|
retrieve_plan->field_ids_.push_back(field_id);
|
|
}
|
|
for (auto dynamic_field : plan_node_proto.dynamic_fields()) {
|
|
retrieve_plan->target_dynamic_fields_.push_back(dynamic_field);
|
|
}
|
|
|
|
return retrieve_plan;
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
|
|
auto& column_info = expr_pb.column_info();
|
|
auto field_id = FieldId(column_info.field_id());
|
|
auto data_type = schema->operator[](field_id).get_data_type();
|
|
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
|
std::vector<::milvus::proto::plan::GenericValue> extra_values;
|
|
for (auto val : expr_pb.extra_values()) {
|
|
extra_values.emplace_back(val);
|
|
}
|
|
return std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
|
|
expr::ColumnInfo(column_info),
|
|
expr_pb.op(),
|
|
expr_pb.value(),
|
|
extra_values);
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseNullExprs(const proto::plan::NullExpr& expr_pb) {
|
|
auto& column_info = expr_pb.column_info();
|
|
auto field_id = FieldId(column_info.field_id());
|
|
auto data_type = schema->operator[](field_id).get_data_type();
|
|
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
|
return std::make_shared<milvus::expr::NullExpr>(
|
|
expr::ColumnInfo(column_info), expr_pb.op());
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseBinaryRangeExprs(
|
|
const proto::plan::BinaryRangeExpr& expr_pb) {
|
|
auto& columnInfo = expr_pb.column_info();
|
|
auto field_id = FieldId(columnInfo.field_id());
|
|
auto data_type = schema->operator[](field_id).get_data_type();
|
|
Assert(data_type == (DataType)columnInfo.data_type());
|
|
return std::make_shared<expr::BinaryRangeFilterExpr>(
|
|
columnInfo,
|
|
expr_pb.lower_value(),
|
|
expr_pb.upper_value(),
|
|
expr_pb.lower_inclusive(),
|
|
expr_pb.upper_inclusive());
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) {
|
|
std::vector<expr::TypedExprPtr> parameters;
|
|
std::vector<DataType> func_param_type_list;
|
|
for (auto& param_expr : expr_pb.function_parameters()) {
|
|
// function parameter can be any type
|
|
auto e = this->ParseExprs(param_expr, TypeIsAny);
|
|
parameters.push_back(e);
|
|
func_param_type_list.push_back(e->type());
|
|
}
|
|
auto& factory = exec::expression::FunctionFactory::Instance();
|
|
exec::expression::FilterFunctionRegisterKey func_sig{
|
|
expr_pb.function_name(), std::move(func_param_type_list)};
|
|
|
|
auto function = factory.GetFilterFunction(func_sig);
|
|
if (function == nullptr) {
|
|
ThrowInfo(ExprInvalid,
|
|
"function " + func_sig.ToString() + " not found. ");
|
|
}
|
|
return std::make_shared<expr::CallExpr>(
|
|
expr_pb.function_name(), parameters, function);
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
|
|
auto& left_column_info = expr_pb.left_column_info();
|
|
auto left_field_id = FieldId(left_column_info.field_id());
|
|
auto left_data_type = schema->operator[](left_field_id).get_data_type();
|
|
Assert(left_data_type ==
|
|
static_cast<DataType>(left_column_info.data_type()));
|
|
|
|
auto& right_column_info = expr_pb.right_column_info();
|
|
auto right_field_id = FieldId(right_column_info.field_id());
|
|
auto right_data_type = schema->operator[](right_field_id).get_data_type();
|
|
Assert(right_data_type ==
|
|
static_cast<DataType>(right_column_info.data_type()));
|
|
|
|
return std::make_shared<expr::CompareExpr>(left_field_id,
|
|
right_field_id,
|
|
left_data_type,
|
|
right_data_type,
|
|
expr_pb.op());
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
|
|
auto& columnInfo = expr_pb.column_info();
|
|
auto field_id = FieldId(columnInfo.field_id());
|
|
auto data_type = schema->operator[](field_id).get_data_type();
|
|
Assert(data_type == (DataType)columnInfo.data_type());
|
|
std::vector<::milvus::proto::plan::GenericValue> values;
|
|
for (size_t i = 0; i < expr_pb.values_size(); i++) {
|
|
values.emplace_back(expr_pb.values(i));
|
|
}
|
|
return std::make_shared<expr::TermFilterExpr>(
|
|
columnInfo, values, expr_pb.is_in_field());
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb) {
|
|
auto op = static_cast<expr::LogicalUnaryExpr::OpType>(expr_pb.op());
|
|
Assert(op == expr::LogicalUnaryExpr::OpType::LogicalNot);
|
|
auto child_expr = this->ParseExprs(expr_pb.child());
|
|
return std::make_shared<expr::LogicalUnaryExpr>(op, child_expr);
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb) {
|
|
auto op = static_cast<expr::LogicalBinaryExpr::OpType>(expr_pb.op());
|
|
auto left_expr = this->ParseExprs(expr_pb.left());
|
|
auto right_expr = this->ParseExprs(expr_pb.right());
|
|
return std::make_shared<expr::LogicalBinaryExpr>(op, left_expr, right_expr);
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseBinaryArithOpEvalRangeExprs(
|
|
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
|
|
auto& column_info = expr_pb.column_info();
|
|
auto field_id = FieldId(column_info.field_id());
|
|
auto data_type = schema->operator[](field_id).get_data_type();
|
|
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
|
return std::make_shared<expr::BinaryArithOpEvalRangeExpr>(
|
|
column_info,
|
|
expr_pb.op(),
|
|
expr_pb.arith_op(),
|
|
expr_pb.value(),
|
|
expr_pb.right_operand());
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
|
|
auto& column_info = expr_pb.info();
|
|
auto field_id = FieldId(column_info.field_id());
|
|
auto data_type = schema->operator[](field_id).get_data_type();
|
|
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
|
return std::make_shared<expr::ExistsExpr>(column_info);
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseJsonContainsExprs(
|
|
const proto::plan::JSONContainsExpr& expr_pb) {
|
|
auto& columnInfo = expr_pb.column_info();
|
|
auto field_id = FieldId(columnInfo.field_id());
|
|
auto data_type = schema->operator[](field_id).get_data_type();
|
|
Assert(data_type == (DataType)columnInfo.data_type());
|
|
std::vector<::milvus::proto::plan::GenericValue> values;
|
|
for (size_t i = 0; i < expr_pb.elements_size(); i++) {
|
|
values.emplace_back(expr_pb.elements(i));
|
|
}
|
|
return std::make_shared<expr::JsonContainsExpr>(
|
|
columnInfo,
|
|
expr_pb.op(),
|
|
expr_pb.elements_same_type(),
|
|
std::move(values));
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseColumnExprs(const proto::plan::ColumnExpr& expr_pb) {
|
|
return std::make_shared<expr::ColumnExpr>(expr_pb.info());
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseValueExprs(const proto::plan::ValueExpr& expr_pb) {
|
|
return std::make_shared<expr::ValueExpr>(expr_pb.value());
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::CreateAlwaysTrueExprs() {
|
|
return std::make_shared<expr::AlwaysTrueExpr>();
|
|
}
|
|
|
|
expr::TypedExprPtr
|
|
ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
|
|
TypeCheckFunction type_check) {
|
|
using ppe = proto::plan::Expr;
|
|
expr::TypedExprPtr result;
|
|
switch (expr_pb.expr_case()) {
|
|
case ppe::kUnaryRangeExpr: {
|
|
result = ParseUnaryRangeExprs(expr_pb.unary_range_expr());
|
|
break;
|
|
}
|
|
case ppe::kBinaryExpr: {
|
|
result = ParseBinaryExprs(expr_pb.binary_expr());
|
|
break;
|
|
}
|
|
case ppe::kUnaryExpr: {
|
|
result = ParseUnaryExprs(expr_pb.unary_expr());
|
|
break;
|
|
}
|
|
case ppe::kTermExpr: {
|
|
result = ParseTermExprs(expr_pb.term_expr());
|
|
break;
|
|
}
|
|
case ppe::kBinaryRangeExpr: {
|
|
result = ParseBinaryRangeExprs(expr_pb.binary_range_expr());
|
|
break;
|
|
}
|
|
case ppe::kCompareExpr: {
|
|
result = ParseCompareExprs(expr_pb.compare_expr());
|
|
break;
|
|
}
|
|
case ppe::kBinaryArithOpEvalRangeExpr: {
|
|
result = ParseBinaryArithOpEvalRangeExprs(
|
|
expr_pb.binary_arith_op_eval_range_expr());
|
|
break;
|
|
}
|
|
case ppe::kExistsExpr: {
|
|
result = ParseExistExprs(expr_pb.exists_expr());
|
|
break;
|
|
}
|
|
case ppe::kAlwaysTrueExpr: {
|
|
result = CreateAlwaysTrueExprs();
|
|
break;
|
|
}
|
|
case ppe::kJsonContainsExpr: {
|
|
result = ParseJsonContainsExprs(expr_pb.json_contains_expr());
|
|
break;
|
|
}
|
|
case ppe::kCallExpr: {
|
|
result = ParseCallExprs(expr_pb.call_expr());
|
|
break;
|
|
}
|
|
// may emit various types
|
|
case ppe::kColumnExpr: {
|
|
result = ParseColumnExprs(expr_pb.column_expr());
|
|
break;
|
|
}
|
|
case ppe::kValueExpr: {
|
|
result = ParseValueExprs(expr_pb.value_expr());
|
|
break;
|
|
}
|
|
case ppe::kNullExpr: {
|
|
result = ParseNullExprs(expr_pb.null_expr());
|
|
break;
|
|
}
|
|
default: {
|
|
std::string s;
|
|
google::protobuf::TextFormat::PrintToString(expr_pb, &s);
|
|
ThrowInfo(ExprInvalid,
|
|
std::string("unsupported expr proto node: ") + s);
|
|
}
|
|
}
|
|
if (type_check(result->type())) {
|
|
return result;
|
|
}
|
|
ThrowInfo(
|
|
ExprInvalid, "expr type check failed, actual type: {}", result->type());
|
|
}
|
|
|
|
std::shared_ptr<rescores::Scorer> ProtoParser::ParseScorer(const proto::plan::ScoreFunction& function){
|
|
auto expr = ParseExprs(function.filter());
|
|
return std::make_shared<rescores::WeightScorer>(expr, function.weight());
|
|
}
|
|
|
|
} // namespace milvus::query
|