mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: support random score for boost function score (#44214)
And support set function mode and boost mode when run search with boost. RandomScore support get random function score between [0, weight). FunctionMode decide how to calculate boost score for multiple boost function scores. BoostMode decide how to calculate final score for origin score and boost score. relate: https://github.com/milvus-io/milvus/issues/43867 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
13c3b0b909
commit
1b20e956be
@ -50,6 +50,7 @@ add_subdirectory( clustering )
|
||||
add_subdirectory( exec )
|
||||
add_subdirectory( bitset )
|
||||
add_subdirectory( futures )
|
||||
add_subdirectory( rescores )
|
||||
|
||||
milvus_add_pkg_config("milvus_core")
|
||||
|
||||
@ -67,6 +68,7 @@ add_library(milvus_core SHARED
|
||||
$<TARGET_OBJECTS:milvus_exec>
|
||||
$<TARGET_OBJECTS:milvus_bitset>
|
||||
$<TARGET_OBJECTS:milvus_futures>
|
||||
$<TARGET_OBJECTS:milvus_rescores>
|
||||
)
|
||||
|
||||
set(LINK_TARGETS
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <google/protobuf/repeated_field.h>
|
||||
|
||||
#include "pb/schema.pb.h"
|
||||
#include "common/EasyAssert.h"
|
||||
@ -26,6 +27,11 @@
|
||||
using std::string;
|
||||
|
||||
namespace milvus {
|
||||
|
||||
template <typename T>
|
||||
using ProtoRepeated = google::protobuf::RepeatedPtrField<T>;
|
||||
using ProtoParams = ProtoRepeated<proto::common::KeyValuePair>;
|
||||
|
||||
static std::map<string, string>
|
||||
RepeatedKeyValToMap(
|
||||
const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>&
|
||||
|
||||
@ -17,7 +17,9 @@
|
||||
#include "RescoresNode.h"
|
||||
#include <cstddef>
|
||||
#include "exec/operator/Utils.h"
|
||||
#include "log/Log.h"
|
||||
#include "monitor/Monitor.h"
|
||||
#include "pb/plan.pb.h"
|
||||
|
||||
namespace milvus::exec {
|
||||
|
||||
@ -31,6 +33,7 @@ PhyRescoresNode::PhyRescoresNode(
|
||||
scorer->id(),
|
||||
"PhyRescoresNode") {
|
||||
scorers_ = scorer->scorers();
|
||||
option_ = scorer->option();
|
||||
};
|
||||
|
||||
void
|
||||
@ -62,13 +65,15 @@ PhyRescoresNode::GetOutput() {
|
||||
auto query_context_ = exec_context->get_query_context();
|
||||
auto query_info = exec_context->get_query_config();
|
||||
milvus::SearchResult search_result = query_context_->get_search_result();
|
||||
auto segment = query_context_->get_segment();
|
||||
auto op_ctx = query_context_->get_op_context();
|
||||
|
||||
// prepare segment offset
|
||||
FixedVector<int32_t> offsets;
|
||||
std::vector<size_t> offset_idx;
|
||||
|
||||
for (size_t i = 0; i < search_result.seg_offsets_.size(); i++) {
|
||||
// remain offset will be -1 if result count not enough (less than topk)
|
||||
// remain offset will be placeholder(-1) if result count not enough (less than topk)
|
||||
// skip placeholder offset
|
||||
if (search_result.seg_offsets_[i] >= 0) {
|
||||
offsets.push_back(
|
||||
@ -83,14 +88,15 @@ PhyRescoresNode::GetOutput() {
|
||||
return input_;
|
||||
}
|
||||
|
||||
std::vector<std::optional<float>> boost_scores(offsets.size());
|
||||
auto function_mode = option_->function_mode();
|
||||
|
||||
for (auto& scorer : scorers_) {
|
||||
auto filter = scorer->filter();
|
||||
// rescore for all result if no filter
|
||||
// boost for all result if no filter
|
||||
if (!filter) {
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
search_result.distances_[offset_idx[i]] =
|
||||
scorer->rescore(search_result.distances_[offset_idx[i]]);
|
||||
}
|
||||
scorer->batch_score(
|
||||
op_ctx, segment, function_mode, offsets, boost_scores);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -116,13 +122,12 @@ PhyRescoresNode::GetOutput() {
|
||||
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
|
||||
auto col_vec_size = col_vec->size();
|
||||
TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size);
|
||||
Assert(bitsetview.size() == offsets.size());
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitsetview[i] > 0) {
|
||||
search_result.distances_[offset_idx[i]] = scorer->rescore(
|
||||
search_result.distances_[offset_idx[i]]);
|
||||
}
|
||||
}
|
||||
scorer->batch_score(op_ctx,
|
||||
segment,
|
||||
function_mode,
|
||||
offsets,
|
||||
bitsetview,
|
||||
boost_scores);
|
||||
} else {
|
||||
// query all segment if expr not native
|
||||
expr_set->Eval(0, 1, true, eval_ctx, results);
|
||||
@ -133,13 +138,34 @@ PhyRescoresNode::GetOutput() {
|
||||
auto col_vec_size = col_vec->size();
|
||||
TargetBitmapView view(col_vec->GetRawData(), col_vec_size);
|
||||
bitset.append(view);
|
||||
scorer->batch_score(
|
||||
op_ctx, segment, function_mode, offsets, bitset, boost_scores);
|
||||
}
|
||||
}
|
||||
|
||||
// calculate final score
|
||||
auto boost_mode = option_->boost_mode();
|
||||
switch (boost_mode) {
|
||||
case proto::plan::BoostModeMultiply:
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitset[offsets[i]] > 0) {
|
||||
search_result.distances_[offset_idx[i]] = scorer->rescore(
|
||||
search_result.distances_[offset_idx[i]]);
|
||||
if (boost_scores[i].has_value()) {
|
||||
search_result.distances_[offset_idx[i]] *=
|
||||
boost_scores[i].value();
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case proto::plan::BoostModeSum:
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (boost_scores[i].has_value()) {
|
||||
search_result.distances_[offset_idx[i]] +=
|
||||
boost_scores[i].value();
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
default:
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
fmt::format("unknown boost boost mode: {}", boost_mode));
|
||||
}
|
||||
|
||||
knowhere::MetricType metric_type = query_context_->get_metric_type();
|
||||
|
||||
@ -72,6 +72,7 @@ class PhyRescoresNode : public Operator {
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<rescores::Scorer>> scorers_;
|
||||
const proto::plan::ScoreOption* option_;
|
||||
bool is_finished_{false};
|
||||
};
|
||||
} // namespace milvus::exec
|
||||
|
||||
@ -32,7 +32,7 @@ TEST(Rescorer, Normal) {
|
||||
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
|
||||
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
|
||||
auto str_fid = schema->AddDebugField("string1", DataType::VARCHAR);
|
||||
auto str_fid = schema->AddDebugField("string", DataType::VARCHAR);
|
||||
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
|
||||
schema->set_primary_field_id(str_fid);
|
||||
size_t N = 50;
|
||||
@ -142,4 +142,155 @@ TEST(Rescorer, Normal) {
|
||||
auto search_result =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
}
|
||||
|
||||
// random function with seed
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Int8
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
int64_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
int64_val: 100
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 50}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
scorers: <
|
||||
type: 1
|
||||
weight: 1
|
||||
params: <key: "seed", value: "123">
|
||||
>)";
|
||||
|
||||
proto::plan::PlanNode plan_node;
|
||||
auto ok =
|
||||
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
|
||||
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
}
|
||||
|
||||
// random function with field as random seed
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Int8
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
int64_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
int64_val: 100
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 50}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
scorers: <
|
||||
type: 1
|
||||
weight: 1
|
||||
params: <key: "field", value: "int64">
|
||||
>)";
|
||||
|
||||
proto::plan::PlanNode plan_node;
|
||||
auto ok =
|
||||
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
|
||||
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
}
|
||||
|
||||
// random function with field and seed
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
predicates: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: 101
|
||||
data_type: Int8
|
||||
>
|
||||
lower_inclusive: true,
|
||||
upper_inclusive: false,
|
||||
lower_value: <
|
||||
int64_val: -1
|
||||
>
|
||||
upper_value: <
|
||||
int64_val: 100
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 50}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
scorers: <
|
||||
type: 1
|
||||
weight: 1
|
||||
params: <key: "seed", value: "123">
|
||||
params: <key: "field", value: "int64">
|
||||
>)";
|
||||
|
||||
proto::plan::PlanNode plan_node;
|
||||
auto ok =
|
||||
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
|
||||
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
|
||||
auto search_result_same_seed =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
|
||||
// should return same score when use same seed
|
||||
for (auto i = 0; i < 10; i++) {
|
||||
AssertInfo(search_result->distances_[i] ==
|
||||
search_result_same_seed->distances_[i],
|
||||
"distance not equal %f:%f",
|
||||
search_result->distances_[i],
|
||||
search_result_same_seed->distances_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -24,6 +24,7 @@
|
||||
#include "common/Vector.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "plan/PlanNodeIdGenerator.h"
|
||||
#include "rescores/Scorer.h"
|
||||
@ -442,10 +443,12 @@ class RescoresNode : public PlanNode {
|
||||
public:
|
||||
RescoresNode(
|
||||
const PlanNodeId& id,
|
||||
const std::vector<std::shared_ptr<rescores::Scorer>> scorers,
|
||||
const std::vector<std::shared_ptr<rescores::Scorer>>& scorers,
|
||||
const proto::plan::ScoreOption& option,
|
||||
const std::vector<PlanNodePtr>& sources = std::vector<PlanNodePtr>{})
|
||||
: PlanNode(id),
|
||||
scorers_(std::move(scorers)),
|
||||
option_(std::move(option)),
|
||||
sources_{std::move(sources)} {
|
||||
}
|
||||
|
||||
@ -459,6 +462,11 @@ class RescoresNode : public PlanNode {
|
||||
return sources_;
|
||||
}
|
||||
|
||||
const proto::plan::ScoreOption*
|
||||
option() const {
|
||||
return &option_;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<rescores::Scorer>>&
|
||||
scorers() const {
|
||||
return scorers_;
|
||||
@ -476,6 +484,7 @@ class RescoresNode : public PlanNode {
|
||||
}
|
||||
|
||||
private:
|
||||
const proto::plan::ScoreOption option_;
|
||||
const std::vector<PlanNodePtr> sources_;
|
||||
const std::vector<std::shared_ptr<rescores::Scorer>> scorers_;
|
||||
};
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "common/EasyAssert.h"
|
||||
#include "exec/expression/function/FunctionFactory.h"
|
||||
#include "log/Log.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "query/Utils.h"
|
||||
#include "knowhere/comp/materialized_view.h"
|
||||
@ -215,7 +216,10 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
||||
}
|
||||
|
||||
plannode = std::make_shared<milvus::plan::RescoresNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), std::move(scorers), sources);
|
||||
milvus::plan::GetNextPlanNodeId(),
|
||||
std::move(scorers),
|
||||
plan_node_proto.score_option(),
|
||||
sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
}
|
||||
|
||||
@ -629,12 +633,21 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
|
||||
|
||||
std::shared_ptr<rescores::Scorer>
|
||||
ProtoParser::ParseScorer(const proto::plan::ScoreFunction& function) {
|
||||
expr::TypedExprPtr expr = nullptr;
|
||||
if (function.has_filter()) {
|
||||
auto expr = ParseExprs(function.filter());
|
||||
return std::make_shared<rescores::WeightScorer>(expr,
|
||||
function.weight());
|
||||
expr = ParseExprs(function.filter());
|
||||
}
|
||||
|
||||
switch (function.type()) {
|
||||
case proto::plan::FunctionTypeWeight:
|
||||
return std::make_shared<rescores::WeightScorer>(expr,
|
||||
function.weight());
|
||||
case proto::plan::FunctionTypeRandom:
|
||||
return std::make_shared<rescores::RandomScorer>(
|
||||
expr, function.weight(), function.params());
|
||||
default:
|
||||
ThrowInfo(UnexpectedError, "unknown function type");
|
||||
}
|
||||
return std::make_shared<rescores::WeightScorer>(nullptr, function.weight());
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
||||
13
internal/core/src/rescores/CMakeLists.txt
Normal file
13
internal/core/src/rescores/CMakeLists.txt
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
add_source_at_current_directory_recursively()
|
||||
add_library(milvus_rescores OBJECT ${SOURCE_FILES})
|
||||
466
internal/core/src/rescores/Murmur3.c
Normal file
466
internal/core/src/rescores/Murmur3.c
Normal file
@ -0,0 +1,466 @@
|
||||
//-----------------------------------------------------------------------------
|
||||
// MurmurHash3 was written by Austin Appleby, and is placed in the public
|
||||
// domain. The author hereby disclaims copyright to this source code.
|
||||
|
||||
// Note - The x86 and x64 versions do _not_ produce the same results, as the
|
||||
// algorithms are optimized for their respective platforms. You can still
|
||||
// compile and run any of them on any platform, but your performance with the
|
||||
// non-native version will be less than optimal.
|
||||
|
||||
#include "Murmur3.h"
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Platform-specific functions and macros
|
||||
|
||||
#ifdef __GNUC__
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
#else
|
||||
#define FORCE_INLINE inline
|
||||
#endif
|
||||
|
||||
static FORCE_INLINE uint32_t
|
||||
rotl32(uint32_t x, int8_t r) {
|
||||
return (x << r) | (x >> (32 - r));
|
||||
}
|
||||
|
||||
static FORCE_INLINE uint64_t
|
||||
rotl64(uint64_t x, int8_t r) {
|
||||
return (x << r) | (x >> (64 - r));
|
||||
}
|
||||
|
||||
#define ROTL32(x, y) rotl32(x, y)
|
||||
#define ROTL64(x, y) rotl64(x, y)
|
||||
|
||||
#define BIG_CONSTANT(x) (x##LLU)
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Block read - if your platform needs to do endian-swapping or can only
|
||||
// handle aligned reads, do the conversion here
|
||||
|
||||
#define getblock(p, i) (p[i])
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Finalization mix - force all bits of a hash block to avalanche
|
||||
|
||||
static FORCE_INLINE uint32_t
|
||||
fmix32(uint32_t h) {
|
||||
h ^= h >> 16;
|
||||
h *= 0x85ebca6b;
|
||||
h ^= h >> 13;
|
||||
h *= 0xc2b2ae35;
|
||||
h ^= h >> 16;
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
//----------
|
||||
|
||||
static FORCE_INLINE uint64_t
|
||||
fmix64(uint64_t k) {
|
||||
k ^= k >> 33;
|
||||
k *= BIG_CONSTANT(0xff51afd7ed558ccd);
|
||||
k ^= k >> 33;
|
||||
k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53);
|
||||
k ^= k >> 33;
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
void
|
||||
MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out) {
|
||||
const uint8_t* data = (const uint8_t*)key;
|
||||
const int nblocks = len / 4;
|
||||
int i;
|
||||
|
||||
uint32_t h1 = seed;
|
||||
|
||||
uint32_t c1 = 0xcc9e2d51;
|
||||
uint32_t c2 = 0x1b873593;
|
||||
|
||||
//----------
|
||||
// body
|
||||
|
||||
const uint32_t* blocks = (const uint32_t*)(data + nblocks * 4);
|
||||
|
||||
for (i = -nblocks; i; i++) {
|
||||
uint32_t k1 = getblock(blocks, i);
|
||||
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
|
||||
h1 ^= k1;
|
||||
h1 = ROTL32(h1, 13);
|
||||
h1 = h1 * 5 + 0xe6546b64;
|
||||
}
|
||||
|
||||
//----------
|
||||
// tail
|
||||
|
||||
const uint8_t* tail = (const uint8_t*)(data + nblocks * 4);
|
||||
|
||||
uint32_t k1 = 0;
|
||||
|
||||
switch (len & 3) {
|
||||
case 3:
|
||||
k1 ^= tail[2] << 16;
|
||||
case 2:
|
||||
k1 ^= tail[1] << 8;
|
||||
case 1:
|
||||
k1 ^= tail[0];
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
};
|
||||
|
||||
//----------
|
||||
// finalization
|
||||
|
||||
h1 ^= len;
|
||||
|
||||
h1 = fmix32(h1);
|
||||
|
||||
*(uint32_t*)out = h1;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
void
|
||||
MurmurHash3_x86_128(const void* key, const int len, uint32_t seed, void* out) {
|
||||
const uint8_t* data = (const uint8_t*)key;
|
||||
const int nblocks = len / 16;
|
||||
int i;
|
||||
|
||||
uint32_t h1 = seed;
|
||||
uint32_t h2 = seed;
|
||||
uint32_t h3 = seed;
|
||||
uint32_t h4 = seed;
|
||||
|
||||
uint32_t c1 = 0x239b961b;
|
||||
uint32_t c2 = 0xab0e9789;
|
||||
uint32_t c3 = 0x38b34ae5;
|
||||
uint32_t c4 = 0xa1e38b93;
|
||||
|
||||
//----------
|
||||
// body
|
||||
|
||||
const uint32_t* blocks = (const uint32_t*)(data + nblocks * 16);
|
||||
|
||||
for (i = -nblocks; i; i++) {
|
||||
uint32_t k1 = getblock(blocks, i * 4 + 0);
|
||||
uint32_t k2 = getblock(blocks, i * 4 + 1);
|
||||
uint32_t k3 = getblock(blocks, i * 4 + 2);
|
||||
uint32_t k4 = getblock(blocks, i * 4 + 3);
|
||||
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
|
||||
h1 = ROTL32(h1, 19);
|
||||
h1 += h2;
|
||||
h1 = h1 * 5 + 0x561ccd1b;
|
||||
|
||||
k2 *= c2;
|
||||
k2 = ROTL32(k2, 16);
|
||||
k2 *= c3;
|
||||
h2 ^= k2;
|
||||
|
||||
h2 = ROTL32(h2, 17);
|
||||
h2 += h3;
|
||||
h2 = h2 * 5 + 0x0bcaa747;
|
||||
|
||||
k3 *= c3;
|
||||
k3 = ROTL32(k3, 17);
|
||||
k3 *= c4;
|
||||
h3 ^= k3;
|
||||
|
||||
h3 = ROTL32(h3, 15);
|
||||
h3 += h4;
|
||||
h3 = h3 * 5 + 0x96cd1c35;
|
||||
|
||||
k4 *= c4;
|
||||
k4 = ROTL32(k4, 18);
|
||||
k4 *= c1;
|
||||
h4 ^= k4;
|
||||
|
||||
h4 = ROTL32(h4, 13);
|
||||
h4 += h1;
|
||||
h4 = h4 * 5 + 0x32ac3b17;
|
||||
}
|
||||
|
||||
//----------
|
||||
// tail
|
||||
|
||||
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
|
||||
|
||||
uint32_t k1 = 0;
|
||||
uint32_t k2 = 0;
|
||||
uint32_t k3 = 0;
|
||||
uint32_t k4 = 0;
|
||||
|
||||
switch (len & 15) {
|
||||
case 15:
|
||||
k4 ^= tail[14] << 16;
|
||||
case 14:
|
||||
k4 ^= tail[13] << 8;
|
||||
case 13:
|
||||
k4 ^= tail[12] << 0;
|
||||
k4 *= c4;
|
||||
k4 = ROTL32(k4, 18);
|
||||
k4 *= c1;
|
||||
h4 ^= k4;
|
||||
|
||||
case 12:
|
||||
k3 ^= tail[11] << 24;
|
||||
case 11:
|
||||
k3 ^= tail[10] << 16;
|
||||
case 10:
|
||||
k3 ^= tail[9] << 8;
|
||||
case 9:
|
||||
k3 ^= tail[8] << 0;
|
||||
k3 *= c3;
|
||||
k3 = ROTL32(k3, 17);
|
||||
k3 *= c4;
|
||||
h3 ^= k3;
|
||||
|
||||
case 8:
|
||||
k2 ^= tail[7] << 24;
|
||||
case 7:
|
||||
k2 ^= tail[6] << 16;
|
||||
case 6:
|
||||
k2 ^= tail[5] << 8;
|
||||
case 5:
|
||||
k2 ^= tail[4] << 0;
|
||||
k2 *= c2;
|
||||
k2 = ROTL32(k2, 16);
|
||||
k2 *= c3;
|
||||
h2 ^= k2;
|
||||
|
||||
case 4:
|
||||
k1 ^= tail[3] << 24;
|
||||
case 3:
|
||||
k1 ^= tail[2] << 16;
|
||||
case 2:
|
||||
k1 ^= tail[1] << 8;
|
||||
case 1:
|
||||
k1 ^= tail[0] << 0;
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
};
|
||||
|
||||
//----------
|
||||
// finalization
|
||||
|
||||
h1 ^= len;
|
||||
h2 ^= len;
|
||||
h3 ^= len;
|
||||
h4 ^= len;
|
||||
|
||||
h1 += h2;
|
||||
h1 += h3;
|
||||
h1 += h4;
|
||||
h2 += h1;
|
||||
h3 += h1;
|
||||
h4 += h1;
|
||||
|
||||
h1 = fmix32(h1);
|
||||
h2 = fmix32(h2);
|
||||
h3 = fmix32(h3);
|
||||
h4 = fmix32(h4);
|
||||
|
||||
h1 += h2;
|
||||
h1 += h3;
|
||||
h1 += h4;
|
||||
h2 += h1;
|
||||
h3 += h1;
|
||||
h4 += h1;
|
||||
|
||||
((uint32_t*)out)[0] = h1;
|
||||
((uint32_t*)out)[1] = h2;
|
||||
((uint32_t*)out)[2] = h3;
|
||||
((uint32_t*)out)[3] = h4;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
void
|
||||
MurmurHash3_x64_128(const void* key,
|
||||
const int len,
|
||||
const uint32_t seed,
|
||||
void* out) {
|
||||
const uint8_t* data = (const uint8_t*)key;
|
||||
const int nblocks = len / 16;
|
||||
int i;
|
||||
|
||||
uint64_t h1 = seed;
|
||||
uint64_t h2 = seed;
|
||||
|
||||
uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5);
|
||||
uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f);
|
||||
|
||||
//----------
|
||||
// body
|
||||
|
||||
const uint64_t* blocks = (const uint64_t*)(data);
|
||||
|
||||
for (i = 0; i < nblocks; i++) {
|
||||
uint64_t k1 = getblock(blocks, i * 2 + 0);
|
||||
uint64_t k2 = getblock(blocks, i * 2 + 1);
|
||||
|
||||
k1 *= c1;
|
||||
k1 = ROTL64(k1, 31);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
|
||||
h1 = ROTL64(h1, 27);
|
||||
h1 += h2;
|
||||
h1 = h1 * 5 + 0x52dce729;
|
||||
|
||||
k2 *= c2;
|
||||
k2 = ROTL64(k2, 33);
|
||||
k2 *= c1;
|
||||
h2 ^= k2;
|
||||
|
||||
h2 = ROTL64(h2, 31);
|
||||
h2 += h1;
|
||||
h2 = h2 * 5 + 0x38495ab5;
|
||||
}
|
||||
|
||||
//----------
|
||||
// tail
|
||||
|
||||
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
|
||||
|
||||
uint64_t k1 = 0;
|
||||
uint64_t k2 = 0;
|
||||
|
||||
switch (len & 15) {
|
||||
case 15:
|
||||
k2 ^= (uint64_t)(tail[14]) << 48;
|
||||
case 14:
|
||||
k2 ^= (uint64_t)(tail[13]) << 40;
|
||||
case 13:
|
||||
k2 ^= (uint64_t)(tail[12]) << 32;
|
||||
case 12:
|
||||
k2 ^= (uint64_t)(tail[11]) << 24;
|
||||
case 11:
|
||||
k2 ^= (uint64_t)(tail[10]) << 16;
|
||||
case 10:
|
||||
k2 ^= (uint64_t)(tail[9]) << 8;
|
||||
case 9:
|
||||
k2 ^= (uint64_t)(tail[8]) << 0;
|
||||
k2 *= c2;
|
||||
k2 = ROTL64(k2, 33);
|
||||
k2 *= c1;
|
||||
h2 ^= k2;
|
||||
|
||||
case 8:
|
||||
k1 ^= (uint64_t)(tail[7]) << 56;
|
||||
case 7:
|
||||
k1 ^= (uint64_t)(tail[6]) << 48;
|
||||
case 6:
|
||||
k1 ^= (uint64_t)(tail[5]) << 40;
|
||||
case 5:
|
||||
k1 ^= (uint64_t)(tail[4]) << 32;
|
||||
case 4:
|
||||
k1 ^= (uint64_t)(tail[3]) << 24;
|
||||
case 3:
|
||||
k1 ^= (uint64_t)(tail[2]) << 16;
|
||||
case 2:
|
||||
k1 ^= (uint64_t)(tail[1]) << 8;
|
||||
case 1:
|
||||
k1 ^= (uint64_t)(tail[0]) << 0;
|
||||
k1 *= c1;
|
||||
k1 = ROTL64(k1, 31);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
};
|
||||
|
||||
//----------
|
||||
// finalization
|
||||
|
||||
h1 ^= len;
|
||||
h2 ^= len;
|
||||
|
||||
h1 += h2;
|
||||
h2 += h1;
|
||||
|
||||
h1 = fmix64(h1);
|
||||
h2 = fmix64(h2);
|
||||
|
||||
h1 += h2;
|
||||
h2 += h1;
|
||||
|
||||
((uint64_t*)out)[0] = h1;
|
||||
((uint64_t*)out)[1] = h2;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
// 128 bit key = (64bit key + 64bit seed)
|
||||
// block num =1
|
||||
// len = 16
|
||||
uint64_t
|
||||
MurmurHash3_x64_64_Special(const uint64_t key, const uint64_t seed) {
|
||||
const uint8_t* data = (const uint8_t*)key;
|
||||
int i;
|
||||
|
||||
uint64_t h1 = key ^ seed;
|
||||
uint64_t h2 = seed;
|
||||
|
||||
uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5);
|
||||
uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f);
|
||||
|
||||
//----------
|
||||
// body
|
||||
|
||||
// 1 block
|
||||
{
|
||||
uint64_t k1 = key;
|
||||
uint64_t k2 = seed;
|
||||
|
||||
k1 *= c1;
|
||||
k1 = ROTL64(k1, 31);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
|
||||
h1 = ROTL64(h1, 27);
|
||||
h1 += h2;
|
||||
h1 = h1 * 5 + 0x52dce729;
|
||||
|
||||
k2 *= c2;
|
||||
k2 = ROTL64(k2, 33);
|
||||
k2 *= c1;
|
||||
h2 ^= k2;
|
||||
|
||||
h2 = ROTL64(h2, 31);
|
||||
h2 += h1;
|
||||
h2 = h2 * 5 + 0x38495ab5;
|
||||
}
|
||||
|
||||
//----------
|
||||
// No tail
|
||||
|
||||
//----------
|
||||
// finalization
|
||||
|
||||
h1 ^= 16;
|
||||
h2 ^= 16;
|
||||
|
||||
h1 += h2;
|
||||
h2 += h1;
|
||||
|
||||
h1 = fmix64(h1);
|
||||
h2 = fmix64(h2);
|
||||
|
||||
h1 += h2;
|
||||
h2 += h1;
|
||||
|
||||
return h1;
|
||||
}
|
||||
46
internal/core/src/rescores/Murmur3.h
Normal file
46
internal/core/src/rescores/Murmur3.h
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// MurmurHash3 was written by Austin Appleby, and is placed in the
|
||||
// public domain. The author hereby disclaims copyright to this source
|
||||
// code.
|
||||
|
||||
#ifndef _MURMURHASH3_H_
|
||||
#define _MURMURHASH3_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
void
|
||||
MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out);
|
||||
|
||||
void
|
||||
MurmurHash3_x86_128(const void* key, int len, uint32_t seed, void* out);
|
||||
|
||||
void
|
||||
MurmurHash3_x64_128(const void* key, int len, uint32_t seed, void* out);
|
||||
|
||||
uint64_t
|
||||
MurmurHash3_x64_64_Special(const uint64_t key, const uint64_t seed);
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // _MURMURHASH3_H_
|
||||
202
internal/core/src/rescores/Scorer.cpp
Normal file
202
internal/core/src/rescores/Scorer.cpp
Normal file
@ -0,0 +1,202 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include "common/Types.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
#include "Scorer.h"
|
||||
#include "Utils.h"
|
||||
#include "log/Log.h"
|
||||
#include "rescores/Murmur3.h"
|
||||
|
||||
namespace milvus::rescores {
|
||||
|
||||
void
|
||||
WeightScorer::batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmapView& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) {
|
||||
Assert(bitmap.size() == offsets.size());
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitmap[i] > 0) {
|
||||
set_score(boost_scores[i], mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
WeightScorer::batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmap& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) {
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitmap[offsets[i]] > 0) {
|
||||
set_score(boost_scores[i], mode);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
WeightScorer::batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
std::vector<std::optional<float>>& boost_scores) {
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
set_score(boost_scores[i], mode);
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
WeightScorer::set_score(std::optional<float>& score,
|
||||
const proto::plan::FunctionMode& mode) {
|
||||
if (!score.has_value()) {
|
||||
score = std::make_optional(weight_);
|
||||
} else {
|
||||
score = std::make_optional(
|
||||
function_score_merge(score.value(), weight_, mode));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
RandomScorer::batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmapView& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) {
|
||||
Assert(bitmap.size() == offsets.size());
|
||||
FixedVector<int64_t> target_offsets;
|
||||
FixedVector<int> idx;
|
||||
target_offsets.reserve(offsets.size());
|
||||
idx.reserve(offsets.size());
|
||||
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitmap[i] > 0) {
|
||||
target_offsets.push_back(static_cast<int64_t>(offsets[i]));
|
||||
idx.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// skip if empty
|
||||
if (target_offsets.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
random_score(op_ctx, segment, mode, target_offsets, &idx, boost_scores);
|
||||
}
|
||||
|
||||
void
|
||||
RandomScorer::batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmap& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) {
|
||||
FixedVector<int64_t> target_offsets;
|
||||
FixedVector<int> idx;
|
||||
target_offsets.reserve(offsets.size());
|
||||
idx.reserve(offsets.size());
|
||||
|
||||
for (auto i = 0; i < offsets.size(); i++) {
|
||||
if (bitmap[offsets[i]] > 0) {
|
||||
target_offsets.push_back(static_cast<int64_t>(offsets[i]));
|
||||
idx.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// skip if empty
|
||||
if (target_offsets.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
random_score(op_ctx, segment, mode, target_offsets, &idx, boost_scores);
|
||||
}
|
||||
|
||||
void
|
||||
RandomScorer::batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
std::vector<std::optional<float>>& boost_scores) {
|
||||
FixedVector<int64_t> target_offsets;
|
||||
target_offsets.reserve(offsets.size());
|
||||
|
||||
for (int offset : offsets) {
|
||||
target_offsets.push_back(static_cast<int64_t>(offset));
|
||||
}
|
||||
|
||||
random_score(op_ctx, segment, mode, target_offsets, nullptr, boost_scores);
|
||||
}
|
||||
|
||||
void
|
||||
RandomScorer::random_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int64_t>& target_offsets,
|
||||
const FixedVector<int>* idx,
|
||||
std::vector<std::optional<float>>& boost_scores) {
|
||||
if (field_.get() != -1) {
|
||||
auto array = segment->bulk_subscript(
|
||||
op_ctx, field_, target_offsets.data(), target_offsets.size());
|
||||
AssertInfo(array->has_scalars(), "seed field must be scalar");
|
||||
AssertInfo(array->scalars().has_long_data(),
|
||||
"now only support int64 field as seed");
|
||||
// TODO: Support varchar and int32 field as random field.
|
||||
|
||||
auto datas = array->scalars().long_data();
|
||||
for (int i = 0; i < datas.data_size(); i++) {
|
||||
auto a = datas.data()[i];
|
||||
auto random_score =
|
||||
hash_to_double(MurmurHash3_x64_64_Special(a, seed_));
|
||||
if (idx == nullptr) {
|
||||
set_score(random_score, boost_scores[i], mode);
|
||||
} else {
|
||||
set_score(random_score, boost_scores[idx->at(i)], mode);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// if not set field, use offset and seed to hash.
|
||||
for (int i = 0; i < target_offsets.size(); i++) {
|
||||
double random_score = hash_to_double(MurmurHash3_x64_64_Special(
|
||||
target_offsets[i] + segment->get_segment_id(), seed_));
|
||||
if (idx == nullptr) {
|
||||
set_score(random_score, boost_scores[i], mode);
|
||||
} else {
|
||||
set_score(random_score, boost_scores[idx->at(i)], mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
RandomScorer::set_score(float random_value,
|
||||
std::optional<float>& score,
|
||||
const proto::plan::FunctionMode& mode) {
|
||||
if (!score.has_value()) {
|
||||
score = std::make_optional(random_value * weight_);
|
||||
} else {
|
||||
score = std::make_optional(function_score_merge(
|
||||
score.value(), (random_value * weight_), mode));
|
||||
}
|
||||
}
|
||||
} // namespace milvus::rescores
|
||||
@ -16,16 +16,52 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
#include "pb/common.pb.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "common/protobuf_utils.h"
|
||||
|
||||
namespace milvus::rescores {
|
||||
class Scorer {
|
||||
public:
|
||||
virtual ~Scorer() = default;
|
||||
|
||||
virtual expr::TypedExprPtr
|
||||
filter() = 0;
|
||||
|
||||
virtual float
|
||||
rescore(float old_score) = 0;
|
||||
// filter result of offset[i] was bitmapview[i]
|
||||
// add boost score for idx[i] if bitmap[i] was true
|
||||
virtual void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmapView& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) = 0;
|
||||
|
||||
// score by bitmap
|
||||
// filter result of offset[i] was bitmap[offset[i]]
|
||||
// add boost score for idx[i] if bitmap[i] was true
|
||||
virtual void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmap& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) = 0;
|
||||
|
||||
// score for all offset
|
||||
// used when no filter
|
||||
virtual void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
std::vector<std::optional<float>>& boost_scores) = 0;
|
||||
|
||||
virtual float
|
||||
weight() = 0;
|
||||
@ -41,10 +77,28 @@ class WeightScorer : public Scorer {
|
||||
return filter_;
|
||||
}
|
||||
|
||||
float
|
||||
rescore(float old_score) override {
|
||||
return old_score * weight_;
|
||||
}
|
||||
void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmapView& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) override;
|
||||
|
||||
void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmap& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) override;
|
||||
|
||||
void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
std::vector<std::optional<float>>& boost_scores) override;
|
||||
|
||||
float
|
||||
weight() override {
|
||||
@ -52,7 +106,96 @@ class WeightScorer : public Scorer {
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
set_score(std::optional<float>& old_score,
|
||||
const proto::plan::FunctionMode& mode);
|
||||
|
||||
expr::TypedExprPtr filter_;
|
||||
float weight_;
|
||||
};
|
||||
|
||||
class RandomScorer : public Scorer {
|
||||
public:
|
||||
RandomScorer(expr::TypedExprPtr& filter,
|
||||
float weight,
|
||||
const ProtoParams& params) {
|
||||
auto param_map = RepeatedKeyValToMap(params);
|
||||
if (auto it = param_map.find("seed"); it != param_map.end()) {
|
||||
try {
|
||||
seed_ = std::stoll(it->second);
|
||||
} catch (const std::exception& e) {
|
||||
ThrowInfo(ErrorCode::InvalidParameter,
|
||||
"parse boost random seed params failed: {}",
|
||||
e.what());
|
||||
}
|
||||
}
|
||||
|
||||
if (auto it = param_map.find("field_id"); it != param_map.end()) {
|
||||
try {
|
||||
field_ = FieldId(std::stoll(it->second));
|
||||
} catch (const std::exception& e) {
|
||||
ThrowInfo(ErrorCode::InvalidParameter,
|
||||
"parse boost random seed field ID failed: {}",
|
||||
e.what());
|
||||
}
|
||||
} else {
|
||||
field_ = FieldId(-1);
|
||||
}
|
||||
|
||||
weight_ = weight;
|
||||
filter_ = filter;
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
filter() override {
|
||||
return filter_;
|
||||
}
|
||||
|
||||
void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmapView& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) override;
|
||||
|
||||
void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const TargetBitmap& bitmap,
|
||||
std::vector<std::optional<float>>& boost_scores) override;
|
||||
|
||||
void
|
||||
batch_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
std::vector<std::optional<float>>& boost_scores) override;
|
||||
|
||||
float
|
||||
weight() override {
|
||||
return weight_;
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
set_score(float random_value,
|
||||
std::optional<float>& old_score,
|
||||
const proto::plan::FunctionMode& mode);
|
||||
|
||||
void
|
||||
random_score(milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
const proto::plan::FunctionMode& mode,
|
||||
const FixedVector<int64_t>& target_offsets,
|
||||
const FixedVector<int>* idx,
|
||||
std::vector<std::optional<float>>& boost_scores);
|
||||
|
||||
expr::TypedExprPtr filter_;
|
||||
float weight_;
|
||||
int64_t seed_;
|
||||
FieldId field_;
|
||||
};
|
||||
} // namespace milvus::rescores
|
||||
56
internal/core/src/rescores/Utils.h
Normal file
56
internal/core/src/rescores/Utils.h
Normal file
@ -0,0 +1,56 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include "pb/plan.pb.h"
|
||||
#include "common/EasyAssert.h"
|
||||
|
||||
extern "C" {
|
||||
#include "Murmur3.h"
|
||||
}
|
||||
|
||||
namespace milvus::rescores {
|
||||
|
||||
inline float
|
||||
function_score_merge(const float& a,
|
||||
const float& b,
|
||||
const proto::plan::FunctionMode& mode) {
|
||||
switch (mode) {
|
||||
case proto::plan::FunctionModeMultiply:
|
||||
return a * b;
|
||||
case proto::plan::FunctionModeSum:
|
||||
return a + b;
|
||||
default:
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
fmt::format("unknown boost function mode: {}:{}",
|
||||
proto::plan::FunctionMode_Name(mode),
|
||||
mode));
|
||||
}
|
||||
}
|
||||
|
||||
#define MAGIC_BITS (0x3FFL << 52)
|
||||
|
||||
double
|
||||
hash_to_double(const uint64_t& h) {
|
||||
auto double_bytes = (MAGIC_BITS | (h & 0xFFFFFFFFFFFFF));
|
||||
double result;
|
||||
memcpy(&result, &double_bytes, sizeof(result));
|
||||
return result - 1.0;
|
||||
}
|
||||
} // namespace milvus::rescores
|
||||
@ -3,6 +3,7 @@ package planparserv2
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/antlr4-go/antlr/v4"
|
||||
@ -10,13 +11,16 @@ import (
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
@ -255,7 +259,7 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scorers, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)
|
||||
scorers, options, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -276,7 +280,8 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
|
||||
FieldId: fieldID,
|
||||
},
|
||||
},
|
||||
Scorers: scorers,
|
||||
Scorers: scorers,
|
||||
ScoreOption: options,
|
||||
PlanOptions: &planpb.PlanOption{
|
||||
ExprUseJsonStats: exprParams.UseJSONStats,
|
||||
},
|
||||
@ -284,6 +289,67 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
|
||||
return planNode, nil
|
||||
}
|
||||
|
||||
func prepareBoostRandomParams(schema *typeutil.SchemaHelper, bytes string) ([]*commonpb.KeyValuePair, error) {
|
||||
paramsMap := make(map[string]any)
|
||||
|
||||
dec := json.NewDecoder(strings.NewReader(bytes))
|
||||
dec.UseNumber()
|
||||
|
||||
err := dec.Decode(¶msMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*commonpb.KeyValuePair, 0)
|
||||
for key, value := range paramsMap {
|
||||
switch key {
|
||||
// parse field name to field ID
|
||||
case RandomScoreFileNameKey:
|
||||
name, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("random seed field name must be string")
|
||||
}
|
||||
|
||||
field, err := schema.GetFieldFromName(name)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrFieldNotFound(value, "random seed field not found")
|
||||
}
|
||||
|
||||
if field.DataType != schemapb.DataType_Int64 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("only support int64 field as random seed, but got %s", field.DataType.String())
|
||||
}
|
||||
result = append(result, &commonpb.KeyValuePair{Key: RandomScoreFileIdKey, Value: fmt.Sprint(field.FieldID)})
|
||||
case RandomScoreSeedKey:
|
||||
number, ok := value.(json.Number)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("random seed must be number")
|
||||
}
|
||||
|
||||
result = append(result, &commonpb.KeyValuePair{Key: key, Value: number.String()})
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func setBoostType(schema *typeutil.SchemaHelper, scorer *planpb.ScoreFunction, params []*commonpb.KeyValuePair) error {
|
||||
scorer.Type = planpb.FunctionType_FunctionTypeWeight
|
||||
for _, param := range params {
|
||||
switch param.GetKey() {
|
||||
case BoostRandomScoreKey:
|
||||
{
|
||||
scorer.Type = planpb.FunctionType_FunctionTypeRandom
|
||||
params, err := prepareBoostRandomParams(schema, param.GetValue())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
scorer.Params = params
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.FunctionSchema, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.ScoreFunction, error) {
|
||||
rerankerName := rerank.GetRerankName(function)
|
||||
switch rerankerName {
|
||||
@ -308,6 +374,12 @@ func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.Functi
|
||||
return nil, fmt.Errorf("parse function scorer weight params failed with error: {%v}", err)
|
||||
}
|
||||
scorer.Weight = float32(weight)
|
||||
|
||||
err = setBoostType(schema, scorer, function.GetParams())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return scorer, nil
|
||||
default:
|
||||
// if not boost scorer, regard as normal function scorer
|
||||
@ -317,22 +389,67 @@ func CreateSearchScorer(schema *typeutil.SchemaHelper, function *schemapb.Functi
|
||||
}
|
||||
}
|
||||
|
||||
func CreateSearchScorers(schema *typeutil.SchemaHelper, functionScore *schemapb.FunctionScore, exprTemplateValues map[string]*schemapb.TemplateValue) ([]*planpb.ScoreFunction, error) {
|
||||
func ParseBoostMode(s string) (planpb.BoostMode, error) {
|
||||
s = strings.ToLower(s)
|
||||
switch s {
|
||||
case "multiply":
|
||||
return planpb.BoostMode_BoostModeMultiply, nil
|
||||
case "sum":
|
||||
return planpb.BoostMode_BoostModeSum, nil
|
||||
default:
|
||||
return 0, merr.WrapErrParameterInvalidMsg("unknown boost mode: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func ParseFunctionMode(s string) (planpb.FunctionMode, error) {
|
||||
s = strings.ToLower(s)
|
||||
switch s {
|
||||
case "multiply":
|
||||
return planpb.FunctionMode_FunctionModeMultiply, nil
|
||||
case "sum":
|
||||
return planpb.FunctionMode_FunctionModeSum, nil
|
||||
default:
|
||||
return 0, merr.WrapErrParameterInvalidMsg("unknown function mode: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateSearchScorers(schema *typeutil.SchemaHelper, functionScore *schemapb.FunctionScore, exprTemplateValues map[string]*schemapb.TemplateValue) ([]*planpb.ScoreFunction, *planpb.ScoreOption, error) {
|
||||
scorers := []*planpb.ScoreFunction{}
|
||||
for _, function := range functionScore.GetFunctions() {
|
||||
// create scorer for search plan
|
||||
scorer, err := CreateSearchScorer(schema, function, exprTemplateValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if scorer != nil {
|
||||
scorers = append(scorers, scorer)
|
||||
}
|
||||
}
|
||||
if len(scorers) == 0 {
|
||||
return nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
return scorers, nil
|
||||
|
||||
option := &planpb.ScoreOption{}
|
||||
|
||||
s, ok := funcutil.TryGetAttrByKeyFromRepeatedKV(BoostModeKey, functionScore.GetParams())
|
||||
if ok {
|
||||
boostMode, err := ParseBoostMode(s)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
option.BoostMode = boostMode
|
||||
}
|
||||
|
||||
s, ok = funcutil.TryGetAttrByKeyFromRepeatedKV(BoostFunctionModeKey, functionScore.GetParams())
|
||||
if ok {
|
||||
functionMode, err := ParseFunctionMode(s)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
option.FunctionMode = functionMode
|
||||
}
|
||||
|
||||
return scorers, option, nil
|
||||
}
|
||||
|
||||
func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo, exprTemplateValues map[string]*schemapb.TemplateValue, functionScorer *schemapb.FunctionScore) (*planpb.PlanNode, error) {
|
||||
|
||||
@ -16,6 +16,15 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
BoostRandomScoreKey = "random_score"
|
||||
BoostModeKey = "boost_mode"
|
||||
BoostFunctionModeKey = "function_mode"
|
||||
RandomScoreSeedKey = "seed"
|
||||
RandomScoreFileNameKey = "field"
|
||||
RandomScoreFileIdKey = "field_id"
|
||||
)
|
||||
|
||||
func IsBool(n *planpb.GenericValue) bool {
|
||||
switch n.GetVal().(type) {
|
||||
case *planpb.GenericValue_BoolVal:
|
||||
|
||||
@ -3,6 +3,7 @@ package milvus.proto.plan;
|
||||
|
||||
option go_package = "github.com/milvus-io/milvus/pkg/v2/proto/planpb";
|
||||
import "schema.proto";
|
||||
import "common.proto";
|
||||
|
||||
enum OpType {
|
||||
Invalid = 0;
|
||||
@ -279,9 +280,35 @@ message QueryPlanNode {
|
||||
int64 limit = 3;
|
||||
};
|
||||
|
||||
enum FunctionType{
|
||||
FunctionTypeWeight = 0;
|
||||
FunctionTypeRandom = 1;
|
||||
}
|
||||
|
||||
// FunctionMode decide how to calculate boost score
|
||||
// for multiple boost function scores
|
||||
enum FunctionMode{
|
||||
FunctionModeMultiply = 0;
|
||||
FunctionModeSum = 1;
|
||||
};
|
||||
|
||||
// BoostMode decide how to calculate final score
|
||||
// for origin score and boost score.
|
||||
enum BoostMode{
|
||||
BoostModeMultiply = 0;
|
||||
BoostModeSum = 1;
|
||||
};
|
||||
|
||||
message ScoreFunction {
|
||||
Expr filter =1;
|
||||
float weight = 2;
|
||||
Expr filter = 1;
|
||||
float weight = 2;
|
||||
FunctionType type = 3;
|
||||
repeated common.KeyValuePair params = 4;
|
||||
}
|
||||
|
||||
message ScoreOption{
|
||||
BoostMode boost_mode = 1;
|
||||
FunctionMode function_mode = 2;
|
||||
}
|
||||
|
||||
message PlanOption {
|
||||
@ -298,4 +325,5 @@ message PlanNode {
|
||||
repeated string dynamic_fields = 5;
|
||||
repeated ScoreFunction scorers = 6;
|
||||
PlanOption plan_options = 7;
|
||||
ScoreOption score_option = 8;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user