mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
feat: support load and query with bm25 metric (#36071)
relate: https://github.com/milvus-io/milvus/issues/35853 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
90285830de
commit
db34572c56
3
Makefile
3
Makefile
@ -532,6 +532,9 @@ generate-mockery-utils: getdeps
|
|||||||
# proxy_client_manager.go
|
# proxy_client_manager.go
|
||||||
$(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage
|
$(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage
|
||||||
$(INSTALL_PATH)/mockery --name=ProxyWatcherInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_watcher.go --with-expecter --structname=MockProxyWatcher --inpackage
|
$(INSTALL_PATH)/mockery --name=ProxyWatcherInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_watcher.go --with-expecter --structname=MockProxyWatcher --inpackage
|
||||||
|
# function
|
||||||
|
$(INSTALL_PATH)/mockery --name=FunctionRunner --dir=$(PWD)/internal/util/function --output=$(PWD)/internal/util/function --filename=mock_function.go --with-expecter --structname=MockFunctionRunner --inpackage
|
||||||
|
|
||||||
|
|
||||||
generate-mockery-kv: getdeps
|
generate-mockery-kv: getdeps
|
||||||
$(INSTALL_PATH)/mockery --name=TxnKV --dir=$(PWD)/pkg/kv --output=$(PWD)/internal/kv/mocks --filename=txn_kv.go --with-expecter
|
$(INSTALL_PATH)/mockery --name=TxnKV --dir=$(PWD)/pkg/kv --output=$(PWD)/internal/kv/mocks --filename=txn_kv.go --with-expecter
|
||||||
|
|||||||
@ -410,7 +410,8 @@ inline bool
|
|||||||
IsFloatVectorMetricType(const MetricType& metric_type) {
|
IsFloatVectorMetricType(const MetricType& metric_type) {
|
||||||
return metric_type == knowhere::metric::L2 ||
|
return metric_type == knowhere::metric::L2 ||
|
||||||
metric_type == knowhere::metric::IP ||
|
metric_type == knowhere::metric::IP ||
|
||||||
metric_type == knowhere::metric::COSINE;
|
metric_type == knowhere::metric::COSINE ||
|
||||||
|
metric_type == knowhere::metric::BM25;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool
|
inline bool
|
||||||
|
|||||||
@ -160,13 +160,15 @@ inline bool
|
|||||||
IsFloatMetricType(const knowhere::MetricType& metric_type) {
|
IsFloatMetricType(const knowhere::MetricType& metric_type) {
|
||||||
return IsMetricType(metric_type, knowhere::metric::L2) ||
|
return IsMetricType(metric_type, knowhere::metric::L2) ||
|
||||||
IsMetricType(metric_type, knowhere::metric::IP) ||
|
IsMetricType(metric_type, knowhere::metric::IP) ||
|
||||||
IsMetricType(metric_type, knowhere::metric::COSINE);
|
IsMetricType(metric_type, knowhere::metric::COSINE) ||
|
||||||
|
IsMetricType(metric_type, knowhere::metric::BM25);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool
|
inline bool
|
||||||
PositivelyRelated(const knowhere::MetricType& metric_type) {
|
PositivelyRelated(const knowhere::MetricType& metric_type) {
|
||||||
return IsMetricType(metric_type, knowhere::metric::IP) ||
|
return IsMetricType(metric_type, knowhere::metric::IP) ||
|
||||||
IsMetricType(metric_type, knowhere::metric::COSINE);
|
IsMetricType(metric_type, knowhere::metric::COSINE) ||
|
||||||
|
IsMetricType(metric_type, knowhere::metric::BM25);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string
|
inline std::string
|
||||||
|
|||||||
@ -409,7 +409,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
|
|||||||
milvus::tracer::AddEvent("finish_knowhere_index_search");
|
milvus::tracer::AddEvent("finish_knowhere_index_search");
|
||||||
if (!res.has_value()) {
|
if (!res.has_value()) {
|
||||||
PanicInfo(ErrorCode::UnexpectedError,
|
PanicInfo(ErrorCode::UnexpectedError,
|
||||||
"failed to search: {}: {}",
|
"failed to search: config={} {}: {}",
|
||||||
|
search_conf.dump(),
|
||||||
KnowhereStatusString(res.error()),
|
KnowhereStatusString(res.error()),
|
||||||
res.what());
|
res.what());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,6 +52,11 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||||||
search_info.materialized_view_involved =
|
search_info.materialized_view_involved =
|
||||||
query_info_proto.materialized_view_involved();
|
query_info_proto.materialized_view_involved();
|
||||||
|
|
||||||
|
if (query_info_proto.bm25_avgdl() > 0) {
|
||||||
|
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
|
||||||
|
query_info_proto.bm25_avgdl();
|
||||||
|
}
|
||||||
|
|
||||||
if (query_info_proto.group_by_field_id() > 0) {
|
if (query_info_proto.group_by_field_id() > 0) {
|
||||||
auto group_by_field_id =
|
auto group_by_field_id =
|
||||||
FieldId(query_info_proto.group_by_field_id());
|
FieldId(query_info_proto.group_by_field_id());
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||||
|
|
||||||
#include "IndexConfigGenerator.h"
|
#include "IndexConfigGenerator.h"
|
||||||
|
#include "knowhere/comp/index_param.h"
|
||||||
#include "log/Log.h"
|
#include "log/Log.h"
|
||||||
|
|
||||||
namespace milvus::segcore {
|
namespace milvus::segcore {
|
||||||
@ -49,15 +50,28 @@ VecIndexConfig::VecIndexConfig(const int64_t max_index_row_cout,
|
|||||||
std::to_string(config_.get_nlist());
|
std::to_string(config_.get_nlist());
|
||||||
build_params_[knowhere::indexparam::SSIZE] = std::to_string(
|
build_params_[knowhere::indexparam::SSIZE] = std::to_string(
|
||||||
std::max((int)(config_.get_chunk_rows() / config_.get_nlist()), 48));
|
std::max((int)(config_.get_chunk_rows() / config_.get_nlist()), 48));
|
||||||
|
|
||||||
|
if (is_sparse && metric_type_ == knowhere::metric::BM25) {
|
||||||
|
build_params_[knowhere::meta::BM25_K1] =
|
||||||
|
index_meta_.GetIndexParams().at(knowhere::meta::BM25_K1);
|
||||||
|
build_params_[knowhere::meta::BM25_B] =
|
||||||
|
index_meta_.GetIndexParams().at(knowhere::meta::BM25_B);
|
||||||
|
build_params_[knowhere::meta::BM25_AVGDL] =
|
||||||
|
index_meta_.GetIndexParams().at(knowhere::meta::BM25_AVGDL);
|
||||||
|
}
|
||||||
|
|
||||||
search_params_[knowhere::indexparam::NPROBE] =
|
search_params_[knowhere::indexparam::NPROBE] =
|
||||||
std::to_string(config_.get_nprobe());
|
std::to_string(config_.get_nprobe());
|
||||||
|
|
||||||
// note for sparse vector index: drop_ratio_build is not allowed for growing
|
// note for sparse vector index: drop_ratio_build is not allowed for growing
|
||||||
// segment index.
|
// segment index.
|
||||||
LOG_INFO(
|
LOG_INFO(
|
||||||
"VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}",
|
"VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}, "
|
||||||
|
"config={}",
|
||||||
origin_index_type_,
|
origin_index_type_,
|
||||||
index_type_,
|
index_type_,
|
||||||
metric_type_);
|
metric_type_,
|
||||||
|
build_params_.dump());
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
@ -100,6 +114,11 @@ VecIndexConfig::GetSearchConf(const SearchInfo& searchInfo) {
|
|||||||
searchParam.search_params_[key] = searchInfo.search_params_[key];
|
searchParam.search_params_[key] = searchInfo.search_params_[key];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (metric_type_ == knowhere::metric::BM25) {
|
||||||
|
searchParam.search_params_[knowhere::meta::BM25_AVGDL] =
|
||||||
|
searchInfo.search_params_[knowhere::meta::BM25_AVGDL];
|
||||||
|
}
|
||||||
return searchParam;
|
return searchParam;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
#include "indexbuilder/ScalarIndexCreator.h"
|
#include "indexbuilder/ScalarIndexCreator.h"
|
||||||
#include "indexbuilder/VecIndexCreator.h"
|
#include "indexbuilder/VecIndexCreator.h"
|
||||||
#include "indexbuilder/index_c.h"
|
#include "indexbuilder/index_c.h"
|
||||||
|
#include "knowhere/comp/index_param.h"
|
||||||
#include "pb/index_cgo_msg.pb.h"
|
#include "pb/index_cgo_msg.pb.h"
|
||||||
#include "storage/Types.h"
|
#include "storage/Types.h"
|
||||||
|
|
||||||
@ -100,6 +101,14 @@ generate_build_conf(const milvus::IndexType& index_type,
|
|||||||
};
|
};
|
||||||
} else if (index_type == knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
|
} else if (index_type == knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
|
||||||
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
|
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
|
||||||
|
if (metric_type == knowhere::metric::BM25) {
|
||||||
|
return knowhere::Json{
|
||||||
|
{knowhere::meta::METRIC_TYPE, metric_type},
|
||||||
|
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},
|
||||||
|
{knowhere::meta::BM25_K1, "1.2"},
|
||||||
|
{knowhere::meta::BM25_B, "0.75"},
|
||||||
|
{knowhere::meta::BM25_AVGDL, "100"}};
|
||||||
|
}
|
||||||
return knowhere::Json{
|
return knowhere::Json{
|
||||||
{knowhere::meta::METRIC_TYPE, metric_type},
|
{knowhere::meta::METRIC_TYPE, metric_type},
|
||||||
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},
|
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},
|
||||||
|
|||||||
@ -652,7 +652,7 @@ func (s *L0CompactionTaskSuite) TestPorcessStateTrans() {
|
|||||||
s.Equal(datapb.CompactionTaskState_failed, t.GetState())
|
s.Equal(datapb.CompactionTaskState_failed, t.GetState())
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Run("test unkonwn task", func() {
|
s.Run("test unknown task", func() {
|
||||||
t := s.generateTestL0Task(datapb.CompactionTaskState_unknown)
|
t := s.generateTestL0Task(datapb.CompactionTaskState_unknown)
|
||||||
|
|
||||||
got := t.Process()
|
got := t.Process()
|
||||||
|
|||||||
@ -73,7 +73,7 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (eNode *embeddingNode) Name() string {
|
func (eNode *embeddingNode) Name() string {
|
||||||
return fmt.Sprintf("embeddingNode-%s-%s", "BM25test", eNode.channelName)
|
return fmt.Sprintf("embeddingNode-%s", eNode.channelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, inputFieldId, outputFieldId int64, data *storage.InsertData, meta map[int64]*storage.BM25Stats) error {
|
func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, inputFieldId, outputFieldId int64, data *storage.InsertData, meta map[int64]*storage.BM25Stats) error {
|
||||||
|
|||||||
@ -96,6 +96,7 @@ message SubSearchRequest {
|
|||||||
string metricType = 9;
|
string metricType = 9;
|
||||||
int64 group_by_field_id = 10;
|
int64 group_by_field_id = 10;
|
||||||
int64 group_size = 11;
|
int64 group_size = 11;
|
||||||
|
int64 field_id = 12;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SearchRequest {
|
message SearchRequest {
|
||||||
@ -124,6 +125,7 @@ message SearchRequest {
|
|||||||
common.ConsistencyLevel consistency_level = 22;
|
common.ConsistencyLevel consistency_level = 22;
|
||||||
int64 group_by_field_id = 23;
|
int64 group_by_field_id = 23;
|
||||||
int64 group_size = 24;
|
int64 group_size = 24;
|
||||||
|
int64 field_id = 25;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SubSearchResults {
|
message SubSearchResults {
|
||||||
|
|||||||
@ -64,6 +64,8 @@ message QueryInfo {
|
|||||||
bool materialized_view_involved = 7;
|
bool materialized_view_involved = 7;
|
||||||
int64 group_size = 8;
|
int64 group_size = 8;
|
||||||
bool group_strict_size = 9;
|
bool group_strict_size = 9;
|
||||||
|
double bm25_avgdl = 10;
|
||||||
|
int64 query_field_id =11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ColumnInfo {
|
message ColumnInfo {
|
||||||
|
|||||||
@ -367,6 +367,7 @@ message SegmentLoadInfo {
|
|||||||
int64 storageVersion = 18;
|
int64 storageVersion = 18;
|
||||||
bool is_sorted = 19;
|
bool is_sorted = 19;
|
||||||
map<int64, data.TextIndexStats> textStatsLogs = 20;
|
map<int64, data.TextIndexStats> textStatsLogs = 20;
|
||||||
|
repeated data.FieldBinlog bm25logs = 21;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FieldIndexInfo {
|
message FieldIndexInfo {
|
||||||
|
|||||||
@ -361,8 +361,12 @@ func (cit *createIndexTask) parseIndexParams() error {
|
|||||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "float vector index does not support metric type: "+metricType)
|
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "float vector index does not support metric type: "+metricType)
|
||||||
}
|
}
|
||||||
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) {
|
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) {
|
||||||
if metricType != metric.IP {
|
if metricType != metric.IP && metricType != metric.BM25 {
|
||||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP is the supported metric type for sparse index")
|
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP&BM25 is the supported metric type for sparse index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if metricType == metric.BM25 && cit.functionSchema.GetType() != schemapb.FunctionType_BM25 {
|
||||||
|
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only BM25 Function output field support BM25 metric type")
|
||||||
}
|
}
|
||||||
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) {
|
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) {
|
||||||
if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) {
|
if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) {
|
||||||
|
|||||||
@ -370,6 +370,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
GroupSize: t.rankParams.GetGroupSize(),
|
GroupSize: t.rankParams.GetGroupSize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
||||||
// set PartitionIDs for sub search
|
// set PartitionIDs for sub search
|
||||||
if t.partitionKeyMode {
|
if t.partitionKeyMode {
|
||||||
// isolatioin has tighter constraint, check first
|
// isolatioin has tighter constraint, check first
|
||||||
@ -449,6 +450,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.SearchRequest.Offset = offset
|
t.SearchRequest.Offset = offset
|
||||||
|
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
|
||||||
|
|
||||||
if t.partitionKeyMode {
|
if t.partitionKeyMode {
|
||||||
// isolatioin has tighter constraint, check first
|
// isolatioin has tighter constraint, check first
|
||||||
@ -511,6 +513,8 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
|
|||||||
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
||||||
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
|
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
queryInfo.QueryFieldId = annField.GetFieldID()
|
||||||
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
|
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
|
||||||
if planErr != nil {
|
if planErr != nil {
|
||||||
log.Warn("failed to create query plan", zap.Error(planErr),
|
log.Warn("failed to create query plan", zap.Error(planErr),
|
||||||
|
|||||||
@ -81,6 +81,7 @@ func PackSegmentLoadInfo(segment *datapb.SegmentInfo, channelCheckpoint *msgpb.M
|
|||||||
NumOfRows: segment.NumOfRows,
|
NumOfRows: segment.NumOfRows,
|
||||||
Statslogs: segment.Statslogs,
|
Statslogs: segment.Statslogs,
|
||||||
Deltalogs: segment.Deltalogs,
|
Deltalogs: segment.Deltalogs,
|
||||||
|
Bm25Logs: segment.Bm25Statslogs,
|
||||||
InsertChannel: segment.InsertChannel,
|
InsertChannel: segment.InsertChannel,
|
||||||
IndexInfos: indexes,
|
IndexInfos: indexes,
|
||||||
StartPosition: segment.GetStartPosition(),
|
StartPosition: segment.GetStartPosition(),
|
||||||
|
|||||||
@ -34,6 +34,7 @@ import (
|
|||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||||
@ -43,6 +44,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/function"
|
||||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
@ -54,6 +56,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||||
@ -111,6 +114,8 @@ type shardDelegator struct {
|
|||||||
lifetime lifetime.Lifetime[lifetime.State]
|
lifetime lifetime.Lifetime[lifetime.State]
|
||||||
|
|
||||||
distribution *distribution
|
distribution *distribution
|
||||||
|
idfOracle IDFOracle
|
||||||
|
|
||||||
segmentManager segments.SegmentManager
|
segmentManager segments.SegmentManager
|
||||||
tsafeManager tsafe.Manager
|
tsafeManager tsafe.Manager
|
||||||
pkOracle pkoracle.PkOracle
|
pkOracle pkoracle.PkOracle
|
||||||
@ -135,6 +140,10 @@ type shardDelegator struct {
|
|||||||
// in order to make add/remove growing be atomic, need lock before modify these meta info
|
// in order to make add/remove growing be atomic, need lock before modify these meta info
|
||||||
growingSegmentLock sync.RWMutex
|
growingSegmentLock sync.RWMutex
|
||||||
partitionStatsMut sync.RWMutex
|
partitionStatsMut sync.RWMutex
|
||||||
|
|
||||||
|
// fieldId -> functionRunner map for search function field
|
||||||
|
functionRunners map[UniqueID]function.FunctionRunner
|
||||||
|
hasBM25Field bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||||
@ -235,6 +244,19 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// build idf for bm25 search
|
||||||
|
if req.GetReq().GetMetricType() == metric.BM25 {
|
||||||
|
avgdl, err := sd.buildBM25IDF(req.GetReq())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if avgdl <= 0 {
|
||||||
|
log.Warn("search bm25 from empty data, skip search", zap.String("channel", sd.vchannelName), zap.Float64("avgdl", avgdl))
|
||||||
|
return []*internalpb.SearchResults{}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// get final sealedNum after possible segment prune
|
// get final sealedNum after possible segment prune
|
||||||
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
||||||
log.Debug("search segments...",
|
log.Debug("search segments...",
|
||||||
@ -335,6 +357,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||||||
IsAdvanced: false,
|
IsAdvanced: false,
|
||||||
GroupByFieldId: subReq.GetGroupByFieldId(),
|
GroupByFieldId: subReq.GetGroupByFieldId(),
|
||||||
GroupSize: subReq.GetGroupSize(),
|
GroupSize: subReq.GetGroupSize(),
|
||||||
|
FieldId: subReq.GetFieldId(),
|
||||||
}
|
}
|
||||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||||
searchReq := &querypb.SearchRequest{
|
searchReq := &querypb.SearchRequest{
|
||||||
@ -862,6 +885,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||||||
|
|
||||||
excludedSegments := NewExcludedSegments(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.GetAsDuration(time.Second))
|
excludedSegments := NewExcludedSegments(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.GetAsDuration(time.Second))
|
||||||
|
|
||||||
|
idfOracle := NewIDFOracle(collection.Schema().GetFunctions())
|
||||||
sd := &shardDelegator{
|
sd := &shardDelegator{
|
||||||
collectionID: collectionID,
|
collectionID: collectionID,
|
||||||
replicaID: replicaID,
|
replicaID: replicaID,
|
||||||
@ -871,7 +895,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||||||
segmentManager: manager.Segment,
|
segmentManager: manager.Segment,
|
||||||
workerManager: workerManager,
|
workerManager: workerManager,
|
||||||
lifetime: lifetime.NewLifetime(lifetime.Initializing),
|
lifetime: lifetime.NewLifetime(lifetime.Initializing),
|
||||||
distribution: NewDistribution(),
|
distribution: NewDistribution(idfOracle),
|
||||||
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock),
|
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock),
|
||||||
pkOracle: pkoracle.NewPkOracle(),
|
pkOracle: pkoracle.NewPkOracle(),
|
||||||
tsafeManager: tsafeManager,
|
tsafeManager: tsafeManager,
|
||||||
@ -880,9 +904,25 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||||||
factory: factory,
|
factory: factory,
|
||||||
queryHook: queryHook,
|
queryHook: queryHook,
|
||||||
chunkManager: chunkManager,
|
chunkManager: chunkManager,
|
||||||
|
idfOracle: idfOracle,
|
||||||
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
|
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
|
||||||
excludedSegments: excludedSegments,
|
excludedSegments: excludedSegments,
|
||||||
|
functionRunners: make(map[int64]function.FunctionRunner),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, tf := range collection.Schema().GetFunctions() {
|
||||||
|
if tf.GetType() == schemapb.FunctionType_BM25 {
|
||||||
|
functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sd.functionRunners[tf.OutputFieldIds[0]] = functionRunner
|
||||||
|
if tf.GetType() == schemapb.FunctionType_BM25 {
|
||||||
|
sd.hasBM25Field = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m := sync.Mutex{}
|
m := sync.Mutex{}
|
||||||
sd.tsCond = sync.NewCond(&m)
|
sd.tsCond = sync.NewCond(&m)
|
||||||
if sd.lifetime.Add(lifetime.NotStopped) == nil {
|
if sd.lifetime.Add(lifetime.NotStopped) == nil {
|
||||||
|
|||||||
@ -27,11 +27,14 @@ import (
|
|||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||||
@ -67,6 +70,8 @@ type InsertData struct {
|
|||||||
PrimaryKeys []storage.PrimaryKey
|
PrimaryKeys []storage.PrimaryKey
|
||||||
Timestamps []uint64
|
Timestamps []uint64
|
||||||
InsertRecord *segcorepb.InsertRecord
|
InsertRecord *segcorepb.InsertRecord
|
||||||
|
BM25Stats map[int64]*storage.BM25Stats
|
||||||
|
|
||||||
StartPosition *msgpb.MsgPosition
|
StartPosition *msgpb.MsgPosition
|
||||||
PartitionID int64
|
PartitionID int64
|
||||||
}
|
}
|
||||||
@ -149,6 +154,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
|
|||||||
if !sd.pkOracle.Exists(growing, paramtable.GetNodeID()) {
|
if !sd.pkOracle.Exists(growing, paramtable.GetNodeID()) {
|
||||||
// register created growing segment after insert, avoid to add empty growing to delegator
|
// register created growing segment after insert, avoid to add empty growing to delegator
|
||||||
sd.pkOracle.Register(growing, paramtable.GetNodeID())
|
sd.pkOracle.Register(growing, paramtable.GetNodeID())
|
||||||
|
sd.idfOracle.Register(segmentID, insertData.BM25Stats, segments.SegmentTypeGrowing)
|
||||||
sd.segmentManager.Put(context.Background(), segments.SegmentTypeGrowing, growing)
|
sd.segmentManager.Put(context.Background(), segments.SegmentTypeGrowing, growing)
|
||||||
sd.addGrowing(SegmentEntry{
|
sd.addGrowing(SegmentEntry{
|
||||||
NodeID: paramtable.GetNodeID(),
|
NodeID: paramtable.GetNodeID(),
|
||||||
@ -158,10 +164,12 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
|
|||||||
TargetVersion: initialTargetVersion,
|
TargetVersion: initialTargetVersion,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
sd.growingSegmentLock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("insert into growing segment",
|
sd.growingSegmentLock.Unlock()
|
||||||
|
} else {
|
||||||
|
sd.idfOracle.UpdateGrowing(growing.ID(), insertData.BM25Stats)
|
||||||
|
}
|
||||||
|
log.Info("insert into growing segment",
|
||||||
zap.Int64("collectionID", growing.Collection()),
|
zap.Int64("collectionID", growing.Collection()),
|
||||||
zap.Int64("segmentID", segmentID),
|
zap.Int64("segmentID", segmentID),
|
||||||
zap.Int("rowCount", len(insertData.RowIDs)),
|
zap.Int("rowCount", len(insertData.RowIDs)),
|
||||||
@ -375,8 +383,11 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm
|
|||||||
segmentIDs = lo.Map(loaded, func(segment segments.Segment, _ int) int64 { return segment.ID() })
|
segmentIDs = lo.Map(loaded, func(segment segments.Segment, _ int) int64 { return segment.ID() })
|
||||||
log.Info("load growing segments done", zap.Int64s("segmentIDs", segmentIDs))
|
log.Info("load growing segments done", zap.Int64s("segmentIDs", segmentIDs))
|
||||||
|
|
||||||
for _, candidate := range loaded {
|
for _, segment := range loaded {
|
||||||
sd.pkOracle.Register(candidate, paramtable.GetNodeID())
|
sd.pkOracle.Register(segment, paramtable.GetNodeID())
|
||||||
|
if sd.hasBM25Field {
|
||||||
|
sd.idfOracle.Register(segment.ID(), segment.GetBM25Stats(), segments.SegmentTypeGrowing)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sd.addGrowing(lo.Map(loaded, func(segment segments.Segment, _ int) SegmentEntry {
|
sd.addGrowing(lo.Map(loaded, func(segment segments.Segment, _ int) SegmentEntry {
|
||||||
return SegmentEntry{
|
return SegmentEntry{
|
||||||
@ -472,6 +483,16 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
|
|||||||
infos := lo.Filter(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) bool {
|
infos := lo.Filter(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) bool {
|
||||||
return !sd.pkOracle.Exists(pkoracle.NewCandidateKey(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed), targetNodeID)
|
return !sd.pkOracle.Exists(pkoracle.NewCandidateKey(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed), targetNodeID)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
|
||||||
|
if sd.hasBM25Field {
|
||||||
|
bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to load bm25 stats for segment", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...)
|
candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("failed to load bloom filter set for segment", zap.Error(err))
|
log.Warn("failed to load bloom filter set for segment", zap.Error(err))
|
||||||
@ -479,7 +500,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("load delete...")
|
log.Debug("load delete...")
|
||||||
err = sd.loadStreamDelete(ctx, candidates, infos, req, targetNodeID, worker)
|
err = sd.loadStreamDelete(ctx, candidates, bm25Stats, infos, req, targetNodeID, worker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("load stream delete failed", zap.Error(err))
|
log.Warn("load stream delete failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
@ -552,6 +573,7 @@ func (sd *shardDelegator) RefreshLevel0DeletionStats() {
|
|||||||
|
|
||||||
func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
|
func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
|
||||||
candidates []*pkoracle.BloomFilterSet,
|
candidates []*pkoracle.BloomFilterSet,
|
||||||
|
bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats],
|
||||||
infos []*querypb.SegmentLoadInfo,
|
infos []*querypb.SegmentLoadInfo,
|
||||||
req *querypb.LoadSegmentsRequest,
|
req *querypb.LoadSegmentsRequest,
|
||||||
targetNodeID int64,
|
targetNodeID int64,
|
||||||
@ -665,6 +687,14 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
|
|||||||
)
|
)
|
||||||
sd.pkOracle.Register(candidate, targetNodeID)
|
sd.pkOracle.Register(candidate, targetNodeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if bm25Stats != nil {
|
||||||
|
bm25Stats.Range(func(segmentID int64, stats map[int64]*storage.BM25Stats) bool {
|
||||||
|
sd.idfOracle.Register(segmentID, stats, segments.SegmentTypeSealed)
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
log.Info("load delete done")
|
log.Info("load delete done")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -963,3 +993,47 @@ func (sd *shardDelegator) TryCleanExcludedSegments(ts uint64) {
|
|||||||
sd.excludedSegments.CleanInvalid(ts)
|
sd.excludedSegments.CleanInvalid(ts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64, error) {
|
||||||
|
pb := &commonpb.PlaceholderGroup{}
|
||||||
|
proto.Unmarshal(req.GetPlaceholderGroup(), pb)
|
||||||
|
|
||||||
|
if len(pb.Placeholders) != 1 || len(pb.Placeholders[0].Values) != 1 {
|
||||||
|
return 0, merr.WrapErrParameterInvalidMsg("please provide varchar for bm25")
|
||||||
|
}
|
||||||
|
|
||||||
|
holder := pb.Placeholders[0]
|
||||||
|
if holder.Type != commonpb.PlaceholderType_VarChar {
|
||||||
|
return 0, fmt.Errorf("can't build BM25 IDF for data not varchar")
|
||||||
|
}
|
||||||
|
|
||||||
|
str := funcutil.GetVarCharFromPlaceholder(holder)
|
||||||
|
functionRunner, ok := sd.functionRunners[req.GetFieldId()]
|
||||||
|
if !ok {
|
||||||
|
return 0, fmt.Errorf("functionRunner not found for field: %d", req.GetFieldId())
|
||||||
|
}
|
||||||
|
|
||||||
|
// get search text term frequency
|
||||||
|
output, err := functionRunner.BatchRun(str)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tfArray, ok := output[0].(*schemapb.SparseFloatArray)
|
||||||
|
if !ok {
|
||||||
|
return 0, fmt.Errorf("functionRunner return unknown data")
|
||||||
|
}
|
||||||
|
|
||||||
|
idfSparseVector, avgdl, err := sd.idfOracle.BuildIDF(req.GetFieldId(), tfArray)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = SetBM25Params(req, avgdl)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.PlaceholderGroup = funcutil.SparseVectorDataToPlaceholderGroupBytes(idfSparseVector)
|
||||||
|
return avgdl, nil
|
||||||
|
}
|
||||||
|
|||||||
@ -26,14 +26,19 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
|
"github.com/pingcap/log"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||||
@ -42,10 +47,12 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
"github.com/milvus-io/milvus/internal/util/bloomfilter"
|
"github.com/milvus-io/milvus/internal/util/bloomfilter"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/function"
|
||||||
"github.com/milvus-io/milvus/internal/util/initcore"
|
"github.com/milvus-io/milvus/internal/util/initcore"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
@ -95,13 +102,7 @@ func (s *DelegatorDataSuite) TearDownSuite() {
|
|||||||
paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.Key)
|
paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DelegatorDataSuite) SetupTest() {
|
func (s *DelegatorDataSuite) genNormalCollection() {
|
||||||
s.workerManager = &cluster.MockManager{}
|
|
||||||
s.manager = segments.NewManager()
|
|
||||||
s.tsafeManager = tsafe.NewTSafeReplica()
|
|
||||||
s.loader = &segments.MockLoader{}
|
|
||||||
|
|
||||||
// init schema
|
|
||||||
s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||||
Name: "TestCollection",
|
Name: "TestCollection",
|
||||||
Fields: []*schemapb.FieldSchema{
|
Fields: []*schemapb.FieldSchema{
|
||||||
@ -154,7 +155,59 @@ func (s *DelegatorDataSuite) SetupTest() {
|
|||||||
LoadType: querypb.LoadType_LoadCollection,
|
LoadType: querypb.LoadType_LoadCollection,
|
||||||
PartitionIDs: []int64{1001, 1002},
|
PartitionIDs: []int64{1001, 1002},
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DelegatorDataSuite) genCollectionWithFunction() {
|
||||||
|
s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||||
|
Name: "TestCollection",
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
Name: "id",
|
||||||
|
FieldID: 100,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
AutoID: true,
|
||||||
|
}, {
|
||||||
|
Name: "vector",
|
||||||
|
FieldID: 101,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
DataType: schemapb.DataType_SparseFloatVector,
|
||||||
|
}, {
|
||||||
|
Name: "text",
|
||||||
|
FieldID: 102,
|
||||||
|
DataType: schemapb.DataType_VarChar,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.MaxLengthKey,
|
||||||
|
Value: "256",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []*schemapb.FunctionSchema{{
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldIds: []int64{102},
|
||||||
|
OutputFieldIds: []int64{101},
|
||||||
|
}},
|
||||||
|
}, nil, nil)
|
||||||
|
|
||||||
|
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||||
|
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||||
|
return s.mq, nil
|
||||||
|
},
|
||||||
|
}, 10000, nil, s.chunkManager)
|
||||||
|
s.NoError(err)
|
||||||
|
s.delegator = delegator.(*shardDelegator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DelegatorDataSuite) SetupTest() {
|
||||||
|
s.workerManager = &cluster.MockManager{}
|
||||||
|
s.manager = segments.NewManager()
|
||||||
|
s.tsafeManager = tsafe.NewTSafeReplica()
|
||||||
|
s.loader = &segments.MockLoader{}
|
||||||
|
|
||||||
|
// init schema
|
||||||
|
s.genNormalCollection()
|
||||||
s.mq = &msgstream.MockMsgStream{}
|
s.mq = &msgstream.MockMsgStream{}
|
||||||
s.rootPath = s.Suite.T().Name()
|
s.rootPath = s.Suite.T().Name()
|
||||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||||
@ -471,6 +524,127 @@ func (s *DelegatorDataSuite) TestProcessDelete() {
|
|||||||
s.False(s.delegator.distribution.Serviceable())
|
s.False(s.delegator.distribution.Serviceable())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DelegatorDataSuite) TestLoadGrowingWithBM25() {
|
||||||
|
s.genCollectionWithFunction()
|
||||||
|
mockSegment := segments.NewMockSegment(s.T())
|
||||||
|
s.loader.EXPECT().Load(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]segments.Segment{mockSegment}, nil)
|
||||||
|
|
||||||
|
mockSegment.EXPECT().Partition().Return(111)
|
||||||
|
mockSegment.EXPECT().ID().Return(111)
|
||||||
|
mockSegment.EXPECT().Type().Return(commonpb.SegmentState_Growing)
|
||||||
|
mockSegment.EXPECT().GetBM25Stats().Return(map[int64]*storage.BM25Stats{})
|
||||||
|
|
||||||
|
err := s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{{SegmentID: 1}}, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DelegatorDataSuite) TestLoadSegmentsWithBm25() {
|
||||||
|
s.genCollectionWithFunction()
|
||||||
|
s.Run("normal_run", func() {
|
||||||
|
defer func() {
|
||||||
|
s.workerManager.ExpectedCalls = nil
|
||||||
|
s.loader.ExpectedCalls = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
statsMap := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]()
|
||||||
|
stats := storage.NewBM25Stats()
|
||||||
|
stats.Append(map[uint32]float32{1: 1})
|
||||||
|
|
||||||
|
statsMap.Insert(1, map[int64]*storage.BM25Stats{101: stats})
|
||||||
|
|
||||||
|
s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(statsMap, nil)
|
||||||
|
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||||
|
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||||
|
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||||
|
return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||||
|
})
|
||||||
|
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
workers := make(map[int64]*cluster.MockWorker)
|
||||||
|
worker1 := &cluster.MockWorker{}
|
||||||
|
workers[1] = worker1
|
||||||
|
|
||||||
|
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||||
|
Return(nil)
|
||||||
|
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||||
|
return workers[nodeID]
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
DstNodeID: 1,
|
||||||
|
CollectionID: s.collectionID,
|
||||||
|
Infos: []*querypb.SegmentLoadInfo{
|
||||||
|
{
|
||||||
|
SegmentID: 100,
|
||||||
|
PartitionID: 500,
|
||||||
|
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||||
|
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||||
|
Level: datapb.SegmentLevel_L1,
|
||||||
|
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.NoError(err)
|
||||||
|
sealed, _ := s.delegator.GetSegmentInfo(false)
|
||||||
|
s.Require().Equal(1, len(sealed))
|
||||||
|
s.Equal(int64(1), sealed[0].NodeID)
|
||||||
|
s.ElementsMatch([]SegmentEntry{
|
||||||
|
{
|
||||||
|
SegmentID: 100,
|
||||||
|
NodeID: 1,
|
||||||
|
PartitionID: 500,
|
||||||
|
TargetVersion: unreadableTargetVersion,
|
||||||
|
Level: datapb.SegmentLevel_L1,
|
||||||
|
},
|
||||||
|
}, sealed[0].Segments)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("loadBM25_failed", func() {
|
||||||
|
defer func() {
|
||||||
|
s.workerManager.ExpectedCalls = nil
|
||||||
|
s.loader.ExpectedCalls = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(nil, fmt.Errorf("mock error"))
|
||||||
|
|
||||||
|
workers := make(map[int64]*cluster.MockWorker)
|
||||||
|
worker1 := &cluster.MockWorker{}
|
||||||
|
workers[1] = worker1
|
||||||
|
|
||||||
|
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||||
|
Return(nil)
|
||||||
|
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||||
|
return workers[nodeID]
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
DstNodeID: 1,
|
||||||
|
CollectionID: s.collectionID,
|
||||||
|
Infos: []*querypb.SegmentLoadInfo{
|
||||||
|
{
|
||||||
|
SegmentID: 100,
|
||||||
|
PartitionID: 500,
|
||||||
|
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||||
|
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||||
|
Level: datapb.SegmentLevel_L1,
|
||||||
|
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DelegatorDataSuite) TestLoadSegments() {
|
func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||||
s.Run("normal_run", func() {
|
s.Run("normal_run", func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -883,6 +1057,214 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DelegatorDataSuite) TestBuildBM25IDF() {
|
||||||
|
s.genCollectionWithFunction()
|
||||||
|
|
||||||
|
genBM25Stats := func(start uint32, end uint32) map[int64]*storage.BM25Stats {
|
||||||
|
result := make(map[int64]*storage.BM25Stats)
|
||||||
|
result[101] = storage.NewBM25Stats()
|
||||||
|
for i := start; i < end; i++ {
|
||||||
|
row := map[uint32]float32{i: 1}
|
||||||
|
result[101].Append(row)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
genSnapShot := func(seals, grows []int64, targetVersion int64) *snapshot {
|
||||||
|
snapshot := &snapshot{
|
||||||
|
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
|
||||||
|
targetVersion: targetVersion,
|
||||||
|
}
|
||||||
|
|
||||||
|
newSeal := []SegmentEntry{}
|
||||||
|
for _, seg := range seals {
|
||||||
|
newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion})
|
||||||
|
}
|
||||||
|
|
||||||
|
newGrow := []SegmentEntry{}
|
||||||
|
for _, seg := range grows {
|
||||||
|
newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion})
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Test-", zap.Any("shanshot", snapshot), zap.Any("seg", newSeal))
|
||||||
|
snapshot.dist[0].Segments = newSeal
|
||||||
|
snapshot.growing = newGrow
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
genStringFieldData := func(strs ...string) *schemapb.FieldData {
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_VarChar,
|
||||||
|
FieldId: 102,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: strs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Run("normal case", func() {
|
||||||
|
// register sealed
|
||||||
|
sealedSegs := []int64{1, 2, 3, 4}
|
||||||
|
for _, segID := range sealedSegs {
|
||||||
|
// every segment stats only has one token, avgdl = 1
|
||||||
|
s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||||
|
}
|
||||||
|
snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100)
|
||||||
|
|
||||||
|
s.delegator.idfOracle.SyncDistribution(snapshot)
|
||||||
|
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
plan, err := proto.Marshal(&planpb.PlanNode{
|
||||||
|
Node: &planpb.PlanNode_VectorAnns{
|
||||||
|
VectorAnns: &planpb.VectorANNS{
|
||||||
|
QueryInfo: &planpb.QueryInfo{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
req := &internalpb.SearchRequest{
|
||||||
|
PlaceholderGroup: placeholderGroupBytes,
|
||||||
|
SerializedExprPlan: plan,
|
||||||
|
FieldId: 101,
|
||||||
|
}
|
||||||
|
avgdl, err := s.delegator.buildBM25IDF(req)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(float64(1), avgdl)
|
||||||
|
|
||||||
|
// check avgdl in plan
|
||||||
|
newplan := &planpb.PlanNode{}
|
||||||
|
err = proto.Unmarshal(req.GetSerializedExprPlan(), newplan)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
annplan, ok := newplan.GetNode().(*planpb.PlanNode_VectorAnns)
|
||||||
|
s.Require().True(ok)
|
||||||
|
s.Equal(avgdl, annplan.VectorAnns.QueryInfo.Bm25Avgdl)
|
||||||
|
|
||||||
|
// check idf in placeholder
|
||||||
|
placeholder := &commonpb.PlaceholderGroup{}
|
||||||
|
err = proto.Unmarshal(req.GetPlaceholderGroup(), placeholder)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Equal(placeholder.GetPlaceholders()[0].GetType(), commonpb.PlaceholderType_SparseFloatVector)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("invalid place holder type error", func() {
|
||||||
|
placeholderGroupBytes, err := proto.Marshal(&commonpb.PlaceholderGroup{
|
||||||
|
Placeholders: []*commonpb.PlaceholderValue{{Type: commonpb.PlaceholderType_SparseFloatVector}},
|
||||||
|
})
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
req := &internalpb.SearchRequest{
|
||||||
|
PlaceholderGroup: placeholderGroupBytes,
|
||||||
|
FieldId: 101,
|
||||||
|
}
|
||||||
|
_, err = s.delegator.buildBM25IDF(req)
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("no function runner error", func() {
|
||||||
|
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
req := &internalpb.SearchRequest{
|
||||||
|
PlaceholderGroup: placeholderGroupBytes,
|
||||||
|
FieldId: 103, // invalid field id
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.delegator.buildBM25IDF(req)
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("function runner run failed error", func() {
|
||||||
|
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
oldRunner := s.delegator.functionRunners
|
||||||
|
mockRunner := function.NewMockFunctionRunner(s.T())
|
||||||
|
s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner}
|
||||||
|
mockRunner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock err"))
|
||||||
|
defer func() {
|
||||||
|
s.delegator.functionRunners = oldRunner
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := &internalpb.SearchRequest{
|
||||||
|
PlaceholderGroup: placeholderGroupBytes,
|
||||||
|
FieldId: 101,
|
||||||
|
}
|
||||||
|
_, err = s.delegator.buildBM25IDF(req)
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("function runner output type error", func() {
|
||||||
|
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
oldRunner := s.delegator.functionRunners
|
||||||
|
mockRunner := function.NewMockFunctionRunner(s.T())
|
||||||
|
s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner}
|
||||||
|
mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil)
|
||||||
|
defer func() {
|
||||||
|
s.delegator.functionRunners = oldRunner
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := &internalpb.SearchRequest{
|
||||||
|
PlaceholderGroup: placeholderGroupBytes,
|
||||||
|
FieldId: 101,
|
||||||
|
}
|
||||||
|
_, err = s.delegator.buildBM25IDF(req)
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("idf oracle build idf error", func() {
|
||||||
|
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
oldRunner := s.delegator.functionRunners
|
||||||
|
mockRunner := function.NewMockFunctionRunner(s.T())
|
||||||
|
s.delegator.functionRunners = map[int64]function.FunctionRunner{103: mockRunner}
|
||||||
|
mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{&schemapb.SparseFloatArray{Contents: [][]byte{typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{1: 1})}}}, nil)
|
||||||
|
defer func() {
|
||||||
|
s.delegator.functionRunners = oldRunner
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := &internalpb.SearchRequest{
|
||||||
|
PlaceholderGroup: placeholderGroupBytes,
|
||||||
|
FieldId: 103, // invalid field
|
||||||
|
}
|
||||||
|
_, err = s.delegator.buildBM25IDF(req)
|
||||||
|
s.Error(err)
|
||||||
|
log.Info("test", zap.Error(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("set avgdl failed", func() {
|
||||||
|
// register sealed
|
||||||
|
sealedSegs := []int64{1, 2, 3, 4}
|
||||||
|
for _, segID := range sealedSegs {
|
||||||
|
// every segment stats only has one token, avgdl = 1
|
||||||
|
s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||||
|
}
|
||||||
|
snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100)
|
||||||
|
|
||||||
|
s.delegator.idfOracle.SyncDistribution(snapshot)
|
||||||
|
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
req := &internalpb.SearchRequest{
|
||||||
|
PlaceholderGroup: placeholderGroupBytes,
|
||||||
|
FieldId: 101,
|
||||||
|
}
|
||||||
|
_, err = s.delegator.buildBM25IDF(req)
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DelegatorDataSuite) TestReleaseSegment() {
|
func (s *DelegatorDataSuite) TestReleaseSegment() {
|
||||||
s.loader.EXPECT().
|
s.loader.EXPECT().
|
||||||
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
|
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
|
||||||
|
|||||||
@ -178,6 +178,84 @@ func (s *DelegatorSuite) TearDownTest() {
|
|||||||
s.delegator = nil
|
s.delegator = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
|
||||||
|
s.Run("init function failed", func() {
|
||||||
|
manager := segments.NewManager()
|
||||||
|
manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||||
|
Name: "TestCollection",
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
Name: "id",
|
||||||
|
FieldID: 100,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
AutoID: true,
|
||||||
|
}, {
|
||||||
|
Name: "vector",
|
||||||
|
FieldID: 101,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
DataType: schemapb.DataType_SparseFloatVector,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []*schemapb.FunctionSchema{{
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldIds: []int64{102},
|
||||||
|
OutputFieldIds: []int64{101, 103}, // invalid output field
|
||||||
|
}},
|
||||||
|
}, nil, nil)
|
||||||
|
|
||||||
|
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||||
|
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||||
|
return s.mq, nil
|
||||||
|
},
|
||||||
|
}, 10000, nil, s.chunkManager)
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("init function failed", func() {
|
||||||
|
manager := segments.NewManager()
|
||||||
|
manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||||
|
Name: "TestCollection",
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
Name: "id",
|
||||||
|
FieldID: 100,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
AutoID: true,
|
||||||
|
}, {
|
||||||
|
Name: "vector",
|
||||||
|
FieldID: 101,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
DataType: schemapb.DataType_SparseFloatVector,
|
||||||
|
}, {
|
||||||
|
Name: "text",
|
||||||
|
FieldID: 102,
|
||||||
|
DataType: schemapb.DataType_VarChar,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.MaxLengthKey,
|
||||||
|
Value: "256",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []*schemapb.FunctionSchema{{
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldIds: []int64{102},
|
||||||
|
OutputFieldIds: []int64{101},
|
||||||
|
}},
|
||||||
|
}, nil, nil)
|
||||||
|
|
||||||
|
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||||
|
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||||
|
return s.mq, nil
|
||||||
|
},
|
||||||
|
}, 10000, nil, s.chunkManager)
|
||||||
|
s.NoError(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DelegatorSuite) TestBasicInfo() {
|
func (s *DelegatorSuite) TestBasicInfo() {
|
||||||
s.Equal(s.collectionID, s.delegator.Collection())
|
s.Equal(s.collectionID, s.delegator.Collection())
|
||||||
s.Equal(s.version, s.delegator.Version())
|
s.Equal(s.version, s.delegator.Version())
|
||||||
|
|||||||
@ -74,6 +74,8 @@ type distribution struct {
|
|||||||
// current is the snapshot for quick usage for search/query
|
// current is the snapshot for quick usage for search/query
|
||||||
// generated for each change of distribution
|
// generated for each change of distribution
|
||||||
current *atomic.Pointer[snapshot]
|
current *atomic.Pointer[snapshot]
|
||||||
|
|
||||||
|
idfOracle IDFOracle
|
||||||
// protects current & segments
|
// protects current & segments
|
||||||
mut sync.RWMutex
|
mut sync.RWMutex
|
||||||
}
|
}
|
||||||
@ -89,7 +91,7 @@ type SegmentEntry struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDistribution creates a new distribution instance with all field initialized.
|
// NewDistribution creates a new distribution instance with all field initialized.
|
||||||
func NewDistribution() *distribution {
|
func NewDistribution(idfOracle IDFOracle) *distribution {
|
||||||
dist := &distribution{
|
dist := &distribution{
|
||||||
serviceable: atomic.NewBool(false),
|
serviceable: atomic.NewBool(false),
|
||||||
growingSegments: make(map[UniqueID]SegmentEntry),
|
growingSegments: make(map[UniqueID]SegmentEntry),
|
||||||
@ -98,6 +100,7 @@ func NewDistribution() *distribution {
|
|||||||
current: atomic.NewPointer[snapshot](nil),
|
current: atomic.NewPointer[snapshot](nil),
|
||||||
offlines: typeutil.NewSet[int64](),
|
offlines: typeutil.NewSet[int64](),
|
||||||
targetVersion: atomic.NewInt64(initialTargetVersion),
|
targetVersion: atomic.NewInt64(initialTargetVersion),
|
||||||
|
idfOracle: idfOracle,
|
||||||
}
|
}
|
||||||
|
|
||||||
dist.genSnapshot()
|
dist.genSnapshot()
|
||||||
@ -367,6 +370,7 @@ func (d *distribution) genSnapshot() chan struct{} {
|
|||||||
d.current.Store(newSnapShot)
|
d.current.Store(newSnapShot)
|
||||||
// shall be a new one
|
// shall be a new one
|
||||||
d.snapshots.GetOrInsert(d.snapshotVersion, newSnapShot)
|
d.snapshots.GetOrInsert(d.snapshotVersion, newSnapShot)
|
||||||
|
d.idfOracle.SyncDistribution(newSnapShot)
|
||||||
|
|
||||||
// first snapshot, return closed chan
|
// first snapshot, return closed chan
|
||||||
if last == nil {
|
if last == nil {
|
||||||
|
|||||||
@ -21,6 +21,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DistributionSuite struct {
|
type DistributionSuite struct {
|
||||||
@ -29,7 +31,7 @@ type DistributionSuite struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DistributionSuite) SetupTest() {
|
func (s *DistributionSuite) SetupTest() {
|
||||||
s.dist = NewDistribution()
|
s.dist = NewDistribution(NewIDFOracle([]*schemapb.FunctionSchema{}))
|
||||||
s.Equal(initialTargetVersion, s.dist.getTargetVersion())
|
s.Equal(initialTargetVersion, s.dist.getTargetVersion())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
262
internal/querynodev2/delegator/idf_oracle.go
Normal file
262
internal/querynodev2/delegator/idf_oracle.go
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package delegator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||||
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IDFOracle interface {
|
||||||
|
// Activate(segmentID int64, state commonpb.SegmentState) error
|
||||||
|
// Deactivate(segmentID int64, state commonpb.SegmentState) error
|
||||||
|
|
||||||
|
SyncDistribution(snapshot *snapshot)
|
||||||
|
|
||||||
|
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
|
||||||
|
|
||||||
|
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
|
||||||
|
Remove(segmentID int64, state commonpb.SegmentState)
|
||||||
|
|
||||||
|
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type bm25Stats struct {
|
||||||
|
stats map[int64]*storage.BM25Stats
|
||||||
|
activate bool
|
||||||
|
targetVersion int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
|
||||||
|
for fieldID, newstats := range stats {
|
||||||
|
if stats, ok := s.stats[fieldID]; ok {
|
||||||
|
stats.Merge(newstats)
|
||||||
|
} else {
|
||||||
|
log.Panic("merge failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bm25Stats) Minus(stats map[int64]*storage.BM25Stats) {
|
||||||
|
for fieldID, newstats := range stats {
|
||||||
|
if stats, ok := s.stats[fieldID]; ok {
|
||||||
|
stats.Minus(newstats)
|
||||||
|
} else {
|
||||||
|
log.Panic("minus failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bm25Stats) GetStats(fieldID int64) (*storage.BM25Stats, error) {
|
||||||
|
stats, ok := s.stats[fieldID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("field not found in idf oracle BM25 stats")
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bm25Stats) NumRow() int64 {
|
||||||
|
for _, stats := range s.stats {
|
||||||
|
return stats.NumRow()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBm25Stats(functions []*schemapb.FunctionSchema) *bm25Stats {
|
||||||
|
stats := &bm25Stats{
|
||||||
|
stats: make(map[int64]*storage.BM25Stats),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, function := range functions {
|
||||||
|
if function.GetType() == schemapb.FunctionType_BM25 {
|
||||||
|
stats.stats[function.GetOutputFieldIds()[0]] = storage.NewBM25Stats()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
type idfOracle struct {
|
||||||
|
sync.RWMutex
|
||||||
|
|
||||||
|
current *bm25Stats
|
||||||
|
|
||||||
|
growing map[int64]*bm25Stats
|
||||||
|
sealed map[int64]*bm25Stats
|
||||||
|
|
||||||
|
targetVersion int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) {
|
||||||
|
o.Lock()
|
||||||
|
defer o.Unlock()
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case segments.SegmentTypeGrowing:
|
||||||
|
if _, ok := o.growing[segmentID]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
o.growing[segmentID] = &bm25Stats{
|
||||||
|
stats: stats,
|
||||||
|
activate: true,
|
||||||
|
targetVersion: initialTargetVersion,
|
||||||
|
}
|
||||||
|
o.current.Merge(stats)
|
||||||
|
case segments.SegmentTypeSealed:
|
||||||
|
if _, ok := o.sealed[segmentID]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
o.sealed[segmentID] = &bm25Stats{
|
||||||
|
stats: stats,
|
||||||
|
activate: false,
|
||||||
|
targetVersion: initialTargetVersion,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) {
|
||||||
|
if len(stats) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.Lock()
|
||||||
|
defer o.Unlock()
|
||||||
|
|
||||||
|
old, ok := o.growing[segmentID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
old.Merge(stats)
|
||||||
|
if old.activate {
|
||||||
|
o.current.Merge(stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) {
|
||||||
|
o.Lock()
|
||||||
|
defer o.Unlock()
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case segments.SegmentTypeGrowing:
|
||||||
|
if stats, ok := o.growing[segmentID]; ok {
|
||||||
|
if stats.activate {
|
||||||
|
o.current.Minus(stats.stats)
|
||||||
|
}
|
||||||
|
delete(o.growing, segmentID)
|
||||||
|
}
|
||||||
|
case segments.SegmentTypeSealed:
|
||||||
|
if stats, ok := o.sealed[segmentID]; ok {
|
||||||
|
if stats.activate {
|
||||||
|
o.current.Minus(stats.stats)
|
||||||
|
}
|
||||||
|
delete(o.sealed, segmentID)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *idfOracle) activate(stats *bm25Stats) {
|
||||||
|
stats.activate = true
|
||||||
|
o.current.Merge(stats.stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *idfOracle) deactivate(stats *bm25Stats) {
|
||||||
|
stats.activate = false
|
||||||
|
o.current.Minus(stats.stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
|
||||||
|
o.Lock()
|
||||||
|
defer o.Unlock()
|
||||||
|
|
||||||
|
sealed, growing := snapshot.Peek()
|
||||||
|
|
||||||
|
for _, item := range sealed {
|
||||||
|
for _, segment := range item.Segments {
|
||||||
|
if stats, ok := o.sealed[segment.SegmentID]; ok {
|
||||||
|
stats.targetVersion = segment.TargetVersion
|
||||||
|
} else {
|
||||||
|
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, segment := range growing {
|
||||||
|
if stats, ok := o.growing[segment.SegmentID]; ok {
|
||||||
|
stats.targetVersion = segment.TargetVersion
|
||||||
|
} else {
|
||||||
|
log.Warn("idf oracle lack some growing segment", zap.Int64("segmentID", segment.SegmentID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
o.targetVersion = snapshot.targetVersion
|
||||||
|
|
||||||
|
for _, stats := range o.sealed {
|
||||||
|
if !stats.activate && stats.targetVersion == o.targetVersion {
|
||||||
|
o.activate(stats)
|
||||||
|
} else if stats.activate && stats.targetVersion != o.targetVersion {
|
||||||
|
o.deactivate(stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stats := range o.growing {
|
||||||
|
if !stats.activate && (stats.targetVersion == o.targetVersion || stats.targetVersion == initialTargetVersion) {
|
||||||
|
o.activate(stats)
|
||||||
|
} else if stats.activate && (stats.targetVersion != o.targetVersion && stats.targetVersion != initialTargetVersion) {
|
||||||
|
o.deactivate(stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("sync distribution finished", zap.Int64("version", o.targetVersion), zap.Int64("numrow", o.current.NumRow()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *idfOracle) BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) {
|
||||||
|
o.RLock()
|
||||||
|
defer o.RUnlock()
|
||||||
|
|
||||||
|
stats, err := o.current.GetStats(fieldID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
idfBytes := make([][]byte, len(tfs.GetContents()))
|
||||||
|
for i, tf := range tfs.GetContents() {
|
||||||
|
idf := stats.BuildIDF(tf)
|
||||||
|
idfBytes[i] = idf
|
||||||
|
}
|
||||||
|
return idfBytes, stats.GetAvgdl(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIDFOracle(functions []*schemapb.FunctionSchema) IDFOracle {
|
||||||
|
return &idfOracle{
|
||||||
|
current: newBm25Stats(functions),
|
||||||
|
growing: make(map[int64]*bm25Stats),
|
||||||
|
sealed: make(map[int64]*bm25Stats),
|
||||||
|
}
|
||||||
|
}
|
||||||
198
internal/querynodev2/delegator/idf_oracle_test.go
Normal file
198
internal/querynodev2/delegator/idf_oracle_test.go
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package delegator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IDFOracleSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
collectionID int64
|
||||||
|
collectionSchema *schemapb.CollectionSchema
|
||||||
|
idfOracle *idfOracle
|
||||||
|
|
||||||
|
targetVersion int64
|
||||||
|
snapshot *snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *IDFOracleSuite) SetupSuite() {
|
||||||
|
suite.collectionID = 111
|
||||||
|
suite.collectionSchema = &schemapb.CollectionSchema{
|
||||||
|
Functions: []*schemapb.FunctionSchema{{
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldIds: []int64{101},
|
||||||
|
OutputFieldIds: []int64{102},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *IDFOracleSuite) SetupTest() {
|
||||||
|
suite.idfOracle = NewIDFOracle(suite.collectionSchema.GetFunctions()).(*idfOracle)
|
||||||
|
suite.snapshot = &snapshot{
|
||||||
|
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
|
||||||
|
}
|
||||||
|
suite.targetVersion = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *IDFOracleSuite) genStats(start uint32, end uint32) map[int64]*storage.BM25Stats {
|
||||||
|
result := make(map[int64]*storage.BM25Stats)
|
||||||
|
result[102] = storage.NewBM25Stats()
|
||||||
|
for i := start; i < end; i++ {
|
||||||
|
row := map[uint32]float32{i: 1}
|
||||||
|
result[102].Append(row)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// update test snapshot
|
||||||
|
func (suite *IDFOracleSuite) updateSnapshot(seals, grows, drops []int64) *snapshot {
|
||||||
|
suite.targetVersion++
|
||||||
|
snapshot := &snapshot{
|
||||||
|
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
|
||||||
|
targetVersion: suite.targetVersion,
|
||||||
|
}
|
||||||
|
|
||||||
|
dropSet := typeutil.NewSet[int64]()
|
||||||
|
dropSet.Insert(drops...)
|
||||||
|
|
||||||
|
newSeal := []SegmentEntry{}
|
||||||
|
for _, seg := range suite.snapshot.dist[0].Segments {
|
||||||
|
if !dropSet.Contain(seg.SegmentID) {
|
||||||
|
seg.TargetVersion = suite.targetVersion
|
||||||
|
}
|
||||||
|
newSeal = append(newSeal, seg)
|
||||||
|
}
|
||||||
|
for _, seg := range seals {
|
||||||
|
newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion})
|
||||||
|
}
|
||||||
|
|
||||||
|
newGrow := []SegmentEntry{}
|
||||||
|
for _, seg := range suite.snapshot.growing {
|
||||||
|
if !dropSet.Contain(seg.SegmentID) {
|
||||||
|
seg.TargetVersion = suite.targetVersion
|
||||||
|
} else {
|
||||||
|
seg.TargetVersion = redundantTargetVersion
|
||||||
|
}
|
||||||
|
newGrow = append(newGrow, seg)
|
||||||
|
}
|
||||||
|
for _, seg := range grows {
|
||||||
|
newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion})
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot.dist[0].Segments = newSeal
|
||||||
|
snapshot.growing = newGrow
|
||||||
|
suite.snapshot = snapshot
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *IDFOracleSuite) TestSealed() {
|
||||||
|
// register sealed
|
||||||
|
sealedSegs := []int64{1, 2, 3, 4}
|
||||||
|
for _, segID := range sealedSegs {
|
||||||
|
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduplicate register
|
||||||
|
for _, segID := range sealedSegs {
|
||||||
|
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// register sealed segment but all deactvate
|
||||||
|
suite.Zero(suite.idfOracle.current.NumRow())
|
||||||
|
|
||||||
|
// update and sync snapshot make all sealed activate
|
||||||
|
suite.updateSnapshot(sealedSegs, []int64{}, []int64{})
|
||||||
|
suite.idfOracle.SyncDistribution(suite.snapshot)
|
||||||
|
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
|
||||||
|
|
||||||
|
releasedSeg := []int64{1, 2, 3}
|
||||||
|
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
|
||||||
|
suite.idfOracle.SyncDistribution(suite.snapshot)
|
||||||
|
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
|
||||||
|
|
||||||
|
for _, segID := range releasedSeg {
|
||||||
|
suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed)
|
||||||
|
}
|
||||||
|
|
||||||
|
sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1})
|
||||||
|
bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1})
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(float64(1), avgdl)
|
||||||
|
suite.Equal(map[uint32]float32{4: 0.2876821}, typeutil.SparseFloatBytesToMap(bytes[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *IDFOracleSuite) TestGrow() {
|
||||||
|
// register grow
|
||||||
|
growSegs := []int64{1, 2, 3, 4}
|
||||||
|
for _, segID := range growSegs {
|
||||||
|
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing)
|
||||||
|
}
|
||||||
|
// reduplicate register
|
||||||
|
for _, segID := range growSegs {
|
||||||
|
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing)
|
||||||
|
}
|
||||||
|
|
||||||
|
// register sealed segment but all deactvate
|
||||||
|
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
|
||||||
|
suite.updateSnapshot([]int64{}, growSegs, []int64{})
|
||||||
|
|
||||||
|
releasedSeg := []int64{1, 2, 3}
|
||||||
|
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
|
||||||
|
suite.idfOracle.SyncDistribution(suite.snapshot)
|
||||||
|
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
|
||||||
|
|
||||||
|
suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6))
|
||||||
|
suite.Equal(int64(2), suite.idfOracle.current.NumRow())
|
||||||
|
|
||||||
|
for _, segID := range releasedSeg {
|
||||||
|
suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *IDFOracleSuite) TestStats() {
|
||||||
|
stats := newBm25Stats([]*schemapb.FunctionSchema{{
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldIds: []int64{101},
|
||||||
|
OutputFieldIds: []int64{102},
|
||||||
|
}})
|
||||||
|
|
||||||
|
suite.Panics(func() {
|
||||||
|
stats.Merge(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.Panics(func() {
|
||||||
|
stats.Minus(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := stats.GetStats(103)
|
||||||
|
suite.Error(err)
|
||||||
|
|
||||||
|
_, err = stats.GetStats(102)
|
||||||
|
suite.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIDFOracle(t *testing.T) {
|
||||||
|
suite.Run(t, new(IDFOracleSuite))
|
||||||
|
}
|
||||||
64
internal/querynodev2/delegator/util.go
Normal file
64
internal/querynodev2/delegator/util.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package delegator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BuildSparseFieldData(field *schemapb.FieldSchema, sparseArray *schemapb.SparseFloatArray) *schemapb.FieldData {
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
Type: field.GetDataType(),
|
||||||
|
FieldName: field.GetName(),
|
||||||
|
Field: &schemapb.FieldData_Vectors{
|
||||||
|
Vectors: &schemapb.VectorField{
|
||||||
|
Dim: sparseArray.GetDim(),
|
||||||
|
Data: &schemapb.VectorField_SparseFloatVector{
|
||||||
|
SparseFloatVector: sparseArray,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: field.GetFieldID(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetBM25Params(req *internalpb.SearchRequest, avgdl float64) error {
|
||||||
|
log := log.With(zap.Int64("collection", req.GetCollectionID()))
|
||||||
|
|
||||||
|
serializedPlan := req.GetSerializedExprPlan()
|
||||||
|
// plan not found
|
||||||
|
if serializedPlan == nil {
|
||||||
|
log.Warn("serialized plan not found")
|
||||||
|
return merr.WrapErrParameterInvalid("serialized search plan", "nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
plan := planpb.PlanNode{}
|
||||||
|
err := proto.Unmarshal(serializedPlan, &plan)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to unmarshal plan", zap.Error(err))
|
||||||
|
return merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
switch plan.GetNode().(type) {
|
||||||
|
case *planpb.PlanNode_VectorAnns:
|
||||||
|
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||||
|
queryInfo.Bm25Avgdl = avgdl
|
||||||
|
serializedExprPlan, err := proto.Marshal(&plan)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to marshal optimized plan", zap.Error(err))
|
||||||
|
return merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
|
||||||
|
}
|
||||||
|
req.SerializedExprPlan = serializedExprPlan
|
||||||
|
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
||||||
|
default:
|
||||||
|
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -60,6 +60,7 @@ func loadL0Segments(ctx context.Context, delegator delegator.ShardDelegator, req
|
|||||||
NumOfRows: segmentInfo.NumOfRows,
|
NumOfRows: segmentInfo.NumOfRows,
|
||||||
Statslogs: segmentInfo.Statslogs,
|
Statslogs: segmentInfo.Statslogs,
|
||||||
Deltalogs: segmentInfo.Deltalogs,
|
Deltalogs: segmentInfo.Deltalogs,
|
||||||
|
Bm25Logs: segmentInfo.Bm25Statslogs,
|
||||||
InsertChannel: segmentInfo.InsertChannel,
|
InsertChannel: segmentInfo.InsertChannel,
|
||||||
StartPosition: segmentInfo.GetStartPosition(),
|
StartPosition: segmentInfo.GetStartPosition(),
|
||||||
Level: segmentInfo.GetLevel(),
|
Level: segmentInfo.GetLevel(),
|
||||||
@ -101,6 +102,7 @@ func loadGrowingSegments(ctx context.Context, delegator delegator.ShardDelegator
|
|||||||
NumOfRows: segmentInfo.NumOfRows,
|
NumOfRows: segmentInfo.NumOfRows,
|
||||||
Statslogs: segmentInfo.Statslogs,
|
Statslogs: segmentInfo.Statslogs,
|
||||||
Deltalogs: segmentInfo.Deltalogs,
|
Deltalogs: segmentInfo.Deltalogs,
|
||||||
|
Bm25Logs: segmentInfo.Bm25Statslogs,
|
||||||
InsertChannel: segmentInfo.InsertChannel,
|
InsertChannel: segmentInfo.InsertChannel,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
211
internal/querynodev2/pipeline/embedding_node.go
Normal file
211
internal/querynodev2/pipeline/embedding_node.go
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package pipeline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||||
|
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||||
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/function"
|
||||||
|
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type embeddingNode struct {
|
||||||
|
*BaseNode
|
||||||
|
|
||||||
|
collectionID int64
|
||||||
|
channel string
|
||||||
|
|
||||||
|
manager *DataManager
|
||||||
|
|
||||||
|
functionRunners []function.FunctionRunner
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEmbeddingNode(collectionID int64, channelName string, manager *DataManager, maxQueueLength int32) (*embeddingNode, error) {
|
||||||
|
collection := manager.Collection.Get(collectionID)
|
||||||
|
if collection == nil {
|
||||||
|
log.Error("embeddingNode init failed with collection not exist", zap.Int64("collection", collectionID))
|
||||||
|
return nil, merr.WrapErrCollectionNotFound(collectionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(collection.Schema().GetFunctions()) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
node := &embeddingNode{
|
||||||
|
BaseNode: base.NewBaseNode(fmt.Sprintf("EmbeddingNode-%s", channelName), maxQueueLength),
|
||||||
|
collectionID: collectionID,
|
||||||
|
channel: channelName,
|
||||||
|
manager: manager,
|
||||||
|
functionRunners: make([]function.FunctionRunner, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tf := range collection.Schema().GetFunctions() {
|
||||||
|
functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
node.functionRunners = append(node.functionRunners, functionRunner)
|
||||||
|
}
|
||||||
|
return node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eNode *embeddingNode) Name() string {
|
||||||
|
return fmt.Sprintf("embeddingNode-%s", eNode.channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eNode *embeddingNode) addInsertData(insertDatas map[UniqueID]*delegator.InsertData, msg *InsertMsg, collection *Collection) error {
|
||||||
|
iData, ok := insertDatas[msg.SegmentID]
|
||||||
|
if !ok {
|
||||||
|
iData = &delegator.InsertData{
|
||||||
|
PartitionID: msg.PartitionID,
|
||||||
|
BM25Stats: make(map[int64]*storage.BM25Stats),
|
||||||
|
StartPosition: &msgpb.MsgPosition{
|
||||||
|
Timestamp: msg.BeginTs(),
|
||||||
|
ChannelName: msg.GetShardName(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
insertDatas[msg.SegmentID] = iData
|
||||||
|
}
|
||||||
|
|
||||||
|
err := eNode.embedding(msg, iData.BM25Stats)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to function data", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
insertRecord, err := storage.TransferInsertMsgToInsertRecord(collection.Schema(), msg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to get primary keys, err = %d", err)
|
||||||
|
log.Error(err.Error(), zap.String("channel", eNode.channel))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if iData.InsertRecord == nil {
|
||||||
|
iData.InsertRecord = insertRecord
|
||||||
|
} else {
|
||||||
|
err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to merge field data", zap.String("channel", eNode.channel), zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
iData.InsertRecord.NumRows += insertRecord.NumRows
|
||||||
|
}
|
||||||
|
|
||||||
|
pks, err := segments.GetPrimaryKeys(msg, collection.Schema())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to get primary keys from insert message", zap.String("channel", eNode.channel), zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
iData.PrimaryKeys = append(iData.PrimaryKeys, pks...)
|
||||||
|
iData.RowIDs = append(iData.RowIDs, msg.RowIDs...)
|
||||||
|
iData.Timestamps = append(iData.Timestamps, msg.Timestamps...)
|
||||||
|
log.Debug("pipeline embedding insert msg",
|
||||||
|
zap.Int64("collectionID", eNode.collectionID),
|
||||||
|
zap.Int64("segmentID", msg.SegmentID),
|
||||||
|
zap.Int("insertRowNum", len(pks)),
|
||||||
|
zap.Uint64("timestampMin", msg.BeginTimestamp),
|
||||||
|
zap.Uint64("timestampMax", msg.EndTimestamp))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error {
|
||||||
|
functionSchema := runner.GetSchema()
|
||||||
|
inputFieldID := functionSchema.GetInputFieldIds()[0]
|
||||||
|
outputFieldID := functionSchema.GetOutputFieldIds()[0]
|
||||||
|
outputField := runner.GetOutputFields()[0]
|
||||||
|
|
||||||
|
data, err := GetEmbeddingFieldData(msg.GetFieldsData(), inputFieldID)
|
||||||
|
if data == nil || err != nil {
|
||||||
|
return merr.WrapErrFieldNotFound(fmt.Sprint(inputFieldID))
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := runner.BatchRun(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sparseArray, ok := output[0].(*schemapb.SparseFloatArray)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("BM25 runner return unknown type output")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := stats[outputFieldID]; !ok {
|
||||||
|
stats[outputFieldID] = storage.NewBM25Stats()
|
||||||
|
}
|
||||||
|
stats[outputFieldID].AppendBytes(sparseArray.GetContents()...)
|
||||||
|
msg.FieldsData = append(msg.FieldsData, delegator.BuildSparseFieldData(outputField, sparseArray))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eNode *embeddingNode) embedding(msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error {
|
||||||
|
for _, functionRunner := range eNode.functionRunners {
|
||||||
|
functionSchema := functionRunner.GetSchema()
|
||||||
|
switch functionSchema.GetType() {
|
||||||
|
case schemapb.FunctionType_BM25:
|
||||||
|
err := eNode.bm25Embedding(functionRunner, msg, stats)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Warn("pipeline embedding with unknown function type", zap.Any("type", functionSchema.GetType()))
|
||||||
|
return fmt.Errorf("unknown function type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eNode *embeddingNode) Operate(in Msg) Msg {
|
||||||
|
nodeMsg := in.(*insertNodeMsg)
|
||||||
|
nodeMsg.insertDatas = make(map[int64]*delegator.InsertData)
|
||||||
|
|
||||||
|
collection := eNode.manager.Collection.Get(eNode.collectionID)
|
||||||
|
if collection == nil {
|
||||||
|
log.Error("embeddingNode with collection not exist", zap.Int64("collection", eNode.collectionID))
|
||||||
|
panic("embeddingNode with collection not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range nodeMsg.insertMsgs {
|
||||||
|
err := eNode.addInsertData(nodeMsg.insertDatas, msg, collection)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodeMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEmbeddingFieldData(datas []*schemapb.FieldData, fieldID int64) ([]string, error) {
|
||||||
|
for _, data := range datas {
|
||||||
|
if data.GetFieldId() == fieldID {
|
||||||
|
return data.GetScalars().GetStringData().GetData(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("field %d not found", fieldID)
|
||||||
|
}
|
||||||
281
internal/querynodev2/pipeline/embedding_node_test.go
Normal file
281
internal/querynodev2/pipeline/embedding_node_test.go
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package pipeline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||||
|
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/function"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
|
)
|
||||||
|
|
||||||
|
// test of embedding node
|
||||||
|
type EmbeddingNodeSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
// datas
|
||||||
|
collectionID int64
|
||||||
|
collectionSchema *schemapb.CollectionSchema
|
||||||
|
channel string
|
||||||
|
msgs []*InsertMsg
|
||||||
|
|
||||||
|
// mocks
|
||||||
|
manager *segments.Manager
|
||||||
|
segManager *segments.MockSegmentManager
|
||||||
|
colManager *segments.MockCollectionManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *EmbeddingNodeSuite) SetupTest() {
|
||||||
|
paramtable.Init()
|
||||||
|
suite.collectionID = 111
|
||||||
|
suite.channel = "test-channel"
|
||||||
|
suite.collectionSchema = &schemapb.CollectionSchema{
|
||||||
|
Name: "test-collection",
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
{
|
||||||
|
FieldID: common.TimeStampField,
|
||||||
|
Name: common.TimeStampFieldName,
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
}, {
|
||||||
|
Name: "pk",
|
||||||
|
FieldID: 100,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
}, {
|
||||||
|
Name: "text",
|
||||||
|
FieldID: 101,
|
||||||
|
DataType: schemapb.DataType_VarChar,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{},
|
||||||
|
}, {
|
||||||
|
Name: "sparse",
|
||||||
|
FieldID: 102,
|
||||||
|
DataType: schemapb.DataType_SparseFloatVector,
|
||||||
|
IsFunctionOutput: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []*schemapb.FunctionSchema{{
|
||||||
|
Name: "BM25",
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldIds: []int64{101},
|
||||||
|
OutputFieldIds: []int64{102},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
suite.msgs = []*msgstream.InsertMsg{{
|
||||||
|
BaseMsg: msgstream.BaseMsg{},
|
||||||
|
InsertRequest: &msgpb.InsertRequest{
|
||||||
|
SegmentID: 1,
|
||||||
|
NumRows: 3,
|
||||||
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||||
|
Timestamps: []uint64{1, 1, 1},
|
||||||
|
FieldsData: []*schemapb.FieldData{
|
||||||
|
{
|
||||||
|
FieldId: 100,
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
FieldId: 101,
|
||||||
|
Type: schemapb.DataType_VarChar,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
suite.segManager = segments.NewMockSegmentManager(suite.T())
|
||||||
|
suite.colManager = segments.NewMockCollectionManager(suite.T())
|
||||||
|
|
||||||
|
suite.manager = &segments.Manager{
|
||||||
|
Collection: suite.colManager,
|
||||||
|
Segment: suite.segManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *EmbeddingNodeSuite) TestCreateEmbeddingNode() {
|
||||||
|
suite.Run("collection not found", func() {
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once()
|
||||||
|
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.Run("function invalid", func() {
|
||||||
|
collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema)
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema)
|
||||||
|
collection.Schema().Functions = []*schemapb.FunctionSchema{{}}
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||||
|
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.Run("normal case", func() {
|
||||||
|
collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema)
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||||
|
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *EmbeddingNodeSuite) TestOperator() {
|
||||||
|
suite.Run("collection not found", func() {
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||||
|
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once()
|
||||||
|
suite.Panics(func() {
|
||||||
|
node.Operate(&insertNodeMsg{})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.Run("add InsertData Failed", func() {
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2)
|
||||||
|
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
suite.Panics(func() {
|
||||||
|
node.Operate(&insertNodeMsg{
|
||||||
|
insertMsgs: []*msgstream.InsertMsg{{
|
||||||
|
BaseMsg: msgstream.BaseMsg{},
|
||||||
|
InsertRequest: &msgpb.InsertRequest{
|
||||||
|
SegmentID: 1,
|
||||||
|
NumRows: 3,
|
||||||
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||||
|
FieldsData: []*schemapb.FieldData{
|
||||||
|
{
|
||||||
|
FieldId: 100,
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.Run("normal case", func() {
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2)
|
||||||
|
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
suite.NotPanics(func() {
|
||||||
|
output := node.Operate(&insertNodeMsg{
|
||||||
|
insertMsgs: suite.msgs,
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, ok := output.(*insertNodeMsg)
|
||||||
|
suite.Require().True(ok)
|
||||||
|
suite.Require().NotNil(msg.insertDatas)
|
||||||
|
suite.Require().Equal(int64(3), msg.insertDatas[1].BM25Stats[102].NumRow())
|
||||||
|
suite.Require().Equal(int64(3), msg.insertDatas[1].InsertRecord.GetNumRows())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *EmbeddingNodeSuite) TestAddInsertData() {
|
||||||
|
suite.Run("transfer insert msg failed", func() {
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||||
|
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
// transfer insert msg failed because rowbase data not support sparse vector
|
||||||
|
insertDatas := make(map[int64]*delegator.InsertData)
|
||||||
|
rowBaseReq := proto.Clone(suite.msgs[0].InsertRequest).(*msgpb.InsertRequest)
|
||||||
|
rowBaseReq.Version = msgpb.InsertDataVersion_RowBased
|
||||||
|
rowBaseMsg := &msgstream.InsertMsg{
|
||||||
|
BaseMsg: msgstream.BaseMsg{},
|
||||||
|
InsertRequest: rowBaseReq,
|
||||||
|
}
|
||||||
|
err = node.addInsertData(insertDatas, rowBaseMsg, collection)
|
||||||
|
suite.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.Run("merge failed data failed", func() {
|
||||||
|
// remove pk
|
||||||
|
suite.collectionSchema.Fields[1].IsPrimaryKey = false
|
||||||
|
defer func() {
|
||||||
|
suite.collectionSchema.Fields[1].IsPrimaryKey = true
|
||||||
|
}()
|
||||||
|
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||||
|
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
insertDatas := make(map[int64]*delegator.InsertData)
|
||||||
|
err = node.addInsertData(insertDatas, suite.msgs[0], collection)
|
||||||
|
suite.Error(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *EmbeddingNodeSuite) TestBM25Embedding() {
|
||||||
|
suite.Run("function run failed", func() {
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||||
|
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
runner := function.NewMockFunctionRunner(suite.T())
|
||||||
|
runner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock error"))
|
||||||
|
runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0])
|
||||||
|
runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil})
|
||||||
|
|
||||||
|
err = node.bm25Embedding(runner, suite.msgs[0], nil)
|
||||||
|
suite.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.Run("output with unknown type failed", func() {
|
||||||
|
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||||
|
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||||
|
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
runner := function.NewMockFunctionRunner(suite.T())
|
||||||
|
runner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil)
|
||||||
|
runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0])
|
||||||
|
runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil})
|
||||||
|
|
||||||
|
err = node.bm25Embedding(runner, suite.msgs[0], nil)
|
||||||
|
suite.Error(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingNode(t *testing.T) {
|
||||||
|
suite.Run(t, new(EmbeddingNodeSuite))
|
||||||
|
}
|
||||||
@ -62,7 +62,7 @@ func (iNode *insertNode) addInsertData(insertDatas map[UniqueID]*delegator.Inser
|
|||||||
} else {
|
} else {
|
||||||
err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
|
err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("failed to merge field data", zap.Error(err))
|
log.Error("failed to merge field data", zap.String("channel", iNode.channel), zap.Error(err))
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
iData.InsertRecord.NumRows += insertRecord.NumRows
|
iData.InsertRecord.NumRows += insertRecord.NumRows
|
||||||
@ -95,21 +95,23 @@ func (iNode *insertNode) Operate(in Msg) Msg {
|
|||||||
return nodeMsg.insertMsgs[i].BeginTs() < nodeMsg.insertMsgs[j].BeginTs()
|
return nodeMsg.insertMsgs[i].BeginTs() < nodeMsg.insertMsgs[j].BeginTs()
|
||||||
})
|
})
|
||||||
|
|
||||||
insertDatas := make(map[UniqueID]*delegator.InsertData)
|
// build insert data if no embedding node
|
||||||
|
if nodeMsg.insertDatas == nil {
|
||||||
collection := iNode.manager.Collection.Get(iNode.collectionID)
|
collection := iNode.manager.Collection.Get(iNode.collectionID)
|
||||||
if collection == nil {
|
if collection == nil {
|
||||||
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
|
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
|
||||||
panic("insertNode with collection not exist")
|
panic("insertNode with collection not exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nodeMsg.insertDatas = make(map[UniqueID]*delegator.InsertData)
|
||||||
// get InsertData and merge datas of same segment
|
// get InsertData and merge datas of same segment
|
||||||
for _, msg := range nodeMsg.insertMsgs {
|
for _, msg := range nodeMsg.insertMsgs {
|
||||||
iNode.addInsertData(insertDatas, msg, collection)
|
iNode.addInsertData(nodeMsg.insertDatas, msg, collection)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
iNode.delegator.ProcessInsert(insertDatas)
|
iNode.delegator.ProcessInsert(nodeMsg.insertDatas)
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.QueryNodeWaitProcessingMsgCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.DeleteLabel).Inc()
|
metrics.QueryNodeWaitProcessingMsgCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.DeleteLabel).Inc()
|
||||||
|
|
||||||
return &deleteNodeMsg{
|
return &deleteNodeMsg{
|
||||||
|
|||||||
@ -19,6 +19,7 @@ package pipeline
|
|||||||
import (
|
import (
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/collector"
|
"github.com/milvus-io/milvus/internal/querynodev2/collector"
|
||||||
|
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||||
@ -27,6 +28,7 @@ import (
|
|||||||
type insertNodeMsg struct {
|
type insertNodeMsg struct {
|
||||||
insertMsgs []*InsertMsg
|
insertMsgs []*InsertMsg
|
||||||
deleteMsgs []*DeleteMsg
|
deleteMsgs []*DeleteMsg
|
||||||
|
insertDatas map[int64]*delegator.InsertData
|
||||||
timeRange TimeRange
|
timeRange TimeRange
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ type pipeline struct {
|
|||||||
base.StreamPipeline
|
base.StreamPipeline
|
||||||
|
|
||||||
collectionID UniqueID
|
collectionID UniqueID
|
||||||
|
embeddingNode embeddingNode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *pipeline) Close() {
|
func (p *pipeline) Close() {
|
||||||
@ -54,8 +55,21 @@ func NewPipeLine(
|
|||||||
}
|
}
|
||||||
|
|
||||||
filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
||||||
|
|
||||||
|
embeddingNode, err := newEmbeddingNode(collectionID, channel, manager, pipelineQueueLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
insertNode := newInsertNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
insertNode := newInsertNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
||||||
deleteNode := newDeleteNode(collectionID, channel, manager, tSafeManager, delegator, pipelineQueueLength)
|
deleteNode := newDeleteNode(collectionID, channel, manager, tSafeManager, delegator, pipelineQueueLength)
|
||||||
|
|
||||||
|
// skip add embedding node when collection has no function.
|
||||||
|
if embeddingNode != nil {
|
||||||
|
p.Add(filterNode, embeddingNode, insertNode, deleteNode)
|
||||||
|
} else {
|
||||||
p.Add(filterNode, insertNode, deleteNode)
|
p.Add(filterNode, insertNode, deleteNode)
|
||||||
|
}
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -299,6 +299,18 @@ func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// new collection without segcore prepare
|
||||||
|
// ONLY FOR TEST
|
||||||
|
func NewCollectionWithoutSegcoreForTest(collectionID int64, schema *schemapb.CollectionSchema) *Collection {
|
||||||
|
coll := &Collection{
|
||||||
|
id: collectionID,
|
||||||
|
partitions: typeutil.NewConcurrentSet[int64](),
|
||||||
|
refCount: atomic.NewUint32(0),
|
||||||
|
}
|
||||||
|
coll.schema.Store(schema)
|
||||||
|
return coll
|
||||||
|
}
|
||||||
|
|
||||||
// deleteCollection delete collection and free the collection memory
|
// deleteCollection delete collection and free the collection memory
|
||||||
func DeleteCollection(collection *Collection) {
|
func DeleteCollection(collection *Collection) {
|
||||||
/*
|
/*
|
||||||
|
|||||||
@ -47,6 +47,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/testutils"
|
"github.com/milvus-io/milvus/pkg/util/testutils"
|
||||||
@ -220,6 +221,11 @@ func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema {
|
|||||||
DataType: param.dataType,
|
DataType: param.dataType,
|
||||||
ElementType: schemapb.DataType_Int32,
|
ElementType: schemapb.DataType_Int32,
|
||||||
}
|
}
|
||||||
|
if param.dataType == schemapb.DataType_VarChar {
|
||||||
|
field.TypeParams = []*commonpb.KeyValuePair{
|
||||||
|
{Key: common.MaxLengthKey, Value: "128"},
|
||||||
|
}
|
||||||
|
}
|
||||||
return field
|
return field
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,6 +269,35 @@ func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema {
|
|||||||
return fieldVec
|
return fieldVec
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSchema {
|
||||||
|
fieldRowID := genConstantFieldSchema(rowIDField)
|
||||||
|
fieldTimestamp := genConstantFieldSchema(timestampField)
|
||||||
|
pkFieldSchema := genPKFieldSchema(simpleInt64Field)
|
||||||
|
textFieldSchema := genConstantFieldSchema(simpleVarCharField)
|
||||||
|
sparseFieldSchema := genVectorFieldSchema(simpleSparseFloatVectorField)
|
||||||
|
sparseFieldSchema.IsFunctionOutput = true
|
||||||
|
|
||||||
|
schema := &schemapb.CollectionSchema{
|
||||||
|
Name: collectionName,
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
fieldRowID,
|
||||||
|
fieldTimestamp,
|
||||||
|
pkFieldSchema,
|
||||||
|
textFieldSchema,
|
||||||
|
sparseFieldSchema,
|
||||||
|
},
|
||||||
|
Functions: []*schemapb.FunctionSchema{{
|
||||||
|
Name: "BM25",
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldNames: []string{textFieldSchema.GetName()},
|
||||||
|
InputFieldIds: []int64{textFieldSchema.GetFieldID()},
|
||||||
|
OutputFieldNames: []string{sparseFieldSchema.GetName()},
|
||||||
|
OutputFieldIds: []int64{sparseFieldSchema.GetFieldID()},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
// some tests do not yet support sparse float vector, see comments of
|
// some tests do not yet support sparse float vector, see comments of
|
||||||
// GenSparseFloatVecDataset in indexcgowrapper/dataset.go
|
// GenSparseFloatVecDataset in indexcgowrapper/dataset.go
|
||||||
func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema {
|
func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema {
|
||||||
@ -671,6 +706,32 @@ func SaveDeltaLog(collectionID int64,
|
|||||||
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
|
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveBM25Log(collectionID int64, partitionID int64, segmentID int64, fieldID int64, msgLength int, cm storage.ChunkManager) (*datapb.FieldBinlog, error) {
|
||||||
|
stats := storage.NewBM25Stats()
|
||||||
|
|
||||||
|
for i := 0; i < msgLength; i++ {
|
||||||
|
stats.Append(map[uint32]float32{1: 1})
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, err := stats.Serialize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
kvs := make(map[string][]byte, 1)
|
||||||
|
key := path.Join(cm.RootPath(), common.SegmentBm25LogPath, metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID, 1001))
|
||||||
|
kvs[key] = bytes
|
||||||
|
fieldBinlog := &datapb.FieldBinlog{
|
||||||
|
FieldID: fieldID,
|
||||||
|
Binlogs: []*datapb.Binlog{{
|
||||||
|
LogPath: key,
|
||||||
|
TimestampFrom: 100,
|
||||||
|
TimestampTo: 200,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
|
||||||
|
}
|
||||||
|
|
||||||
func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
|
func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
|
||||||
fieldSchema *schemapb.FieldSchema,
|
fieldSchema *schemapb.FieldSchema,
|
||||||
indexInfo *indexpb.IndexInfo,
|
indexInfo *indexpb.IndexInfo,
|
||||||
|
|||||||
@ -14,6 +14,10 @@ import (
|
|||||||
pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||||
|
|
||||||
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
|
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
|
|
||||||
|
storage "github.com/milvus-io/milvus/internal/storage"
|
||||||
|
|
||||||
|
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockLoader is an autogenerated mock type for the Loader type
|
// MockLoader is an autogenerated mock type for the Loader type
|
||||||
@ -101,6 +105,76 @@ func (_c *MockLoader_Load_Call) RunAndReturn(run func(context.Context, int64, co
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadBM25Stats provides a mock function with given fields: ctx, collectionID, infos
|
||||||
|
func (_m *MockLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) {
|
||||||
|
_va := make([]interface{}, len(infos))
|
||||||
|
for _i := range infos {
|
||||||
|
_va[_i] = infos[_i]
|
||||||
|
}
|
||||||
|
var _ca []interface{}
|
||||||
|
_ca = append(_ca, ctx, collectionID)
|
||||||
|
_ca = append(_ca, _va...)
|
||||||
|
ret := _m.Called(_ca...)
|
||||||
|
|
||||||
|
var r0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)); ok {
|
||||||
|
return rf(ctx, collectionID, infos...)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]); ok {
|
||||||
|
r0 = rf(ctx, collectionID, infos...)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) error); ok {
|
||||||
|
r1 = rf(ctx, collectionID, infos...)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoader_LoadBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadBM25Stats'
|
||||||
|
type MockLoader_LoadBM25Stats_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadBM25Stats is a helper method to define mock.On call
|
||||||
|
// - ctx context.Context
|
||||||
|
// - collectionID int64
|
||||||
|
// - infos ...*querypb.SegmentLoadInfo
|
||||||
|
func (_e *MockLoader_Expecter) LoadBM25Stats(ctx interface{}, collectionID interface{}, infos ...interface{}) *MockLoader_LoadBM25Stats_Call {
|
||||||
|
return &MockLoader_LoadBM25Stats_Call{Call: _e.mock.On("LoadBM25Stats",
|
||||||
|
append([]interface{}{ctx, collectionID}, infos...)...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockLoader_LoadBM25Stats_Call) Run(run func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo)) *MockLoader_LoadBM25Stats_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-2)
|
||||||
|
for i, a := range args[2:] {
|
||||||
|
if a != nil {
|
||||||
|
variadicArgs[i] = a.(*querypb.SegmentLoadInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockLoader_LoadBM25Stats_Call) Return(_a0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], _a1 error) *MockLoader_LoadBM25Stats_Call {
|
||||||
|
_c.Call.Return(_a0, _a1)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockLoader_LoadBM25Stats_Call) RunAndReturn(run func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)) *MockLoader_LoadBM25Stats_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// LoadBloomFilterSet provides a mock function with given fields: ctx, collectionID, version, infos
|
// LoadBloomFilterSet provides a mock function with given fields: ctx, collectionID, version, infos
|
||||||
func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
|
func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
|
||||||
_va := make([]interface{}, len(infos))
|
_va := make([]interface{}, len(infos))
|
||||||
|
|||||||
@ -290,6 +290,49 @@ func (_c *MockSegment_ExistIndex_Call) RunAndReturn(run func(int64) bool) *MockS
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBM25Stats provides a mock function with given fields:
|
||||||
|
func (_m *MockSegment) GetBM25Stats() map[int64]*storage.BM25Stats {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 map[int64]*storage.BM25Stats
|
||||||
|
if rf, ok := ret.Get(0).(func() map[int64]*storage.BM25Stats); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(map[int64]*storage.BM25Stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockSegment_GetBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBM25Stats'
|
||||||
|
type MockSegment_GetBM25Stats_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBM25Stats is a helper method to define mock.On call
|
||||||
|
func (_e *MockSegment_Expecter) GetBM25Stats() *MockSegment_GetBM25Stats_Call {
|
||||||
|
return &MockSegment_GetBM25Stats_Call{Call: _e.mock.On("GetBM25Stats")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockSegment_GetBM25Stats_Call) Run(run func()) *MockSegment_GetBM25Stats_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run()
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockSegment_GetBM25Stats_Call) Return(_a0 map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call {
|
||||||
|
_c.Call.Return(_a0)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockSegment_GetBM25Stats_Call) RunAndReturn(run func() map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// GetIndex provides a mock function with given fields: fieldID
|
// GetIndex provides a mock function with given fields: fieldID
|
||||||
func (_m *MockSegment) GetIndex(fieldID int64) *IndexedFieldInfo {
|
func (_m *MockSegment) GetIndex(fieldID int64) *IndexedFieldInfo {
|
||||||
ret := _m.Called(fieldID)
|
ret := _m.Called(fieldID)
|
||||||
@ -1570,6 +1613,39 @@ func (_c *MockSegment_Unpin_Call) RunAndReturn(run func()) *MockSegment_Unpin_Ca
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateBM25Stats provides a mock function with given fields: stats
|
||||||
|
func (_m *MockSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) {
|
||||||
|
_m.Called(stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockSegment_UpdateBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBM25Stats'
|
||||||
|
type MockSegment_UpdateBM25Stats_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBM25Stats is a helper method to define mock.On call
|
||||||
|
// - stats map[int64]*storage.BM25Stats
|
||||||
|
func (_e *MockSegment_Expecter) UpdateBM25Stats(stats interface{}) *MockSegment_UpdateBM25Stats_Call {
|
||||||
|
return &MockSegment_UpdateBM25Stats_Call{Call: _e.mock.On("UpdateBM25Stats", stats)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockSegment_UpdateBM25Stats_Call) Run(run func(stats map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(map[int64]*storage.BM25Stats))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockSegment_UpdateBM25Stats_Call) Return() *MockSegment_UpdateBM25Stats_Call {
|
||||||
|
_c.Call.Return()
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockSegment_UpdateBM25Stats_Call) RunAndReturn(run func(map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateBloomFilter provides a mock function with given fields: pks
|
// UpdateBloomFilter provides a mock function with given fields: pks
|
||||||
func (_m *MockSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
func (_m *MockSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
||||||
_m.Called(pks)
|
_m.Called(pks)
|
||||||
|
|||||||
@ -91,6 +91,8 @@ type baseSegment struct {
|
|||||||
isLazyLoad bool
|
isLazyLoad bool
|
||||||
channel metautil.Channel
|
channel metautil.Channel
|
||||||
|
|
||||||
|
bm25Stats map[int64]*storage.BM25Stats
|
||||||
|
|
||||||
resourceUsageCache *atomic.Pointer[ResourceUsage]
|
resourceUsageCache *atomic.Pointer[ResourceUsage]
|
||||||
|
|
||||||
needUpdatedVersion *atomic.Int64 // only for lazy load mode update index
|
needUpdatedVersion *atomic.Int64 // only for lazy load mode update index
|
||||||
@ -107,6 +109,7 @@ func newBaseSegment(collection *Collection, segmentType SegmentType, version int
|
|||||||
version: atomic.NewInt64(version),
|
version: atomic.NewInt64(version),
|
||||||
segmentType: segmentType,
|
segmentType: segmentType,
|
||||||
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
|
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
|
||||||
|
bm25Stats: make(map[int64]*storage.BM25Stats),
|
||||||
channel: channel,
|
channel: channel,
|
||||||
isLazyLoad: isLazyLoad(collection, segmentType),
|
isLazyLoad: isLazyLoad(collection, segmentType),
|
||||||
|
|
||||||
@ -185,6 +188,20 @@ func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
|||||||
s.bloomFilterSet.UpdateBloomFilter(pks)
|
s.bloomFilterSet.UpdateBloomFilter(pks)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *baseSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) {
|
||||||
|
for fieldID, new := range stats {
|
||||||
|
if current, ok := s.bm25Stats[fieldID]; ok {
|
||||||
|
current.Merge(new)
|
||||||
|
} else {
|
||||||
|
s.bm25Stats[fieldID] = new
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *baseSegment) GetBM25Stats() map[int64]*storage.BM25Stats {
|
||||||
|
return s.bm25Stats
|
||||||
|
}
|
||||||
|
|
||||||
// MayPkExist returns true if the given PK exists in the PK range and being positive through the bloom filter,
|
// MayPkExist returns true if the given PK exists in the PK range and being positive through the bloom filter,
|
||||||
// false otherwise,
|
// false otherwise,
|
||||||
// may returns true even the PK doesn't exist actually
|
// may returns true even the PK doesn't exist actually
|
||||||
|
|||||||
@ -87,6 +87,10 @@ type Segment interface {
|
|||||||
MayPkExist(lc *storage.LocationsCache) bool
|
MayPkExist(lc *storage.LocationsCache) bool
|
||||||
BatchPkExist(lc *storage.BatchLocationsCache) []bool
|
BatchPkExist(lc *storage.BatchLocationsCache) []bool
|
||||||
|
|
||||||
|
// BM25 stats
|
||||||
|
UpdateBM25Stats(stats map[int64]*storage.BM25Stats)
|
||||||
|
GetBM25Stats() map[int64]*storage.BM25Stats
|
||||||
|
|
||||||
// Read operations
|
// Read operations
|
||||||
Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error)
|
Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error)
|
||||||
Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error)
|
Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error)
|
||||||
|
|||||||
@ -77,6 +77,9 @@ type Loader interface {
|
|||||||
// LoadBloomFilterSet loads needed statslog for RemoteSegment.
|
// LoadBloomFilterSet loads needed statslog for RemoteSegment.
|
||||||
LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)
|
LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)
|
||||||
|
|
||||||
|
// LoadBM25Stats loads BM25 statslog for RemoteSegment
|
||||||
|
LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)
|
||||||
|
|
||||||
// LoadIndex append index for segment and remove vector binlogs.
|
// LoadIndex append index for segment and remove vector binlogs.
|
||||||
LoadIndex(ctx context.Context,
|
LoadIndex(ctx context.Context,
|
||||||
segment Segment,
|
segment Segment,
|
||||||
@ -543,6 +546,47 @@ func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentTyp
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (loader *segmentLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) {
|
||||||
|
segmentNum := len(infos)
|
||||||
|
if segmentNum == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
segments := lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) int64 {
|
||||||
|
return info.GetSegmentID()
|
||||||
|
})
|
||||||
|
log.Info("start loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Int("segmentNum", segmentNum))
|
||||||
|
|
||||||
|
loadedStats := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]()
|
||||||
|
loadRemoteBM25Func := func(idx int) error {
|
||||||
|
loadInfo := infos[idx]
|
||||||
|
segmentID := loadInfo.SegmentID
|
||||||
|
stats := make(map[int64]*storage.BM25Stats)
|
||||||
|
|
||||||
|
log.Info("loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64("segment", segmentID))
|
||||||
|
logpaths := loader.filterBM25Stats(loadInfo.Bm25Logs)
|
||||||
|
err := loader.loadBm25Stats(ctx, segmentID, stats, logpaths)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("load remote segment bm25 stats failed",
|
||||||
|
zap.Int64("segmentID", segmentID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
loadedStats.Insert(segmentID, stats)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := funcutil.ProcessFuncParallel(segmentNum, segmentNum, loadRemoteBM25Func, "loadRemoteBM25Func")
|
||||||
|
if err != nil {
|
||||||
|
// no partial success here
|
||||||
|
log.Warn("failed to load bm25 stats for remote segment", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return loadedStats, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
|
func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
|
||||||
log := log.Ctx(ctx).With(
|
log := log.Ctx(ctx).With(
|
||||||
zap.Int64("collectionID", collectionID),
|
zap.Int64("collectionID", collectionID),
|
||||||
@ -826,6 +870,16 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(loadInfo.Bm25Logs) > 0 {
|
||||||
|
log.Info("loading bm25 stats...")
|
||||||
|
bm25StatsLogs := loader.filterBM25Stats(loadInfo.Bm25Logs)
|
||||||
|
|
||||||
|
err = loader.loadBm25Stats(ctx, segment.ID(), segment.bm25Stats, bm25StatsLogs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -898,6 +952,26 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi
|
|||||||
return result, storage.DefaultStatsType
|
return result, storage.DefaultStatsType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (loader *segmentLoader) filterBM25Stats(fieldBinlogs []*datapb.FieldBinlog) map[int64][]string {
|
||||||
|
result := make(map[int64][]string, 0)
|
||||||
|
for _, fieldBinlog := range fieldBinlogs {
|
||||||
|
logpaths := []string{}
|
||||||
|
for _, binlog := range fieldBinlog.GetBinlogs() {
|
||||||
|
_, logidx := path.Split(binlog.GetLogPath())
|
||||||
|
// if special status log exist
|
||||||
|
// only load one file
|
||||||
|
if logidx == storage.CompoundStatsType.LogIdx() {
|
||||||
|
logpaths = []string{binlog.GetLogPath()}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
logpaths = append(logpaths, binlog.GetLogPath())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result[fieldBinlog.FieldID] = logpaths
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error {
|
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error {
|
||||||
runningGroup, _ := errgroup.WithContext(ctx)
|
runningGroup, _ := errgroup.WithContext(ctx)
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
@ -989,6 +1063,51 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
|
|||||||
return segment.LoadIndex(ctx, indexInfo, fieldType)
|
return segment.LoadIndex(ctx, indexInfo, fieldType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (loader *segmentLoader) loadBm25Stats(ctx context.Context, segmentID int64, stats map[int64]*storage.BM25Stats, binlogPaths map[int64][]string) error {
|
||||||
|
log := log.Ctx(ctx).With(
|
||||||
|
zap.Int64("segmentID", segmentID),
|
||||||
|
)
|
||||||
|
if len(binlogPaths) == 0 {
|
||||||
|
log.Info("there are no bm25 stats logs saved with segment")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pathList := []string{}
|
||||||
|
fieldList := []int64{}
|
||||||
|
fieldOffset := []int{}
|
||||||
|
for fieldId, logpaths := range binlogPaths {
|
||||||
|
pathList = append(pathList, logpaths...)
|
||||||
|
fieldList = append(fieldList, fieldId)
|
||||||
|
fieldOffset = append(fieldOffset, len(logpaths))
|
||||||
|
}
|
||||||
|
|
||||||
|
startTs := time.Now()
|
||||||
|
values, err := loader.cm.MultiRead(ctx, pathList)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cnt := 0
|
||||||
|
for i, fieldID := range fieldList {
|
||||||
|
newStats, ok := stats[fieldID]
|
||||||
|
if !ok {
|
||||||
|
newStats = storage.NewBM25Stats()
|
||||||
|
stats[fieldID] = newStats
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < fieldOffset[i]; j++ {
|
||||||
|
err := newStats.Deserialize(values[cnt+j])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cnt += fieldOffset[i]
|
||||||
|
log.Info("Successfully load bm25 stats", zap.Duration("time", time.Since(startTs)), zap.Int64("numRow", newStats.NumRow()), zap.Int64("fieldID", fieldID))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet,
|
func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet,
|
||||||
binlogPaths []string, logType storage.StatsLogType,
|
binlogPaths []string, logType storage.StatsLogType,
|
||||||
) error {
|
) error {
|
||||||
|
|||||||
@ -70,14 +70,16 @@ func (suite *SegmentLoaderSuite) SetupSuite() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (suite *SegmentLoaderSuite) SetupTest() {
|
func (suite *SegmentLoaderSuite) SetupTest() {
|
||||||
// Dependencies
|
|
||||||
suite.manager = NewManager()
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// TODO:: cpp chunk manager not support local chunk manager
|
// TODO:: cpp chunk manager not support local chunk manager
|
||||||
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
|
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
|
||||||
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
|
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
|
||||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||||
|
|
||||||
|
// Dependencies
|
||||||
|
suite.manager = NewManager()
|
||||||
suite.loader = NewLoader(suite.manager, suite.chunkManager)
|
suite.loader = NewLoader(suite.manager, suite.chunkManager)
|
||||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||||
|
|
||||||
@ -92,6 +94,22 @@ func (suite *SegmentLoaderSuite) SetupTest() {
|
|||||||
suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta)
|
suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *SegmentLoaderSuite) SetupBM25() {
|
||||||
|
// Dependencies
|
||||||
|
suite.manager = NewManager()
|
||||||
|
suite.loader = NewLoader(suite.manager, suite.chunkManager)
|
||||||
|
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||||
|
|
||||||
|
suite.schema = GenTestBM25CollectionSchema("test")
|
||||||
|
indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema)
|
||||||
|
loadMeta := &querypb.LoadMetaInfo{
|
||||||
|
LoadType: querypb.LoadType_LoadCollection,
|
||||||
|
CollectionID: suite.collectionID,
|
||||||
|
PartitionIDs: []int64{suite.partitionID},
|
||||||
|
}
|
||||||
|
suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta)
|
||||||
|
}
|
||||||
|
|
||||||
func (suite *SegmentLoaderSuite) TearDownTest() {
|
func (suite *SegmentLoaderSuite) TearDownTest() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
for i := 0; i < suite.segmentNum; i++ {
|
for i := 0; i < suite.segmentNum; i++ {
|
||||||
@ -407,6 +425,41 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *SegmentLoaderSuite) TestLoadBm25Stats() {
|
||||||
|
suite.SetupBM25()
|
||||||
|
msgLength := 1
|
||||||
|
sparseFieldID := simpleSparseFloatVectorField.id
|
||||||
|
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)
|
||||||
|
|
||||||
|
for i := 0; i < suite.segmentNum; i++ {
|
||||||
|
segmentID := suite.segmentID + int64(i)
|
||||||
|
|
||||||
|
bm25logs, err := SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{
|
||||||
|
SegmentID: segmentID,
|
||||||
|
PartitionID: suite.partitionID,
|
||||||
|
CollectionID: suite.collectionID,
|
||||||
|
Bm25Logs: []*datapb.FieldBinlog{bm25logs},
|
||||||
|
NumOfRows: int64(msgLength),
|
||||||
|
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
statsMap, err := suite.loader.LoadBM25Stats(context.Background(), suite.collectionID, loadInfos...)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
for i := 0; i < suite.segmentNum; i++ {
|
||||||
|
segmentID := suite.segmentID + int64(i)
|
||||||
|
stats, ok := statsMap.Get(segmentID)
|
||||||
|
suite.True(ok)
|
||||||
|
fieldStats, ok := stats[sparseFieldID]
|
||||||
|
suite.True(ok)
|
||||||
|
suite.Equal(int64(msgLength), fieldStats.NumRow())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
|
func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)
|
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)
|
||||||
|
|||||||
@ -21,10 +21,10 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"maps"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/util/bloomfilter"
|
"github.com/milvus-io/milvus/internal/util/bloomfilter"
|
||||||
@ -347,7 +347,7 @@ func (m *BM25Stats) AppendFieldData(datas ...*SparseFloatVectorFieldData) {
|
|||||||
// Update BM25Stats by sparse vector bytes
|
// Update BM25Stats by sparse vector bytes
|
||||||
func (m *BM25Stats) AppendBytes(datas ...[]byte) {
|
func (m *BM25Stats) AppendBytes(datas ...[]byte) {
|
||||||
for _, data := range datas {
|
for _, data := range datas {
|
||||||
dim := len(data) / 8
|
dim := typeutil.SparseFloatRowElementCount(data)
|
||||||
for i := 0; i < dim; i++ {
|
for i := 0; i < dim; i++ {
|
||||||
index := typeutil.SparseFloatRowIndexAt(data, i)
|
index := typeutil.SparseFloatRowIndexAt(data, i)
|
||||||
value := typeutil.SparseFloatRowValueAt(data, i)
|
value := typeutil.SparseFloatRowValueAt(data, i)
|
||||||
@ -454,17 +454,19 @@ func (m *BM25Stats) Deserialize(bs []byte) error {
|
|||||||
m.rowsWithToken[keys[i]] += values[i]
|
m.rowsWithToken[keys[i]] += values[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("test-- deserialize", zap.Int64("numrow", m.numRow), zap.Int64("tokenNum", m.numToken))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *BM25Stats) BuildIDF(tf map[uint32]float32) map[uint32]float32 {
|
func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) {
|
||||||
vector := make(map[uint32]float32)
|
dim := typeutil.SparseFloatRowElementCount(tf)
|
||||||
for key, value := range tf {
|
idf = make([]byte, len(tf))
|
||||||
|
for idx := 0; idx < dim; idx++ {
|
||||||
|
key := typeutil.SparseFloatRowIndexAt(tf, idx)
|
||||||
|
value := typeutil.SparseFloatRowValueAt(tf, idx)
|
||||||
nq := m.rowsWithToken[key]
|
nq := m.rowsWithToken[key]
|
||||||
vector[key] = value * float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5)))
|
typeutil.SparseFloatRowSetAt(idf, idx, key, value*float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5))))
|
||||||
}
|
}
|
||||||
return vector
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *BM25Stats) GetAvgdl() float64 {
|
func (m *BM25Stats) GetAvgdl() float64 {
|
||||||
|
|||||||
@ -358,7 +358,7 @@ func readDoubleArray(blobReaders []io.Reader) []float64 {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema) (idata *InsertData, err error) {
|
func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema, skipFunction bool) (idata *InsertData, err error) {
|
||||||
blobReaders := make([]io.Reader, 0)
|
blobReaders := make([]io.Reader, 0)
|
||||||
for _, blob := range msg.RowData {
|
for _, blob := range msg.RowData {
|
||||||
blobReaders = append(blobReaders, bytes.NewReader(blob.GetValue()))
|
blobReaders = append(blobReaders, bytes.NewReader(blob.GetValue()))
|
||||||
@ -371,7 +371,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range collSchema.Fields {
|
for _, field := range collSchema.Fields {
|
||||||
if field.GetIsFunctionOutput() {
|
if skipFunction && field.GetIsFunctionOutput() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -696,7 +696,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
|
|||||||
|
|
||||||
func InsertMsgToInsertData(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) (idata *InsertData, err error) {
|
func InsertMsgToInsertData(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) (idata *InsertData, err error) {
|
||||||
if msg.IsRowBased() {
|
if msg.IsRowBased() {
|
||||||
return RowBasedInsertMsgToInsertData(msg, schema)
|
return RowBasedInsertMsgToInsertData(msg, schema, true)
|
||||||
}
|
}
|
||||||
return ColumnBasedInsertMsgToInsertData(msg, schema)
|
return ColumnBasedInsertMsgToInsertData(msg, schema)
|
||||||
}
|
}
|
||||||
@ -1272,7 +1272,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
|
|||||||
|
|
||||||
func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msgstream.InsertMsg) (*segcorepb.InsertRecord, error) {
|
func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msgstream.InsertMsg) (*segcorepb.InsertRecord, error) {
|
||||||
if msg.IsRowBased() {
|
if msg.IsRowBased() {
|
||||||
insertData, err := RowBasedInsertMsgToInsertData(msg, schema)
|
insertData, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1282,6 +1282,7 @@ func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msg
|
|||||||
// column base insert msg
|
// column base insert msg
|
||||||
insertRecord := &segcorepb.InsertRecord{
|
insertRecord := &segcorepb.InsertRecord{
|
||||||
NumRows: int64(msg.NumRows),
|
NumRows: int64(msg.NumRows),
|
||||||
|
FieldsData: make([]*schemapb.FieldData, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
insertRecord.FieldsData = append(insertRecord.FieldsData, msg.FieldsData...)
|
insertRecord.FieldsData = append(insertRecord.FieldsData, msg.FieldsData...)
|
||||||
|
|||||||
@ -1035,7 +1035,7 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) {
|
|||||||
fieldIDs = fieldIDs[:len(fieldIDs)-2]
|
fieldIDs = fieldIDs[:len(fieldIDs)-2]
|
||||||
msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim)
|
msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim)
|
||||||
|
|
||||||
idata, err := RowBasedInsertMsgToInsertData(msg, schema)
|
idata, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
for idx, fID := range fieldIDs {
|
for idx, fID := range fieldIDs {
|
||||||
column := columns[idx]
|
column := columns[idx]
|
||||||
@ -1096,7 +1096,7 @@ func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err := RowBasedInsertMsgToInsertData(msg, schema)
|
_, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1139,7 +1139,7 @@ func TestRowBasedInsertMsgToInsertBFloat16VectorDataError(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err := RowBasedInsertMsgToInsertData(msg, schema)
|
_, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
184
internal/util/function/mock_function.go
Normal file
184
internal/util/function/mock_function.go
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
// Code generated by mockery v2.32.4. DO NOT EDIT.
|
||||||
|
|
||||||
|
package function
|
||||||
|
|
||||||
|
import (
|
||||||
|
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
mock "github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockFunctionRunner is an autogenerated mock type for the FunctionRunner type
|
||||||
|
type MockFunctionRunner struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockFunctionRunner_Expecter struct {
|
||||||
|
mock *mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_m *MockFunctionRunner) EXPECT() *MockFunctionRunner_Expecter {
|
||||||
|
return &MockFunctionRunner_Expecter{mock: &_m.Mock}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRun provides a mock function with given fields: inputs
|
||||||
|
func (_m *MockFunctionRunner) BatchRun(inputs ...interface{}) ([]interface{}, error) {
|
||||||
|
var _ca []interface{}
|
||||||
|
_ca = append(_ca, inputs...)
|
||||||
|
ret := _m.Called(_ca...)
|
||||||
|
|
||||||
|
var r0 []interface{}
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(...interface{}) ([]interface{}, error)); ok {
|
||||||
|
return rf(inputs...)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(...interface{}) []interface{}); ok {
|
||||||
|
r0 = rf(inputs...)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).([]interface{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(...interface{}) error); ok {
|
||||||
|
r1 = rf(inputs...)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockFunctionRunner_BatchRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BatchRun'
|
||||||
|
type MockFunctionRunner_BatchRun_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRun is a helper method to define mock.On call
|
||||||
|
// - inputs ...interface{}
|
||||||
|
func (_e *MockFunctionRunner_Expecter) BatchRun(inputs ...interface{}) *MockFunctionRunner_BatchRun_Call {
|
||||||
|
return &MockFunctionRunner_BatchRun_Call{Call: _e.mock.On("BatchRun",
|
||||||
|
append([]interface{}{}, inputs...)...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_BatchRun_Call) Run(run func(inputs ...interface{})) *MockFunctionRunner_BatchRun_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
variadicArgs := make([]interface{}, len(args)-0)
|
||||||
|
for i, a := range args[0:] {
|
||||||
|
if a != nil {
|
||||||
|
variadicArgs[i] = a.(interface{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run(variadicArgs...)
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_BatchRun_Call) Return(_a0 []interface{}, _a1 error) *MockFunctionRunner_BatchRun_Call {
|
||||||
|
_c.Call.Return(_a0, _a1)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_BatchRun_Call) RunAndReturn(run func(...interface{}) ([]interface{}, error)) *MockFunctionRunner_BatchRun_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOutputFields provides a mock function with given fields:
|
||||||
|
func (_m *MockFunctionRunner) GetOutputFields() []*schemapb.FieldSchema {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 []*schemapb.FieldSchema
|
||||||
|
if rf, ok := ret.Get(0).(func() []*schemapb.FieldSchema); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).([]*schemapb.FieldSchema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockFunctionRunner_GetOutputFields_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOutputFields'
|
||||||
|
type MockFunctionRunner_GetOutputFields_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOutputFields is a helper method to define mock.On call
|
||||||
|
func (_e *MockFunctionRunner_Expecter) GetOutputFields() *MockFunctionRunner_GetOutputFields_Call {
|
||||||
|
return &MockFunctionRunner_GetOutputFields_Call{Call: _e.mock.On("GetOutputFields")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_GetOutputFields_Call) Run(run func()) *MockFunctionRunner_GetOutputFields_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run()
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_GetOutputFields_Call) Return(_a0 []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call {
|
||||||
|
_c.Call.Return(_a0)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_GetOutputFields_Call) RunAndReturn(run func() []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSchema provides a mock function with given fields:
|
||||||
|
func (_m *MockFunctionRunner) GetSchema() *schemapb.FunctionSchema {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 *schemapb.FunctionSchema
|
||||||
|
if rf, ok := ret.Get(0).(func() *schemapb.FunctionSchema); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*schemapb.FunctionSchema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockFunctionRunner_GetSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSchema'
|
||||||
|
type MockFunctionRunner_GetSchema_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSchema is a helper method to define mock.On call
|
||||||
|
func (_e *MockFunctionRunner_Expecter) GetSchema() *MockFunctionRunner_GetSchema_Call {
|
||||||
|
return &MockFunctionRunner_GetSchema_Call{Call: _e.mock.On("GetSchema")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_GetSchema_Call) Run(run func()) *MockFunctionRunner_GetSchema_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run()
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_GetSchema_Call) Return(_a0 *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call {
|
||||||
|
_c.Call.Return(_a0)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockFunctionRunner_GetSchema_Call) RunAndReturn(run func() *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockFunctionRunner creates a new instance of MockFunctionRunner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||||
|
// The first argument is typically a *testing.T value.
|
||||||
|
func NewMockFunctionRunner(t interface {
|
||||||
|
mock.TestingT
|
||||||
|
Cleanup(func())
|
||||||
|
}) *MockFunctionRunner {
|
||||||
|
mock := &MockFunctionRunner{}
|
||||||
|
mock.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
||||||
@ -6,12 +6,26 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
|
"github.com/samber/lo"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte {
|
||||||
|
placeholderGroup := &commonpb.PlaceholderGroup{
|
||||||
|
Placeholders: []*commonpb.PlaceholderValue{{
|
||||||
|
Tag: "$0",
|
||||||
|
Type: commonpb.PlaceholderType_SparseFloatVector,
|
||||||
|
Values: contents,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, _ := proto.Marshal(placeholderGroup)
|
||||||
|
return bytes
|
||||||
|
}
|
||||||
|
|
||||||
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
|
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
|
||||||
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
|
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -93,6 +107,14 @@ func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.Place
|
|||||||
Values: [][]byte{bytes},
|
Values: [][]byte{bytes},
|
||||||
}
|
}
|
||||||
return placeholderValue, nil
|
return placeholderValue, nil
|
||||||
|
case schemapb.DataType_VarChar:
|
||||||
|
strs := fieldData.GetScalars().GetStringData().GetData()
|
||||||
|
placeholderValue := &commonpb.PlaceholderValue{
|
||||||
|
Tag: "$0",
|
||||||
|
Type: commonpb.PlaceholderType_VarChar,
|
||||||
|
Values: lo.Map(strs, func(str string, _ int) []byte { return []byte(str) }),
|
||||||
|
}
|
||||||
|
return placeholderValue, nil
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("field is not a vector field")
|
return nil, errors.New("field is not a vector field")
|
||||||
}
|
}
|
||||||
@ -157,3 +179,7 @@ func flattenedBFloat16VectorsToByteVectors(flattenedVectors []byte, dimension in
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetVarCharFromPlaceholder(holder *commonpb.PlaceholderValue) []string {
|
||||||
|
return lo.Map(holder.Values, func(bytes []byte, _ int) string { return string(bytes) })
|
||||||
|
}
|
||||||
|
|||||||
@ -36,4 +36,6 @@ const (
|
|||||||
|
|
||||||
// SUPERSTRUCTURE represents superstructure distance
|
// SUPERSTRUCTURE represents superstructure distance
|
||||||
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
||||||
|
|
||||||
|
BM25 MetricType = "BM25"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -21,5 +21,5 @@ import "strings"
|
|||||||
// PositivelyRelated return if metricType are "ip" or "IP"
|
// PositivelyRelated return if metricType are "ip" or "IP"
|
||||||
func PositivelyRelated(metricType string) bool {
|
func PositivelyRelated(metricType string) bool {
|
||||||
mUpper := strings.ToUpper(metricType)
|
mUpper := strings.ToUpper(metricType)
|
||||||
return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE)
|
return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE) || mUpper == strings.ToUpper(BM25)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -212,7 +212,7 @@ func TestIndexAutoSparseVector(t *testing.T) {
|
|||||||
for _, unsupportedMt := range hp.UnsupportedSparseVecMetricsType {
|
for _, unsupportedMt := range hp.UnsupportedSparseVecMetricsType {
|
||||||
idx := index.NewAutoIndex(unsupportedMt)
|
idx := index.NewAutoIndex(unsupportedMt)
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idx))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idx))
|
||||||
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
|
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
|
||||||
}
|
}
|
||||||
|
|
||||||
// auto index with different metric type on sparse vec
|
// auto index with different metric type on sparse vec
|
||||||
@ -829,11 +829,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) {
|
|||||||
for _, mt := range hp.UnsupportedSparseVecMetricsType {
|
for _, mt := range hp.UnsupportedSparseVecMetricsType {
|
||||||
idxInverted := index.NewSparseInvertedIndex(mt, 0.2)
|
idxInverted := index.NewSparseInvertedIndex(mt, 0.2)
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
|
||||||
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
|
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
|
||||||
|
|
||||||
idxWand := index.NewSparseWANDIndex(mt, 0.2)
|
idxWand := index.NewSparseWANDIndex(mt, 0.2)
|
||||||
_, err = mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
|
_, err = mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
|
||||||
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
|
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
|
||||||
}
|
}
|
||||||
|
|
||||||
// create index with invalid drop_ratio_build
|
// create index with invalid drop_ratio_build
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user