mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
feat: support use score function on segment search and use filter (#43868)
relate: https://github.com/milvus-io/milvus/issues/43867 Support boost function score, multiply by the weight if match filter. Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
71dc135289
commit
dcf04a58b8
@ -26,10 +26,12 @@
|
||||
#include "exec/operator/IterativeFilterNode.h"
|
||||
#include "exec/operator/MvccNode.h"
|
||||
#include "exec/operator/Operator.h"
|
||||
#include "exec/operator/RescoresNode.h"
|
||||
#include "exec/operator/VectorSearchNode.h"
|
||||
#include "exec/operator/RandomSampleNode.h"
|
||||
#include "exec/operator/GroupByNode.h"
|
||||
#include "exec/Task.h"
|
||||
#include "plan/PlanNode.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
@ -87,6 +89,11 @@ DriverFactory::CreateDriver(std::unique_ptr<DriverContext> ctx,
|
||||
plannode)) {
|
||||
operators.push_back(std::make_unique<PhyRandomSampleNode>(
|
||||
id, ctx.get(), samplenode));
|
||||
} else if (auto rescoresnode =
|
||||
std::dynamic_pointer_cast<const plan::RescoresNode>(
|
||||
plannode)) {
|
||||
operators.push_back(
|
||||
std::make_unique<PhyRescoresNode>(id, ctx.get(), rescoresnode));
|
||||
}
|
||||
// TODO: add more operators
|
||||
}
|
||||
|
||||
132
internal/core/src/exec/operator/RescoresNode.cpp
Normal file
132
internal/core/src/exec/operator/RescoresNode.cpp
Normal file
@ -0,0 +1,132 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "RescoresNode.h"
|
||||
#include <cstddef>
|
||||
#include "exec/operator/Utils.h"
|
||||
#include "monitor/Monitor.h"
|
||||
|
||||
namespace milvus::exec {
|
||||
|
||||
PhyRescoresNode::PhyRescoresNode(
|
||||
int32_t operator_id,
|
||||
DriverContext* ctx,
|
||||
const std::shared_ptr<const plan::RescoresNode>& scorer)
|
||||
: Operator(ctx,
|
||||
scorer->output_type(),
|
||||
operator_id,
|
||||
scorer->id(),
|
||||
"PhyRescoresNode") {
|
||||
scorers_ = scorer->scorers();
|
||||
};
|
||||
|
||||
void
|
||||
PhyRescoresNode::AddInput(RowVectorPtr& input) {
|
||||
input_ = std::move(input);
|
||||
}
|
||||
|
||||
bool
|
||||
PhyRescoresNode::IsFinished() {
|
||||
return is_finished_;
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
PhyRescoresNode::GetOutput() {
|
||||
if (is_finished_ || !no_more_input_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DeferLambda([&]() { is_finished_ = true; });
|
||||
|
||||
if (input_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
ExecContext* exec_context = operator_context_->get_exec_context();
|
||||
auto query_context_ = exec_context->get_query_context();
|
||||
auto query_info = exec_context->get_query_config();
|
||||
milvus::SearchResult search_result = query_context_->get_search_result();
|
||||
|
||||
// prepare segment offset
|
||||
FixedVector<int32_t> offsets;
|
||||
std::vector<size_t> offset_idx;
|
||||
|
||||
for (size_t i = 0; i < search_result.seg_offsets_.size(); i++) {
|
||||
// remain offset will be -1 if result count not enough (less than topk)
|
||||
// skip placeholder offset
|
||||
if (search_result.seg_offsets_[i] >= 0){
|
||||
offsets.push_back(static_cast<int32_t>(search_result.seg_offsets_[i]));
|
||||
offset_idx.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& scorer : scorers_) {
|
||||
auto filter = scorer->filter();
|
||||
std::vector<expr::TypedExprPtr> filters;
|
||||
filters.emplace_back(filter);
|
||||
auto expr_set = std::make_unique<ExprSet>(filters, exec_context);
|
||||
std::vector<VectorPtr> results;
|
||||
EvalCtx eval_ctx(exec_context, expr_set.get());
|
||||
|
||||
const auto& exprs = expr_set->exprs();
|
||||
bool is_native_supported = true;
|
||||
for (const auto& expr : exprs) {
|
||||
is_native_supported =
|
||||
(is_native_supported && (expr->SupportOffsetInput()));
|
||||
}
|
||||
|
||||
if (is_native_supported) {
|
||||
// could set input offset if expr was native supported
|
||||
eval_ctx.set_offset_input(&offsets);
|
||||
expr_set->Eval(0, 1, true, eval_ctx, results);
|
||||
|
||||
// filter result for offsets[i] was resut bitset[i]
|
||||
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
|
||||
auto col_vec_size = col_vec->size();
|
||||
TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size);
|
||||
Assert(bitsetview.size() == offsets.size());
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitsetview[i] > 0) {
|
||||
search_result.distances_[offset_idx[i]] =
|
||||
scorer->rescore(search_result.distances_[offset_idx[i]]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// query all segment if expr not native
|
||||
expr_set->Eval(0, 1, true, eval_ctx, results);
|
||||
|
||||
// filter result for offsets[i] was bitset[offset[i]]
|
||||
TargetBitmap bitset;
|
||||
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
|
||||
auto col_vec_size = col_vec->size();
|
||||
TargetBitmapView view(col_vec->GetRawData(), col_vec_size);
|
||||
bitset.append(view);
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitset[offsets[i]] > 0) {
|
||||
search_result.distances_[offset_idx[i]] =
|
||||
scorer->rescore(search_result.distances_[offset_idx[i]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
knowhere::MetricType metric_type = query_context_->get_metric_type();
|
||||
bool large_is_better = PositivelyRelated(metric_type);
|
||||
sort_search_result(search_result, large_is_better);
|
||||
query_context_->set_search_result(std::move(search_result));
|
||||
return input_;
|
||||
};
|
||||
|
||||
} // namespace milvus::exec
|
||||
77
internal/core/src/exec/operator/RescoresNode.h
Normal file
77
internal/core/src/exec/operator/RescoresNode.h
Normal file
@ -0,0 +1,77 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "exec/Driver.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "exec/operator/Operator.h"
|
||||
#include "exec/QueryContext.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "plan/PlanNode.h"
|
||||
|
||||
// difference between FilterBitsNode and RescoresNode is that
|
||||
// FilterBitsNode will go through whole segment and return bitset to indicate which offset is filtered out or not
|
||||
// RescoresNode will accept offsets array and execute over these and generate result valid offsets
|
||||
namespace milvus::exec {
|
||||
class PhyRescoresNode : public Operator {
|
||||
public:
|
||||
PhyRescoresNode(int32_t operator_id,
|
||||
DriverContext* ctx,
|
||||
const std::shared_ptr<const plan::RescoresNode>& scorer);
|
||||
|
||||
bool
|
||||
IsFilter() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NeedInput() const override {
|
||||
return !is_finished_;
|
||||
}
|
||||
|
||||
void
|
||||
AddInput(RowVectorPtr& input) override;
|
||||
|
||||
RowVectorPtr
|
||||
GetOutput() override;
|
||||
|
||||
bool
|
||||
IsFinished() override;
|
||||
|
||||
void
|
||||
Close() override {
|
||||
Operator::Close();
|
||||
}
|
||||
|
||||
BlockingReason
|
||||
IsBlocked(ContinueFuture* /* unused */) override {
|
||||
return BlockingReason::kNotBlocked;
|
||||
}
|
||||
|
||||
virtual std::string
|
||||
ToString() const override {
|
||||
return "PhyRescoresNode";
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<rescores::Scorer>> scorers_;
|
||||
bool is_finished_{false};
|
||||
};
|
||||
} // namespace milvus::exec
|
||||
@ -16,8 +16,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "knowhere/index/index_node.h"
|
||||
#include "log/Log.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "segcore/SegmentGrowingImpl.h"
|
||||
#include "segcore/ConcurrentVector.h"
|
||||
@ -96,5 +99,41 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info,
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline void
|
||||
sort_search_result(milvus::SearchResult& result, bool large_is_better) {
|
||||
auto nq = result.total_nq_;
|
||||
auto topk = result.unity_topK_;
|
||||
auto size = nq * topk;
|
||||
|
||||
std::vector<float> new_distances = std::vector<float>();
|
||||
std::vector<int64_t> new_seg_offsets = std::vector<int64_t>();
|
||||
new_distances.reserve(size);
|
||||
new_seg_offsets.reserve(size);
|
||||
|
||||
std::vector<size_t> idx(topk);
|
||||
|
||||
for (size_t start = 0; start < size; start += topk) {
|
||||
for (size_t i = 0; i < idx.size(); ++i) idx[i] = start + i;
|
||||
|
||||
if (large_is_better) {
|
||||
std::sort(idx.begin(), idx.end(), [&](size_t i, size_t j) {
|
||||
return result.distances_[i] > result.distances_[j] || (result.seg_offsets_[j] >=0 &&result.seg_offsets_[j] < 0);
|
||||
});
|
||||
} else {
|
||||
std::sort(idx.begin(), idx.end(), [&](size_t i, size_t j) {
|
||||
return result.distances_[i] < result.distances_[j] || (result.seg_offsets_[j] >=0 &&result.seg_offsets_[j] < 0);
|
||||
});
|
||||
}
|
||||
for (auto i : idx) {
|
||||
new_distances.push_back(result.distances_[i]);
|
||||
new_seg_offsets.push_back(result.seg_offsets_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
result.distances_ = new_distances;
|
||||
result.seg_offsets_ = new_seg_offsets;
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
||||
@ -26,6 +26,7 @@
|
||||
#include "common/EasyAssert.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "plan/PlanNodeIdGenerator.h"
|
||||
#include "rescores/Scorer.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace plan {
|
||||
@ -438,6 +439,46 @@ class CountNode : public PlanNode {
|
||||
const std::vector<PlanNodePtr> sources_;
|
||||
};
|
||||
|
||||
class RescoresNode : public PlanNode {
|
||||
public:
|
||||
RescoresNode(
|
||||
const PlanNodeId& id,
|
||||
const std::vector<std::shared_ptr<rescores::Scorer>> scorers,
|
||||
const std::vector<PlanNodePtr>& sources = std::vector<PlanNodePtr>{})
|
||||
: PlanNode(id), scorers_(std::move(scorers)), sources_{std::move(sources)} {
|
||||
}
|
||||
|
||||
DataType
|
||||
output_type() const override {
|
||||
return DataType::INT64;
|
||||
}
|
||||
|
||||
std::vector<PlanNodePtr>
|
||||
sources() const override {
|
||||
return sources_;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<rescores::Scorer>>&
|
||||
scorers() const {
|
||||
return scorers_;
|
||||
}
|
||||
|
||||
std::string_view
|
||||
name() const override {
|
||||
return "RescoresNode";
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return fmt::format("RescoresNode:\n\t[source node:{}]",
|
||||
SourceToString());
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<PlanNodePtr> sources_;
|
||||
const std::vector<std::shared_ptr<rescores::Scorer>> scorers_;
|
||||
};
|
||||
|
||||
enum class ExecutionStrategy {
|
||||
// Process splits as they come in any available driver.
|
||||
kUngrouped,
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include "query/Utils.h"
|
||||
#include "knowhere/comp/materialized_view.h"
|
||||
#include "plan/PlanNode.h"
|
||||
#include "rescores/Scorer.h"
|
||||
|
||||
namespace milvus::query {
|
||||
namespace planpb = milvus::proto::plan;
|
||||
@ -206,6 +207,18 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
}
|
||||
|
||||
// if has score function, run filter and scorer at last
|
||||
if (plan_node_proto.scorers_size() > 0){
|
||||
std::vector<std::shared_ptr<rescores::Scorer>> scorers;
|
||||
for (const auto& function: plan_node_proto.scorers()){
|
||||
scorers.push_back(ParseScorer(function));
|
||||
}
|
||||
|
||||
plannode = std::make_shared<milvus::plan::RescoresNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), std::move(scorers), sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
}
|
||||
|
||||
plan_node->plannodes_ = plannode;
|
||||
|
||||
return plan_node;
|
||||
@ -588,4 +601,9 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
|
||||
ExprInvalid, "expr type check failed, actual type: {}", result->type());
|
||||
}
|
||||
|
||||
std::shared_ptr<rescores::Scorer> ProtoParser::ParseScorer(const proto::plan::ScoreFunction& function){
|
||||
auto expr = ParseExprs(function.filter());
|
||||
return std::make_shared<rescores::WeightScorer>(expr, function.weight());
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
||||
@ -54,6 +54,9 @@ class ProtoParser {
|
||||
ParseExprs(const proto::plan::Expr& expr_pb,
|
||||
TypeCheckFunction type_check = TypeIsBool);
|
||||
|
||||
std::shared_ptr<rescores::Scorer>
|
||||
ParseScorer(const proto::plan::ScoreFunction& function);
|
||||
|
||||
private:
|
||||
expr::TypedExprPtr
|
||||
CreateAlwaysTrueExprs();
|
||||
|
||||
58
internal/core/src/rescores/Scorer.h
Normal file
58
internal/core/src/rescores/Scorer.h
Normal file
@ -0,0 +1,58 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "expr/ITypeExpr.h"
|
||||
|
||||
namespace milvus::rescores {
|
||||
class Scorer {
|
||||
public:
|
||||
virtual expr::TypedExprPtr
|
||||
filter() = 0;
|
||||
|
||||
virtual float
|
||||
rescore(float old_score) = 0;
|
||||
|
||||
virtual float
|
||||
weight() = 0;
|
||||
};
|
||||
|
||||
class WeightScorer : public Scorer {
|
||||
public:
|
||||
WeightScorer(expr::TypedExprPtr filter, float weight)
|
||||
: filter_(std::move(filter)), weight_(weight){};
|
||||
|
||||
expr::TypedExprPtr
|
||||
filter() override {
|
||||
return filter_;
|
||||
}
|
||||
|
||||
float
|
||||
rescore(float old_score) override {
|
||||
return old_score * weight_;
|
||||
}
|
||||
|
||||
float
|
||||
weight() override{
|
||||
return weight_;
|
||||
}
|
||||
|
||||
private:
|
||||
expr::TypedExprPtr filter_;
|
||||
float weight_;
|
||||
};
|
||||
} // namespace milvus::rescores
|
||||
@ -2,6 +2,7 @@ package planparserv2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/antlr4-go/antlr/v4"
|
||||
@ -11,8 +12,10 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
@ -156,7 +159,7 @@ func CreateRetrievePlan(schema *typeutil.SchemaHelper, exprStr string, exprTempl
|
||||
return planNode, nil
|
||||
}
|
||||
|
||||
func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, error) {
|
||||
func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo, exprTemplateValues map[string]*schemapb.TemplateValue, functionScorer *schemapb.FunctionScore) (*planpb.PlanNode, error) {
|
||||
parse := func() (*planpb.Expr, error) {
|
||||
if len(exprStr) <= 0 {
|
||||
return nil, nil
|
||||
@ -199,6 +202,16 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
|
||||
log.Error("Invalid dataType", zap.Any("dataType", dataType))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scorers, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(scorers) != 0 && (queryInfo.GroupByFieldId != -1 || queryInfo.SearchIteratorV2Info != nil) {
|
||||
return nil, fmt.Errorf("don't support use segment scorer with group_by or search_iterator")
|
||||
}
|
||||
|
||||
planNode := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
@ -209,10 +222,62 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
|
||||
FieldId: fieldID,
|
||||
},
|
||||
},
|
||||
Scorers: scorers,
|
||||
}
|
||||
return planNode, nil
|
||||
}
|
||||
|
||||
func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.FunctionSchema, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.ScoreFunction, error) {
|
||||
rerankerName := rerank.GetRerankName(function)
|
||||
switch rerankerName {
|
||||
case rerank.BoostName:
|
||||
scorer := &planpb.ScoreFunction{}
|
||||
filter, ok := funcutil.TryGetAttrByKeyFromRepeatedKV(rerank.FilterKey, function.GetParams())
|
||||
if ok {
|
||||
expr, err := ParseExpr(schema, filter, exprTemplateValues)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse expr failed with error: {%v}", err)
|
||||
}
|
||||
scorer.Filter = expr
|
||||
}
|
||||
|
||||
weightStr, ok := funcutil.TryGetAttrByKeyFromRepeatedKV(rerank.WeightKey, function.GetParams())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("must set weight params for weight scorer")
|
||||
}
|
||||
|
||||
weight, err := strconv.ParseFloat(weightStr, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse function scorer weight params failed with error: {%v}", err)
|
||||
}
|
||||
scorer.Weight = float32(weight)
|
||||
return scorer, nil
|
||||
default:
|
||||
// if not boost scorer, regard as normal function scorer
|
||||
// will be checked at ranker
|
||||
// return nil here
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func CreateSearchScorers(schema *typeutil.SchemaHelper, functionScore *schemapb.FunctionScore, exprTemplateValues map[string]*schemapb.TemplateValue) ([]*planpb.ScoreFunction, error) {
|
||||
scorers := []*planpb.ScoreFunction{}
|
||||
for _, function := range functionScore.GetFunctions() {
|
||||
// create scorer for search plan
|
||||
scorer, err := CreateSearchScorer(schema, function, exprTemplateValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if scorer != nil {
|
||||
scorers = append(scorers, scorer)
|
||||
}
|
||||
}
|
||||
if len(scorers) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return scorers, nil
|
||||
}
|
||||
|
||||
func CreateRequeryPlan(pkField *schemapb.FieldSchema, ids *schemapb.IDs) *planpb.PlanNode {
|
||||
var values []*planpb.GenericValue
|
||||
switch ids.GetIdField().(type) {
|
||||
|
||||
@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
@ -230,7 +231,7 @@ func TestExpr_Like(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
assert.NotNil(t, plan)
|
||||
fmt.Println(plan)
|
||||
@ -243,7 +244,7 @@ func TestExpr_Like(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
assert.NotNil(t, plan)
|
||||
fmt.Println(plan)
|
||||
@ -256,7 +257,7 @@ func TestExpr_Like(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
assert.NotNil(t, plan)
|
||||
fmt.Println(plan)
|
||||
@ -653,7 +654,7 @@ func TestCreateSearchPlan(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -664,7 +665,7 @@ func TestCreateFloat16SearchPlan(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -675,7 +676,7 @@ func TestCreateBFloat16earchPlan(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -686,7 +687,7 @@ func TestCreateSparseFloatVectorSearchPlan(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -848,19 +849,19 @@ func TestCreateRetrievePlan_Invalid(t *testing.T) {
|
||||
func TestCreateSearchPlan_Invalid(t *testing.T) {
|
||||
t.Run("invalid expr", func(t *testing.T) {
|
||||
schema := newTestSchemaHelper(t)
|
||||
_, err := CreateSearchPlan(schema, "invalid expression", "", nil, nil)
|
||||
_, err := CreateSearchPlan(schema, "invalid expression", "", nil, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid vector field", func(t *testing.T) {
|
||||
schema := newTestSchemaHelper(t)
|
||||
_, err := CreateSearchPlan(schema, "Int64Field > 0", "not_exist", nil, nil)
|
||||
_, err := CreateSearchPlan(schema, "Int64Field > 0", "not_exist", nil, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not vector type", func(t *testing.T) {
|
||||
schema := newTestSchemaHelper(t)
|
||||
_, err := CreateSearchPlan(schema, "Int64Field > 0", "VarCharField", nil, nil)
|
||||
_, err := CreateSearchPlan(schema, "Int64Field > 0", "VarCharField", nil, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@ -1006,7 +1007,7 @@ func Test_JSONExpr(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@ -1035,7 +1036,7 @@ func Test_InvalidExprOnJSONField(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err, expr)
|
||||
}
|
||||
}
|
||||
@ -1072,7 +1073,7 @@ func Test_InvalidExprWithoutJSONField(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@ -1109,7 +1110,7 @@ func Test_InvalidExprWithMultipleJSONField(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@ -1130,7 +1131,7 @@ func Test_exprWithSingleQuotes(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -1145,7 +1146,7 @@ func Test_exprWithSingleQuotes(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@ -1180,7 +1181,7 @@ func Test_JSONContains(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@ -1211,7 +1212,7 @@ func Test_InvalidJSONContains(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@ -1274,7 +1275,7 @@ func Test_EscapeString(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -1291,7 +1292,7 @@ c'`,
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
fmt.Println(plan)
|
||||
}
|
||||
@ -1319,7 +1320,7 @@ func Test_JSONContainsAll(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr())
|
||||
assert.Equal(t, planpb.JSONContainsExpr_ContainsAll, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp())
|
||||
@ -1341,7 +1342,7 @@ func Test_JSONContainsAll(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@ -1364,7 +1365,7 @@ func Test_JSONContainsAny(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr())
|
||||
assert.Equal(t, planpb.JSONContainsExpr_ContainsAny, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp())
|
||||
@ -1386,7 +1387,7 @@ func Test_JSONContainsAny(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@ -1434,7 +1435,7 @@ func Test_ArrayExpr(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
}
|
||||
|
||||
@ -1468,7 +1469,7 @@ func Test_ArrayExpr(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err, expr)
|
||||
}
|
||||
}
|
||||
@ -1496,7 +1497,7 @@ func Test_ArrayLength(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
}
|
||||
|
||||
@ -1522,7 +1523,7 @@ func Test_ArrayLength(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.Error(t, err, expr)
|
||||
}
|
||||
}
|
||||
@ -1580,6 +1581,106 @@ func TestRandomSampleWithFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SegmentScorers(t *testing.T) {
|
||||
schema := newTestSchemaHelper(t)
|
||||
|
||||
// helper to build a boost segment scorer function
|
||||
makeBoostRanker := func(filter string, weight string) *schemapb.FunctionSchema {
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: rerank.WeightKey, Value: weight},
|
||||
{Key: "reranker", Value: rerank.BoostName},
|
||||
}
|
||||
if filter != "" {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: rerank.FilterKey, Value: filter})
|
||||
}
|
||||
return &schemapb.FunctionSchema{
|
||||
Params: params,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("ok - single boost scorer", func(t *testing.T) {
|
||||
fs := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
makeBoostRanker("Int64Field > 0", "1.5"),
|
||||
},
|
||||
}
|
||||
plan, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: -1}, nil, fs)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan)
|
||||
assert.Equal(t, 1, len(plan.Scorers))
|
||||
// filter should be parsed into Expr when provided
|
||||
assert.NotNil(t, plan.Scorers[0])
|
||||
})
|
||||
|
||||
t.Run("ok - multiple boost scorers", func(t *testing.T) {
|
||||
fs := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
makeBoostRanker("Int64Field > 0", "1.0"),
|
||||
makeBoostRanker("", "2.0"),
|
||||
},
|
||||
}
|
||||
plan, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: -1}, nil, fs)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan)
|
||||
assert.Equal(t, 2, len(plan.Scorers))
|
||||
})
|
||||
|
||||
t.Run("error - not segment scorer flag", func(t *testing.T) {
|
||||
fs := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{{Params: []*commonpb.KeyValuePair{{Key: "reranker", Value: rerank.WeightedName}}}},
|
||||
}
|
||||
plan, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: -1}, nil, fs)
|
||||
assert.NoError(t, err)
|
||||
// not segment scorer means ignored
|
||||
assert.NotNil(t, plan)
|
||||
assert.Equal(t, 0, len(plan.Scorers))
|
||||
})
|
||||
|
||||
t.Run("error - missing weight", func(t *testing.T) {
|
||||
// missing weight should cause CreateSearchScorer to fail
|
||||
fs := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{Params: []*commonpb.KeyValuePair{
|
||||
{Key: "reranker", Value: rerank.BoostName},
|
||||
// no weight
|
||||
}},
|
||||
},
|
||||
}
|
||||
_, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{}, nil, fs)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("error - invalid weight format", func(t *testing.T) {
|
||||
fs := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
makeBoostRanker("", "invalid_float"),
|
||||
},
|
||||
}
|
||||
_, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{}, nil, fs)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("error - scorer with group_by", func(t *testing.T) {
|
||||
fs := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
makeBoostRanker("", "1.0"),
|
||||
},
|
||||
}
|
||||
_, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{GroupByFieldId: 100}, nil, fs)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("error - scorer with search_iterator_v2", func(t *testing.T) {
|
||||
fs := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
makeBoostRanker("", "1.0"),
|
||||
},
|
||||
}
|
||||
_, err := CreateSearchPlan(schema, "", "FloatVectorField", &planpb.QueryInfo{SearchIteratorV2Info: &planpb.SearchIteratorV2Info{}}, nil, fs)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
schemaHelper := newTestSchemaHelper(t)
|
||||
|
||||
@ -1653,7 +1754,7 @@ func BenchmarkWithString(b *testing.B) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(b, err)
|
||||
assert.NotNil(b, plan)
|
||||
}
|
||||
@ -1696,7 +1797,7 @@ func BenchmarkTemplateWithString(b *testing.B) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, mv)
|
||||
}, mv, nil)
|
||||
assert.NoError(b, err)
|
||||
assert.NotNil(b, plan)
|
||||
}
|
||||
@ -1711,7 +1812,7 @@ func TestNestedPathWithChinese(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
paths := plan.GetVectorAnns().GetPredicates().GetUnaryRangeExpr().GetColumnInfo().GetNestedPath()
|
||||
assert.NotNil(t, paths)
|
||||
@ -1725,7 +1826,7 @@ func TestNestedPathWithChinese(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
paths = plan.GetVectorAnns().GetPredicates().GetUnaryRangeExpr().GetColumnInfo().GetNestedPath()
|
||||
assert.NotNil(t, paths)
|
||||
@ -1751,7 +1852,7 @@ func Test_JSONPathNullExpr(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan)
|
||||
|
||||
@ -1760,7 +1861,7 @@ func Test_JSONPathNullExpr(t *testing.T) {
|
||||
MetricType: "",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan2)
|
||||
|
||||
|
||||
@ -634,7 +634,7 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
|
||||
|
||||
searchInfo.planInfo.QueryFieldId = annField.GetFieldID()
|
||||
start := time.Now()
|
||||
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo, exprTemplateValues)
|
||||
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo, exprTemplateValues, t.request.GetFunctionScore())
|
||||
if planErr != nil {
|
||||
log.Ctx(t.ctx).Warn("failed to create query plan", zap.Error(planErr),
|
||||
zap.String("dsl", dsl), // may be very large if large term passed.
|
||||
|
||||
@ -120,7 +120,7 @@ func TestParsePartitionKeys(t *testing.T) {
|
||||
idx++
|
||||
t.Log(idx, tc.name, tc.expr)
|
||||
// test search plan
|
||||
searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo, nil)
|
||||
searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
expr, err := ParseExprFromPlan(searchPlan)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -55,7 +55,7 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
|
||||
}
|
||||
|
||||
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true)
|
||||
base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -154,7 +154,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
|
||||
case linearFunction:
|
||||
decayFunc.reScorer = linearDecay
|
||||
default:
|
||||
return nil, fmt.Errorf("Invaild decay function: %s, only support [%s,%s,%s]", decayFunctionName, gaussFunction, linearFunction, expFunction)
|
||||
return nil, fmt.Errorf("Invaild decay function: %s, only support [%s,%s,%s]", DecayFunctionName, gaussFunction, linearFunction, expFunction)
|
||||
}
|
||||
return decayFunc, nil
|
||||
}
|
||||
|
||||
@ -36,10 +36,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
decayFunctionName string = "decay"
|
||||
modelFunctionName string = "model"
|
||||
rrfName string = "rrf"
|
||||
weightedName string = "weighted"
|
||||
DecayFunctionName string = "decay"
|
||||
ModelFunctionName string = "model"
|
||||
RRFName string = "rrf"
|
||||
WeightedName string = "weighted"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -63,6 +63,13 @@ const (
|
||||
weightedRankType // weightedRankType = 2
|
||||
)
|
||||
|
||||
// segment scorer configs
|
||||
const (
|
||||
BoostName = "boost"
|
||||
FilterKey = "filter"
|
||||
WeightKey = "weight"
|
||||
)
|
||||
|
||||
var rankTypeMap = map[string]rankType{
|
||||
"invalid": invalidRankType,
|
||||
"rrf": rrfRankType,
|
||||
@ -104,7 +111,7 @@ type Reranker interface {
|
||||
GetRankName() string
|
||||
}
|
||||
|
||||
func getRerankName(funcSchema *schemapb.FunctionSchema) string {
|
||||
func GetRerankName(funcSchema *schemapb.FunctionSchema) string {
|
||||
for _, param := range funcSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case reranker:
|
||||
@ -128,20 +135,22 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
|
||||
return nil, fmt.Errorf("Rerank function should not have output field, but now is %d", len(funcSchema.GetOutputFieldNames()))
|
||||
}
|
||||
|
||||
rerankerName := getRerankName(funcSchema)
|
||||
rerankerName := GetRerankName(funcSchema)
|
||||
var rerankFunc Reranker
|
||||
var newRerankErr error
|
||||
switch rerankerName {
|
||||
case decayFunctionName:
|
||||
case DecayFunctionName:
|
||||
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
|
||||
case modelFunctionName:
|
||||
case ModelFunctionName:
|
||||
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
|
||||
case rrfName:
|
||||
case RRFName:
|
||||
rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema)
|
||||
case weightedName:
|
||||
case WeightedName:
|
||||
rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema)
|
||||
case BoostName:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s,%s,%s]", rerankerName, decayFunctionName, modelFunctionName, rrfName, weightedName)
|
||||
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s,%s,%s]", rerankerName, DecayFunctionName, ModelFunctionName, RRFName, WeightedName)
|
||||
}
|
||||
|
||||
if newRerankErr != nil {
|
||||
@ -151,15 +160,28 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
|
||||
}
|
||||
|
||||
func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore) (*FunctionScore, error) {
|
||||
if len(funcScoreSchema.Functions) > 1 || len(funcScoreSchema.Functions) == 0 {
|
||||
return nil, fmt.Errorf("Currently only supports one rerank, but got %d", len(funcScoreSchema.Functions))
|
||||
}
|
||||
funcScore := &FunctionScore{}
|
||||
var err error
|
||||
if funcScore.reranker, err = createFunction(collSchema, funcScoreSchema.Functions[0]); err != nil {
|
||||
return nil, err
|
||||
|
||||
for _, function := range funcScoreSchema.Functions {
|
||||
reranker, err := createFunction(collSchema, function)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if reranker != nil {
|
||||
if funcScore.reranker == nil {
|
||||
funcScore.reranker = reranker
|
||||
} else {
|
||||
// now only support only use one proxy rerank
|
||||
return nil, fmt.Errorf("Currently only supports one rerank")
|
||||
}
|
||||
}
|
||||
}
|
||||
return funcScore, nil
|
||||
|
||||
if funcScore.reranker != nil {
|
||||
return funcScore, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair) (*FunctionScore, error) {
|
||||
@ -189,7 +211,7 @@ func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParam
|
||||
}
|
||||
switch rankTypeMap[rankTypeStr] {
|
||||
case rrfRankType:
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: rrfName})
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: RRFName})
|
||||
if v, ok := params[RRFParamsKey]; ok {
|
||||
if reflect.ValueOf(params[RRFParamsKey]).CanFloat() {
|
||||
k := reflect.ValueOf(v).Float()
|
||||
@ -199,7 +221,7 @@ func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParam
|
||||
}
|
||||
}
|
||||
case weightedRankType:
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: weightedName})
|
||||
fSchema.Params = append(fSchema.Params, &commonpb.KeyValuePair{Key: reranker, Value: WeightedName})
|
||||
if v, ok := params[WeightsParamsKey]; ok {
|
||||
if d, err := json.Marshal(v); err != nil {
|
||||
return nil, fmt.Errorf("The weights param should be an array")
|
||||
|
||||
@ -61,7 +61,7 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() {
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: reranker, Value: decayFunctionName},
|
||||
{Key: reranker, Value: DecayFunctionName},
|
||||
{Key: originKey, Value: "4"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "4"},
|
||||
@ -69,6 +69,15 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() {
|
||||
{Key: functionKey, Value: "gauss"},
|
||||
},
|
||||
}
|
||||
|
||||
segmentScorer := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: reranker, Value: BoostName},
|
||||
{Key: WeightKey, Value: "2"},
|
||||
},
|
||||
}
|
||||
funcScores := &schemapb.FunctionScore{
|
||||
Functions: []*schemapb.FunctionSchema{functionSchema},
|
||||
}
|
||||
@ -80,6 +89,14 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() {
|
||||
s.Equal(true, f.IsSupportGroup())
|
||||
s.Equal("decay", f.reranker.GetRankName())
|
||||
|
||||
// two ranker but one was boost scorer
|
||||
{
|
||||
funcScores.Functions = append(funcScores.Functions, segmentScorer)
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.NoError(err)
|
||||
funcScores.Functions = funcScores.Functions[:1]
|
||||
}
|
||||
|
||||
{
|
||||
schema.Fields[3].Nullable = true
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
@ -91,13 +108,13 @@ func (s *FunctionScoreSuite) TestNewFunctionScore() {
|
||||
funcScores.Functions[0].Params[0].Value = "NotExist"
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.ErrorContains(err, "Unsupported rerank function")
|
||||
funcScores.Functions[0].Params[0].Value = decayFunctionName
|
||||
funcScores.Functions[0].Params[0].Value = DecayFunctionName
|
||||
}
|
||||
|
||||
{
|
||||
funcScores.Functions = append(funcScores.Functions, functionSchema)
|
||||
_, err := NewFunctionScore(schema, funcScores)
|
||||
s.ErrorContains(err, "Currently only supports one rerank, but got")
|
||||
s.ErrorContains(err, "Currently only supports one rerank")
|
||||
funcScores.Functions = funcScores.Functions[:1]
|
||||
}
|
||||
|
||||
@ -136,7 +153,7 @@ func (s *FunctionScoreSuite) TestFunctionScoreProcess() {
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"ts"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: reranker, Value: decayFunctionName},
|
||||
{Key: reranker, Value: DecayFunctionName},
|
||||
{Key: originKey, Value: "4"},
|
||||
{Key: scaleKey, Value: "4"},
|
||||
{Key: offsetKey, Value: "4"},
|
||||
@ -322,7 +339,7 @@ func (s *FunctionScoreSuite) TestlegacyFunction() {
|
||||
rankParams := []*commonpb.KeyValuePair{}
|
||||
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.NoError(err)
|
||||
s.Equal(f.RerankName(), rrfName)
|
||||
s.Equal(f.RerankName(), RRFName)
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
@ -363,7 +380,7 @@ func (s *FunctionScoreSuite) TestlegacyFunction() {
|
||||
}
|
||||
f, err := NewFunctionScoreWithlegacy(schema, rankParams)
|
||||
s.NoError(err)
|
||||
s.Equal(f.reranker.GetRankName(), weightedName)
|
||||
s.Equal(f.reranker.GetRankName(), WeightedName)
|
||||
}
|
||||
{
|
||||
rankParams := []*commonpb.KeyValuePair{
|
||||
|
||||
@ -323,7 +323,7 @@ type ModelFunction[T PKType] struct {
|
||||
}
|
||||
|
||||
func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, true)
|
||||
base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ type RRFFunction[T PKType] struct {
|
||||
}
|
||||
|
||||
func newRRFFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, rrfName, true)
|
||||
base, err := newRerankBase(collSchema, funcSchema, RRFName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ type WeightedFunction[T PKType] struct {
|
||||
}
|
||||
|
||||
func newWeightedFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, weightedName, true)
|
||||
base, err := newRerankBase(collSchema, funcSchema, WeightedName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -253,6 +253,11 @@ message QueryPlanNode {
|
||||
int64 limit = 3;
|
||||
};
|
||||
|
||||
message ScoreFunction {
|
||||
Expr filter =1;
|
||||
float weight = 2;
|
||||
}
|
||||
|
||||
message PlanNode {
|
||||
oneof node {
|
||||
VectorANNS vector_anns = 1;
|
||||
@ -261,4 +266,5 @@ message PlanNode {
|
||||
}
|
||||
repeated int64 output_field_ids = 3;
|
||||
repeated string dynamic_fields = 5;
|
||||
repeated ScoreFunction scorers = 6;
|
||||
}
|
||||
|
||||
@ -2428,6 +2428,61 @@ func (x *QueryPlanNode) GetLimit() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
type ScoreFunction struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Filter *Expr `protobuf:"bytes,1,opt,name=filter,proto3" json:"filter,omitempty"`
|
||||
Weight float32 `protobuf:"fixed32,2,opt,name=weight,proto3" json:"weight,omitempty"`
|
||||
}
|
||||
|
||||
func (x *ScoreFunction) Reset() {
|
||||
*x = ScoreFunction{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_plan_proto_msgTypes[25]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *ScoreFunction) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ScoreFunction) ProtoMessage() {}
|
||||
|
||||
func (x *ScoreFunction) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_plan_proto_msgTypes[25]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ScoreFunction.ProtoReflect.Descriptor instead.
|
||||
func (*ScoreFunction) Descriptor() ([]byte, []int) {
|
||||
return file_plan_proto_rawDescGZIP(), []int{25}
|
||||
}
|
||||
|
||||
func (x *ScoreFunction) GetFilter() *Expr {
|
||||
if x != nil {
|
||||
return x.Filter
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *ScoreFunction) GetWeight() float32 {
|
||||
if x != nil {
|
||||
return x.Weight
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type PlanNode struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
@ -2438,15 +2493,16 @@ type PlanNode struct {
|
||||
// *PlanNode_VectorAnns
|
||||
// *PlanNode_Predicates
|
||||
// *PlanNode_Query
|
||||
Node isPlanNode_Node `protobuf_oneof:"node"`
|
||||
OutputFieldIds []int64 `protobuf:"varint,3,rep,packed,name=output_field_ids,json=outputFieldIds,proto3" json:"output_field_ids,omitempty"`
|
||||
DynamicFields []string `protobuf:"bytes,5,rep,name=dynamic_fields,json=dynamicFields,proto3" json:"dynamic_fields,omitempty"`
|
||||
Node isPlanNode_Node `protobuf_oneof:"node"`
|
||||
OutputFieldIds []int64 `protobuf:"varint,3,rep,packed,name=output_field_ids,json=outputFieldIds,proto3" json:"output_field_ids,omitempty"`
|
||||
DynamicFields []string `protobuf:"bytes,5,rep,name=dynamic_fields,json=dynamicFields,proto3" json:"dynamic_fields,omitempty"`
|
||||
Scorers []*ScoreFunction `protobuf:"bytes,6,rep,name=scorers,proto3" json:"scorers,omitempty"`
|
||||
}
|
||||
|
||||
func (x *PlanNode) Reset() {
|
||||
*x = PlanNode{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_plan_proto_msgTypes[25]
|
||||
mi := &file_plan_proto_msgTypes[26]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@ -2459,7 +2515,7 @@ func (x *PlanNode) String() string {
|
||||
func (*PlanNode) ProtoMessage() {}
|
||||
|
||||
func (x *PlanNode) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_plan_proto_msgTypes[25]
|
||||
mi := &file_plan_proto_msgTypes[26]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@ -2472,7 +2528,7 @@ func (x *PlanNode) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use PlanNode.ProtoReflect.Descriptor instead.
|
||||
func (*PlanNode) Descriptor() ([]byte, []int) {
|
||||
return file_plan_proto_rawDescGZIP(), []int{25}
|
||||
return file_plan_proto_rawDescGZIP(), []int{26}
|
||||
}
|
||||
|
||||
func (m *PlanNode) GetNode() isPlanNode_Node {
|
||||
@ -2517,6 +2573,13 @@ func (x *PlanNode) GetDynamicFields() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *PlanNode) GetScorers() []*ScoreFunction {
|
||||
if x != nil {
|
||||
return x.Scorers
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isPlanNode_Node interface {
|
||||
isPlanNode_Node()
|
||||
}
|
||||
@ -2948,56 +3011,66 @@ var file_plan_proto_rawDesc = []byte{
|
||||
0x0a, 0x08, 0x69, 0x73, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08,
|
||||
0x52, 0x07, 0x69, 0x73, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d,
|
||||
0x69, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x22,
|
||||
0x9a, 0x02, 0x0a, 0x08, 0x50, 0x6c, 0x61, 0x6e, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x40, 0x0a, 0x0b,
|
||||
0x76, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x5f, 0x61, 0x6e, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28,
|
||||
0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x41, 0x4e, 0x4e, 0x53,
|
||||
0x48, 0x00, 0x52, 0x0a, 0x76, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x41, 0x6e, 0x6e, 0x73, 0x12, 0x39,
|
||||
0x0a, 0x0a, 0x70, 0x72, 0x65, 0x64, 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01,
|
||||
0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x45, 0x78, 0x70, 0x72, 0x48, 0x00, 0x52, 0x0a, 0x70,
|
||||
0x72, 0x65, 0x64, 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x12, 0x38, 0x0a, 0x05, 0x71, 0x75, 0x65,
|
||||
0x72, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75,
|
||||
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x51, 0x75, 0x65,
|
||||
0x72, 0x79, 0x50, 0x6c, 0x61, 0x6e, 0x4e, 0x6f, 0x64, 0x65, 0x48, 0x00, 0x52, 0x05, 0x71, 0x75,
|
||||
0x65, 0x72, 0x79, 0x12, 0x28, 0x0a, 0x10, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x66, 0x69,
|
||||
0x65, 0x6c, 0x64, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x03, 0x52, 0x0e, 0x6f,
|
||||
0x75, 0x74, 0x70, 0x75, 0x74, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x49, 0x64, 0x73, 0x12, 0x25, 0x0a,
|
||||
0x0e, 0x64, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18,
|
||||
0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x46, 0x69,
|
||||
0x65, 0x6c, 0x64, 0x73, 0x42, 0x06, 0x0a, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x2a, 0xea, 0x01, 0x0a,
|
||||
0x06, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x6e, 0x76, 0x61, 0x6c,
|
||||
0x69, 0x64, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x47, 0x72, 0x65, 0x61, 0x74, 0x65, 0x72, 0x54,
|
||||
0x68, 0x61, 0x6e, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x47, 0x72, 0x65, 0x61, 0x74, 0x65, 0x72,
|
||||
0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x02, 0x12, 0x0c, 0x0a, 0x08, 0x4c, 0x65, 0x73, 0x73, 0x54,
|
||||
0x68, 0x61, 0x6e, 0x10, 0x03, 0x12, 0x0d, 0x0a, 0x09, 0x4c, 0x65, 0x73, 0x73, 0x45, 0x71, 0x75,
|
||||
0x61, 0x6c, 0x10, 0x04, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x05, 0x12,
|
||||
0x0c, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x06, 0x12, 0x0f, 0x0a,
|
||||
0x0b, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x07, 0x12, 0x10,
|
||||
0x0a, 0x0c, 0x50, 0x6f, 0x73, 0x74, 0x66, 0x69, 0x78, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x08,
|
||||
0x12, 0x09, 0x0a, 0x05, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x09, 0x12, 0x09, 0x0a, 0x05, 0x52,
|
||||
0x61, 0x6e, 0x67, 0x65, 0x10, 0x0a, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x6e, 0x10, 0x0b, 0x12, 0x09,
|
||||
0x0a, 0x05, 0x4e, 0x6f, 0x74, 0x49, 0x6e, 0x10, 0x0c, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x65, 0x78,
|
||||
0x74, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0d, 0x12, 0x0f, 0x0a, 0x0b, 0x50, 0x68, 0x72, 0x61,
|
||||
0x73, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0e, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x6e,
|
||||
0x65, 0x72, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0f, 0x2a, 0x58, 0x0a, 0x0b, 0x41, 0x72, 0x69,
|
||||
0x74, 0x68, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x6e, 0x6b, 0x6e,
|
||||
0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x64, 0x64, 0x10, 0x01, 0x12, 0x07,
|
||||
0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03,
|
||||
0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64,
|
||||
0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74,
|
||||
0x68, 0x10, 0x06, 0x2a, 0x7d, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79, 0x70,
|
||||
0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f,
|
||||
0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74,
|
||||
0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56,
|
||||
0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, 0x42, 0x46, 0x6c, 0x6f, 0x61,
|
||||
0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x03, 0x12, 0x15, 0x0a, 0x11, 0x53,
|
||||
0x70, 0x61, 0x72, 0x73, 0x65, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72,
|
||||
0x10, 0x04, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x74, 0x38, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72,
|
||||
0x10, 0x05, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75,
|
||||
0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70,
|
||||
0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x58, 0x0a, 0x0d, 0x53, 0x63, 0x6f, 0x72, 0x65, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x12, 0x2f, 0x0a, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b,
|
||||
0x32, 0x17, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
|
||||
0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x45, 0x78, 0x70, 0x72, 0x52, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65,
|
||||
0x72, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28,
|
||||
0x02, 0x52, 0x06, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0xd6, 0x02, 0x0a, 0x08, 0x50, 0x6c,
|
||||
0x61, 0x6e, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x40, 0x0a, 0x0b, 0x76, 0x65, 0x63, 0x74, 0x6f, 0x72,
|
||||
0x5f, 0x61, 0x6e, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x69,
|
||||
0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e,
|
||||
0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x41, 0x4e, 0x4e, 0x53, 0x48, 0x00, 0x52, 0x0a, 0x76, 0x65,
|
||||
0x63, 0x74, 0x6f, 0x72, 0x41, 0x6e, 0x6e, 0x73, 0x12, 0x39, 0x0a, 0x0a, 0x70, 0x72, 0x65, 0x64,
|
||||
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d,
|
||||
0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e,
|
||||
0x2e, 0x45, 0x78, 0x70, 0x72, 0x48, 0x00, 0x52, 0x0a, 0x70, 0x72, 0x65, 0x64, 0x69, 0x63, 0x61,
|
||||
0x74, 0x65, 0x73, 0x12, 0x38, 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x04, 0x20, 0x01,
|
||||
0x28, 0x0b, 0x32, 0x20, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x2e, 0x70, 0x6c, 0x61, 0x6e, 0x2e, 0x51, 0x75, 0x65, 0x72, 0x79, 0x50, 0x6c, 0x61, 0x6e,
|
||||
0x4e, 0x6f, 0x64, 0x65, 0x48, 0x00, 0x52, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x28, 0x0a,
|
||||
0x10, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x69, 0x64,
|
||||
0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x03, 0x52, 0x0e, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x46,
|
||||
0x69, 0x65, 0x6c, 0x64, 0x49, 0x64, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x79, 0x6e, 0x61, 0x6d,
|
||||
0x69, 0x63, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52,
|
||||
0x0d, 0x64, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x3a,
|
||||
0x0a, 0x07, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32,
|
||||
0x20, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70,
|
||||
0x6c, 0x61, 0x6e, 0x2e, 0x53, 0x63, 0x6f, 0x72, 0x65, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x52, 0x07, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x72, 0x73, 0x42, 0x06, 0x0a, 0x04, 0x6e, 0x6f,
|
||||
0x64, 0x65, 0x2a, 0xea, 0x01, 0x0a, 0x06, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a,
|
||||
0x07, 0x49, 0x6e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x47, 0x72,
|
||||
0x65, 0x61, 0x74, 0x65, 0x72, 0x54, 0x68, 0x61, 0x6e, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x47,
|
||||
0x72, 0x65, 0x61, 0x74, 0x65, 0x72, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x02, 0x12, 0x0c, 0x0a,
|
||||
0x08, 0x4c, 0x65, 0x73, 0x73, 0x54, 0x68, 0x61, 0x6e, 0x10, 0x03, 0x12, 0x0d, 0x0a, 0x09, 0x4c,
|
||||
0x65, 0x73, 0x73, 0x45, 0x71, 0x75, 0x61, 0x6c, 0x10, 0x04, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x71,
|
||||
0x75, 0x61, 0x6c, 0x10, 0x05, 0x12, 0x0c, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x45, 0x71, 0x75, 0x61,
|
||||
0x6c, 0x10, 0x06, 0x12, 0x0f, 0x0a, 0x0b, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x4d, 0x61, 0x74,
|
||||
0x63, 0x68, 0x10, 0x07, 0x12, 0x10, 0x0a, 0x0c, 0x50, 0x6f, 0x73, 0x74, 0x66, 0x69, 0x78, 0x4d,
|
||||
0x61, 0x74, 0x63, 0x68, 0x10, 0x08, 0x12, 0x09, 0x0a, 0x05, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10,
|
||||
0x09, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x10, 0x0a, 0x12, 0x06, 0x0a, 0x02,
|
||||
0x49, 0x6e, 0x10, 0x0b, 0x12, 0x09, 0x0a, 0x05, 0x4e, 0x6f, 0x74, 0x49, 0x6e, 0x10, 0x0c, 0x12,
|
||||
0x0d, 0x0a, 0x09, 0x54, 0x65, 0x78, 0x74, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0d, 0x12, 0x0f,
|
||||
0x0a, 0x0b, 0x50, 0x68, 0x72, 0x61, 0x73, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0e, 0x12,
|
||||
0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x0f, 0x2a,
|
||||
0x58, 0x0a, 0x0b, 0x41, 0x72, 0x69, 0x74, 0x68, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b,
|
||||
0x0a, 0x07, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41,
|
||||
0x64, 0x64, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a,
|
||||
0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12,
|
||||
0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64, 0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61,
|
||||
0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x10, 0x06, 0x2a, 0x7d, 0x0a, 0x0a, 0x56, 0x65, 0x63,
|
||||
0x74, 0x6f, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72,
|
||||
0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f,
|
||||
0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c,
|
||||
0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x02, 0x12, 0x12, 0x0a,
|
||||
0x0e, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10,
|
||||
0x03, 0x12, 0x15, 0x0a, 0x11, 0x53, 0x70, 0x61, 0x72, 0x73, 0x65, 0x46, 0x6c, 0x6f, 0x61, 0x74,
|
||||
0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x04, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x74, 0x38,
|
||||
0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x05, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68,
|
||||
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f,
|
||||
0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@ -3013,7 +3086,7 @@ func file_plan_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_plan_proto_enumTypes = make([]protoimpl.EnumInfo, 7)
|
||||
var file_plan_proto_msgTypes = make([]protoimpl.MessageInfo, 26)
|
||||
var file_plan_proto_msgTypes = make([]protoimpl.MessageInfo, 27)
|
||||
var file_plan_proto_goTypes = []interface{}{
|
||||
(OpType)(0), // 0: milvus.proto.plan.OpType
|
||||
(ArithOpType)(0), // 1: milvus.proto.plan.ArithOpType
|
||||
@ -3047,16 +3120,17 @@ var file_plan_proto_goTypes = []interface{}{
|
||||
(*Expr)(nil), // 29: milvus.proto.plan.Expr
|
||||
(*VectorANNS)(nil), // 30: milvus.proto.plan.VectorANNS
|
||||
(*QueryPlanNode)(nil), // 31: milvus.proto.plan.QueryPlanNode
|
||||
(*PlanNode)(nil), // 32: milvus.proto.plan.PlanNode
|
||||
(schemapb.DataType)(0), // 33: milvus.proto.schema.DataType
|
||||
(*ScoreFunction)(nil), // 32: milvus.proto.plan.ScoreFunction
|
||||
(*PlanNode)(nil), // 33: milvus.proto.plan.PlanNode
|
||||
(schemapb.DataType)(0), // 34: milvus.proto.schema.DataType
|
||||
}
|
||||
var file_plan_proto_depIdxs = []int32{
|
||||
8, // 0: milvus.proto.plan.GenericValue.array_val:type_name -> milvus.proto.plan.Array
|
||||
7, // 1: milvus.proto.plan.Array.array:type_name -> milvus.proto.plan.GenericValue
|
||||
33, // 2: milvus.proto.plan.Array.element_type:type_name -> milvus.proto.schema.DataType
|
||||
34, // 2: milvus.proto.plan.Array.element_type:type_name -> milvus.proto.schema.DataType
|
||||
9, // 3: milvus.proto.plan.QueryInfo.search_iterator_v2_info:type_name -> milvus.proto.plan.SearchIteratorV2Info
|
||||
33, // 4: milvus.proto.plan.ColumnInfo.data_type:type_name -> milvus.proto.schema.DataType
|
||||
33, // 5: milvus.proto.plan.ColumnInfo.element_type:type_name -> milvus.proto.schema.DataType
|
||||
34, // 4: milvus.proto.plan.ColumnInfo.data_type:type_name -> milvus.proto.schema.DataType
|
||||
34, // 5: milvus.proto.plan.ColumnInfo.element_type:type_name -> milvus.proto.schema.DataType
|
||||
11, // 6: milvus.proto.plan.ColumnExpr.info:type_name -> milvus.proto.plan.ColumnInfo
|
||||
11, // 7: milvus.proto.plan.ExistsExpr.info:type_name -> milvus.proto.plan.ColumnInfo
|
||||
7, // 8: milvus.proto.plan.ValueExpr.value:type_name -> milvus.proto.plan.GenericValue
|
||||
@ -3115,14 +3189,16 @@ var file_plan_proto_depIdxs = []int32{
|
||||
29, // 61: milvus.proto.plan.VectorANNS.predicates:type_name -> milvus.proto.plan.Expr
|
||||
10, // 62: milvus.proto.plan.VectorANNS.query_info:type_name -> milvus.proto.plan.QueryInfo
|
||||
29, // 63: milvus.proto.plan.QueryPlanNode.predicates:type_name -> milvus.proto.plan.Expr
|
||||
30, // 64: milvus.proto.plan.PlanNode.vector_anns:type_name -> milvus.proto.plan.VectorANNS
|
||||
29, // 65: milvus.proto.plan.PlanNode.predicates:type_name -> milvus.proto.plan.Expr
|
||||
31, // 66: milvus.proto.plan.PlanNode.query:type_name -> milvus.proto.plan.QueryPlanNode
|
||||
67, // [67:67] is the sub-list for method output_type
|
||||
67, // [67:67] is the sub-list for method input_type
|
||||
67, // [67:67] is the sub-list for extension type_name
|
||||
67, // [67:67] is the sub-list for extension extendee
|
||||
0, // [0:67] is the sub-list for field type_name
|
||||
29, // 64: milvus.proto.plan.ScoreFunction.filter:type_name -> milvus.proto.plan.Expr
|
||||
30, // 65: milvus.proto.plan.PlanNode.vector_anns:type_name -> milvus.proto.plan.VectorANNS
|
||||
29, // 66: milvus.proto.plan.PlanNode.predicates:type_name -> milvus.proto.plan.Expr
|
||||
31, // 67: milvus.proto.plan.PlanNode.query:type_name -> milvus.proto.plan.QueryPlanNode
|
||||
32, // 68: milvus.proto.plan.PlanNode.scorers:type_name -> milvus.proto.plan.ScoreFunction
|
||||
69, // [69:69] is the sub-list for method output_type
|
||||
69, // [69:69] is the sub-list for method input_type
|
||||
69, // [69:69] is the sub-list for extension type_name
|
||||
69, // [69:69] is the sub-list for extension extendee
|
||||
0, // [0:69] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_plan_proto_init() }
|
||||
@ -3432,6 +3508,18 @@ func file_plan_proto_init() {
|
||||
}
|
||||
}
|
||||
file_plan_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*ScoreFunction); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_plan_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*PlanNode); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -3471,7 +3559,7 @@ func file_plan_proto_init() {
|
||||
(*Expr_NullExpr)(nil),
|
||||
(*Expr_RandomSampleExpr)(nil),
|
||||
}
|
||||
file_plan_proto_msgTypes[25].OneofWrappers = []interface{}{
|
||||
file_plan_proto_msgTypes[26].OneofWrappers = []interface{}{
|
||||
(*PlanNode_VectorAnns)(nil),
|
||||
(*PlanNode_Predicates)(nil),
|
||||
(*PlanNode_Query)(nil),
|
||||
@ -3482,7 +3570,7 @@ func file_plan_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_plan_proto_rawDesc,
|
||||
NumEnums: 7,
|
||||
NumMessages: 26,
|
||||
NumMessages: 27,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
|
||||
@ -190,6 +190,18 @@ func GetAttrByKeyFromRepeatedKV(key string, kvs []*commonpb.KeyValuePair) (strin
|
||||
return "", fmt.Errorf("key %s not found", key)
|
||||
}
|
||||
|
||||
// TryGetAttrByKeyFromRepeatedKV return the value corresponding to key in kv pair
|
||||
// return false if key not exist
|
||||
func TryGetAttrByKeyFromRepeatedKV(key string, kvs []*commonpb.KeyValuePair) (string, bool) {
|
||||
for _, kv := range kvs {
|
||||
if kv.Key == key {
|
||||
return kv.Value, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// CheckCtxValid check if the context is valid
|
||||
func CheckCtxValid(ctx context.Context) bool {
|
||||
return ctx.Err() != context.DeadlineExceeded && ctx.Err() != context.Canceled
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user