milvus/internal/core/src/query/PlanProto.cpp
zhagnlu 489087d18b
enhance: refactor executor framework V2 (#35251)
#32636

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
2024-09-13 20:57:09 +08:00

397 lines
16 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "PlanProto.h"
#include <google/protobuf/text_format.h>
#include <cstdint>
#include <string>
#include "common/VectorTrait.h"
#include "common/EasyAssert.h"
#include "pb/plan.pb.h"
#include "query/Utils.h"
#include "knowhere/comp/materialized_view.h"
#include "plan/PlanNode.h"
namespace milvus::query {
namespace planpb = milvus::proto::plan;
std::unique_ptr<VectorPlanNode>
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
// TODO: add more buffs
Assert(plan_node_proto.has_vector_anns());
auto& anns_proto = plan_node_proto.vector_anns();
auto expr_parser = [&]() -> plan::PlanNodePtr {
auto expr = ParseExprs(anns_proto.predicates());
return std::make_shared<plan::FilterBitsNode>(
milvus::plan::GetNextPlanNodeId(), expr);
};
auto search_info_parser = [&]() -> SearchInfo {
SearchInfo search_info;
auto& query_info_proto = anns_proto.query_info();
auto field_id = FieldId(anns_proto.field_id());
search_info.field_id_ = field_id;
search_info.metric_type_ = query_info_proto.metric_type();
search_info.topk_ = query_info_proto.topk();
search_info.round_decimal_ = query_info_proto.round_decimal();
search_info.search_params_ =
nlohmann::json::parse(query_info_proto.search_params());
search_info.materialized_view_involved =
query_info_proto.materialized_view_involved();
if (query_info_proto.group_by_field_id() > 0) {
auto group_by_field_id =
FieldId(query_info_proto.group_by_field_id());
search_info.group_by_field_id_ = group_by_field_id;
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
search_info.group_strict_size_ =
query_info_proto.group_strict_size();
}
return search_info;
};
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::BinaryVector) {
return std::make_unique<BinaryVectorANNS>();
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::Float16Vector) {
return std::make_unique<Float16VectorANNS>();
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::BFloat16Vector) {
return std::make_unique<BFloat16VectorANNS>();
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::SparseFloatVector) {
return std::make_unique<SparseFloatVectorANNS>();
} else {
return std::make_unique<FloatVectorANNS>();
}
}();
plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
plan_node->search_info_ = std::move(search_info_parser());
milvus::plan::PlanNodePtr plannode;
std::vector<milvus::plan::PlanNodePtr> sources;
if (anns_proto.has_predicates()) {
plannode = std::move(expr_parser());
if (plan_node->search_info_.materialized_view_involved) {
const auto expr_info = plannode->GatherInfo();
knowhere::MaterializedViewSearchInfo materialized_view_search_info;
for (const auto& [expr_field_id, vals] :
expr_info.field_id_to_values) {
materialized_view_search_info
.field_id_to_touched_categories_cnt[expr_field_id] =
vals.size();
}
materialized_view_search_info.is_pure_and = expr_info.is_pure_and;
materialized_view_search_info.has_not = expr_info.has_not;
plan_node->search_info_
.search_params_[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] =
materialized_view_search_info;
}
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
if (plan_node->search_info_.group_by_field_id_ != std::nullopt) {
plannode = std::make_shared<milvus::plan::GroupByNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
plan_node->plannodes_ = plannode;
return plan_node;
}
std::unique_ptr<RetrievePlanNode>
ProtoParser::RetrievePlanNodeFromProto(
const planpb::PlanNode& plan_node_proto) {
Assert(plan_node_proto.has_predicates() || plan_node_proto.has_query());
milvus::plan::PlanNodePtr plannode;
std::vector<milvus::plan::PlanNodePtr> sources;
auto plan_node = [&]() -> std::unique_ptr<RetrievePlanNode> {
auto node = std::make_unique<RetrievePlanNode>();
if (plan_node_proto.has_predicates()) { // version before 2023.03.30.
node->is_count_ = false;
auto& predicate_proto = plan_node_proto.predicates();
auto expr_parser = [&]() -> plan::PlanNodePtr {
auto expr = ParseExprs(predicate_proto);
return std::make_shared<plan::FilterBitsNode>(
milvus::plan::GetNextPlanNodeId(), expr);
}();
plannode = std::move(expr_parser);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId(), sources);
node->plannodes_ = std::move(plannode);
} else {
auto& query = plan_node_proto.query();
if (query.has_predicates()) {
auto& predicate_proto = query.predicates();
auto expr_parser = [&]() -> plan::PlanNodePtr {
auto expr = ParseExprs(predicate_proto);
return std::make_shared<plan::FilterBitsNode>(
milvus::plan::GetNextPlanNodeId(), expr);
}();
plannode = std::move(expr_parser);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
node->is_count_ = query.is_count();
node->limit_ = query.limit();
if (node->is_count_) {
plannode = std::make_shared<milvus::plan::CountNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
node->plannodes_ = plannode;
}
return node;
}();
return plan_node;
}
std::unique_ptr<Plan>
ProtoParser::CreatePlan(const proto::plan::PlanNode& plan_node_proto) {
LOG_DEBUG("create search plan from proto: {}",
plan_node_proto.DebugString());
auto plan = std::make_unique<Plan>(schema);
auto plan_node = PlanNodeFromProto(plan_node_proto);
plan->tag2field_["$0"] = plan_node->search_info_.field_id_;
plan->plan_node_ = std::move(plan_node);
ExtractedPlanInfo extra_info(schema.size());
extra_info.add_involved_field(plan->plan_node_->search_info_.field_id_);
plan->extra_info_opt_ = std::move(extra_info);
for (auto field_id_raw : plan_node_proto.output_field_ids()) {
auto field_id = FieldId(field_id_raw);
plan->target_entries_.push_back(field_id);
}
for (auto dynamic_field : plan_node_proto.dynamic_fields()) {
plan->target_dynamic_fields_.push_back(dynamic_field);
}
return plan;
}
std::unique_ptr<RetrievePlan>
ProtoParser::CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto) {
LOG_DEBUG("create retrieve plan from proto: {}",
plan_node_proto.DebugString());
auto retrieve_plan = std::make_unique<RetrievePlan>(schema);
auto plan_node = RetrievePlanNodeFromProto(plan_node_proto);
retrieve_plan->plan_node_ = std::move(plan_node);
for (auto field_id_raw : plan_node_proto.output_field_ids()) {
auto field_id = FieldId(field_id_raw);
retrieve_plan->field_ids_.push_back(field_id);
}
for (auto dynamic_field : plan_node_proto.dynamic_fields()) {
retrieve_plan->target_dynamic_fields_.push_back(dynamic_field);
}
return retrieve_plan;
}
expr::TypedExprPtr
ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
return std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(column_info), expr_pb.op(), expr_pb.value());
}
expr::TypedExprPtr
ProtoParser::ParseBinaryRangeExprs(
const proto::plan::BinaryRangeExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
return std::make_shared<expr::BinaryRangeFilterExpr>(
columnInfo,
expr_pb.lower_value(),
expr_pb.upper_value(),
expr_pb.lower_inclusive(),
expr_pb.upper_inclusive());
}
expr::TypedExprPtr
ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
auto& left_column_info = expr_pb.left_column_info();
auto left_field_id = FieldId(left_column_info.field_id());
auto left_data_type = schema[left_field_id].get_data_type();
Assert(left_data_type ==
static_cast<DataType>(left_column_info.data_type()));
auto& right_column_info = expr_pb.right_column_info();
auto right_field_id = FieldId(right_column_info.field_id());
auto right_data_type = schema[right_field_id].get_data_type();
Assert(right_data_type ==
static_cast<DataType>(right_column_info.data_type()));
return std::make_shared<expr::CompareExpr>(left_field_id,
right_field_id,
left_data_type,
right_data_type,
expr_pb.op());
}
expr::TypedExprPtr
ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
std::vector<::milvus::proto::plan::GenericValue> values;
for (size_t i = 0; i < expr_pb.values_size(); i++) {
values.emplace_back(expr_pb.values(i));
}
return std::make_shared<expr::TermFilterExpr>(
columnInfo, values, expr_pb.is_in_field());
}
expr::TypedExprPtr
ProtoParser::ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb) {
auto op = static_cast<expr::LogicalUnaryExpr::OpType>(expr_pb.op());
Assert(op == expr::LogicalUnaryExpr::OpType::LogicalNot);
auto child_expr = this->ParseExprs(expr_pb.child());
return std::make_shared<expr::LogicalUnaryExpr>(op, child_expr);
}
expr::TypedExprPtr
ProtoParser::ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb) {
auto op = static_cast<expr::LogicalBinaryExpr::OpType>(expr_pb.op());
auto left_expr = this->ParseExprs(expr_pb.left());
auto right_expr = this->ParseExprs(expr_pb.right());
return std::make_shared<expr::LogicalBinaryExpr>(op, left_expr, right_expr);
}
expr::TypedExprPtr
ProtoParser::ParseBinaryArithOpEvalRangeExprs(
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
return std::make_shared<expr::BinaryArithOpEvalRangeExpr>(
column_info,
expr_pb.op(),
expr_pb.arith_op(),
expr_pb.value(),
expr_pb.right_operand());
}
expr::TypedExprPtr
ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
auto& column_info = expr_pb.info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
return std::make_shared<expr::ExistsExpr>(column_info);
}
expr::TypedExprPtr
ProtoParser::ParseJsonContainsExprs(
const proto::plan::JSONContainsExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
std::vector<::milvus::proto::plan::GenericValue> values;
for (size_t i = 0; i < expr_pb.elements_size(); i++) {
values.emplace_back(expr_pb.elements(i));
}
return std::make_shared<expr::JsonContainsExpr>(
columnInfo,
expr_pb.op(),
expr_pb.elements_same_type(),
std::move(values));
}
expr::TypedExprPtr
ProtoParser::CreateAlwaysTrueExprs() {
return std::make_shared<expr::AlwaysTrueExpr>();
}
expr::TypedExprPtr
ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) {
using ppe = proto::plan::Expr;
switch (expr_pb.expr_case()) {
case ppe::kUnaryRangeExpr: {
return ParseUnaryRangeExprs(expr_pb.unary_range_expr());
}
case ppe::kBinaryExpr: {
return ParseBinaryExprs(expr_pb.binary_expr());
}
case ppe::kUnaryExpr: {
return ParseUnaryExprs(expr_pb.unary_expr());
}
case ppe::kTermExpr: {
return ParseTermExprs(expr_pb.term_expr());
}
case ppe::kBinaryRangeExpr: {
return ParseBinaryRangeExprs(expr_pb.binary_range_expr());
}
case ppe::kCompareExpr: {
return ParseCompareExprs(expr_pb.compare_expr());
}
case ppe::kBinaryArithOpEvalRangeExpr: {
return ParseBinaryArithOpEvalRangeExprs(
expr_pb.binary_arith_op_eval_range_expr());
}
case ppe::kExistsExpr: {
return ParseExistExprs(expr_pb.exists_expr());
}
case ppe::kAlwaysTrueExpr: {
return CreateAlwaysTrueExprs();
}
case ppe::kJsonContainsExpr: {
return ParseJsonContainsExprs(expr_pb.json_contains_expr());
}
default: {
std::string s;
google::protobuf::TextFormat::PrintToString(expr_pb, &s);
PanicInfo(ExprInvalid,
std::string("unsupported expr proto node: ") + s);
}
}
}
} // namespace milvus::query