enhance: support random score for boost function score (#44214)

And support set function mode and boost mode when run search with boost.

RandomScore support get random function score between [0, weight).
FunctionMode decide how to calculate boost score for multiple boost
function scores.
BoostMode decide how to calculate final score for origin score and boost
score.
relate: https://github.com/milvus-io/milvus/issues/43867

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2025-09-24 17:50:04 +08:00 committed by GitHub
parent 13c3b0b909
commit 1b20e956be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2251 additions and 687 deletions

View File

@ -50,6 +50,7 @@ add_subdirectory( clustering )
add_subdirectory( exec )
add_subdirectory( bitset )
add_subdirectory( futures )
add_subdirectory( rescores )
milvus_add_pkg_config("milvus_core")
@ -67,6 +68,7 @@ add_library(milvus_core SHARED
$<TARGET_OBJECTS:milvus_exec>
$<TARGET_OBJECTS:milvus_bitset>
$<TARGET_OBJECTS:milvus_futures>
$<TARGET_OBJECTS:milvus_rescores>
)
set(LINK_TARGETS

View File

@ -19,6 +19,7 @@
#include <string>
#include <map>
#include <google/protobuf/text_format.h>
#include <google/protobuf/repeated_field.h>
#include "pb/schema.pb.h"
#include "common/EasyAssert.h"
@ -26,6 +27,11 @@
using std::string;
namespace milvus {
template <typename T>
using ProtoRepeated = google::protobuf::RepeatedPtrField<T>;
using ProtoParams = ProtoRepeated<proto::common::KeyValuePair>;
static std::map<string, string>
RepeatedKeyValToMap(
const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>&

View File

@ -17,7 +17,9 @@
#include "RescoresNode.h"
#include <cstddef>
#include "exec/operator/Utils.h"
#include "log/Log.h"
#include "monitor/Monitor.h"
#include "pb/plan.pb.h"
namespace milvus::exec {
@ -31,6 +33,7 @@ PhyRescoresNode::PhyRescoresNode(
scorer->id(),
"PhyRescoresNode") {
scorers_ = scorer->scorers();
option_ = scorer->option();
};
void
@ -62,13 +65,15 @@ PhyRescoresNode::GetOutput() {
auto query_context_ = exec_context->get_query_context();
auto query_info = exec_context->get_query_config();
milvus::SearchResult search_result = query_context_->get_search_result();
auto segment = query_context_->get_segment();
auto op_ctx = query_context_->get_op_context();
// prepare segment offset
FixedVector<int32_t> offsets;
std::vector<size_t> offset_idx;
for (size_t i = 0; i < search_result.seg_offsets_.size(); i++) {
// remain offset will be -1 if result count not enough (less than topk)
// remain offset will be placeholder(-1) if result count not enough (less than topk)
// skip placeholder offset
if (search_result.seg_offsets_[i] >= 0) {
offsets.push_back(
@ -83,14 +88,15 @@ PhyRescoresNode::GetOutput() {
return input_;
}
std::vector<std::optional<float>> boost_scores(offsets.size());
auto function_mode = option_->function_mode();
for (auto& scorer : scorers_) {
auto filter = scorer->filter();
// rescore for all result if no filter
// boost for all result if no filter
if (!filter) {
for (auto i = 0; i < offsets.size(); i++) {
search_result.distances_[offset_idx[i]] =
scorer->rescore(search_result.distances_[offset_idx[i]]);
}
scorer->batch_score(
op_ctx, segment, function_mode, offsets, boost_scores);
continue;
}
@ -116,13 +122,12 @@ PhyRescoresNode::GetOutput() {
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
auto col_vec_size = col_vec->size();
TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size);
Assert(bitsetview.size() == offsets.size());
for (auto i = 0; i < offsets.size(); i++) {
if (bitsetview[i] > 0) {
search_result.distances_[offset_idx[i]] = scorer->rescore(
search_result.distances_[offset_idx[i]]);
}
}
scorer->batch_score(op_ctx,
segment,
function_mode,
offsets,
bitsetview,
boost_scores);
} else {
// query all segment if expr not native
expr_set->Eval(0, 1, true, eval_ctx, results);
@ -133,13 +138,34 @@ PhyRescoresNode::GetOutput() {
auto col_vec_size = col_vec->size();
TargetBitmapView view(col_vec->GetRawData(), col_vec_size);
bitset.append(view);
scorer->batch_score(
op_ctx, segment, function_mode, offsets, bitset, boost_scores);
}
}
// calculate final score
auto boost_mode = option_->boost_mode();
switch (boost_mode) {
case proto::plan::BoostModeMultiply:
for (auto i = 0; i < offsets.size(); i++) {
if (bitset[offsets[i]] > 0) {
search_result.distances_[offset_idx[i]] = scorer->rescore(
search_result.distances_[offset_idx[i]]);
if (boost_scores[i].has_value()) {
search_result.distances_[offset_idx[i]] *=
boost_scores[i].value();
}
}
}
break;
case proto::plan::BoostModeSum:
for (auto i = 0; i < offsets.size(); i++) {
if (boost_scores[i].has_value()) {
search_result.distances_[offset_idx[i]] +=
boost_scores[i].value();
}
}
break;
default:
ThrowInfo(ErrorCode::UnexpectedError,
fmt::format("unknown boost boost mode: {}", boost_mode));
}
knowhere::MetricType metric_type = query_context_->get_metric_type();

View File

@ -72,6 +72,7 @@ class PhyRescoresNode : public Operator {
private:
std::vector<std::shared_ptr<rescores::Scorer>> scorers_;
const proto::plan::ScoreOption* option_;
bool is_finished_{false};
};
} // namespace milvus::exec

View File

@ -32,7 +32,7 @@ TEST(Rescorer, Normal) {
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
auto str_fid = schema->AddDebugField("string1", DataType::VARCHAR);
auto str_fid = schema->AddDebugField("string", DataType::VARCHAR);
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
schema->set_primary_field_id(str_fid);
size_t N = 50;
@ -142,4 +142,155 @@ TEST(Rescorer, Normal) {
auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
}
// random function with seed
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Int8
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
int64_val: -1
>
upper_value: <
int64_val: 100
>
>
>
query_info: <
topk: 10
metric_type: "L2"
search_params: "{\"ef\": 50}"
>
placeholder_tag: "$0"
>
scorers: <
type: 1
weight: 1
params: <key: "seed", value: "123">
>)";
proto::plan::PlanNode plan_node;
auto ok =
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
}
// random function with field as random seed
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Int8
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
int64_val: -1
>
upper_value: <
int64_val: 100
>
>
>
query_info: <
topk: 10
metric_type: "L2"
search_params: "{\"ef\": 50}"
>
placeholder_tag: "$0"
>
scorers: <
type: 1
weight: 1
params: <key: "field", value: "int64">
>)";
proto::plan::PlanNode plan_node;
auto ok =
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
}
// random function with field and seed
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_range_expr: <
column_info: <
field_id: 101
data_type: Int8
>
lower_inclusive: true,
upper_inclusive: false,
lower_value: <
int64_val: -1
>
upper_value: <
int64_val: 100
>
>
>
query_info: <
topk: 10
metric_type: "L2"
search_params: "{\"ef\": 50}"
>
placeholder_tag: "$0"
>
scorers: <
type: 1
weight: 1
params: <key: "seed", value: "123">
params: <key: "field", value: "int64">
>)";
proto::plan::PlanNode plan_node;
auto ok =
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
auto search_result_same_seed =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
// should return same score when use same seed
for (auto i = 0; i < 10; i++) {
AssertInfo(search_result->distances_[i] ==
search_result_same_seed->distances_[i],
"distance not equal %f:%f",
search_result->distances_[i],
search_result_same_seed->distances_[i]);
}
}
}

View File

@ -24,6 +24,7 @@
#include "common/Vector.h"
#include "expr/ITypeExpr.h"
#include "common/EasyAssert.h"
#include "pb/plan.pb.h"
#include "segcore/SegmentInterface.h"
#include "plan/PlanNodeIdGenerator.h"
#include "rescores/Scorer.h"
@ -442,10 +443,12 @@ class RescoresNode : public PlanNode {
public:
RescoresNode(
const PlanNodeId& id,
const std::vector<std::shared_ptr<rescores::Scorer>> scorers,
const std::vector<std::shared_ptr<rescores::Scorer>>& scorers,
const proto::plan::ScoreOption& option,
const std::vector<PlanNodePtr>& sources = std::vector<PlanNodePtr>{})
: PlanNode(id),
scorers_(std::move(scorers)),
option_(std::move(option)),
sources_{std::move(sources)} {
}
@ -459,6 +462,11 @@ class RescoresNode : public PlanNode {
return sources_;
}
const proto::plan::ScoreOption*
option() const {
return &option_;
}
const std::vector<std::shared_ptr<rescores::Scorer>>&
scorers() const {
return scorers_;
@ -476,6 +484,7 @@ class RescoresNode : public PlanNode {
}
private:
const proto::plan::ScoreOption option_;
const std::vector<PlanNodePtr> sources_;
const std::vector<std::shared_ptr<rescores::Scorer>> scorers_;
};

View File

@ -21,6 +21,7 @@
#include "common/EasyAssert.h"
#include "exec/expression/function/FunctionFactory.h"
#include "log/Log.h"
#include "expr/ITypeExpr.h"
#include "pb/plan.pb.h"
#include "query/Utils.h"
#include "knowhere/comp/materialized_view.h"
@ -215,7 +216,10 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
}
plannode = std::make_shared<milvus::plan::RescoresNode>(
milvus::plan::GetNextPlanNodeId(), std::move(scorers), sources);
milvus::plan::GetNextPlanNodeId(),
std::move(scorers),
plan_node_proto.score_option(),
sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
@ -629,12 +633,21 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
std::shared_ptr<rescores::Scorer>
ProtoParser::ParseScorer(const proto::plan::ScoreFunction& function) {
expr::TypedExprPtr expr = nullptr;
if (function.has_filter()) {
auto expr = ParseExprs(function.filter());
return std::make_shared<rescores::WeightScorer>(expr,
function.weight());
expr = ParseExprs(function.filter());
}
switch (function.type()) {
case proto::plan::FunctionTypeWeight:
return std::make_shared<rescores::WeightScorer>(expr,
function.weight());
case proto::plan::FunctionTypeRandom:
return std::make_shared<rescores::RandomScorer>(
expr, function.weight(), function.params());
default:
ThrowInfo(UnexpectedError, "unknown function type");
}
return std::make_shared<rescores::WeightScorer>(nullptr, function.weight());
}
} // namespace milvus::query

View File

@ -0,0 +1,13 @@
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License
add_source_at_current_directory_recursively()
add_library(milvus_rescores OBJECT ${SOURCE_FILES})

View File

@ -0,0 +1,466 @@
//-----------------------------------------------------------------------------
// MurmurHash3 was written by Austin Appleby, and is placed in the public
// domain. The author hereby disclaims copyright to this source code.
// Note - The x86 and x64 versions do _not_ produce the same results, as the
// algorithms are optimized for their respective platforms. You can still
// compile and run any of them on any platform, but your performance with the
// non-native version will be less than optimal.
#include "Murmur3.h"
//-----------------------------------------------------------------------------
// Platform-specific functions and macros
#ifdef __GNUC__
#define FORCE_INLINE __attribute__((always_inline)) inline
#else
#define FORCE_INLINE inline
#endif
static FORCE_INLINE uint32_t
rotl32(uint32_t x, int8_t r) {
return (x << r) | (x >> (32 - r));
}
static FORCE_INLINE uint64_t
rotl64(uint64_t x, int8_t r) {
return (x << r) | (x >> (64 - r));
}
#define ROTL32(x, y) rotl32(x, y)
#define ROTL64(x, y) rotl64(x, y)
#define BIG_CONSTANT(x) (x##LLU)
//-----------------------------------------------------------------------------
// Block read - if your platform needs to do endian-swapping or can only
// handle aligned reads, do the conversion here
#define getblock(p, i) (p[i])
//-----------------------------------------------------------------------------
// Finalization mix - force all bits of a hash block to avalanche
static FORCE_INLINE uint32_t
fmix32(uint32_t h) {
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
return h;
}
//----------
static FORCE_INLINE uint64_t
fmix64(uint64_t k) {
k ^= k >> 33;
k *= BIG_CONSTANT(0xff51afd7ed558ccd);
k ^= k >> 33;
k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53);
k ^= k >> 33;
return k;
}
//-----------------------------------------------------------------------------
void
MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out) {
const uint8_t* data = (const uint8_t*)key;
const int nblocks = len / 4;
int i;
uint32_t h1 = seed;
uint32_t c1 = 0xcc9e2d51;
uint32_t c2 = 0x1b873593;
//----------
// body
const uint32_t* blocks = (const uint32_t*)(data + nblocks * 4);
for (i = -nblocks; i; i++) {
uint32_t k1 = getblock(blocks, i);
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
h1 = ROTL32(h1, 13);
h1 = h1 * 5 + 0xe6546b64;
}
//----------
// tail
const uint8_t* tail = (const uint8_t*)(data + nblocks * 4);
uint32_t k1 = 0;
switch (len & 3) {
case 3:
k1 ^= tail[2] << 16;
case 2:
k1 ^= tail[1] << 8;
case 1:
k1 ^= tail[0];
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
};
//----------
// finalization
h1 ^= len;
h1 = fmix32(h1);
*(uint32_t*)out = h1;
}
//-----------------------------------------------------------------------------
void
MurmurHash3_x86_128(const void* key, const int len, uint32_t seed, void* out) {
const uint8_t* data = (const uint8_t*)key;
const int nblocks = len / 16;
int i;
uint32_t h1 = seed;
uint32_t h2 = seed;
uint32_t h3 = seed;
uint32_t h4 = seed;
uint32_t c1 = 0x239b961b;
uint32_t c2 = 0xab0e9789;
uint32_t c3 = 0x38b34ae5;
uint32_t c4 = 0xa1e38b93;
//----------
// body
const uint32_t* blocks = (const uint32_t*)(data + nblocks * 16);
for (i = -nblocks; i; i++) {
uint32_t k1 = getblock(blocks, i * 4 + 0);
uint32_t k2 = getblock(blocks, i * 4 + 1);
uint32_t k3 = getblock(blocks, i * 4 + 2);
uint32_t k4 = getblock(blocks, i * 4 + 3);
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
h1 = ROTL32(h1, 19);
h1 += h2;
h1 = h1 * 5 + 0x561ccd1b;
k2 *= c2;
k2 = ROTL32(k2, 16);
k2 *= c3;
h2 ^= k2;
h2 = ROTL32(h2, 17);
h2 += h3;
h2 = h2 * 5 + 0x0bcaa747;
k3 *= c3;
k3 = ROTL32(k3, 17);
k3 *= c4;
h3 ^= k3;
h3 = ROTL32(h3, 15);
h3 += h4;
h3 = h3 * 5 + 0x96cd1c35;
k4 *= c4;
k4 = ROTL32(k4, 18);
k4 *= c1;
h4 ^= k4;
h4 = ROTL32(h4, 13);
h4 += h1;
h4 = h4 * 5 + 0x32ac3b17;
}
//----------
// tail
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
uint32_t k1 = 0;
uint32_t k2 = 0;
uint32_t k3 = 0;
uint32_t k4 = 0;
switch (len & 15) {
case 15:
k4 ^= tail[14] << 16;
case 14:
k4 ^= tail[13] << 8;
case 13:
k4 ^= tail[12] << 0;
k4 *= c4;
k4 = ROTL32(k4, 18);
k4 *= c1;
h4 ^= k4;
case 12:
k3 ^= tail[11] << 24;
case 11:
k3 ^= tail[10] << 16;
case 10:
k3 ^= tail[9] << 8;
case 9:
k3 ^= tail[8] << 0;
k3 *= c3;
k3 = ROTL32(k3, 17);
k3 *= c4;
h3 ^= k3;
case 8:
k2 ^= tail[7] << 24;
case 7:
k2 ^= tail[6] << 16;
case 6:
k2 ^= tail[5] << 8;
case 5:
k2 ^= tail[4] << 0;
k2 *= c2;
k2 = ROTL32(k2, 16);
k2 *= c3;
h2 ^= k2;
case 4:
k1 ^= tail[3] << 24;
case 3:
k1 ^= tail[2] << 16;
case 2:
k1 ^= tail[1] << 8;
case 1:
k1 ^= tail[0] << 0;
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
};
//----------
// finalization
h1 ^= len;
h2 ^= len;
h3 ^= len;
h4 ^= len;
h1 += h2;
h1 += h3;
h1 += h4;
h2 += h1;
h3 += h1;
h4 += h1;
h1 = fmix32(h1);
h2 = fmix32(h2);
h3 = fmix32(h3);
h4 = fmix32(h4);
h1 += h2;
h1 += h3;
h1 += h4;
h2 += h1;
h3 += h1;
h4 += h1;
((uint32_t*)out)[0] = h1;
((uint32_t*)out)[1] = h2;
((uint32_t*)out)[2] = h3;
((uint32_t*)out)[3] = h4;
}
//-----------------------------------------------------------------------------
void
MurmurHash3_x64_128(const void* key,
const int len,
const uint32_t seed,
void* out) {
const uint8_t* data = (const uint8_t*)key;
const int nblocks = len / 16;
int i;
uint64_t h1 = seed;
uint64_t h2 = seed;
uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5);
uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f);
//----------
// body
const uint64_t* blocks = (const uint64_t*)(data);
for (i = 0; i < nblocks; i++) {
uint64_t k1 = getblock(blocks, i * 2 + 0);
uint64_t k2 = getblock(blocks, i * 2 + 1);
k1 *= c1;
k1 = ROTL64(k1, 31);
k1 *= c2;
h1 ^= k1;
h1 = ROTL64(h1, 27);
h1 += h2;
h1 = h1 * 5 + 0x52dce729;
k2 *= c2;
k2 = ROTL64(k2, 33);
k2 *= c1;
h2 ^= k2;
h2 = ROTL64(h2, 31);
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
}
//----------
// tail
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
uint64_t k1 = 0;
uint64_t k2 = 0;
switch (len & 15) {
case 15:
k2 ^= (uint64_t)(tail[14]) << 48;
case 14:
k2 ^= (uint64_t)(tail[13]) << 40;
case 13:
k2 ^= (uint64_t)(tail[12]) << 32;
case 12:
k2 ^= (uint64_t)(tail[11]) << 24;
case 11:
k2 ^= (uint64_t)(tail[10]) << 16;
case 10:
k2 ^= (uint64_t)(tail[9]) << 8;
case 9:
k2 ^= (uint64_t)(tail[8]) << 0;
k2 *= c2;
k2 = ROTL64(k2, 33);
k2 *= c1;
h2 ^= k2;
case 8:
k1 ^= (uint64_t)(tail[7]) << 56;
case 7:
k1 ^= (uint64_t)(tail[6]) << 48;
case 6:
k1 ^= (uint64_t)(tail[5]) << 40;
case 5:
k1 ^= (uint64_t)(tail[4]) << 32;
case 4:
k1 ^= (uint64_t)(tail[3]) << 24;
case 3:
k1 ^= (uint64_t)(tail[2]) << 16;
case 2:
k1 ^= (uint64_t)(tail[1]) << 8;
case 1:
k1 ^= (uint64_t)(tail[0]) << 0;
k1 *= c1;
k1 = ROTL64(k1, 31);
k1 *= c2;
h1 ^= k1;
};
//----------
// finalization
h1 ^= len;
h2 ^= len;
h1 += h2;
h2 += h1;
h1 = fmix64(h1);
h2 = fmix64(h2);
h1 += h2;
h2 += h1;
((uint64_t*)out)[0] = h1;
((uint64_t*)out)[1] = h2;
}
//-----------------------------------------------------------------------------
// 128 bit key = (64bit key + 64bit seed)
// block num =1
// len = 16
uint64_t
MurmurHash3_x64_64_Special(const uint64_t key, const uint64_t seed) {
const uint8_t* data = (const uint8_t*)key;
int i;
uint64_t h1 = key ^ seed;
uint64_t h2 = seed;
uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5);
uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f);
//----------
// body
// 1 block
{
uint64_t k1 = key;
uint64_t k2 = seed;
k1 *= c1;
k1 = ROTL64(k1, 31);
k1 *= c2;
h1 ^= k1;
h1 = ROTL64(h1, 27);
h1 += h2;
h1 = h1 * 5 + 0x52dce729;
k2 *= c2;
k2 = ROTL64(k2, 33);
k2 *= c1;
h2 ^= k2;
h2 = ROTL64(h2, 31);
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
}
//----------
// No tail
//----------
// finalization
h1 ^= 16;
h2 ^= 16;
h1 += h2;
h2 += h1;
h1 = fmix64(h1);
h2 = fmix64(h2);
h1 += h2;
h2 += h1;
return h1;
}

View File

@ -0,0 +1,46 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
//-----------------------------------------------------------------------------
// MurmurHash3 was written by Austin Appleby, and is placed in the
// public domain. The author hereby disclaims copyright to this source
// code.
#ifndef _MURMURHASH3_H_
#define _MURMURHASH3_H_
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
//-----------------------------------------------------------------------------
void
MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out);
void
MurmurHash3_x86_128(const void* key, int len, uint32_t seed, void* out);
void
MurmurHash3_x64_128(const void* key, int len, uint32_t seed, void* out);
uint64_t
MurmurHash3_x64_64_Special(const uint64_t key, const uint64_t seed);
//-----------------------------------------------------------------------------
#ifdef __cplusplus
}
#endif
#endif // _MURMURHASH3_H_

View File

@ -0,0 +1,202 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstddef>
#include <optional>
#include <random>
#include "common/Types.h"
#include "expr/ITypeExpr.h"
#include "Scorer.h"
#include "Utils.h"
#include "log/Log.h"
#include "rescores/Murmur3.h"
namespace milvus::rescores {
void
WeightScorer::batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmapView& bitmap,
std::vector<std::optional<float>>& boost_scores) {
Assert(bitmap.size() == offsets.size());
for (auto i = 0; i < offsets.size(); i++) {
if (bitmap[i] > 0) {
set_score(boost_scores[i], mode);
}
}
}
void
WeightScorer::batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmap& bitmap,
std::vector<std::optional<float>>& boost_scores) {
for (auto i = 0; i < offsets.size(); i++) {
if (bitmap[offsets[i]] > 0) {
set_score(boost_scores[i], mode);
}
}
};
void
WeightScorer::batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
std::vector<std::optional<float>>& boost_scores) {
for (auto i = 0; i < offsets.size(); i++) {
set_score(boost_scores[i], mode);
}
};
void
WeightScorer::set_score(std::optional<float>& score,
const proto::plan::FunctionMode& mode) {
if (!score.has_value()) {
score = std::make_optional(weight_);
} else {
score = std::make_optional(
function_score_merge(score.value(), weight_, mode));
}
}
void
RandomScorer::batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmapView& bitmap,
std::vector<std::optional<float>>& boost_scores) {
Assert(bitmap.size() == offsets.size());
FixedVector<int64_t> target_offsets;
FixedVector<int> idx;
target_offsets.reserve(offsets.size());
idx.reserve(offsets.size());
for (auto i = 0; i < offsets.size(); i++) {
if (bitmap[i] > 0) {
target_offsets.push_back(static_cast<int64_t>(offsets[i]));
idx.push_back(i);
}
}
// skip if empty
if (target_offsets.empty()) {
return;
}
random_score(op_ctx, segment, mode, target_offsets, &idx, boost_scores);
}
void
RandomScorer::batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmap& bitmap,
std::vector<std::optional<float>>& boost_scores) {
FixedVector<int64_t> target_offsets;
FixedVector<int> idx;
target_offsets.reserve(offsets.size());
idx.reserve(offsets.size());
for (auto i = 0; i < offsets.size(); i++) {
if (bitmap[offsets[i]] > 0) {
target_offsets.push_back(static_cast<int64_t>(offsets[i]));
idx.push_back(i);
}
}
// skip if empty
if (target_offsets.empty()) {
return;
}
random_score(op_ctx, segment, mode, target_offsets, &idx, boost_scores);
}
void
RandomScorer::batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
std::vector<std::optional<float>>& boost_scores) {
FixedVector<int64_t> target_offsets;
target_offsets.reserve(offsets.size());
for (int offset : offsets) {
target_offsets.push_back(static_cast<int64_t>(offset));
}
random_score(op_ctx, segment, mode, target_offsets, nullptr, boost_scores);
}
void
RandomScorer::random_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int64_t>& target_offsets,
const FixedVector<int>* idx,
std::vector<std::optional<float>>& boost_scores) {
if (field_.get() != -1) {
auto array = segment->bulk_subscript(
op_ctx, field_, target_offsets.data(), target_offsets.size());
AssertInfo(array->has_scalars(), "seed field must be scalar");
AssertInfo(array->scalars().has_long_data(),
"now only support int64 field as seed");
// TODO: Support varchar and int32 field as random field.
auto datas = array->scalars().long_data();
for (int i = 0; i < datas.data_size(); i++) {
auto a = datas.data()[i];
auto random_score =
hash_to_double(MurmurHash3_x64_64_Special(a, seed_));
if (idx == nullptr) {
set_score(random_score, boost_scores[i], mode);
} else {
set_score(random_score, boost_scores[idx->at(i)], mode);
}
}
} else {
// if not set field, use offset and seed to hash.
for (int i = 0; i < target_offsets.size(); i++) {
double random_score = hash_to_double(MurmurHash3_x64_64_Special(
target_offsets[i] + segment->get_segment_id(), seed_));
if (idx == nullptr) {
set_score(random_score, boost_scores[i], mode);
} else {
set_score(random_score, boost_scores[idx->at(i)], mode);
}
}
}
}
void
RandomScorer::set_score(float random_value,
std::optional<float>& score,
const proto::plan::FunctionMode& mode) {
if (!score.has_value()) {
score = std::make_optional(random_value * weight_);
} else {
score = std::make_optional(function_score_merge(
score.value(), (random_value * weight_), mode));
}
}
} // namespace milvus::rescores

View File

@ -16,16 +16,52 @@
#pragma once
#include <exception>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "expr/ITypeExpr.h"
#include "pb/common.pb.h"
#include "pb/plan.pb.h"
#include "segcore/SegmentInterface.h"
#include "common/protobuf_utils.h"
namespace milvus::rescores {
class Scorer {
public:
virtual ~Scorer() = default;
virtual expr::TypedExprPtr
filter() = 0;
virtual float
rescore(float old_score) = 0;
// filter result of offset[i] was bitmapview[i]
// add boost score for idx[i] if bitmap[i] was true
virtual void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmapView& bitmap,
std::vector<std::optional<float>>& boost_scores) = 0;
// score by bitmap
// filter result of offset[i] was bitmap[offset[i]]
// add boost score for idx[i] if bitmap[i] was true
virtual void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmap& bitmap,
std::vector<std::optional<float>>& boost_scores) = 0;
// score for all offset
// used when no filter
virtual void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
std::vector<std::optional<float>>& boost_scores) = 0;
virtual float
weight() = 0;
@ -41,10 +77,28 @@ class WeightScorer : public Scorer {
return filter_;
}
float
rescore(float old_score) override {
return old_score * weight_;
}
void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmapView& bitmap,
std::vector<std::optional<float>>& boost_scores) override;
void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmap& bitmap,
std::vector<std::optional<float>>& boost_scores) override;
void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
std::vector<std::optional<float>>& boost_scores) override;
float
weight() override {
@ -52,7 +106,96 @@ class WeightScorer : public Scorer {
}
private:
void
set_score(std::optional<float>& old_score,
const proto::plan::FunctionMode& mode);
expr::TypedExprPtr filter_;
float weight_;
};
class RandomScorer : public Scorer {
public:
RandomScorer(expr::TypedExprPtr& filter,
float weight,
const ProtoParams& params) {
auto param_map = RepeatedKeyValToMap(params);
if (auto it = param_map.find("seed"); it != param_map.end()) {
try {
seed_ = std::stoll(it->second);
} catch (const std::exception& e) {
ThrowInfo(ErrorCode::InvalidParameter,
"parse boost random seed params failed: {}",
e.what());
}
}
if (auto it = param_map.find("field_id"); it != param_map.end()) {
try {
field_ = FieldId(std::stoll(it->second));
} catch (const std::exception& e) {
ThrowInfo(ErrorCode::InvalidParameter,
"parse boost random seed field ID failed: {}",
e.what());
}
} else {
field_ = FieldId(-1);
}
weight_ = weight;
filter_ = filter;
}
expr::TypedExprPtr
filter() override {
return filter_;
}
void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmapView& bitmap,
std::vector<std::optional<float>>& boost_scores) override;
void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
const TargetBitmap& bitmap,
std::vector<std::optional<float>>& boost_scores) override;
void
batch_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int32_t>& offsets,
std::vector<std::optional<float>>& boost_scores) override;
float
weight() override {
return weight_;
}
private:
void
set_score(float random_value,
std::optional<float>& old_score,
const proto::plan::FunctionMode& mode);
void
random_score(milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
const proto::plan::FunctionMode& mode,
const FixedVector<int64_t>& target_offsets,
const FixedVector<int>* idx,
std::vector<std::optional<float>>& boost_scores);
expr::TypedExprPtr filter_;
float weight_;
int64_t seed_;
FieldId field_;
};
} // namespace milvus::rescores

View File

@ -0,0 +1,56 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <iterator>
#include "pb/plan.pb.h"
#include "common/EasyAssert.h"
extern "C" {
#include "Murmur3.h"
}
namespace milvus::rescores {
inline float
function_score_merge(const float& a,
const float& b,
const proto::plan::FunctionMode& mode) {
switch (mode) {
case proto::plan::FunctionModeMultiply:
return a * b;
case proto::plan::FunctionModeSum:
return a + b;
default:
ThrowInfo(ErrorCode::UnexpectedError,
fmt::format("unknown boost function mode: {}:{}",
proto::plan::FunctionMode_Name(mode),
mode));
}
}
#define MAGIC_BITS (0x3FFL << 52)
double
hash_to_double(const uint64_t& h) {
auto double_bytes = (MAGIC_BITS | (h & 0xFFFFFFFFFFFFF));
double result;
memcpy(&result, &double_bytes, sizeof(result));
return result - 1.0;
}
} // namespace milvus::rescores

View File

@ -3,6 +3,7 @@ package planparserv2
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/antlr4-go/antlr/v4"
@ -10,13 +11,16 @@ import (
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated"
"github.com/milvus-io/milvus/internal/util/function/rerank"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -255,7 +259,7 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
return nil, err
}
scorers, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)
scorers, options, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)
if err != nil {
return nil, err
}
@ -276,7 +280,8 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
FieldId: fieldID,
},
},
Scorers: scorers,
Scorers: scorers,
ScoreOption: options,
PlanOptions: &planpb.PlanOption{
ExprUseJsonStats: exprParams.UseJSONStats,
},
@ -284,6 +289,67 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
return planNode, nil
}
func prepareBoostRandomParams(schema *typeutil.SchemaHelper, bytes string) ([]*commonpb.KeyValuePair, error) {
paramsMap := make(map[string]any)
dec := json.NewDecoder(strings.NewReader(bytes))
dec.UseNumber()
err := dec.Decode(&paramsMap)
if err != nil {
return nil, err
}
result := make([]*commonpb.KeyValuePair, 0)
for key, value := range paramsMap {
switch key {
// parse field name to field ID
case RandomScoreFileNameKey:
name, ok := value.(string)
if !ok {
return nil, merr.WrapErrParameterInvalidMsg("random seed field name must be string")
}
field, err := schema.GetFieldFromName(name)
if err != nil {
return nil, merr.WrapErrFieldNotFound(value, "random seed field not found")
}
if field.DataType != schemapb.DataType_Int64 {
return nil, merr.WrapErrParameterInvalidMsg("only support int64 field as random seed, but got %s", field.DataType.String())
}
result = append(result, &commonpb.KeyValuePair{Key: RandomScoreFileIdKey, Value: fmt.Sprint(field.FieldID)})
case RandomScoreSeedKey:
number, ok := value.(json.Number)
if !ok {
return nil, merr.WrapErrParameterInvalidMsg("random seed must be number")
}
result = append(result, &commonpb.KeyValuePair{Key: key, Value: number.String()})
}
}
return result, nil
}
func setBoostType(schema *typeutil.SchemaHelper, scorer *planpb.ScoreFunction, params []*commonpb.KeyValuePair) error {
scorer.Type = planpb.FunctionType_FunctionTypeWeight
for _, param := range params {
switch param.GetKey() {
case BoostRandomScoreKey:
{
scorer.Type = planpb.FunctionType_FunctionTypeRandom
params, err := prepareBoostRandomParams(schema, param.GetValue())
if err != nil {
return err
}
scorer.Params = params
}
default:
}
}
return nil
}
func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.FunctionSchema, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.ScoreFunction, error) {
rerankerName := rerank.GetRerankName(function)
switch rerankerName {
@ -308,6 +374,12 @@ func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.Functi
return nil, fmt.Errorf("parse function scorer weight params failed with error: {%v}", err)
}
scorer.Weight = float32(weight)
err = setBoostType(schema, scorer, function.GetParams())
if err != nil {
return nil, err
}
return scorer, nil
default:
// if not boost scorer, regard as normal function scorer
@ -317,22 +389,67 @@ func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.Functi
}
}
func CreateSearchScorers(schema *typeutil.SchemaHelper, functionScore *schemapb.FunctionScore, exprTemplateValues map[string]*schemapb.TemplateValue) ([]*planpb.ScoreFunction, error) {
func ParseBoostMode(s string) (planpb.BoostMode, error) {
s = strings.ToLower(s)
switch s {
case "multiply":
return planpb.BoostMode_BoostModeMultiply, nil
case "sum":
return planpb.BoostMode_BoostModeSum, nil
default:
return 0, merr.WrapErrParameterInvalidMsg("unknown boost mode: %s", s)
}
}
func ParseFunctionMode(s string) (planpb.FunctionMode, error) {
s = strings.ToLower(s)
switch s {
case "multiply":
return planpb.FunctionMode_FunctionModeMultiply, nil
case "sum":
return planpb.FunctionMode_FunctionModeSum, nil
default:
return 0, merr.WrapErrParameterInvalidMsg("unknown function mode: %s", s)
}
}
func CreateSearchScorers(schema *typeutil.SchemaHelper, functionScore *schemapb.FunctionScore, exprTemplateValues map[string]*schemapb.TemplateValue) ([]*planpb.ScoreFunction, *planpb.ScoreOption, error) {
scorers := []*planpb.ScoreFunction{}
for _, function := range functionScore.GetFunctions() {
// create scorer for search plan
scorer, err := CreateSearchScorer(schema, function, exprTemplateValues)
if err != nil {
return nil, err
return nil, nil, err
}
if scorer != nil {
scorers = append(scorers, scorer)
}
}
if len(scorers) == 0 {
return nil, nil
return nil, nil, nil
}
return scorers, nil
option := &planpb.ScoreOption{}
s, ok := funcutil.TryGetAttrByKeyFromRepeatedKV(BoostModeKey, functionScore.GetParams())
if ok {
boostMode, err := ParseBoostMode(s)
if err != nil {
return nil, nil, err
}
option.BoostMode = boostMode
}
s, ok = funcutil.TryGetAttrByKeyFromRepeatedKV(BoostFunctionModeKey, functionScore.GetParams())
if ok {
functionMode, err := ParseFunctionMode(s)
if err != nil {
return nil, nil, err
}
option.FunctionMode = functionMode
}
return scorers, option, nil
}
func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo, exprTemplateValues map[string]*schemapb.TemplateValue, functionScorer *schemapb.FunctionScore) (*planpb.PlanNode, error) {

View File

@ -16,6 +16,15 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const (
BoostRandomScoreKey = "random_score"
BoostModeKey = "boost_mode"
BoostFunctionModeKey = "function_mode"
RandomScoreSeedKey = "seed"
RandomScoreFileNameKey = "field"
RandomScoreFileIdKey = "field_id"
)
func IsBool(n *planpb.GenericValue) bool {
switch n.GetVal().(type) {
case *planpb.GenericValue_BoolVal:

View File

@ -3,6 +3,7 @@ package milvus.proto.plan;
option go_package = "github.com/milvus-io/milvus/pkg/v2/proto/planpb";
import "schema.proto";
import "common.proto";
enum OpType {
Invalid = 0;
@ -279,9 +280,35 @@ message QueryPlanNode {
int64 limit = 3;
};
enum FunctionType{
FunctionTypeWeight = 0;
FunctionTypeRandom = 1;
}
// FunctionMode decide how to calculate boost score
// for multiple boost function scores
enum FunctionMode{
FunctionModeMultiply = 0;
FunctionModeSum = 1;
};
// BoostMode decide how to calculate final score
// for origin score and boost score.
enum BoostMode{
BoostModeMultiply = 0;
BoostModeSum = 1;
};
message ScoreFunction {
Expr filter =1;
float weight = 2;
Expr filter = 1;
float weight = 2;
FunctionType type = 3;
repeated common.KeyValuePair params = 4;
}
message ScoreOption{
BoostMode boost_mode = 1;
FunctionMode function_mode = 2;
}
message PlanOption {
@ -298,4 +325,5 @@ message PlanNode {
repeated string dynamic_fields = 5;
repeated ScoreFunction scorers = 6;
PlanOption plan_options = 7;
ScoreOption score_option = 8;
}

File diff suppressed because it is too large Load Diff