diff --git a/internal/core/src/exec/Driver.cpp b/internal/core/src/exec/Driver.cpp index 301a814026..01cd9e1883 100644 --- a/internal/core/src/exec/Driver.cpp +++ b/internal/core/src/exec/Driver.cpp @@ -26,10 +26,12 @@ #include "exec/operator/IterativeFilterNode.h" #include "exec/operator/MvccNode.h" #include "exec/operator/Operator.h" +#include "exec/operator/RescoresNode.h" #include "exec/operator/VectorSearchNode.h" #include "exec/operator/RandomSampleNode.h" #include "exec/operator/GroupByNode.h" #include "exec/Task.h" +#include "plan/PlanNode.h" namespace milvus { namespace exec { @@ -87,6 +89,11 @@ DriverFactory::CreateDriver(std::unique_ptr ctx, plannode)) { operators.push_back(std::make_unique( id, ctx.get(), samplenode)); + } else if (auto rescoresnode = + std::dynamic_pointer_cast( + plannode)) { + operators.push_back( + std::make_unique(id, ctx.get(), rescoresnode)); } // TODO: add more operators } diff --git a/internal/core/src/exec/operator/RescoresNode.cpp b/internal/core/src/exec/operator/RescoresNode.cpp new file mode 100644 index 0000000000..e01d17fe9b --- /dev/null +++ b/internal/core/src/exec/operator/RescoresNode.cpp @@ -0,0 +1,132 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RescoresNode.h" +#include +#include "exec/operator/Utils.h" +#include "monitor/Monitor.h" + +namespace milvus::exec { + +PhyRescoresNode::PhyRescoresNode( + int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& scorer) + : Operator(ctx, + scorer->output_type(), + operator_id, + scorer->id(), + "PhyRescoresNode") { + scorers_ = scorer->scorers(); +}; + +void +PhyRescoresNode::AddInput(RowVectorPtr& input) { + input_ = std::move(input); +} + +bool +PhyRescoresNode::IsFinished() { + return is_finished_; +} + +RowVectorPtr +PhyRescoresNode::GetOutput() { + if (is_finished_ || !no_more_input_) { + return nullptr; + } + + DeferLambda([&]() { is_finished_ = true; }); + + if (input_ == nullptr) { + return nullptr; + } + ExecContext* exec_context = operator_context_->get_exec_context(); + auto query_context_ = exec_context->get_query_context(); + auto query_info = exec_context->get_query_config(); + milvus::SearchResult search_result = query_context_->get_search_result(); + + // prepare segment offset + FixedVector offsets; + std::vector offset_idx; + + for (size_t i = 0; i < search_result.seg_offsets_.size(); i++) { + // remain offset will be -1 if result count not enough (less than topk) + // skip placeholder offset + if (search_result.seg_offsets_[i] >= 0){ + offsets.push_back(static_cast(search_result.seg_offsets_[i])); + offset_idx.push_back(i); + } + } + + for (auto& scorer : scorers_) { + auto filter = scorer->filter(); + std::vector filters; + filters.emplace_back(filter); + auto expr_set = std::make_unique(filters, exec_context); + std::vector results; + EvalCtx eval_ctx(exec_context, expr_set.get()); + + const auto& exprs = expr_set->exprs(); + bool is_native_supported = true; + for (const auto& expr : exprs) { + is_native_supported = + (is_native_supported && (expr->SupportOffsetInput())); + } + + if (is_native_supported) { + // could set input offset if expr was native supported + eval_ctx.set_offset_input(&offsets); + expr_set->Eval(0, 1, true, eval_ctx, results); + + // filter result for offsets[i] was resut bitset[i] + auto col_vec = std::dynamic_pointer_cast(results[0]); + auto col_vec_size = col_vec->size(); + TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size); + Assert(bitsetview.size() == offsets.size()); + for (auto i = 0; i < offsets.size(); i++) { + if (bitsetview[i] > 0) { + search_result.distances_[offset_idx[i]] = + scorer->rescore(search_result.distances_[offset_idx[i]]); + } + } + } else { + // query all segment if expr not native + expr_set->Eval(0, 1, true, eval_ctx, results); + + // filter result for offsets[i] was bitset[offset[i]] + TargetBitmap bitset; + auto col_vec = std::dynamic_pointer_cast(results[0]); + auto col_vec_size = col_vec->size(); + TargetBitmapView view(col_vec->GetRawData(), col_vec_size); + bitset.append(view); + for (auto i = 0; i < offsets.size(); i++) { + if (bitset[offsets[i]] > 0) { + search_result.distances_[offset_idx[i]] = + scorer->rescore(search_result.distances_[offset_idx[i]]); + } + } + } + } + + knowhere::MetricType metric_type = query_context_->get_metric_type(); + bool large_is_better = PositivelyRelated(metric_type); + sort_search_result(search_result, large_is_better); + query_context_->set_search_result(std::move(search_result)); + return input_; +}; + +} // namespace milvus::exec \ No newline at end of file diff --git a/internal/core/src/exec/operator/RescoresNode.h b/internal/core/src/exec/operator/RescoresNode.h new file mode 100644 index 0000000000..e11a9126f3 --- /dev/null +++ b/internal/core/src/exec/operator/RescoresNode.h @@ -0,0 +1,77 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "exec/Driver.h" +#include "exec/expression/Expr.h" +#include "exec/operator/Operator.h" +#include "exec/QueryContext.h" +#include "pb/plan.pb.h" +#include "plan/PlanNode.h" + +// difference between FilterBitsNode and RescoresNode is that +// FilterBitsNode will go through whole segment and return bitset to indicate which offset is filtered out or not +// RescoresNode will accept offsets array and execute over these and generate result valid offsets +namespace milvus::exec { +class PhyRescoresNode : public Operator { + public: + PhyRescoresNode(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& scorer); + + bool + IsFilter() override { + return true; + } + + bool + NeedInput() const override { + return !is_finished_; + } + + void + AddInput(RowVectorPtr& input) override; + + RowVectorPtr + GetOutput() override; + + bool + IsFinished() override; + + void + Close() override { + Operator::Close(); + } + + BlockingReason + IsBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + virtual std::string + ToString() const override { + return "PhyRescoresNode"; + } + + private: + std::vector> scorers_; + bool is_finished_{false}; +}; +} // namespace milvus::exec diff --git a/internal/core/src/exec/operator/Utils.h b/internal/core/src/exec/operator/Utils.h index 77997a8785..ae5019917c 100644 --- a/internal/core/src/exec/operator/Utils.h +++ b/internal/core/src/exec/operator/Utils.h @@ -16,8 +16,11 @@ #pragma once +#include #include "common/QueryInfo.h" +#include "common/QueryResult.h" #include "knowhere/index/index_node.h" +#include "log/Log.h" #include "segcore/SegmentInterface.h" #include "segcore/SegmentGrowingImpl.h" #include "segcore/ConcurrentVector.h" @@ -96,5 +99,41 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info, } return false; } + +inline void +sort_search_result(milvus::SearchResult& result, bool large_is_better) { + auto nq = result.total_nq_; + auto topk = result.unity_topK_; + auto size = nq * topk; + + std::vector new_distances = std::vector(); + std::vector new_seg_offsets = std::vector(); + new_distances.reserve(size); + new_seg_offsets.reserve(size); + + std::vector idx(topk); + + for (size_t start = 0; start < size; start += topk) { + for (size_t i = 0; i < idx.size(); ++i) idx[i] = start + i; + + if (large_is_better) { + std::sort(idx.begin(), idx.end(), [&](size_t i, size_t j) { + return result.distances_[i] > result.distances_[j] || (result.seg_offsets_[j] >=0 &&result.seg_offsets_[j] < 0); + }); + } else { + std::sort(idx.begin(), idx.end(), [&](size_t i, size_t j) { + return result.distances_[i] < result.distances_[j] || (result.seg_offsets_[j] >=0 &&result.seg_offsets_[j] < 0); + }); + } + for (auto i : idx) { + new_distances.push_back(result.distances_[i]); + new_seg_offsets.push_back(result.seg_offsets_[i]); + } + } + + result.distances_ = new_distances; + result.seg_offsets_ = new_seg_offsets; +} + } // namespace exec } // namespace milvus \ No newline at end of file diff --git a/internal/core/src/plan/PlanNode.h b/internal/core/src/plan/PlanNode.h index 8608af6d60..a6641957ff 100644 --- a/internal/core/src/plan/PlanNode.h +++ b/internal/core/src/plan/PlanNode.h @@ -26,6 +26,7 @@ #include "common/EasyAssert.h" #include "segcore/SegmentInterface.h" #include "plan/PlanNodeIdGenerator.h" +#include "rescores/Scorer.h" namespace milvus { namespace plan { @@ -438,6 +439,46 @@ class CountNode : public PlanNode { const std::vector sources_; }; +class RescoresNode : public PlanNode { + public: + RescoresNode( + const PlanNodeId& id, + const std::vector> scorers, + const std::vector& sources = std::vector{}) + : PlanNode(id), scorers_(std::move(scorers)), sources_{std::move(sources)} { + } + + DataType + output_type() const override { + return DataType::INT64; + } + + std::vector + sources() const override { + return sources_; + } + + const std::vector>& + scorers() const { + return scorers_; + } + + std::string_view + name() const override { + return "RescoresNode"; + } + + std::string + ToString() const override { + return fmt::format("RescoresNode:\n\t[source node:{}]", + SourceToString()); + } + + private: + const std::vector sources_; + const std::vector> scorers_; +}; + enum class ExecutionStrategy { // Process splits as they come in any available driver. kUngrouped, diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index b0979b0920..ae08d72366 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -24,6 +24,7 @@ #include "query/Utils.h" #include "knowhere/comp/materialized_view.h" #include "plan/PlanNode.h" +#include "rescores/Scorer.h" namespace milvus::query { namespace planpb = milvus::proto::plan; @@ -206,6 +207,18 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { sources = std::vector{plannode}; } + // if has score function, run filter and scorer at last + if (plan_node_proto.scorers_size() > 0){ + std::vector> scorers; + for (const auto& function: plan_node_proto.scorers()){ + scorers.push_back(ParseScorer(function)); + } + + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId(), std::move(scorers), sources); + sources = std::vector{plannode}; + } + plan_node->plannodes_ = plannode; return plan_node; @@ -588,4 +601,9 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb, ExprInvalid, "expr type check failed, actual type: {}", result->type()); } +std::shared_ptr ProtoParser::ParseScorer(const proto::plan::ScoreFunction& function){ + auto expr = ParseExprs(function.filter()); + return std::make_shared(expr, function.weight()); +} + } // namespace milvus::query diff --git a/internal/core/src/query/PlanProto.h b/internal/core/src/query/PlanProto.h index 1bde865af5..3af2d68bbf 100644 --- a/internal/core/src/query/PlanProto.h +++ b/internal/core/src/query/PlanProto.h @@ -54,6 +54,9 @@ class ProtoParser { ParseExprs(const proto::plan::Expr& expr_pb, TypeCheckFunction type_check = TypeIsBool); + std::shared_ptr + ParseScorer(const proto::plan::ScoreFunction& function); + private: expr::TypedExprPtr CreateAlwaysTrueExprs(); diff --git a/internal/core/src/rescores/Scorer.h b/internal/core/src/rescores/Scorer.h new file mode 100644 index 0000000000..7cd1782fa4 --- /dev/null +++ b/internal/core/src/rescores/Scorer.h @@ -0,0 +1,58 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "expr/ITypeExpr.h" + +namespace milvus::rescores { +class Scorer { + public: + virtual expr::TypedExprPtr + filter() = 0; + + virtual float + rescore(float old_score) = 0; + + virtual float + weight() = 0; +}; + +class WeightScorer : public Scorer { + public: + WeightScorer(expr::TypedExprPtr filter, float weight) + : filter_(std::move(filter)), weight_(weight){}; + + expr::TypedExprPtr + filter() override { + return filter_; + } + + float + rescore(float old_score) override { + return old_score * weight_; + } + + float + weight() override{ + return weight_; + } + + private: + expr::TypedExprPtr filter_; + float weight_; +}; +} // namespace milvus::rescores \ No newline at end of file diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index 00feb869de..a04a1fee88 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -2,6 +2,7 @@ package planparserv2 import ( "fmt" + "strconv" "time" "github.com/antlr4-go/antlr/v4" @@ -11,8 +12,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated" + "github.com/milvus-io/milvus/internal/util/function/rerank" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -156,7 +159,7 @@ func CreateRetrievePlan(schema *typeutil.SchemaHelper, exprStr string, exprTempl return planNode, nil } -func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, error) { +func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo, exprTemplateValues map[string]*schemapb.TemplateValue, functionScorer *schemapb.FunctionScore) (*planpb.PlanNode, error) { parse := func() (*planpb.Expr, error) { if len(exprStr) <= 0 { return nil, nil @@ -199,6 +202,16 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField log.Error("Invalid dataType", zap.Any("dataType", dataType)) return nil, err } + + scorers, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues) + if err != nil { + return nil, err + } + + if len(scorers) != 0 && (queryInfo.GroupByFieldId != -1 || queryInfo.SearchIteratorV2Info != nil) { + return nil, fmt.Errorf("don't support use segment scorer with group_by or search_iterator") + } + planNode := &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ VectorAnns: &planpb.VectorANNS{ @@ -209,10 +222,62 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField FieldId: fieldID, }, }, + Scorers: scorers, } return planNode, nil } +func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.FunctionSchema, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.ScoreFunction, error) { + rerankerName := rerank.GetRerankName(function) + switch rerankerName { + case rerank.BoostName: + scorer := &planpb.ScoreFunction{} + filter, ok := funcutil.TryGetAttrByKeyFromRepeatedKV(rerank.FilterKey, function.GetParams()) + if ok { + expr, err := ParseExpr(schema, filter, exprTemplateValues) + if err != nil { + return nil, fmt.Errorf("parse expr failed with error: {%v}", err) + } + scorer.Filter = expr + } + + weightStr, ok := funcutil.TryGetAttrByKeyFromRepeatedKV(rerank.WeightKey, function.GetParams()) + if !ok { + return nil, fmt.Errorf("must set weight params for weight scorer") + } + + weight, err := strconv.ParseFloat(weightStr, 32) + if err != nil { + return nil, fmt.Errorf("parse function scorer weight params failed with error: {%v}", err) + } + scorer.Weight = float32(weight) + return scorer, nil + default: + // if not boost scorer, regard as normal function scorer + // will be checked at ranker + // return nil here + return nil, nil + } +} + +func CreateSearchScorers(schema *typeutil.SchemaHelper, functionScore *schemapb.FunctionScore, exprTemplateValues map[string]*schemapb.TemplateValue) ([]*planpb.ScoreFunction, error) { + scorers := []*planpb.ScoreFunction{} + for _, function := range functionScore.GetFunctions() { + // create scorer for search plan + scorer, err := CreateSearchScorer(schema, function, exprTemplateValues) + if err != nil { + return nil, err + } + if scorer != nil { + scorers = append(scorers, scorer) + } + } + if len(scorers) == 0 { + return nil, nil + } + return scorers, nil +} + func CreateRequeryPlan(pkField *schemapb.FieldSchema, ids *schemapb.IDs) *planpb.PlanNode { var values []*planpb.GenericValue switch ids.GetIdField().(type) { diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index e6a424270c..c94030fec9 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -14,6 +14,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function/rerank" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" @@ -230,7 +231,7 @@ func TestExpr_Like(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err, expr) assert.NotNil(t, plan) fmt.Println(plan) @@ -243,7 +244,7 @@ func TestExpr_Like(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err, expr) assert.NotNil(t, plan) fmt.Println(plan) @@ -256,7 +257,7 @@ func TestExpr_Like(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err, expr) assert.NotNil(t, plan) fmt.Println(plan) @@ -653,7 +654,7 @@ func TestCreateSearchPlan(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } @@ -664,7 +665,7 @@ func TestCreateFloat16SearchPlan(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } @@ -675,7 +676,7 @@ func TestCreateBFloat16earchPlan(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } @@ -686,7 +687,7 @@ func TestCreateSparseFloatVectorSearchPlan(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } @@ -848,19 +849,19 @@ func TestCreateRetrievePlan_Invalid(t *testing.T) { func TestCreateSearchPlan_Invalid(t *testing.T) { t.Run("invalid expr", func(t *testing.T) { schema := newTestSchemaHelper(t) - _, err := CreateSearchPlan(schema, "invalid expression", "", nil, nil) + _, err := CreateSearchPlan(schema, "invalid expression", "", nil, nil, nil) assert.Error(t, err) }) t.Run("invalid vector field", func(t *testing.T) { schema := newTestSchemaHelper(t) - _, err := CreateSearchPlan(schema, "Int64Field > 0", "not_exist", nil, nil) + _, err := CreateSearchPlan(schema, "Int64Field > 0", "not_exist", nil, nil, nil) assert.Error(t, err) }) t.Run("not vector type", func(t *testing.T) { schema := newTestSchemaHelper(t) - _, err := CreateSearchPlan(schema, "Int64Field > 0", "VarCharField", nil, nil) + _, err := CreateSearchPlan(schema, "Int64Field > 0", "VarCharField", nil, nil, nil) assert.Error(t, err) }) } @@ -1006,7 +1007,7 @@ func Test_JSONExpr(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } } @@ -1035,7 +1036,7 @@ func Test_InvalidExprOnJSONField(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err, expr) } } @@ -1072,7 +1073,7 @@ func Test_InvalidExprWithoutJSONField(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err) } } @@ -1109,7 +1110,7 @@ func Test_InvalidExprWithMultipleJSONField(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err) } } @@ -1130,7 +1131,7 @@ func Test_exprWithSingleQuotes(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } @@ -1145,7 +1146,7 @@ func Test_exprWithSingleQuotes(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err) } } @@ -1180,7 +1181,7 @@ func Test_JSONContains(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } } @@ -1211,7 +1212,7 @@ func Test_InvalidJSONContains(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err) } } @@ -1274,7 +1275,7 @@ func Test_EscapeString(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) } @@ -1291,7 +1292,7 @@ c'`, MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err) fmt.Println(plan) } @@ -1319,7 +1320,7 @@ func Test_JSONContainsAll(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) assert.Equal(t, planpb.JSONContainsExpr_ContainsAll, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) @@ -1341,7 +1342,7 @@ func Test_JSONContainsAll(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err) } } @@ -1364,7 +1365,7 @@ func Test_JSONContainsAny(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) assert.Equal(t, planpb.JSONContainsExpr_ContainsAny, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) @@ -1386,7 +1387,7 @@ func Test_JSONContainsAny(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err) } } @@ -1434,7 +1435,7 @@ func Test_ArrayExpr(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err, expr) } @@ -1468,7 +1469,7 @@ func Test_ArrayExpr(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err, expr) } } @@ -1496,7 +1497,7 @@ func Test_ArrayLength(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err, expr) } @@ -1522,7 +1523,7 @@ func Test_ArrayLength(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.Error(t, err, expr) } } @@ -1580,6 +1581,106 @@ func TestRandomSampleWithFilter(t *testing.T) { } } +func Test_SegmentScorers(t *testing.T) { + schema := newTestSchemaHelper(t) + + // helper to build a boost segment scorer function + makeBoostRanker := func(filter string, weight string) *schemapb.FunctionSchema { + params := []*commonpb.KeyValuePair{ + {Key: rerank.WeightKey, Value: weight}, + {Key: "reranker", Value: rerank.BoostName}, + } + if filter != "" { + params = append(params, &commonpb.KeyValuePair{Key: rerank.FilterKey, Value: filter}) + } + return &schemapb.FunctionSchema{ + Params: params, + } + } + + t.Run("ok - single boost scorer", func(t *testing.T) { + fs := &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{ + makeBoostRanker("Int64Field > 0", "1.5"), + }, + } + plan, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: -1}, nil, fs) + assert.NoError(t, err) + assert.NotNil(t, plan) + assert.Equal(t, 1, len(plan.Scorers)) + // filter should be parsed into Expr when provided + assert.NotNil(t, plan.Scorers[0]) + }) + + t.Run("ok - multiple boost scorers", func(t *testing.T) { + fs := &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{ + makeBoostRanker("Int64Field > 0", "1.0"), + makeBoostRanker("", "2.0"), + }, + } + plan, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: -1}, nil, fs) + assert.NoError(t, err) + assert.NotNil(t, plan) + assert.Equal(t, 2, len(plan.Scorers)) + }) + + t.Run("error - not segment scorer flag", func(t *testing.T) { + fs := &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{{Params: []*commonpb.KeyValuePair{{Key: "reranker", Value: rerank.WeightedName}}}}, + } + plan, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: -1}, nil, fs) + assert.NoError(t, err) + // not segment scorer means ignored + assert.NotNil(t, plan) + assert.Equal(t, 0, len(plan.Scorers)) + }) + + t.Run("error - missing weight", func(t *testing.T) { + // missing weight should cause CreateSearchScorer to fail + fs := &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{ + {Params: []*commonpb.KeyValuePair{ + {Key: "reranker", Value: rerank.BoostName}, + // no weight + }}, + }, + } + _, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{}, nil, fs) + assert.Error(t, err) + }) + + t.Run("error - invalid weight format", func(t *testing.T) { + fs := &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{ + makeBoostRanker("", "invalid_float"), + }, + } + _, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{}, nil, fs) + assert.Error(t, err) + }) + + t.Run("error - scorer with group_by", func(t *testing.T) { + fs := &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{ + makeBoostRanker("", "1.0"), + }, + } + _, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: 100}, nil, fs) + assert.Error(t, err) + }) + + t.Run("error - scorer with search_iterator_v2", func(t *testing.T) { + fs := &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{ + makeBoostRanker("", "1.0"), + }, + } + _, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{SearchIteratorV2Info: &planpb.SearchIteratorV2Info{}}, nil, fs) + assert.Error(t, err) + }) +} + func TestConcurrency(t *testing.T) { schemaHelper := newTestSchemaHelper(t) @@ -1653,7 +1754,7 @@ func BenchmarkWithString(b *testing.B) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(b, err) assert.NotNil(b, plan) } @@ -1696,7 +1797,7 @@ func BenchmarkTemplateWithString(b *testing.B) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, mv) + }, mv, nil) assert.NoError(b, err) assert.NotNil(b, plan) } @@ -1711,7 +1812,7 @@ func TestNestedPathWithChinese(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err, expr) paths := plan.GetVectorAnns().GetPredicates().GetUnaryRangeExpr().GetColumnInfo().GetNestedPath() assert.NotNil(t, paths) @@ -1725,7 +1826,7 @@ func TestNestedPathWithChinese(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err, expr) paths = plan.GetVectorAnns().GetPredicates().GetUnaryRangeExpr().GetColumnInfo().GetNestedPath() assert.NotNil(t, paths) @@ -1751,7 +1852,7 @@ func Test_JSONPathNullExpr(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) assert.NotNil(t, plan) @@ -1760,7 +1861,7 @@ func Test_JSONPathNullExpr(t *testing.T) { MetricType: "", SearchParams: "", RoundDecimal: 0, - }, nil) + }, nil, nil) assert.NoError(t, err) assert.NotNil(t, plan2) diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 670bdedd2b..cb570d5dec 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -634,7 +634,7 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string searchInfo.planInfo.QueryFieldId = annField.GetFieldID() start := time.Now() - plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo, exprTemplateValues) + plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo, exprTemplateValues, t.request.GetFunctionScore()) if planErr != nil { log.Ctx(t.ctx).Warn("failed to create query plan", zap.Error(planErr), zap.String("dsl", dsl), // may be very large if large term passed. diff --git a/internal/util/exprutil/expr_checker_test.go b/internal/util/exprutil/expr_checker_test.go index 79400fd76c..12f3c72878 100644 --- a/internal/util/exprutil/expr_checker_test.go +++ b/internal/util/exprutil/expr_checker_test.go @@ -120,7 +120,7 @@ func TestParsePartitionKeys(t *testing.T) { idx++ t.Log(idx, tc.name, tc.expr) // test search plan - searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo, nil) + searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo, nil, nil) assert.NoError(t, err) expr, err := ParseExprFromPlan(searchPlan) assert.NoError(t, err) diff --git a/internal/util/function/rerank/decay_function.go b/internal/util/function/rerank/decay_function.go index 97fb1b4508..e6d2368331 100644 --- a/internal/util/function/rerank/decay_function.go +++ b/internal/util/function/rerank/decay_function.go @@ -55,7 +55,7 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct { } func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true) + base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true) if err != nil { return nil, err } @@ -154,7 +154,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase case linearFunction: decayFunc.reScorer = linearDecay default: - return nil, fmt.Errorf("Invaild decay function: %s, only support [%s,%s,%s]", decayFunctionName, gaussFunction, linearFunction, expFunction) + return nil, fmt.Errorf("Invaild decay function: %s, only support [%s,%s,%s]", DecayFunctionName, gaussFunction, linearFunction, expFunction) } return decayFunc, nil } diff --git a/internal/util/function/rerank/function_score.go b/internal/util/function/rerank/function_score.go index 9296fa89c6..3511ff61d4 100644 --- a/internal/util/function/rerank/function_score.go +++ b/internal/util/function/rerank/function_score.go @@ -36,10 +36,10 @@ import ( ) const ( - decayFunctionName string = "decay" - modelFunctionName string = "model" - rrfName string = "rrf" - weightedName string = "weighted" + DecayFunctionName string = "decay" + ModelFunctionName string = "model" + RRFName string = "rrf" + WeightedName string = "weighted" ) const ( @@ -63,6 +63,13 @@ const ( weightedRankType // weightedRankType = 2 ) +// segment scorer configs +const ( + BoostName = "boost" + FilterKey = "filter" + WeightKey = "weight" +) + var rankTypeMap = map[string]rankType{ "invalid": invalidRankType, "rrf": rrfRankType, @@ -104,7 +111,7 @@ type Reranker interface { GetRankName() string } -func getRerankName(funcSchema *schemapb.FunctionSchema) string { +func GetRerankName(funcSchema *schemapb.FunctionSchema) string { for _, param := range funcSchema.Params { switch strings.ToLower(param.Key) { case reranker: @@ -128,20 +135,22 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb. return nil, fmt.Errorf("Rerank function should not have output field, but now is %d", len(funcSchema.GetOutputFieldNames())) } - rerankerName := getRerankName(funcSchema) + rerankerName := GetRerankName(funcSchema) var rerankFunc Reranker var newRerankErr error switch rerankerName { - case decayFunctionName: + case DecayFunctionName: rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema) - case modelFunctionName: + case ModelFunctionName: rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema) - case rrfName: + case RRFName: rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema) - case weightedName: + case WeightedName: rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema) + case BoostName: + return nil, nil default: - return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s,%s,%s]", rerankerName, decayFunctionName, modelFunctionName, rrfName, weightedName) + return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s,%s,%s]", rerankerName, DecayFunctionName, ModelFunctionName, RRFName, WeightedName) } if newRerankErr != nil { @@ -151,15 +160,28 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb. } func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore) (*FunctionScore, error) { - if len(funcScoreSchema.Functions) > 1 || len(funcScoreSchema.Functions) == 0 { - return nil, fmt.Errorf("Currently only supports one rerank, but got %d", len(funcScoreSchema.Functions)) - } funcScore := &FunctionScore{} - var err error - if funcScore.reranker, err = createFunction(collSchema, funcScoreSchema.Functions[0]); err != nil { - return nil, err + + for _, function := range funcScoreSchema.Functions { + reranker, err := createFunction(collSchema, function) + if err != nil { + return nil, err + } + + if reranker != nil { + if funcScore.reranker == nil { + funcScore.reranker = reranker + } else { + // now only support only use one proxy rerank + return nil, fmt.Errorf("Currently only supports one rerank") + } + } } - return funcScore, nil + + if funcScore.reranker != nil { + return funcScore, nil + } + return nil, nil } func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair) (*FunctionScore, error) { @@ -189,7 +211,7 @@ func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParam } switch rankTypeMap[rankTypeStr] { case rrfRankType: - fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: rrfName}) + fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: RRFName}) if v, ok := params[RRFParamsKey]; ok { if reflect.ValueOf(params[RRFParamsKey]).CanFloat() { k := reflect.ValueOf(v).Float() @@ -199,7 +221,7 @@ func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParam } } case weightedRankType: - fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: weightedName}) + fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: WeightedName}) if v, ok := params[WeightsParamsKey]; ok { if d, err := json.Marshal(v); err != nil { return nil, fmt.Errorf("The weights param should be an array") diff --git a/internal/util/function/rerank/function_score_test.go b/internal/util/function/rerank/function_score_test.go index 5a01c5b9c6..f688cd1a4d 100644 --- a/internal/util/function/rerank/function_score_test.go +++ b/internal/util/function/rerank/function_score_test.go @@ -61,7 +61,7 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() { Type: schemapb.FunctionType_Rerank, InputFieldNames: []string{"ts"}, Params: []*commonpb.KeyValuePair{ - {Key: reranker, Value: decayFunctionName}, + {Key: reranker, Value: DecayFunctionName}, {Key: originKey, Value: "4"}, {Key: scaleKey, Value: "4"}, {Key: offsetKey, Value: "4"}, @@ -69,6 +69,15 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() { {Key: functionKey, Value: "gauss"}, }, } + + segmentScorer := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + Params: []*commonpb.KeyValuePair{ + {Key: reranker, Value: BoostName}, + {Key: WeightKey, Value: "2"}, + }, + } funcScores := &schemapb.FunctionScore{ Functions: []*schemapb.FunctionSchema{functionSchema}, } @@ -80,6 +89,14 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() { s.Equal(true, f.IsSupportGroup()) s.Equal("decay", f.reranker.GetRankName()) + // two ranker but one was boost scorer + { + funcScores.Functions = append(funcScores.Functions, segmentScorer) + _, err := NewFunctionScore(schema, funcScores) + s.NoError(err) + funcScores.Functions = funcScores.Functions[:1] + } + { schema.Fields[3].Nullable = true _, err := NewFunctionScore(schema, funcScores) @@ -91,13 +108,13 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() { funcScores.Functions[0].Params[0].Value = "NotExist" _, err := NewFunctionScore(schema, funcScores) s.ErrorContains(err, "Unsupported rerank function") - funcScores.Functions[0].Params[0].Value = decayFunctionName + funcScores.Functions[0].Params[0].Value = DecayFunctionName } { funcScores.Functions = append(funcScores.Functions, functionSchema) _, err := NewFunctionScore(schema, funcScores) - s.ErrorContains(err, "Currently only supports one rerank, but got") + s.ErrorContains(err, "Currently only supports one rerank") funcScores.Functions = funcScores.Functions[:1] } @@ -136,7 +153,7 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() { Type: schemapb.FunctionType_Rerank, InputFieldNames: []string{"ts"}, Params: []*commonpb.KeyValuePair{ - {Key: reranker, Value: decayFunctionName}, + {Key: reranker, Value: DecayFunctionName}, {Key: originKey, Value: "4"}, {Key: scaleKey, Value: "4"}, {Key: offsetKey, Value: "4"}, @@ -322,7 +339,7 @@ func (s *FunctionScoreSuite) TestlegacyFunction() { rankParams := []*commonpb.KeyValuePair{} f, err := NewFunctionScoreWithlegacy(schema, rankParams) s.NoError(err) - s.Equal(f.RerankName(), rrfName) + s.Equal(f.RerankName(), RRFName) } { rankParams := []*commonpb.KeyValuePair{ @@ -363,7 +380,7 @@ func (s *FunctionScoreSuite) TestlegacyFunction() { } f, err := NewFunctionScoreWithlegacy(schema, rankParams) s.NoError(err) - s.Equal(f.reranker.GetRankName(), weightedName) + s.Equal(f.reranker.GetRankName(), WeightedName) } { rankParams := []*commonpb.KeyValuePair{ diff --git a/internal/util/function/rerank/model_function.go b/internal/util/function/rerank/model_function.go index 45c320e725..dd1ca37f89 100644 --- a/internal/util/function/rerank/model_function.go +++ b/internal/util/function/rerank/model_function.go @@ -323,7 +323,7 @@ type ModelFunction[T PKType] struct { } func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true) + base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/rrf_function.go b/internal/util/function/rerank/rrf_function.go index 9df02832be..2036abe732 100644 --- a/internal/util/function/rerank/rrf_function.go +++ b/internal/util/function/rerank/rrf_function.go @@ -40,7 +40,7 @@ type RRFFunction[T PKType] struct { } func newRRFFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, rrfName, true) + base, err := newRerankBase(collSchema, funcSchema, RRFName, true) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/weighted_function.go b/internal/util/function/rerank/weighted_function.go index 3aa03c0c64..5f4b155880 100644 --- a/internal/util/function/rerank/weighted_function.go +++ b/internal/util/function/rerank/weighted_function.go @@ -44,7 +44,7 @@ type WeightedFunction[T PKType] struct { } func newWeightedFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, weightedName, true) + base, err := newRerankBase(collSchema, funcSchema, WeightedName, true) if err != nil { return nil, err } diff --git a/pkg/proto/plan.proto b/pkg/proto/plan.proto index 0228beb961..8bbecf1568 100644 --- a/pkg/proto/plan.proto +++ b/pkg/proto/plan.proto @@ -253,6 +253,11 @@ message QueryPlanNode { int64 limit = 3; }; +message ScoreFunction { + Expr filter =1; + float weight = 2; +} + message PlanNode { oneof node { VectorANNS vector_anns = 1; @@ -261,4 +266,5 @@ message PlanNode { } repeated int64 output_field_ids = 3; repeated string dynamic_fields = 5; + repeated ScoreFunction scorers = 6; } diff --git a/pkg/proto/planpb/plan.pb.go b/pkg/proto/planpb/plan.pb.go index eba324a557..c238917c50 100644 --- a/pkg/proto/planpb/plan.pb.go +++ b/pkg/proto/planpb/plan.pb.go @@ -2428,6 +2428,61 @@ func (x *QueryPlanNode) GetLimit() int64 { return 0 } +type ScoreFunction struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Filter *Expr `protobuf:"bytes,1,opt,name=filter,proto3" json:"filter,omitempty"` + Weight float32 `protobuf:"fixed32,2,opt,name=weight,proto3" json:"weight,omitempty"` +} + +func (x *ScoreFunction) Reset() { + *x = ScoreFunction{} + if protoimpl.UnsafeEnabled { + mi := &file_plan_proto_msgTypes[25] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ScoreFunction) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ScoreFunction) ProtoMessage() {} + +func (x *ScoreFunction) ProtoReflect() protoreflect.Message { + mi := &file_plan_proto_msgTypes[25] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ScoreFunction.ProtoReflect.Descriptor instead. +func (*ScoreFunction) Descriptor() ([]byte, []int) { + return file_plan_proto_rawDescGZIP(), []int{25} +} + +func (x *ScoreFunction) GetFilter() *Expr { + if x != nil { + return x.Filter + } + return nil +} + +func (x *ScoreFunction) GetWeight() float32 { + if x != nil { + return x.Weight + } + return 0 +} + type PlanNode struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -2438,15 +2493,16 @@ type PlanNode struct { // *PlanNode_VectorAnns // *PlanNode_Predicates // *PlanNode_Query - Node isPlanNode_Node `protobuf_oneof:"node"` - OutputFieldIds []int64 `protobuf:"varint,3,rep,packed,name=output_field_ids,json=outputFieldIds,proto3" json:"output_field_ids,omitempty"` - DynamicFields []string `protobuf:"bytes,5,rep,name=dynamic_fields,json=dynamicFields,proto3" json:"dynamic_fields,omitempty"` + Node isPlanNode_Node `protobuf_oneof:"node"` + OutputFieldIds []int64 `protobuf:"varint,3,rep,packed,name=output_field_ids,json=outputFieldIds,proto3" json:"output_field_ids,omitempty"` + DynamicFields []string `protobuf:"bytes,5,rep,name=dynamic_fields,json=dynamicFields,proto3" json:"dynamic_fields,omitempty"` + Scorers []*ScoreFunction `protobuf:"bytes,6,rep,name=scorers,proto3" json:"scorers,omitempty"` } func (x *PlanNode) Reset() { *x = PlanNode{} if protoimpl.UnsafeEnabled { - mi := &file_plan_proto_msgTypes[25] + mi := &file_plan_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2459,7 +2515,7 @@ func (x *PlanNode) String() string { func (*PlanNode) ProtoMessage() {} func (x *PlanNode) ProtoReflect() protoreflect.Message { - mi := &file_plan_proto_msgTypes[25] + mi := &file_plan_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2472,7 +2528,7 @@ func (x *PlanNode) ProtoReflect() protoreflect.Message { // Deprecated: Use PlanNode.ProtoReflect.Descriptor instead. func (*PlanNode) Descriptor() ([]byte, []int) { - return file_plan_proto_rawDescGZIP(), []int{25} + return file_plan_proto_rawDescGZIP(), []int{26} } func (m *PlanNode) GetNode() isPlanNode_Node { @@ -2517,6 +2573,13 @@ func (x *PlanNode) GetDynamicFields() []string { return nil } +func (x *PlanNode) GetScorers() []*ScoreFunction { + if x != nil { + return x.Scorers + } + return nil +} + type isPlanNode_Node interface { isPlanNode_Node() } @@ -2948,56 +3011,66 @@ var file_plan_proto_rawDesc = []byte{ 0x0a, 0x08, 0x69, 0x73, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x69, 0x73, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x22, - 0x9a, 0x02, 0x0a, 0x08, 0x50, 0x6c, 0x61, 0x6e, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x40, 0x0a, 0x0b, - 0x76, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x5f, 0x61, 0x6e, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x41, 0x4e, 0x4e, 0x53, - 0x48, 0x00, 0x52, 0x0a, 0x76, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x41, 0x6e, 0x6e, 0x73, 0x12, 0x39, - 0x0a, 0x0a, 0x70, 0x72, 0x65, 0x64, 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x45, 0x78, 0x70, 0x72, 0x48, 0x00, 0x52, 0x0a, 0x70, - 0x72, 0x65, 0x64, 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x12, 0x38, 0x0a, 0x05, 0x71, 0x75, 0x65, - 0x72, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, - 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x51, 0x75, 0x65, - 0x72, 0x79, 0x50, 0x6c, 0x61, 0x6e, 0x4e, 0x6f, 0x64, 0x65, 0x48, 0x00, 0x52, 0x05, 0x71, 0x75, - 0x65, 0x72, 0x79, 0x12, 0x28, 0x0a, 0x10, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x66, 0x69, - 0x65, 0x6c, 0x64, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x03, 0x52, 0x0e, 0x6f, - 0x75, 0x74, 0x70, 0x75, 0x74, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x49, 0x64, 0x73, 0x12, 0x25, 0x0a, - 0x0e, 0x64, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18, - 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x46, 0x69, - 0x65, 0x6c, 0x64, 0x73, 0x42, 0x06, 0x0a, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x2a, 0xea, 0x01, 0x0a, - 0x06, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x6e, 0x76, 0x61, 0x6c, - 0x69, 0x64, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x47, 0x72, 0x65, 0x61, 0x74, 0x65, 0x72, 0x54, - 0x68, 0x61, 0x6e, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x47, 0x72, 0x65, 0x61, 0x74, 0x65, 0x72, - 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x02, 0x12, 0x0c, 0x0a, 0x08, 0x4c, 0x65, 0x73, 0x73, 0x54, - 0x68, 0x61, 0x6e, 0x10, 0x03, 0x12, 0x0d, 0x0a, 0x09, 0x4c, 0x65, 0x73, 0x73, 0x45, 0x71, 0x75, - 0x61, 0x6c, 0x10, 0x04, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x05, 0x12, - 0x0c, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x06, 0x12, 0x0f, 0x0a, - 0x0b, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x07, 0x12, 0x10, - 0x0a, 0x0c, 0x50, 0x6f, 0x73, 0x74, 0x66, 0x69, 0x78, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x08, - 0x12, 0x09, 0x0a, 0x05, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x09, 0x12, 0x09, 0x0a, 0x05, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x10, 0x0a, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x6e, 0x10, 0x0b, 0x12, 0x09, - 0x0a, 0x05, 0x4e, 0x6f, 0x74, 0x49, 0x6e, 0x10, 0x0c, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x65, 0x78, - 0x74, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0d, 0x12, 0x0f, 0x0a, 0x0b, 0x50, 0x68, 0x72, 0x61, - 0x73, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0e, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x6e, - 0x65, 0x72, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0f, 0x2a, 0x58, 0x0a, 0x0b, 0x41, 0x72, 0x69, - 0x74, 0x68, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x6e, 0x6b, 0x6e, - 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x64, 0x64, 0x10, 0x01, 0x12, 0x07, - 0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03, - 0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64, - 0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, - 0x68, 0x10, 0x06, 0x2a, 0x7d, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79, 0x70, - 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, - 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, - 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, - 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, 0x42, 0x46, 0x6c, 0x6f, 0x61, - 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x03, 0x12, 0x15, 0x0a, 0x11, 0x53, - 0x70, 0x61, 0x72, 0x73, 0x65, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, - 0x10, 0x04, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x74, 0x38, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, - 0x10, 0x05, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, - 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, - 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x58, 0x0a, 0x0d, 0x53, 0x63, 0x6f, 0x72, 0x65, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x2f, 0x0a, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x17, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x45, 0x78, 0x70, 0x72, 0x52, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, + 0x72, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x02, 0x52, 0x06, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0xd6, 0x02, 0x0a, 0x08, 0x50, 0x6c, + 0x61, 0x6e, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x40, 0x0a, 0x0b, 0x76, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x5f, 0x61, 0x6e, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x69, + 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, + 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x41, 0x4e, 0x4e, 0x53, 0x48, 0x00, 0x52, 0x0a, 0x76, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x41, 0x6e, 0x6e, 0x73, 0x12, 0x39, 0x0a, 0x0a, 0x70, 0x72, 0x65, 0x64, + 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, + 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, + 0x2e, 0x45, 0x78, 0x70, 0x72, 0x48, 0x00, 0x52, 0x0a, 0x70, 0x72, 0x65, 0x64, 0x69, 0x63, 0x61, + 0x74, 0x65, 0x73, 0x12, 0x38, 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x51, 0x75, 0x65, 0x72, 0x79, 0x50, 0x6c, 0x61, 0x6e, + 0x4e, 0x6f, 0x64, 0x65, 0x48, 0x00, 0x52, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x28, 0x0a, + 0x10, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x69, 0x64, + 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x03, 0x52, 0x0e, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x46, + 0x69, 0x65, 0x6c, 0x64, 0x49, 0x64, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x79, 0x6e, 0x61, 0x6d, + 0x69, 0x63, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x0d, 0x64, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x3a, + 0x0a, 0x07, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x20, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, + 0x6c, 0x61, 0x6e, 0x2e, 0x53, 0x63, 0x6f, 0x72, 0x65, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x07, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x72, 0x73, 0x42, 0x06, 0x0a, 0x04, 0x6e, 0x6f, + 0x64, 0x65, 0x2a, 0xea, 0x01, 0x0a, 0x06, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, + 0x07, 0x49, 0x6e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x47, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x72, 0x54, 0x68, 0x61, 0x6e, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x47, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x72, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x02, 0x12, 0x0c, 0x0a, + 0x08, 0x4c, 0x65, 0x73, 0x73, 0x54, 0x68, 0x61, 0x6e, 0x10, 0x03, 0x12, 0x0d, 0x0a, 0x09, 0x4c, + 0x65, 0x73, 0x73, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x04, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x71, + 0x75, 0x61, 0x6c, 0x10, 0x05, 0x12, 0x0c, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x45, 0x71, 0x75, 0x61, + 0x6c, 0x10, 0x06, 0x12, 0x0f, 0x0a, 0x0b, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x4d, 0x61, 0x74, + 0x63, 0x68, 0x10, 0x07, 0x12, 0x10, 0x0a, 0x0c, 0x50, 0x6f, 0x73, 0x74, 0x66, 0x69, 0x78, 0x4d, + 0x61, 0x74, 0x63, 0x68, 0x10, 0x08, 0x12, 0x09, 0x0a, 0x05, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, + 0x09, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x10, 0x0a, 0x12, 0x06, 0x0a, 0x02, + 0x49, 0x6e, 0x10, 0x0b, 0x12, 0x09, 0x0a, 0x05, 0x4e, 0x6f, 0x74, 0x49, 0x6e, 0x10, 0x0c, 0x12, + 0x0d, 0x0a, 0x09, 0x54, 0x65, 0x78, 0x74, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0d, 0x12, 0x0f, + 0x0a, 0x0b, 0x50, 0x68, 0x72, 0x61, 0x73, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0e, 0x12, + 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0f, 0x2a, + 0x58, 0x0a, 0x0b, 0x41, 0x72, 0x69, 0x74, 0x68, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, + 0x0a, 0x07, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, + 0x64, 0x64, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, + 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, + 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64, 0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, + 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x10, 0x06, 0x2a, 0x7d, 0x0a, 0x0a, 0x56, 0x65, 0x63, + 0x74, 0x6f, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, + 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, + 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, + 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x02, 0x12, 0x12, 0x0a, + 0x0e, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, + 0x03, 0x12, 0x15, 0x0a, 0x11, 0x53, 0x70, 0x61, 0x72, 0x73, 0x65, 0x46, 0x6c, 0x6f, 0x61, 0x74, + 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x04, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x74, 0x38, + 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x05, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, + 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, } var ( @@ -3013,7 +3086,7 @@ func file_plan_proto_rawDescGZIP() []byte { } var file_plan_proto_enumTypes = make([]protoimpl.EnumInfo, 7) -var file_plan_proto_msgTypes = make([]protoimpl.MessageInfo, 26) +var file_plan_proto_msgTypes = make([]protoimpl.MessageInfo, 27) var file_plan_proto_goTypes = []interface{}{ (OpType)(0), // 0: milvus.proto.plan.OpType (ArithOpType)(0), // 1: milvus.proto.plan.ArithOpType @@ -3047,16 +3120,17 @@ var file_plan_proto_goTypes = []interface{}{ (*Expr)(nil), // 29: milvus.proto.plan.Expr (*VectorANNS)(nil), // 30: milvus.proto.plan.VectorANNS (*QueryPlanNode)(nil), // 31: milvus.proto.plan.QueryPlanNode - (*PlanNode)(nil), // 32: milvus.proto.plan.PlanNode - (schemapb.DataType)(0), // 33: milvus.proto.schema.DataType + (*ScoreFunction)(nil), // 32: milvus.proto.plan.ScoreFunction + (*PlanNode)(nil), // 33: milvus.proto.plan.PlanNode + (schemapb.DataType)(0), // 34: milvus.proto.schema.DataType } var file_plan_proto_depIdxs = []int32{ 8, // 0: milvus.proto.plan.GenericValue.array_val:type_name -> milvus.proto.plan.Array 7, // 1: milvus.proto.plan.Array.array:type_name -> milvus.proto.plan.GenericValue - 33, // 2: milvus.proto.plan.Array.element_type:type_name -> milvus.proto.schema.DataType + 34, // 2: milvus.proto.plan.Array.element_type:type_name -> milvus.proto.schema.DataType 9, // 3: milvus.proto.plan.QueryInfo.search_iterator_v2_info:type_name -> milvus.proto.plan.SearchIteratorV2Info - 33, // 4: milvus.proto.plan.ColumnInfo.data_type:type_name -> milvus.proto.schema.DataType - 33, // 5: milvus.proto.plan.ColumnInfo.element_type:type_name -> milvus.proto.schema.DataType + 34, // 4: milvus.proto.plan.ColumnInfo.data_type:type_name -> milvus.proto.schema.DataType + 34, // 5: milvus.proto.plan.ColumnInfo.element_type:type_name -> milvus.proto.schema.DataType 11, // 6: milvus.proto.plan.ColumnExpr.info:type_name -> milvus.proto.plan.ColumnInfo 11, // 7: milvus.proto.plan.ExistsExpr.info:type_name -> milvus.proto.plan.ColumnInfo 7, // 8: milvus.proto.plan.ValueExpr.value:type_name -> milvus.proto.plan.GenericValue @@ -3115,14 +3189,16 @@ var file_plan_proto_depIdxs = []int32{ 29, // 61: milvus.proto.plan.VectorANNS.predicates:type_name -> milvus.proto.plan.Expr 10, // 62: milvus.proto.plan.VectorANNS.query_info:type_name -> milvus.proto.plan.QueryInfo 29, // 63: milvus.proto.plan.QueryPlanNode.predicates:type_name -> milvus.proto.plan.Expr - 30, // 64: milvus.proto.plan.PlanNode.vector_anns:type_name -> milvus.proto.plan.VectorANNS - 29, // 65: milvus.proto.plan.PlanNode.predicates:type_name -> milvus.proto.plan.Expr - 31, // 66: milvus.proto.plan.PlanNode.query:type_name -> milvus.proto.plan.QueryPlanNode - 67, // [67:67] is the sub-list for method output_type - 67, // [67:67] is the sub-list for method input_type - 67, // [67:67] is the sub-list for extension type_name - 67, // [67:67] is the sub-list for extension extendee - 0, // [0:67] is the sub-list for field type_name + 29, // 64: milvus.proto.plan.ScoreFunction.filter:type_name -> milvus.proto.plan.Expr + 30, // 65: milvus.proto.plan.PlanNode.vector_anns:type_name -> milvus.proto.plan.VectorANNS + 29, // 66: milvus.proto.plan.PlanNode.predicates:type_name -> milvus.proto.plan.Expr + 31, // 67: milvus.proto.plan.PlanNode.query:type_name -> milvus.proto.plan.QueryPlanNode + 32, // 68: milvus.proto.plan.PlanNode.scorers:type_name -> milvus.proto.plan.ScoreFunction + 69, // [69:69] is the sub-list for method output_type + 69, // [69:69] is the sub-list for method input_type + 69, // [69:69] is the sub-list for extension type_name + 69, // [69:69] is the sub-list for extension extendee + 0, // [0:69] is the sub-list for field type_name } func init() { file_plan_proto_init() } @@ -3432,6 +3508,18 @@ func file_plan_proto_init() { } } file_plan_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ScoreFunction); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_plan_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PlanNode); i { case 0: return &v.state @@ -3471,7 +3559,7 @@ func file_plan_proto_init() { (*Expr_NullExpr)(nil), (*Expr_RandomSampleExpr)(nil), } - file_plan_proto_msgTypes[25].OneofWrappers = []interface{}{ + file_plan_proto_msgTypes[26].OneofWrappers = []interface{}{ (*PlanNode_VectorAnns)(nil), (*PlanNode_Predicates)(nil), (*PlanNode_Query)(nil), @@ -3482,7 +3570,7 @@ func file_plan_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_plan_proto_rawDesc, NumEnums: 7, - NumMessages: 26, + NumMessages: 27, NumExtensions: 0, NumServices: 0, }, diff --git a/pkg/util/funcutil/func.go b/pkg/util/funcutil/func.go index 91030b4be3..47ec34781d 100644 --- a/pkg/util/funcutil/func.go +++ b/pkg/util/funcutil/func.go @@ -190,6 +190,18 @@ func GetAttrByKeyFromRepeatedKV(key string, kvs []*commonpb.KeyValuePair) (strin return "", fmt.Errorf("key %s not found", key) } +// TryGetAttrByKeyFromRepeatedKV return the value corresponding to key in kv pair +// return false if key not exist +func TryGetAttrByKeyFromRepeatedKV(key string, kvs []*commonpb.KeyValuePair) (string, bool) { + for _, kv := range kvs { + if kv.Key == key { + return kv.Value, true + } + } + + return "", false +} + // CheckCtxValid check if the context is valid func CheckCtxValid(ctx context.Context) bool { return ctx.Err() != context.DeadlineExceeded && ctx.Err() != context.Canceled