From db34572c564b7e012154c1af4ffb2dece4945121 Mon Sep 17 00:00:00 2001 From: aoiasd <45024769+aoiasd@users.noreply.github.com> Date: Fri, 11 Oct 2024 10:23:20 +0800 Subject: [PATCH] feat: support load and query with bm25 metric (#36071) relate: https://github.com/milvus-io/milvus/issues/35853 --------- Signed-off-by: aoiasd --- Makefile | 3 + internal/core/src/common/Types.h | 3 +- internal/core/src/common/Utils.h | 6 +- internal/core/src/index/VectorMemIndex.cpp | 3 +- internal/core/src/query/PlanProto.cpp | 5 + .../core/src/segcore/IndexConfigGenerator.cpp | 23 +- .../test_utils/indexbuilder_test_utils.h | 9 + internal/datacoord/compaction_task_l0_test.go | 2 +- .../pipeline/flow_graph_embedding_node.go | 2 +- internal/proto/internal.proto | 2 + internal/proto/plan.proto | 2 + internal/proto/query_coord.proto | 1 + internal/proxy/task_index.go | 8 +- internal/proxy/task_search.go | 4 + internal/querycoordv2/utils/types.go | 1 + internal/querynodev2/delegator/delegator.go | 44 +- .../querynodev2/delegator/delegator_data.go | 94 ++++- .../delegator/delegator_data_test.go | 396 +++++++++++++++++- .../querynodev2/delegator/delegator_test.go | 78 ++++ .../querynodev2/delegator/distribution.go | 6 +- .../delegator/distribution_test.go | 4 +- internal/querynodev2/delegator/idf_oracle.go | 262 ++++++++++++ .../querynodev2/delegator/idf_oracle_test.go | 198 +++++++++ internal/querynodev2/delegator/util.go | 64 +++ internal/querynodev2/handlers.go | 2 + .../querynodev2/pipeline/embedding_node.go | 211 ++++++++++ .../pipeline/embedding_node_test.go | 281 +++++++++++++ internal/querynodev2/pipeline/insert_node.go | 28 +- internal/querynodev2/pipeline/message.go | 8 +- internal/querynodev2/pipeline/pipeline.go | 18 +- internal/querynodev2/segments/collection.go | 12 + internal/querynodev2/segments/mock_data.go | 61 +++ internal/querynodev2/segments/mock_loader.go | 74 ++++ internal/querynodev2/segments/mock_segment.go | 76 ++++ internal/querynodev2/segments/segment.go | 17 + .../querynodev2/segments/segment_interface.go | 4 + .../querynodev2/segments/segment_loader.go | 119 ++++++ .../segments/segment_loader_test.go | 57 ++- internal/storage/stats.go | 18 +- internal/storage/utils.go | 11 +- internal/storage/utils_test.go | 6 +- internal/util/function/mock_function.go | 184 ++++++++ pkg/util/funcutil/placeholdergroup.go | 26 ++ pkg/util/metric/metric_type.go | 2 + pkg/util/metric/similarity_corelation.go | 2 +- tests/go_client/testcases/index_test.go | 6 +- 46 files changed, 2372 insertions(+), 71 deletions(-) create mode 100644 internal/querynodev2/delegator/idf_oracle.go create mode 100644 internal/querynodev2/delegator/idf_oracle_test.go create mode 100644 internal/querynodev2/delegator/util.go create mode 100644 internal/querynodev2/pipeline/embedding_node.go create mode 100644 internal/querynodev2/pipeline/embedding_node_test.go create mode 100644 internal/util/function/mock_function.go diff --git a/Makefile b/Makefile index 4fe53268a1..c8a87670a9 100644 --- a/Makefile +++ b/Makefile @@ -532,6 +532,9 @@ generate-mockery-utils: getdeps # proxy_client_manager.go $(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage $(INSTALL_PATH)/mockery --name=ProxyWatcherInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_watcher.go --with-expecter --structname=MockProxyWatcher --inpackage + # function + $(INSTALL_PATH)/mockery --name=FunctionRunner --dir=$(PWD)/internal/util/function --output=$(PWD)/internal/util/function --filename=mock_function.go --with-expecter --structname=MockFunctionRunner --inpackage + generate-mockery-kv: getdeps $(INSTALL_PATH)/mockery --name=TxnKV --dir=$(PWD)/pkg/kv --output=$(PWD)/internal/kv/mocks --filename=txn_kv.go --with-expecter diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 477927567c..2473b21a88 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -410,7 +410,8 @@ inline bool IsFloatVectorMetricType(const MetricType& metric_type) { return metric_type == knowhere::metric::L2 || metric_type == knowhere::metric::IP || - metric_type == knowhere::metric::COSINE; + metric_type == knowhere::metric::COSINE || + metric_type == knowhere::metric::BM25; } inline bool diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index 3af4de80e8..c22f3d93ea 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -160,13 +160,15 @@ inline bool IsFloatMetricType(const knowhere::MetricType& metric_type) { return IsMetricType(metric_type, knowhere::metric::L2) || IsMetricType(metric_type, knowhere::metric::IP) || - IsMetricType(metric_type, knowhere::metric::COSINE); + IsMetricType(metric_type, knowhere::metric::COSINE) || + IsMetricType(metric_type, knowhere::metric::BM25); } inline bool PositivelyRelated(const knowhere::MetricType& metric_type) { return IsMetricType(metric_type, knowhere::metric::IP) || - IsMetricType(metric_type, knowhere::metric::COSINE); + IsMetricType(metric_type, knowhere::metric::COSINE) || + IsMetricType(metric_type, knowhere::metric::BM25); } inline std::string diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 6d7767fcf4..2f77e8ffbb 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -409,7 +409,8 @@ VectorMemIndex::Query(const DatasetPtr dataset, milvus::tracer::AddEvent("finish_knowhere_index_search"); if (!res.has_value()) { PanicInfo(ErrorCode::UnexpectedError, - "failed to search: {}: {}", + "failed to search: config={} {}: {}", + search_conf.dump(), KnowhereStatusString(res.error()), res.what()); } diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index e07da98fb3..d61ad31ce9 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -52,6 +52,11 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { search_info.materialized_view_involved = query_info_proto.materialized_view_involved(); + if (query_info_proto.bm25_avgdl() > 0) { + search_info.search_params_[knowhere::meta::BM25_AVGDL] = + query_info_proto.bm25_avgdl(); + } + if (query_info_proto.group_by_field_id() > 0) { auto group_by_field_id = FieldId(query_info_proto.group_by_field_id()); diff --git a/internal/core/src/segcore/IndexConfigGenerator.cpp b/internal/core/src/segcore/IndexConfigGenerator.cpp index 0c0d041359..3b3fbed642 100644 --- a/internal/core/src/segcore/IndexConfigGenerator.cpp +++ b/internal/core/src/segcore/IndexConfigGenerator.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "IndexConfigGenerator.h" +#include "knowhere/comp/index_param.h" #include "log/Log.h" namespace milvus::segcore { @@ -49,15 +50,28 @@ VecIndexConfig::VecIndexConfig(const int64_t max_index_row_cout, std::to_string(config_.get_nlist()); build_params_[knowhere::indexparam::SSIZE] = std::to_string( std::max((int)(config_.get_chunk_rows() / config_.get_nlist()), 48)); + + if (is_sparse && metric_type_ == knowhere::metric::BM25) { + build_params_[knowhere::meta::BM25_K1] = + index_meta_.GetIndexParams().at(knowhere::meta::BM25_K1); + build_params_[knowhere::meta::BM25_B] = + index_meta_.GetIndexParams().at(knowhere::meta::BM25_B); + build_params_[knowhere::meta::BM25_AVGDL] = + index_meta_.GetIndexParams().at(knowhere::meta::BM25_AVGDL); + } + search_params_[knowhere::indexparam::NPROBE] = std::to_string(config_.get_nprobe()); + // note for sparse vector index: drop_ratio_build is not allowed for growing // segment index. LOG_INFO( - "VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}", + "VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}, " + "config={}", origin_index_type_, index_type_, - metric_type_); + metric_type_, + build_params_.dump()); } int64_t @@ -100,6 +114,11 @@ VecIndexConfig::GetSearchConf(const SearchInfo& searchInfo) { searchParam.search_params_[key] = searchInfo.search_params_[key]; } } + + if (metric_type_ == knowhere::metric::BM25) { + searchParam.search_params_[knowhere::meta::BM25_AVGDL] = + searchInfo.search_params_[knowhere::meta::BM25_AVGDL]; + } return searchParam; } diff --git a/internal/core/unittest/test_utils/indexbuilder_test_utils.h b/internal/core/unittest/test_utils/indexbuilder_test_utils.h index a02c5cfe3b..527d11bd25 100644 --- a/internal/core/unittest/test_utils/indexbuilder_test_utils.h +++ b/internal/core/unittest/test_utils/indexbuilder_test_utils.h @@ -25,6 +25,7 @@ #include "indexbuilder/ScalarIndexCreator.h" #include "indexbuilder/VecIndexCreator.h" #include "indexbuilder/index_c.h" +#include "knowhere/comp/index_param.h" #include "pb/index_cgo_msg.pb.h" #include "storage/Types.h" @@ -100,6 +101,14 @@ generate_build_conf(const milvus::IndexType& index_type, }; } else if (index_type == knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX || index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) { + if (metric_type == knowhere::metric::BM25) { + return knowhere::Json{ + {knowhere::meta::METRIC_TYPE, metric_type}, + {knowhere::indexparam::DROP_RATIO_BUILD, "0.1"}, + {knowhere::meta::BM25_K1, "1.2"}, + {knowhere::meta::BM25_B, "0.75"}, + {knowhere::meta::BM25_AVGDL, "100"}}; + } return knowhere::Json{ {knowhere::meta::METRIC_TYPE, metric_type}, {knowhere::indexparam::DROP_RATIO_BUILD, "0.1"}, diff --git a/internal/datacoord/compaction_task_l0_test.go b/internal/datacoord/compaction_task_l0_test.go index 4a9cdc2197..bf87f343bb 100644 --- a/internal/datacoord/compaction_task_l0_test.go +++ b/internal/datacoord/compaction_task_l0_test.go @@ -652,7 +652,7 @@ func (s *L0CompactionTaskSuite) TestPorcessStateTrans() { s.Equal(datapb.CompactionTaskState_failed, t.GetState()) }) - s.Run("test unkonwn task", func() { + s.Run("test unknown task", func() { t := s.generateTestL0Task(datapb.CompactionTaskState_unknown) got := t.Process() diff --git a/internal/flushcommon/pipeline/flow_graph_embedding_node.go b/internal/flushcommon/pipeline/flow_graph_embedding_node.go index 80de77d7be..bf809c49ae 100644 --- a/internal/flushcommon/pipeline/flow_graph_embedding_node.go +++ b/internal/flushcommon/pipeline/flow_graph_embedding_node.go @@ -73,7 +73,7 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e } func (eNode *embeddingNode) Name() string { - return fmt.Sprintf("embeddingNode-%s-%s", "BM25test", eNode.channelName) + return fmt.Sprintf("embeddingNode-%s", eNode.channelName) } func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, inputFieldId, outputFieldId int64, data *storage.InsertData, meta map[int64]*storage.BM25Stats) error { diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index b83a7d75a1..154191d2db 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -96,6 +96,7 @@ message SubSearchRequest { string metricType = 9; int64 group_by_field_id = 10; int64 group_size = 11; + int64 field_id = 12; } message SearchRequest { @@ -124,6 +125,7 @@ message SearchRequest { common.ConsistencyLevel consistency_level = 22; int64 group_by_field_id = 23; int64 group_size = 24; + int64 field_id = 25; } message SubSearchResults { diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index d4818ac338..16ed9aee2b 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -64,6 +64,8 @@ message QueryInfo { bool materialized_view_involved = 7; int64 group_size = 8; bool group_strict_size = 9; + double bm25_avgdl = 10; + int64 query_field_id =11; } message ColumnInfo { diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 2de85755c7..ef5c8d8d1b 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -367,6 +367,7 @@ message SegmentLoadInfo { int64 storageVersion = 18; bool is_sorted = 19; map textStatsLogs = 20; + repeated data.FieldBinlog bm25logs = 21; } message FieldIndexInfo { diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index ff68a38e42..186618acd1 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -361,8 +361,12 @@ func (cit *createIndexTask) parseIndexParams() error { return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "float vector index does not support metric type: "+metricType) } } else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) { - if metricType != metric.IP { - return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP is the supported metric type for sparse index") + if metricType != metric.IP && metricType != metric.BM25 { + return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP&BM25 is the supported metric type for sparse index") + } + + if metricType == metric.BM25 && cit.functionSchema.GetType() != schemapb.FunctionType_BM25 { + return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only BM25 Function output field support BM25 metric type") } } else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) { if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index ccd4e591ca..9bf90be527 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -370,6 +370,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { GroupSize: t.rankParams.GetGroupSize(), } + internalSubReq.FieldId = queryInfo.GetQueryFieldId() // set PartitionIDs for sub search if t.partitionKeyMode { // isolatioin has tighter constraint, check first @@ -449,6 +450,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { } t.SearchRequest.Offset = offset + t.SearchRequest.FieldId = queryInfo.GetQueryFieldId() if t.partitionKeyMode { // isolatioin has tighter constraint, check first @@ -511,6 +513,8 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector { return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column") } + + queryInfo.QueryFieldId = annField.GetFieldID() plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo) if planErr != nil { log.Warn("failed to create query plan", zap.Error(planErr), diff --git a/internal/querycoordv2/utils/types.go b/internal/querycoordv2/utils/types.go index 697ee46bc5..511081d737 100644 --- a/internal/querycoordv2/utils/types.go +++ b/internal/querycoordv2/utils/types.go @@ -81,6 +81,7 @@ func PackSegmentLoadInfo(segment *datapb.SegmentInfo, channelCheckpoint *msgpb.M NumOfRows: segment.NumOfRows, Statslogs: segment.Statslogs, Deltalogs: segment.Deltalogs, + Bm25Logs: segment.Bm25Statslogs, InsertChannel: segment.InsertChannel, IndexInfos: indexes, StartPosition: segment.GetStartPosition(), diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index b04003ac12..417f65bd75 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/cluster" @@ -43,6 +44,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" @@ -54,6 +56,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" + "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -110,7 +113,9 @@ type shardDelegator struct { lifetime lifetime.Lifetime[lifetime.State] - distribution *distribution + distribution *distribution + idfOracle IDFOracle + segmentManager segments.SegmentManager tsafeManager tsafe.Manager pkOracle pkoracle.PkOracle @@ -135,6 +140,10 @@ type shardDelegator struct { // in order to make add/remove growing be atomic, need lock before modify these meta info growingSegmentLock sync.RWMutex partitionStatsMut sync.RWMutex + + // fieldId -> functionRunner map for search function field + functionRunners map[UniqueID]function.FunctionRunner + hasBM25Field bool } // getLogger returns the zap logger with pre-defined shard attributes. @@ -235,6 +244,19 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest }() } + // build idf for bm25 search + if req.GetReq().GetMetricType() == metric.BM25 { + avgdl, err := sd.buildBM25IDF(req.GetReq()) + if err != nil { + return nil, err + } + + if avgdl <= 0 { + log.Warn("search bm25 from empty data, skip search", zap.String("channel", sd.vchannelName), zap.Float64("avgdl", avgdl)) + return []*internalpb.SearchResults{}, nil + } + } + // get final sealedNum after possible segment prune sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) }) log.Debug("search segments...", @@ -335,6 +357,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest IsAdvanced: false, GroupByFieldId: subReq.GetGroupByFieldId(), GroupSize: subReq.GetGroupSize(), + FieldId: subReq.GetFieldId(), } future := conc.Go(func() (*internalpb.SearchResults, error) { searchReq := &querypb.SearchRequest{ @@ -862,6 +885,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni excludedSegments := NewExcludedSegments(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.GetAsDuration(time.Second)) + idfOracle := NewIDFOracle(collection.Schema().GetFunctions()) sd := &shardDelegator{ collectionID: collectionID, replicaID: replicaID, @@ -871,7 +895,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni segmentManager: manager.Segment, workerManager: workerManager, lifetime: lifetime.NewLifetime(lifetime.Initializing), - distribution: NewDistribution(), + distribution: NewDistribution(idfOracle), deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock), pkOracle: pkoracle.NewPkOracle(), tsafeManager: tsafeManager, @@ -880,9 +904,25 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni factory: factory, queryHook: queryHook, chunkManager: chunkManager, + idfOracle: idfOracle, partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot), excludedSegments: excludedSegments, + functionRunners: make(map[int64]function.FunctionRunner), } + + for _, tf := range collection.Schema().GetFunctions() { + if tf.GetType() == schemapb.FunctionType_BM25 { + functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf) + if err != nil { + return nil, err + } + sd.functionRunners[tf.OutputFieldIds[0]] = functionRunner + if tf.GetType() == schemapb.FunctionType_BM25 { + sd.hasBM25Field = true + } + } + } + m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) if sd.lifetime.Add(lifetime.NotStopped) == nil { diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index 3d6db44629..71331f5d3a 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -27,11 +27,14 @@ import ( "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/querynodev2/cluster" @@ -63,10 +66,12 @@ import ( // InsertData type InsertData struct { - RowIDs []int64 - PrimaryKeys []storage.PrimaryKey - Timestamps []uint64 - InsertRecord *segcorepb.InsertRecord + RowIDs []int64 + PrimaryKeys []storage.PrimaryKey + Timestamps []uint64 + InsertRecord *segcorepb.InsertRecord + BM25Stats map[int64]*storage.BM25Stats + StartPosition *msgpb.MsgPosition PartitionID int64 } @@ -149,6 +154,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) { if !sd.pkOracle.Exists(growing, paramtable.GetNodeID()) { // register created growing segment after insert, avoid to add empty growing to delegator sd.pkOracle.Register(growing, paramtable.GetNodeID()) + sd.idfOracle.Register(segmentID, insertData.BM25Stats, segments.SegmentTypeGrowing) sd.segmentManager.Put(context.Background(), segments.SegmentTypeGrowing, growing) sd.addGrowing(SegmentEntry{ NodeID: paramtable.GetNodeID(), @@ -158,10 +164,12 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) { TargetVersion: initialTargetVersion, }) } - sd.growingSegmentLock.Unlock() - } - log.Debug("insert into growing segment", + sd.growingSegmentLock.Unlock() + } else { + sd.idfOracle.UpdateGrowing(growing.ID(), insertData.BM25Stats) + } + log.Info("insert into growing segment", zap.Int64("collectionID", growing.Collection()), zap.Int64("segmentID", segmentID), zap.Int("rowCount", len(insertData.RowIDs)), @@ -375,8 +383,11 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm segmentIDs = lo.Map(loaded, func(segment segments.Segment, _ int) int64 { return segment.ID() }) log.Info("load growing segments done", zap.Int64s("segmentIDs", segmentIDs)) - for _, candidate := range loaded { - sd.pkOracle.Register(candidate, paramtable.GetNodeID()) + for _, segment := range loaded { + sd.pkOracle.Register(segment, paramtable.GetNodeID()) + if sd.hasBM25Field { + sd.idfOracle.Register(segment.ID(), segment.GetBM25Stats(), segments.SegmentTypeGrowing) + } } sd.addGrowing(lo.Map(loaded, func(segment segments.Segment, _ int) SegmentEntry { return SegmentEntry{ @@ -472,6 +483,16 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg infos := lo.Filter(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) bool { return !sd.pkOracle.Exists(pkoracle.NewCandidateKey(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed), targetNodeID) }) + + var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats] + if sd.hasBM25Field { + bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...) + if err != nil { + log.Warn("failed to load bm25 stats for segment", zap.Error(err)) + return err + } + } + candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...) if err != nil { log.Warn("failed to load bloom filter set for segment", zap.Error(err)) @@ -479,7 +500,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg } log.Debug("load delete...") - err = sd.loadStreamDelete(ctx, candidates, infos, req, targetNodeID, worker) + err = sd.loadStreamDelete(ctx, candidates, bm25Stats, infos, req, targetNodeID, worker) if err != nil { log.Warn("load stream delete failed", zap.Error(err)) return err @@ -552,6 +573,7 @@ func (sd *shardDelegator) RefreshLevel0DeletionStats() { func (sd *shardDelegator) loadStreamDelete(ctx context.Context, candidates []*pkoracle.BloomFilterSet, + bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], infos []*querypb.SegmentLoadInfo, req *querypb.LoadSegmentsRequest, targetNodeID int64, @@ -665,6 +687,14 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context, ) sd.pkOracle.Register(candidate, targetNodeID) } + + if bm25Stats != nil { + bm25Stats.Range(func(segmentID int64, stats map[int64]*storage.BM25Stats) bool { + sd.idfOracle.Register(segmentID, stats, segments.SegmentTypeSealed) + return false + }) + } + log.Info("load delete done") return nil @@ -963,3 +993,47 @@ func (sd *shardDelegator) TryCleanExcludedSegments(ts uint64) { sd.excludedSegments.CleanInvalid(ts) } } + +func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64, error) { + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(req.GetPlaceholderGroup(), pb) + + if len(pb.Placeholders) != 1 || len(pb.Placeholders[0].Values) != 1 { + return 0, merr.WrapErrParameterInvalidMsg("please provide varchar for bm25") + } + + holder := pb.Placeholders[0] + if holder.Type != commonpb.PlaceholderType_VarChar { + return 0, fmt.Errorf("can't build BM25 IDF for data not varchar") + } + + str := funcutil.GetVarCharFromPlaceholder(holder) + functionRunner, ok := sd.functionRunners[req.GetFieldId()] + if !ok { + return 0, fmt.Errorf("functionRunner not found for field: %d", req.GetFieldId()) + } + + // get search text term frequency + output, err := functionRunner.BatchRun(str) + if err != nil { + return 0, err + } + + tfArray, ok := output[0].(*schemapb.SparseFloatArray) + if !ok { + return 0, fmt.Errorf("functionRunner return unknown data") + } + + idfSparseVector, avgdl, err := sd.idfOracle.BuildIDF(req.GetFieldId(), tfArray) + if err != nil { + return 0, err + } + + err = SetBM25Params(req, avgdl) + if err != nil { + return 0, err + } + + req.PlaceholderGroup = funcutil.SparseVectorDataToPlaceholderGroupBytes(idfSparseVector) + return avgdl, nil +} diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index 173b17851e..8d9839b694 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -26,14 +26,19 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/pingcap/log" "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/querynodev2/cluster" @@ -42,10 +47,12 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/bloomfilter" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/metric" @@ -95,13 +102,7 @@ func (s *DelegatorDataSuite) TearDownSuite() { paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.Key) } -func (s *DelegatorDataSuite) SetupTest() { - s.workerManager = &cluster.MockManager{} - s.manager = segments.NewManager() - s.tsafeManager = tsafe.NewTSafeReplica() - s.loader = &segments.MockLoader{} - - // init schema +func (s *DelegatorDataSuite) genNormalCollection() { s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{ Name: "TestCollection", Fields: []*schemapb.FieldSchema{ @@ -154,7 +155,59 @@ func (s *DelegatorDataSuite) SetupTest() { LoadType: querypb.LoadType_LoadCollection, PartitionIDs: []int64{1001, 1002}, }) +} +func (s *DelegatorDataSuite) genCollectionWithFunction() { + s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{ + Name: "TestCollection", + Fields: []*schemapb.FieldSchema{ + { + Name: "id", + FieldID: 100, + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + AutoID: true, + }, { + Name: "vector", + FieldID: 101, + IsPrimaryKey: false, + DataType: schemapb.DataType_SparseFloatVector, + }, { + Name: "text", + FieldID: 102, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "256", + }, + }, + }, + }, + Functions: []*schemapb.FunctionSchema{{ + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{102}, + OutputFieldIds: []int64{101}, + }}, + }, nil, nil) + + delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{ + NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { + return s.mq, nil + }, + }, 10000, nil, s.chunkManager) + s.NoError(err) + s.delegator = delegator.(*shardDelegator) +} + +func (s *DelegatorDataSuite) SetupTest() { + s.workerManager = &cluster.MockManager{} + s.manager = segments.NewManager() + s.tsafeManager = tsafe.NewTSafeReplica() + s.loader = &segments.MockLoader{} + + // init schema + s.genNormalCollection() s.mq = &msgstream.MockMsgStream{} s.rootPath = s.Suite.T().Name() chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath) @@ -471,6 +524,127 @@ func (s *DelegatorDataSuite) TestProcessDelete() { s.False(s.delegator.distribution.Serviceable()) } +func (s *DelegatorDataSuite) TestLoadGrowingWithBM25() { + s.genCollectionWithFunction() + mockSegment := segments.NewMockSegment(s.T()) + s.loader.EXPECT().Load(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]segments.Segment{mockSegment}, nil) + + mockSegment.EXPECT().Partition().Return(111) + mockSegment.EXPECT().ID().Return(111) + mockSegment.EXPECT().Type().Return(commonpb.SegmentState_Growing) + mockSegment.EXPECT().GetBM25Stats().Return(map[int64]*storage.BM25Stats{}) + + err := s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{{SegmentID: 1}}, 1) + s.NoError(err) +} + +func (s *DelegatorDataSuite) TestLoadSegmentsWithBm25() { + s.genCollectionWithFunction() + s.Run("normal_run", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + s.loader.ExpectedCalls = nil + }() + + statsMap := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]() + stats := storage.NewBM25Stats() + stats.Append(map[uint32]float32{1: 1}) + + statsMap.Insert(1, map[int64]*storage.BM25Stats{101: stats}) + + s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(statsMap, nil) + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { + return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) + }) + }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + return nil + }) + + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + workers[1] = worker1 + + worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")). + Return(nil) + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{ + Base: commonpbutil.NewMsgBase(), + DstNodeID: 1, + CollectionID: s.collectionID, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: 100, + PartitionID: 500, + StartPosition: &msgpb.MsgPosition{Timestamp: 20000}, + DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000}, + Level: datapb.SegmentLevel_L1, + InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID), + }, + }, + }) + + s.NoError(err) + sealed, _ := s.delegator.GetSegmentInfo(false) + s.Require().Equal(1, len(sealed)) + s.Equal(int64(1), sealed[0].NodeID) + s.ElementsMatch([]SegmentEntry{ + { + SegmentID: 100, + NodeID: 1, + PartitionID: 500, + TargetVersion: unreadableTargetVersion, + Level: datapb.SegmentLevel_L1, + }, + }, sealed[0].Segments) + }) + + s.Run("loadBM25_failed", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + s.loader.ExpectedCalls = nil + }() + + s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(nil, fmt.Errorf("mock error")) + + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + workers[1] = worker1 + + worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")). + Return(nil) + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{ + Base: commonpbutil.NewMsgBase(), + DstNodeID: 1, + CollectionID: s.collectionID, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: 100, + PartitionID: 500, + StartPosition: &msgpb.MsgPosition{Timestamp: 20000}, + DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000}, + Level: datapb.SegmentLevel_L1, + InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID), + }, + }, + }) + + s.Error(err) + }) +} + func (s *DelegatorDataSuite) TestLoadSegments() { s.Run("normal_run", func() { defer func() { @@ -883,6 +1057,214 @@ func (s *DelegatorDataSuite) TestLoadSegments() { }) } +func (s *DelegatorDataSuite) TestBuildBM25IDF() { + s.genCollectionWithFunction() + + genBM25Stats := func(start uint32, end uint32) map[int64]*storage.BM25Stats { + result := make(map[int64]*storage.BM25Stats) + result[101] = storage.NewBM25Stats() + for i := start; i < end; i++ { + row := map[uint32]float32{i: 1} + result[101].Append(row) + } + return result + } + + genSnapShot := func(seals, grows []int64, targetVersion int64) *snapshot { + snapshot := &snapshot{ + dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}}, + targetVersion: targetVersion, + } + + newSeal := []SegmentEntry{} + for _, seg := range seals { + newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion}) + } + + newGrow := []SegmentEntry{} + for _, seg := range grows { + newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion}) + } + + log.Info("Test-", zap.Any("shanshot", snapshot), zap.Any("seg", newSeal)) + snapshot.dist[0].Segments = newSeal + snapshot.growing = newGrow + return snapshot + } + + genStringFieldData := func(strs ...string) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 102, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: strs, + }, + }, + }, + }, + } + } + + s.Run("normal case", func() { + // register sealed + sealedSegs := []int64{1, 2, 3, 4} + for _, segID := range sealedSegs { + // every segment stats only has one token, avgdl = 1 + s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) + } + snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100) + + s.delegator.idfOracle.SyncDistribution(snapshot) + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data")) + s.NoError(err) + + plan, err := proto.Marshal(&planpb.PlanNode{ + Node: &planpb.PlanNode_VectorAnns{ + VectorAnns: &planpb.VectorANNS{ + QueryInfo: &planpb.QueryInfo{}, + }, + }, + }) + s.NoError(err) + + req := &internalpb.SearchRequest{ + PlaceholderGroup: placeholderGroupBytes, + SerializedExprPlan: plan, + FieldId: 101, + } + avgdl, err := s.delegator.buildBM25IDF(req) + s.NoError(err) + s.Equal(float64(1), avgdl) + + // check avgdl in plan + newplan := &planpb.PlanNode{} + err = proto.Unmarshal(req.GetSerializedExprPlan(), newplan) + s.NoError(err) + + annplan, ok := newplan.GetNode().(*planpb.PlanNode_VectorAnns) + s.Require().True(ok) + s.Equal(avgdl, annplan.VectorAnns.QueryInfo.Bm25Avgdl) + + // check idf in placeholder + placeholder := &commonpb.PlaceholderGroup{} + err = proto.Unmarshal(req.GetPlaceholderGroup(), placeholder) + s.Require().NoError(err) + s.Equal(placeholder.GetPlaceholders()[0].GetType(), commonpb.PlaceholderType_SparseFloatVector) + }) + + s.Run("invalid place holder type error", func() { + placeholderGroupBytes, err := proto.Marshal(&commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{{Type: commonpb.PlaceholderType_SparseFloatVector}}, + }) + s.NoError(err) + + req := &internalpb.SearchRequest{ + PlaceholderGroup: placeholderGroupBytes, + FieldId: 101, + } + _, err = s.delegator.buildBM25IDF(req) + s.Error(err) + }) + + s.Run("no function runner error", func() { + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data")) + s.NoError(err) + + req := &internalpb.SearchRequest{ + PlaceholderGroup: placeholderGroupBytes, + FieldId: 103, // invalid field id + } + + _, err = s.delegator.buildBM25IDF(req) + s.Error(err) + }) + + s.Run("function runner run failed error", func() { + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data")) + s.NoError(err) + + oldRunner := s.delegator.functionRunners + mockRunner := function.NewMockFunctionRunner(s.T()) + s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner} + mockRunner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock err")) + defer func() { + s.delegator.functionRunners = oldRunner + }() + + req := &internalpb.SearchRequest{ + PlaceholderGroup: placeholderGroupBytes, + FieldId: 101, + } + _, err = s.delegator.buildBM25IDF(req) + s.Error(err) + }) + + s.Run("function runner output type error", func() { + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data")) + s.NoError(err) + + oldRunner := s.delegator.functionRunners + mockRunner := function.NewMockFunctionRunner(s.T()) + s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner} + mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil) + defer func() { + s.delegator.functionRunners = oldRunner + }() + + req := &internalpb.SearchRequest{ + PlaceholderGroup: placeholderGroupBytes, + FieldId: 101, + } + _, err = s.delegator.buildBM25IDF(req) + s.Error(err) + }) + + s.Run("idf oracle build idf error", func() { + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data")) + s.NoError(err) + + oldRunner := s.delegator.functionRunners + mockRunner := function.NewMockFunctionRunner(s.T()) + s.delegator.functionRunners = map[int64]function.FunctionRunner{103: mockRunner} + mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{&schemapb.SparseFloatArray{Contents: [][]byte{typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{1: 1})}}}, nil) + defer func() { + s.delegator.functionRunners = oldRunner + }() + + req := &internalpb.SearchRequest{ + PlaceholderGroup: placeholderGroupBytes, + FieldId: 103, // invalid field + } + _, err = s.delegator.buildBM25IDF(req) + s.Error(err) + log.Info("test", zap.Error(err)) + }) + + s.Run("set avgdl failed", func() { + // register sealed + sealedSegs := []int64{1, 2, 3, 4} + for _, segID := range sealedSegs { + // every segment stats only has one token, avgdl = 1 + s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) + } + snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100) + + s.delegator.idfOracle.SyncDistribution(snapshot) + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data")) + s.NoError(err) + + req := &internalpb.SearchRequest{ + PlaceholderGroup: placeholderGroupBytes, + FieldId: 101, + } + _, err = s.delegator.buildBM25IDF(req) + s.Error(err) + }) +} + func (s *DelegatorDataSuite) TestReleaseSegment() { s.loader.EXPECT(). Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything). diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index 2dcd9ac5e0..cb541aa6e9 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -178,6 +178,84 @@ func (s *DelegatorSuite) TearDownTest() { s.delegator = nil } +func (s *DelegatorSuite) TestCreateDelegatorWithFunction() { + s.Run("init function failed", func() { + manager := segments.NewManager() + manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{ + Name: "TestCollection", + Fields: []*schemapb.FieldSchema{ + { + Name: "id", + FieldID: 100, + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + AutoID: true, + }, { + Name: "vector", + FieldID: 101, + IsPrimaryKey: false, + DataType: schemapb.DataType_SparseFloatVector, + }, + }, + Functions: []*schemapb.FunctionSchema{{ + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{102}, + OutputFieldIds: []int64{101, 103}, // invalid output field + }}, + }, nil, nil) + + _, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{ + NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { + return s.mq, nil + }, + }, 10000, nil, s.chunkManager) + s.Error(err) + }) + + s.Run("init function failed", func() { + manager := segments.NewManager() + manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{ + Name: "TestCollection", + Fields: []*schemapb.FieldSchema{ + { + Name: "id", + FieldID: 100, + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + AutoID: true, + }, { + Name: "vector", + FieldID: 101, + IsPrimaryKey: false, + DataType: schemapb.DataType_SparseFloatVector, + }, { + Name: "text", + FieldID: 102, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "256", + }, + }, + }, + }, + Functions: []*schemapb.FunctionSchema{{ + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{102}, + OutputFieldIds: []int64{101}, + }}, + }, nil, nil) + + _, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{ + NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { + return s.mq, nil + }, + }, 10000, nil, s.chunkManager) + s.NoError(err) + }) +} + func (s *DelegatorSuite) TestBasicInfo() { s.Equal(s.collectionID, s.delegator.Collection()) s.Equal(s.version, s.delegator.Version()) diff --git a/internal/querynodev2/delegator/distribution.go b/internal/querynodev2/delegator/distribution.go index ae57b0fe04..47c69a60af 100644 --- a/internal/querynodev2/delegator/distribution.go +++ b/internal/querynodev2/delegator/distribution.go @@ -74,6 +74,8 @@ type distribution struct { // current is the snapshot for quick usage for search/query // generated for each change of distribution current *atomic.Pointer[snapshot] + + idfOracle IDFOracle // protects current & segments mut sync.RWMutex } @@ -89,7 +91,7 @@ type SegmentEntry struct { } // NewDistribution creates a new distribution instance with all field initialized. -func NewDistribution() *distribution { +func NewDistribution(idfOracle IDFOracle) *distribution { dist := &distribution{ serviceable: atomic.NewBool(false), growingSegments: make(map[UniqueID]SegmentEntry), @@ -98,6 +100,7 @@ func NewDistribution() *distribution { current: atomic.NewPointer[snapshot](nil), offlines: typeutil.NewSet[int64](), targetVersion: atomic.NewInt64(initialTargetVersion), + idfOracle: idfOracle, } dist.genSnapshot() @@ -367,6 +370,7 @@ func (d *distribution) genSnapshot() chan struct{} { d.current.Store(newSnapShot) // shall be a new one d.snapshots.GetOrInsert(d.snapshotVersion, newSnapShot) + d.idfOracle.SyncDistribution(newSnapShot) // first snapshot, return closed chan if last == nil { diff --git a/internal/querynodev2/delegator/distribution_test.go b/internal/querynodev2/delegator/distribution_test.go index aa8c534e7d..d5fa335384 100644 --- a/internal/querynodev2/delegator/distribution_test.go +++ b/internal/querynodev2/delegator/distribution_test.go @@ -21,6 +21,8 @@ import ( "time" "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) type DistributionSuite struct { @@ -29,7 +31,7 @@ type DistributionSuite struct { } func (s *DistributionSuite) SetupTest() { - s.dist = NewDistribution() + s.dist = NewDistribution(NewIDFOracle([]*schemapb.FunctionSchema{})) s.Equal(initialTargetVersion, s.dist.getTargetVersion()) } diff --git a/internal/querynodev2/delegator/idf_oracle.go b/internal/querynodev2/delegator/idf_oracle.go new file mode 100644 index 0000000000..76989e39a5 --- /dev/null +++ b/internal/querynodev2/delegator/idf_oracle.go @@ -0,0 +1,262 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package delegator + +import ( + "fmt" + "sync" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/querynodev2/segments" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" +) + +type IDFOracle interface { + // Activate(segmentID int64, state commonpb.SegmentState) error + // Deactivate(segmentID int64, state commonpb.SegmentState) error + + SyncDistribution(snapshot *snapshot) + + UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) + + Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) + Remove(segmentID int64, state commonpb.SegmentState) + + BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) +} + +type bm25Stats struct { + stats map[int64]*storage.BM25Stats + activate bool + targetVersion int64 +} + +func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) { + for fieldID, newstats := range stats { + if stats, ok := s.stats[fieldID]; ok { + stats.Merge(newstats) + } else { + log.Panic("merge failed, BM25 stats not exist", zap.Int64("fieldID", fieldID)) + } + } +} + +func (s *bm25Stats) Minus(stats map[int64]*storage.BM25Stats) { + for fieldID, newstats := range stats { + if stats, ok := s.stats[fieldID]; ok { + stats.Minus(newstats) + } else { + log.Panic("minus failed, BM25 stats not exist", zap.Int64("fieldID", fieldID)) + } + } +} + +func (s *bm25Stats) GetStats(fieldID int64) (*storage.BM25Stats, error) { + stats, ok := s.stats[fieldID] + if !ok { + return nil, fmt.Errorf("field not found in idf oracle BM25 stats") + } + return stats, nil +} + +func (s *bm25Stats) NumRow() int64 { + for _, stats := range s.stats { + return stats.NumRow() + } + return 0 +} + +func newBm25Stats(functions []*schemapb.FunctionSchema) *bm25Stats { + stats := &bm25Stats{ + stats: make(map[int64]*storage.BM25Stats), + } + + for _, function := range functions { + if function.GetType() == schemapb.FunctionType_BM25 { + stats.stats[function.GetOutputFieldIds()[0]] = storage.NewBM25Stats() + } + } + return stats +} + +type idfOracle struct { + sync.RWMutex + + current *bm25Stats + + growing map[int64]*bm25Stats + sealed map[int64]*bm25Stats + + targetVersion int64 +} + +func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) { + o.Lock() + defer o.Unlock() + + switch state { + case segments.SegmentTypeGrowing: + if _, ok := o.growing[segmentID]; ok { + return + } + o.growing[segmentID] = &bm25Stats{ + stats: stats, + activate: true, + targetVersion: initialTargetVersion, + } + o.current.Merge(stats) + case segments.SegmentTypeSealed: + if _, ok := o.sealed[segmentID]; ok { + return + } + o.sealed[segmentID] = &bm25Stats{ + stats: stats, + activate: false, + targetVersion: initialTargetVersion, + } + default: + log.Warn("register segment with unknown state", zap.String("stats", state.String())) + return + } +} + +func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) { + if len(stats) == 0 { + return + } + + o.Lock() + defer o.Unlock() + + old, ok := o.growing[segmentID] + if !ok { + return + } + + old.Merge(stats) + if old.activate { + o.current.Merge(stats) + } +} + +func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) { + o.Lock() + defer o.Unlock() + + switch state { + case segments.SegmentTypeGrowing: + if stats, ok := o.growing[segmentID]; ok { + if stats.activate { + o.current.Minus(stats.stats) + } + delete(o.growing, segmentID) + } + case segments.SegmentTypeSealed: + if stats, ok := o.sealed[segmentID]; ok { + if stats.activate { + o.current.Minus(stats.stats) + } + delete(o.sealed, segmentID) + } + default: + return + } +} + +func (o *idfOracle) activate(stats *bm25Stats) { + stats.activate = true + o.current.Merge(stats.stats) +} + +func (o *idfOracle) deactivate(stats *bm25Stats) { + stats.activate = false + o.current.Minus(stats.stats) +} + +func (o *idfOracle) SyncDistribution(snapshot *snapshot) { + o.Lock() + defer o.Unlock() + + sealed, growing := snapshot.Peek() + + for _, item := range sealed { + for _, segment := range item.Segments { + if stats, ok := o.sealed[segment.SegmentID]; ok { + stats.targetVersion = segment.TargetVersion + } else { + log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID)) + } + } + } + + for _, segment := range growing { + if stats, ok := o.growing[segment.SegmentID]; ok { + stats.targetVersion = segment.TargetVersion + } else { + log.Warn("idf oracle lack some growing segment", zap.Int64("segmentID", segment.SegmentID)) + } + } + + o.targetVersion = snapshot.targetVersion + + for _, stats := range o.sealed { + if !stats.activate && stats.targetVersion == o.targetVersion { + o.activate(stats) + } else if stats.activate && stats.targetVersion != o.targetVersion { + o.deactivate(stats) + } + } + + for _, stats := range o.growing { + if !stats.activate && (stats.targetVersion == o.targetVersion || stats.targetVersion == initialTargetVersion) { + o.activate(stats) + } else if stats.activate && (stats.targetVersion != o.targetVersion && stats.targetVersion != initialTargetVersion) { + o.deactivate(stats) + } + } + + log.Debug("sync distribution finished", zap.Int64("version", o.targetVersion), zap.Int64("numrow", o.current.NumRow())) +} + +func (o *idfOracle) BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) { + o.RLock() + defer o.RUnlock() + + stats, err := o.current.GetStats(fieldID) + if err != nil { + return nil, 0, err + } + + idfBytes := make([][]byte, len(tfs.GetContents())) + for i, tf := range tfs.GetContents() { + idf := stats.BuildIDF(tf) + idfBytes[i] = idf + } + return idfBytes, stats.GetAvgdl(), nil +} + +func NewIDFOracle(functions []*schemapb.FunctionSchema) IDFOracle { + return &idfOracle{ + current: newBm25Stats(functions), + growing: make(map[int64]*bm25Stats), + sealed: make(map[int64]*bm25Stats), + } +} diff --git a/internal/querynodev2/delegator/idf_oracle_test.go b/internal/querynodev2/delegator/idf_oracle_test.go new file mode 100644 index 0000000000..a2d36f2e36 --- /dev/null +++ b/internal/querynodev2/delegator/idf_oracle_test.go @@ -0,0 +1,198 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package delegator + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type IDFOracleSuite struct { + suite.Suite + collectionID int64 + collectionSchema *schemapb.CollectionSchema + idfOracle *idfOracle + + targetVersion int64 + snapshot *snapshot +} + +func (suite *IDFOracleSuite) SetupSuite() { + suite.collectionID = 111 + suite.collectionSchema = &schemapb.CollectionSchema{ + Functions: []*schemapb.FunctionSchema{{ + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + }}, + } +} + +func (suite *IDFOracleSuite) SetupTest() { + suite.idfOracle = NewIDFOracle(suite.collectionSchema.GetFunctions()).(*idfOracle) + suite.snapshot = &snapshot{ + dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}}, + } + suite.targetVersion = 0 +} + +func (suite *IDFOracleSuite) genStats(start uint32, end uint32) map[int64]*storage.BM25Stats { + result := make(map[int64]*storage.BM25Stats) + result[102] = storage.NewBM25Stats() + for i := start; i < end; i++ { + row := map[uint32]float32{i: 1} + result[102].Append(row) + } + return result +} + +// update test snapshot +func (suite *IDFOracleSuite) updateSnapshot(seals, grows, drops []int64) *snapshot { + suite.targetVersion++ + snapshot := &snapshot{ + dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}}, + targetVersion: suite.targetVersion, + } + + dropSet := typeutil.NewSet[int64]() + dropSet.Insert(drops...) + + newSeal := []SegmentEntry{} + for _, seg := range suite.snapshot.dist[0].Segments { + if !dropSet.Contain(seg.SegmentID) { + seg.TargetVersion = suite.targetVersion + } + newSeal = append(newSeal, seg) + } + for _, seg := range seals { + newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion}) + } + + newGrow := []SegmentEntry{} + for _, seg := range suite.snapshot.growing { + if !dropSet.Contain(seg.SegmentID) { + seg.TargetVersion = suite.targetVersion + } else { + seg.TargetVersion = redundantTargetVersion + } + newGrow = append(newGrow, seg) + } + for _, seg := range grows { + newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion}) + } + + snapshot.dist[0].Segments = newSeal + snapshot.growing = newGrow + suite.snapshot = snapshot + return snapshot +} + +func (suite *IDFOracleSuite) TestSealed() { + // register sealed + sealedSegs := []int64{1, 2, 3, 4} + for _, segID := range sealedSegs { + suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) + } + + // reduplicate register + for _, segID := range sealedSegs { + suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) + } + + // register sealed segment but all deactvate + suite.Zero(suite.idfOracle.current.NumRow()) + + // update and sync snapshot make all sealed activate + suite.updateSnapshot(sealedSegs, []int64{}, []int64{}) + suite.idfOracle.SyncDistribution(suite.snapshot) + suite.Equal(int64(4), suite.idfOracle.current.NumRow()) + + releasedSeg := []int64{1, 2, 3} + suite.updateSnapshot([]int64{}, []int64{}, releasedSeg) + suite.idfOracle.SyncDistribution(suite.snapshot) + suite.Equal(int64(1), suite.idfOracle.current.NumRow()) + + for _, segID := range releasedSeg { + suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed) + } + + sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1}) + bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1}) + suite.NoError(err) + suite.Equal(float64(1), avgdl) + suite.Equal(map[uint32]float32{4: 0.2876821}, typeutil.SparseFloatBytesToMap(bytes[0])) +} + +func (suite *IDFOracleSuite) TestGrow() { + // register grow + growSegs := []int64{1, 2, 3, 4} + for _, segID := range growSegs { + suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing) + } + // reduplicate register + for _, segID := range growSegs { + suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing) + } + + // register sealed segment but all deactvate + suite.Equal(int64(4), suite.idfOracle.current.NumRow()) + suite.updateSnapshot([]int64{}, growSegs, []int64{}) + + releasedSeg := []int64{1, 2, 3} + suite.updateSnapshot([]int64{}, []int64{}, releasedSeg) + suite.idfOracle.SyncDistribution(suite.snapshot) + suite.Equal(int64(1), suite.idfOracle.current.NumRow()) + + suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6)) + suite.Equal(int64(2), suite.idfOracle.current.NumRow()) + + for _, segID := range releasedSeg { + suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing) + } +} + +func (suite *IDFOracleSuite) TestStats() { + stats := newBm25Stats([]*schemapb.FunctionSchema{{ + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + }}) + + suite.Panics(func() { + stats.Merge(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()}) + }) + + suite.Panics(func() { + stats.Minus(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()}) + }) + + _, err := stats.GetStats(103) + suite.Error(err) + + _, err = stats.GetStats(102) + suite.NoError(err) +} + +func TestIDFOracle(t *testing.T) { + suite.Run(t, new(IDFOracleSuite)) +} diff --git a/internal/querynodev2/delegator/util.go b/internal/querynodev2/delegator/util.go new file mode 100644 index 0000000000..8c9f74b3ab --- /dev/null +++ b/internal/querynodev2/delegator/util.go @@ -0,0 +1,64 @@ +package delegator + +import ( + "fmt" + + "go.uber.org/zap" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func BuildSparseFieldData(field *schemapb.FieldSchema, sparseArray *schemapb.SparseFloatArray) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: sparseArray.GetDim(), + Data: &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: sparseArray, + }, + }, + }, + FieldId: field.GetFieldID(), + } +} + +func SetBM25Params(req *internalpb.SearchRequest, avgdl float64) error { + log := log.With(zap.Int64("collection", req.GetCollectionID())) + + serializedPlan := req.GetSerializedExprPlan() + // plan not found + if serializedPlan == nil { + log.Warn("serialized plan not found") + return merr.WrapErrParameterInvalid("serialized search plan", "nil") + } + + plan := planpb.PlanNode{} + err := proto.Unmarshal(serializedPlan, &plan) + if err != nil { + log.Warn("failed to unmarshal plan", zap.Error(err)) + return merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error()) + } + + switch plan.GetNode().(type) { + case *planpb.PlanNode_VectorAnns: + queryInfo := plan.GetVectorAnns().GetQueryInfo() + queryInfo.Bm25Avgdl = avgdl + serializedExprPlan, err := proto.Marshal(&plan) + if err != nil { + log.Warn("failed to marshal optimized plan", zap.Error(err)) + return merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error()) + } + req.SerializedExprPlan = serializedExprPlan + log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo)) + default: + log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode()))) + } + return nil +} diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 16b2056c2c..d3898679d3 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -60,6 +60,7 @@ func loadL0Segments(ctx context.Context, delegator delegator.ShardDelegator, req NumOfRows: segmentInfo.NumOfRows, Statslogs: segmentInfo.Statslogs, Deltalogs: segmentInfo.Deltalogs, + Bm25Logs: segmentInfo.Bm25Statslogs, InsertChannel: segmentInfo.InsertChannel, StartPosition: segmentInfo.GetStartPosition(), Level: segmentInfo.GetLevel(), @@ -101,6 +102,7 @@ func loadGrowingSegments(ctx context.Context, delegator delegator.ShardDelegator NumOfRows: segmentInfo.NumOfRows, Statslogs: segmentInfo.Statslogs, Deltalogs: segmentInfo.Deltalogs, + Bm25Logs: segmentInfo.Bm25Statslogs, InsertChannel: segmentInfo.InsertChannel, }) } else { diff --git a/internal/querynodev2/pipeline/embedding_node.go b/internal/querynodev2/pipeline/embedding_node.go new file mode 100644 index 0000000000..d3a8298214 --- /dev/null +++ b/internal/querynodev2/pipeline/embedding_node.go @@ -0,0 +1,211 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipeline + +import ( + "fmt" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/querynodev2/delegator" + "github.com/milvus-io/milvus/internal/querynodev2/segments" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/function" + base "github.com/milvus-io/milvus/internal/util/pipeline" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type embeddingNode struct { + *BaseNode + + collectionID int64 + channel string + + manager *DataManager + + functionRunners []function.FunctionRunner +} + +func newEmbeddingNode(collectionID int64, channelName string, manager *DataManager, maxQueueLength int32) (*embeddingNode, error) { + collection := manager.Collection.Get(collectionID) + if collection == nil { + log.Error("embeddingNode init failed with collection not exist", zap.Int64("collection", collectionID)) + return nil, merr.WrapErrCollectionNotFound(collectionID) + } + + if len(collection.Schema().GetFunctions()) == 0 { + return nil, nil + } + + node := &embeddingNode{ + BaseNode: base.NewBaseNode(fmt.Sprintf("EmbeddingNode-%s", channelName), maxQueueLength), + collectionID: collectionID, + channel: channelName, + manager: manager, + functionRunners: make([]function.FunctionRunner, 0), + } + + for _, tf := range collection.Schema().GetFunctions() { + functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf) + if err != nil { + return nil, err + } + node.functionRunners = append(node.functionRunners, functionRunner) + } + return node, nil +} + +func (eNode *embeddingNode) Name() string { + return fmt.Sprintf("embeddingNode-%s", eNode.channel) +} + +func (eNode *embeddingNode) addInsertData(insertDatas map[UniqueID]*delegator.InsertData, msg *InsertMsg, collection *Collection) error { + iData, ok := insertDatas[msg.SegmentID] + if !ok { + iData = &delegator.InsertData{ + PartitionID: msg.PartitionID, + BM25Stats: make(map[int64]*storage.BM25Stats), + StartPosition: &msgpb.MsgPosition{ + Timestamp: msg.BeginTs(), + ChannelName: msg.GetShardName(), + }, + } + insertDatas[msg.SegmentID] = iData + } + + err := eNode.embedding(msg, iData.BM25Stats) + if err != nil { + log.Error("failed to function data", zap.Error(err)) + return err + } + + insertRecord, err := storage.TransferInsertMsgToInsertRecord(collection.Schema(), msg) + if err != nil { + err = fmt.Errorf("failed to get primary keys, err = %d", err) + log.Error(err.Error(), zap.String("channel", eNode.channel)) + return err + } + + if iData.InsertRecord == nil { + iData.InsertRecord = insertRecord + } else { + err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData) + if err != nil { + log.Warn("failed to merge field data", zap.String("channel", eNode.channel), zap.Error(err)) + return err + } + iData.InsertRecord.NumRows += insertRecord.NumRows + } + + pks, err := segments.GetPrimaryKeys(msg, collection.Schema()) + if err != nil { + log.Warn("failed to get primary keys from insert message", zap.String("channel", eNode.channel), zap.Error(err)) + return err + } + + iData.PrimaryKeys = append(iData.PrimaryKeys, pks...) + iData.RowIDs = append(iData.RowIDs, msg.RowIDs...) + iData.Timestamps = append(iData.Timestamps, msg.Timestamps...) + log.Debug("pipeline embedding insert msg", + zap.Int64("collectionID", eNode.collectionID), + zap.Int64("segmentID", msg.SegmentID), + zap.Int("insertRowNum", len(pks)), + zap.Uint64("timestampMin", msg.BeginTimestamp), + zap.Uint64("timestampMax", msg.EndTimestamp)) + return nil +} + +func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error { + functionSchema := runner.GetSchema() + inputFieldID := functionSchema.GetInputFieldIds()[0] + outputFieldID := functionSchema.GetOutputFieldIds()[0] + outputField := runner.GetOutputFields()[0] + + data, err := GetEmbeddingFieldData(msg.GetFieldsData(), inputFieldID) + if data == nil || err != nil { + return merr.WrapErrFieldNotFound(fmt.Sprint(inputFieldID)) + } + + output, err := runner.BatchRun(data) + if err != nil { + return err + } + + sparseArray, ok := output[0].(*schemapb.SparseFloatArray) + if !ok { + return fmt.Errorf("BM25 runner return unknown type output") + } + + if _, ok := stats[outputFieldID]; !ok { + stats[outputFieldID] = storage.NewBM25Stats() + } + stats[outputFieldID].AppendBytes(sparseArray.GetContents()...) + msg.FieldsData = append(msg.FieldsData, delegator.BuildSparseFieldData(outputField, sparseArray)) + return nil +} + +func (eNode *embeddingNode) embedding(msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error { + for _, functionRunner := range eNode.functionRunners { + functionSchema := functionRunner.GetSchema() + switch functionSchema.GetType() { + case schemapb.FunctionType_BM25: + err := eNode.bm25Embedding(functionRunner, msg, stats) + if err != nil { + return err + } + default: + log.Warn("pipeline embedding with unknown function type", zap.Any("type", functionSchema.GetType())) + return fmt.Errorf("unknown function type") + } + } + + return nil +} + +func (eNode *embeddingNode) Operate(in Msg) Msg { + nodeMsg := in.(*insertNodeMsg) + nodeMsg.insertDatas = make(map[int64]*delegator.InsertData) + + collection := eNode.manager.Collection.Get(eNode.collectionID) + if collection == nil { + log.Error("embeddingNode with collection not exist", zap.Int64("collection", eNode.collectionID)) + panic("embeddingNode with collection not exist") + } + + for _, msg := range nodeMsg.insertMsgs { + err := eNode.addInsertData(nodeMsg.insertDatas, msg, collection) + if err != nil { + panic(err) + } + } + + return nodeMsg +} + +func GetEmbeddingFieldData(datas []*schemapb.FieldData, fieldID int64) ([]string, error) { + for _, data := range datas { + if data.GetFieldId() == fieldID { + return data.GetScalars().GetStringData().GetData(), nil + } + } + return nil, fmt.Errorf("field %d not found", fieldID) +} diff --git a/internal/querynodev2/pipeline/embedding_node_test.go b/internal/querynodev2/pipeline/embedding_node_test.go new file mode 100644 index 0000000000..19ab5c5139 --- /dev/null +++ b/internal/querynodev2/pipeline/embedding_node_test.go @@ -0,0 +1,281 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipeline + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/querynodev2/delegator" + "github.com/milvus-io/milvus/internal/querynodev2/segments" + "github.com/milvus-io/milvus/internal/util/function" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// test of embedding node +type EmbeddingNodeSuite struct { + suite.Suite + // datas + collectionID int64 + collectionSchema *schemapb.CollectionSchema + channel string + msgs []*InsertMsg + + // mocks + manager *segments.Manager + segManager *segments.MockSegmentManager + colManager *segments.MockCollectionManager +} + +func (suite *EmbeddingNodeSuite) SetupTest() { + paramtable.Init() + suite.collectionID = 111 + suite.channel = "test-channel" + suite.collectionSchema = &schemapb.CollectionSchema{ + Name: "test-collection", + Fields: []*schemapb.FieldSchema{ + { + FieldID: common.TimeStampField, + Name: common.TimeStampFieldName, + DataType: schemapb.DataType_Int64, + }, { + Name: "pk", + FieldID: 100, + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, { + Name: "text", + FieldID: 101, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{}, + }, { + Name: "sparse", + FieldID: 102, + DataType: schemapb.DataType_SparseFloatVector, + IsFunctionOutput: true, + }, + }, + Functions: []*schemapb.FunctionSchema{{ + Name: "BM25", + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + }}, + } + + suite.msgs = []*msgstream.InsertMsg{{ + BaseMsg: msgstream.BaseMsg{}, + InsertRequest: &msgpb.InsertRequest{ + SegmentID: 1, + NumRows: 3, + Version: msgpb.InsertDataVersion_ColumnBased, + Timestamps: []uint64{1, 1, 1}, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}, + }, + }, { + FieldId: 101, + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}}}}, + }, + }, + }, + }, + }} + + suite.segManager = segments.NewMockSegmentManager(suite.T()) + suite.colManager = segments.NewMockCollectionManager(suite.T()) + + suite.manager = &segments.Manager{ + Collection: suite.colManager, + Segment: suite.segManager, + } +} + +func (suite *EmbeddingNodeSuite) TestCreateEmbeddingNode() { + suite.Run("collection not found", func() { + suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once() + _, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.Error(err) + }) + + suite.Run("function invalid", func() { + collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema) + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema) + collection.Schema().Functions = []*schemapb.FunctionSchema{{}} + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once() + _, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.Error(err) + }) + + suite.Run("normal case", func() { + collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema) + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once() + _, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + }) +} + +func (suite *EmbeddingNodeSuite) TestOperator() { + suite.Run("collection not found", func() { + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once() + node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + + suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once() + suite.Panics(func() { + node.Operate(&insertNodeMsg{}) + }) + }) + + suite.Run("add InsertData Failed", func() { + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2) + node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + + suite.Panics(func() { + node.Operate(&insertNodeMsg{ + insertMsgs: []*msgstream.InsertMsg{{ + BaseMsg: msgstream.BaseMsg{}, + InsertRequest: &msgpb.InsertRequest{ + SegmentID: 1, + NumRows: 3, + Version: msgpb.InsertDataVersion_ColumnBased, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}, + }, + }, + }, + }, + }}, + }) + }) + }) + + suite.Run("normal case", func() { + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2) + node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + + suite.NotPanics(func() { + output := node.Operate(&insertNodeMsg{ + insertMsgs: suite.msgs, + }) + + msg, ok := output.(*insertNodeMsg) + suite.Require().True(ok) + suite.Require().NotNil(msg.insertDatas) + suite.Require().Equal(int64(3), msg.insertDatas[1].BM25Stats[102].NumRow()) + suite.Require().Equal(int64(3), msg.insertDatas[1].InsertRecord.GetNumRows()) + }) + }) +} + +func (suite *EmbeddingNodeSuite) TestAddInsertData() { + suite.Run("transfer insert msg failed", func() { + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once() + node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + + // transfer insert msg failed because rowbase data not support sparse vector + insertDatas := make(map[int64]*delegator.InsertData) + rowBaseReq := proto.Clone(suite.msgs[0].InsertRequest).(*msgpb.InsertRequest) + rowBaseReq.Version = msgpb.InsertDataVersion_RowBased + rowBaseMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{}, + InsertRequest: rowBaseReq, + } + err = node.addInsertData(insertDatas, rowBaseMsg, collection) + suite.Error(err) + }) + + suite.Run("merge failed data failed", func() { + // remove pk + suite.collectionSchema.Fields[1].IsPrimaryKey = false + defer func() { + suite.collectionSchema.Fields[1].IsPrimaryKey = true + }() + + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once() + node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + + insertDatas := make(map[int64]*delegator.InsertData) + err = node.addInsertData(insertDatas, suite.msgs[0], collection) + suite.Error(err) + }) +} + +func (suite *EmbeddingNodeSuite) TestBM25Embedding() { + suite.Run("function run failed", func() { + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once() + node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + + runner := function.NewMockFunctionRunner(suite.T()) + runner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock error")) + runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0]) + runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil}) + + err = node.bm25Embedding(runner, suite.msgs[0], nil) + suite.Error(err) + }) + + suite.Run("output with unknown type failed", func() { + collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema) + suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once() + node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128) + suite.NoError(err) + + runner := function.NewMockFunctionRunner(suite.T()) + runner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil) + runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0]) + runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil}) + + err = node.bm25Embedding(runner, suite.msgs[0], nil) + suite.Error(err) + }) +} + +func TestEmbeddingNode(t *testing.T) { + suite.Run(t, new(EmbeddingNodeSuite)) +} diff --git a/internal/querynodev2/pipeline/insert_node.go b/internal/querynodev2/pipeline/insert_node.go index 6ae9368501..f3fd1bb264 100644 --- a/internal/querynodev2/pipeline/insert_node.go +++ b/internal/querynodev2/pipeline/insert_node.go @@ -62,7 +62,7 @@ func (iNode *insertNode) addInsertData(insertDatas map[UniqueID]*delegator.Inser } else { err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData) if err != nil { - log.Error("failed to merge field data", zap.Error(err)) + log.Error("failed to merge field data", zap.String("channel", iNode.channel), zap.Error(err)) panic(err) } iData.InsertRecord.NumRows += insertRecord.NumRows @@ -95,21 +95,23 @@ func (iNode *insertNode) Operate(in Msg) Msg { return nodeMsg.insertMsgs[i].BeginTs() < nodeMsg.insertMsgs[j].BeginTs() }) - insertDatas := make(map[UniqueID]*delegator.InsertData) - collection := iNode.manager.Collection.Get(iNode.collectionID) - if collection == nil { - log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID)) - panic("insertNode with collection not exist") + // build insert data if no embedding node + if nodeMsg.insertDatas == nil { + collection := iNode.manager.Collection.Get(iNode.collectionID) + if collection == nil { + log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID)) + panic("insertNode with collection not exist") + } + + nodeMsg.insertDatas = make(map[UniqueID]*delegator.InsertData) + // get InsertData and merge datas of same segment + for _, msg := range nodeMsg.insertMsgs { + iNode.addInsertData(nodeMsg.insertDatas, msg, collection) + } } - // get InsertData and merge datas of same segment - for _, msg := range nodeMsg.insertMsgs { - iNode.addInsertData(insertDatas, msg, collection) - } - - iNode.delegator.ProcessInsert(insertDatas) + iNode.delegator.ProcessInsert(nodeMsg.insertDatas) } - metrics.QueryNodeWaitProcessingMsgCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.DeleteLabel).Inc() return &deleteNodeMsg{ diff --git a/internal/querynodev2/pipeline/message.go b/internal/querynodev2/pipeline/message.go index fd5f3acda7..0afae129ea 100644 --- a/internal/querynodev2/pipeline/message.go +++ b/internal/querynodev2/pipeline/message.go @@ -19,15 +19,17 @@ package pipeline import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/querynodev2/collector" + "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" ) type insertNodeMsg struct { - insertMsgs []*InsertMsg - deleteMsgs []*DeleteMsg - timeRange TimeRange + insertMsgs []*InsertMsg + deleteMsgs []*DeleteMsg + insertDatas map[int64]*delegator.InsertData + timeRange TimeRange } type deleteNodeMsg struct { diff --git a/internal/querynodev2/pipeline/pipeline.go b/internal/querynodev2/pipeline/pipeline.go index 16b4fb02c3..fb8d72f4d9 100644 --- a/internal/querynodev2/pipeline/pipeline.go +++ b/internal/querynodev2/pipeline/pipeline.go @@ -31,7 +31,8 @@ type Pipeline interface { type pipeline struct { base.StreamPipeline - collectionID UniqueID + collectionID UniqueID + embeddingNode embeddingNode } func (p *pipeline) Close() { @@ -54,8 +55,21 @@ func NewPipeLine( } filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength) + + embeddingNode, err := newEmbeddingNode(collectionID, channel, manager, pipelineQueueLength) + if err != nil { + return nil, err + } + insertNode := newInsertNode(collectionID, channel, manager, delegator, pipelineQueueLength) deleteNode := newDeleteNode(collectionID, channel, manager, tSafeManager, delegator, pipelineQueueLength) - p.Add(filterNode, insertNode, deleteNode) + + // skip add embedding node when collection has no function. + if embeddingNode != nil { + p.Add(filterNode, embeddingNode, insertNode, deleteNode) + } else { + p.Add(filterNode, insertNode, deleteNode) + } + return p, nil } diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 5a5679c0c9..69f7aed604 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -299,6 +299,18 @@ func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) * } } +// new collection without segcore prepare +// ONLY FOR TEST +func NewCollectionWithoutSegcoreForTest(collectionID int64, schema *schemapb.CollectionSchema) *Collection { + coll := &Collection{ + id: collectionID, + partitions: typeutil.NewConcurrentSet[int64](), + refCount: atomic.NewUint32(0), + } + coll.schema.Store(schema) + return coll +} + // deleteCollection delete collection and free the collection memory func DeleteCollection(collection *Collection) { /* diff --git a/internal/querynodev2/segments/mock_data.go b/internal/querynodev2/segments/mock_data.go index a8ba0ac692..1f734c6fcc 100644 --- a/internal/querynodev2/segments/mock_data.go +++ b/internal/querynodev2/segments/mock_data.go @@ -47,6 +47,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/testutils" @@ -220,6 +221,11 @@ func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema { DataType: param.dataType, ElementType: schemapb.DataType_Int32, } + if param.dataType == schemapb.DataType_VarChar { + field.TypeParams = []*commonpb.KeyValuePair{ + {Key: common.MaxLengthKey, Value: "128"}, + } + } return field } @@ -263,6 +269,35 @@ func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema { return fieldVec } +func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSchema { + fieldRowID := genConstantFieldSchema(rowIDField) + fieldTimestamp := genConstantFieldSchema(timestampField) + pkFieldSchema := genPKFieldSchema(simpleInt64Field) + textFieldSchema := genConstantFieldSchema(simpleVarCharField) + sparseFieldSchema := genVectorFieldSchema(simpleSparseFloatVectorField) + sparseFieldSchema.IsFunctionOutput = true + + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Fields: []*schemapb.FieldSchema{ + fieldRowID, + fieldTimestamp, + pkFieldSchema, + textFieldSchema, + sparseFieldSchema, + }, + Functions: []*schemapb.FunctionSchema{{ + Name: "BM25", + Type: schemapb.FunctionType_BM25, + InputFieldNames: []string{textFieldSchema.GetName()}, + InputFieldIds: []int64{textFieldSchema.GetFieldID()}, + OutputFieldNames: []string{sparseFieldSchema.GetName()}, + OutputFieldIds: []int64{sparseFieldSchema.GetFieldID()}, + }}, + } + return schema +} + // some tests do not yet support sparse float vector, see comments of // GenSparseFloatVecDataset in indexcgowrapper/dataset.go func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema { @@ -671,6 +706,32 @@ func SaveDeltaLog(collectionID int64, return fieldBinlog, cm.MultiWrite(context.Background(), kvs) } +func SaveBM25Log(collectionID int64, partitionID int64, segmentID int64, fieldID int64, msgLength int, cm storage.ChunkManager) (*datapb.FieldBinlog, error) { + stats := storage.NewBM25Stats() + + for i := 0; i < msgLength; i++ { + stats.Append(map[uint32]float32{1: 1}) + } + + bytes, err := stats.Serialize() + if err != nil { + return nil, err + } + + kvs := make(map[string][]byte, 1) + key := path.Join(cm.RootPath(), common.SegmentBm25LogPath, metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID, 1001)) + kvs[key] = bytes + fieldBinlog := &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []*datapb.Binlog{{ + LogPath: key, + TimestampFrom: 100, + TimestampTo: 200, + }}, + } + return fieldBinlog, cm.MultiWrite(context.Background(), kvs) +} + func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64, fieldSchema *schemapb.FieldSchema, indexInfo *indexpb.IndexInfo, diff --git a/internal/querynodev2/segments/mock_loader.go b/internal/querynodev2/segments/mock_loader.go index 6d906d74bb..7eb0a369db 100644 --- a/internal/querynodev2/segments/mock_loader.go +++ b/internal/querynodev2/segments/mock_loader.go @@ -14,6 +14,10 @@ import ( pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" querypb "github.com/milvus-io/milvus/internal/proto/querypb" + + storage "github.com/milvus-io/milvus/internal/storage" + + typeutil "github.com/milvus-io/milvus/pkg/util/typeutil" ) // MockLoader is an autogenerated mock type for the Loader type @@ -101,6 +105,76 @@ func (_c *MockLoader_Load_Call) RunAndReturn(run func(context.Context, int64, co return _c } +// LoadBM25Stats provides a mock function with given fields: ctx, collectionID, infos +func (_m *MockLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) { + _va := make([]interface{}, len(infos)) + for _i := range infos { + _va[_i] = infos[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, collectionID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats] + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)); ok { + return rf(ctx, collectionID, infos...) + } + if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]); ok { + r0 = rf(ctx, collectionID, infos...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) error); ok { + r1 = rf(ctx, collectionID, infos...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockLoader_LoadBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadBM25Stats' +type MockLoader_LoadBM25Stats_Call struct { + *mock.Call +} + +// LoadBM25Stats is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +// - infos ...*querypb.SegmentLoadInfo +func (_e *MockLoader_Expecter) LoadBM25Stats(ctx interface{}, collectionID interface{}, infos ...interface{}) *MockLoader_LoadBM25Stats_Call { + return &MockLoader_LoadBM25Stats_Call{Call: _e.mock.On("LoadBM25Stats", + append([]interface{}{ctx, collectionID}, infos...)...)} +} + +func (_c *MockLoader_LoadBM25Stats_Call) Run(run func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo)) *MockLoader_LoadBM25Stats_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(*querypb.SegmentLoadInfo) + } + } + run(args[0].(context.Context), args[1].(int64), variadicArgs...) + }) + return _c +} + +func (_c *MockLoader_LoadBM25Stats_Call) Return(_a0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], _a1 error) *MockLoader_LoadBM25Stats_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockLoader_LoadBM25Stats_Call) RunAndReturn(run func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)) *MockLoader_LoadBM25Stats_Call { + _c.Call.Return(run) + return _c +} + // LoadBloomFilterSet provides a mock function with given fields: ctx, collectionID, version, infos func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) { _va := make([]interface{}, len(infos)) diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index a0d3df97ba..6253da14e7 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -290,6 +290,49 @@ func (_c *MockSegment_ExistIndex_Call) RunAndReturn(run func(int64) bool) *MockS return _c } +// GetBM25Stats provides a mock function with given fields: +func (_m *MockSegment) GetBM25Stats() map[int64]*storage.BM25Stats { + ret := _m.Called() + + var r0 map[int64]*storage.BM25Stats + if rf, ok := ret.Get(0).(func() map[int64]*storage.BM25Stats); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*storage.BM25Stats) + } + } + + return r0 +} + +// MockSegment_GetBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBM25Stats' +type MockSegment_GetBM25Stats_Call struct { + *mock.Call +} + +// GetBM25Stats is a helper method to define mock.On call +func (_e *MockSegment_Expecter) GetBM25Stats() *MockSegment_GetBM25Stats_Call { + return &MockSegment_GetBM25Stats_Call{Call: _e.mock.On("GetBM25Stats")} +} + +func (_c *MockSegment_GetBM25Stats_Call) Run(run func()) *MockSegment_GetBM25Stats_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSegment_GetBM25Stats_Call) Return(_a0 map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSegment_GetBM25Stats_Call) RunAndReturn(run func() map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call { + _c.Call.Return(run) + return _c +} + // GetIndex provides a mock function with given fields: fieldID func (_m *MockSegment) GetIndex(fieldID int64) *IndexedFieldInfo { ret := _m.Called(fieldID) @@ -1570,6 +1613,39 @@ func (_c *MockSegment_Unpin_Call) RunAndReturn(run func()) *MockSegment_Unpin_Ca return _c } +// UpdateBM25Stats provides a mock function with given fields: stats +func (_m *MockSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) { + _m.Called(stats) +} + +// MockSegment_UpdateBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBM25Stats' +type MockSegment_UpdateBM25Stats_Call struct { + *mock.Call +} + +// UpdateBM25Stats is a helper method to define mock.On call +// - stats map[int64]*storage.BM25Stats +func (_e *MockSegment_Expecter) UpdateBM25Stats(stats interface{}) *MockSegment_UpdateBM25Stats_Call { + return &MockSegment_UpdateBM25Stats_Call{Call: _e.mock.On("UpdateBM25Stats", stats)} +} + +func (_c *MockSegment_UpdateBM25Stats_Call) Run(run func(stats map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(map[int64]*storage.BM25Stats)) + }) + return _c +} + +func (_c *MockSegment_UpdateBM25Stats_Call) Return() *MockSegment_UpdateBM25Stats_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSegment_UpdateBM25Stats_Call) RunAndReturn(run func(map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call { + _c.Call.Return(run) + return _c +} + // UpdateBloomFilter provides a mock function with given fields: pks func (_m *MockSegment) UpdateBloomFilter(pks []storage.PrimaryKey) { _m.Called(pks) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index cce217f644..50a9a51ddc 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -91,6 +91,8 @@ type baseSegment struct { isLazyLoad bool channel metautil.Channel + bm25Stats map[int64]*storage.BM25Stats + resourceUsageCache *atomic.Pointer[ResourceUsage] needUpdatedVersion *atomic.Int64 // only for lazy load mode update index @@ -107,6 +109,7 @@ func newBaseSegment(collection *Collection, segmentType SegmentType, version int version: atomic.NewInt64(version), segmentType: segmentType, bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType), + bm25Stats: make(map[int64]*storage.BM25Stats), channel: channel, isLazyLoad: isLazyLoad(collection, segmentType), @@ -185,6 +188,20 @@ func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) { s.bloomFilterSet.UpdateBloomFilter(pks) } +func (s *baseSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) { + for fieldID, new := range stats { + if current, ok := s.bm25Stats[fieldID]; ok { + current.Merge(new) + } else { + s.bm25Stats[fieldID] = new + } + } +} + +func (s *baseSegment) GetBM25Stats() map[int64]*storage.BM25Stats { + return s.bm25Stats +} + // MayPkExist returns true if the given PK exists in the PK range and being positive through the bloom filter, // false otherwise, // may returns true even the PK doesn't exist actually diff --git a/internal/querynodev2/segments/segment_interface.go b/internal/querynodev2/segments/segment_interface.go index 9ad7ef219a..2bab889348 100644 --- a/internal/querynodev2/segments/segment_interface.go +++ b/internal/querynodev2/segments/segment_interface.go @@ -87,6 +87,10 @@ type Segment interface { MayPkExist(lc *storage.LocationsCache) bool BatchPkExist(lc *storage.BatchLocationsCache) []bool + // BM25 stats + UpdateBM25Stats(stats map[int64]*storage.BM25Stats) + GetBM25Stats() map[int64]*storage.BM25Stats + // Read operations Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 16b7f7548a..f498df722a 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -77,6 +77,9 @@ type Loader interface { // LoadBloomFilterSet loads needed statslog for RemoteSegment. LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) + // LoadBM25Stats loads BM25 statslog for RemoteSegment + LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) + // LoadIndex append index for segment and remove vector binlogs. LoadIndex(ctx context.Context, segment Segment, @@ -543,6 +546,47 @@ func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentTyp return nil } +func (loader *segmentLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) { + segmentNum := len(infos) + if segmentNum == 0 { + return nil, nil + } + + segments := lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) int64 { + return info.GetSegmentID() + }) + log.Info("start loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Int("segmentNum", segmentNum)) + + loadedStats := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]() + loadRemoteBM25Func := func(idx int) error { + loadInfo := infos[idx] + segmentID := loadInfo.SegmentID + stats := make(map[int64]*storage.BM25Stats) + + log.Info("loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64("segment", segmentID)) + logpaths := loader.filterBM25Stats(loadInfo.Bm25Logs) + err := loader.loadBm25Stats(ctx, segmentID, stats, logpaths) + if err != nil { + log.Warn("load remote segment bm25 stats failed", + zap.Int64("segmentID", segmentID), + zap.Error(err), + ) + return err + } + loadedStats.Insert(segmentID, stats) + return nil + } + + err := funcutil.ProcessFuncParallel(segmentNum, segmentNum, loadRemoteBM25Func, "loadRemoteBM25Func") + if err != nil { + // no partial success here + log.Warn("failed to load bm25 stats for remote segment", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Error(err)) + return nil, err + } + + return loadedStats, nil +} + func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) { log := log.Ctx(ctx).With( zap.Int64("collectionID", collectionID), @@ -826,6 +870,16 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, if err != nil { return err } + + if len(loadInfo.Bm25Logs) > 0 { + log.Info("loading bm25 stats...") + bm25StatsLogs := loader.filterBM25Stats(loadInfo.Bm25Logs) + + err = loader.loadBm25Stats(ctx, segment.ID(), segment.bm25Stats, bm25StatsLogs) + if err != nil { + return err + } + } } return nil } @@ -898,6 +952,26 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi return result, storage.DefaultStatsType } +func (loader *segmentLoader) filterBM25Stats(fieldBinlogs []*datapb.FieldBinlog) map[int64][]string { + result := make(map[int64][]string, 0) + for _, fieldBinlog := range fieldBinlogs { + logpaths := []string{} + for _, binlog := range fieldBinlog.GetBinlogs() { + _, logidx := path.Split(binlog.GetLogPath()) + // if special status log exist + // only load one file + if logidx == storage.CompoundStatsType.LogIdx() { + logpaths = []string{binlog.GetLogPath()} + break + } else { + logpaths = append(logpaths, binlog.GetLogPath()) + } + } + result[fieldBinlog.FieldID] = logpaths + } + return result +} + func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error { runningGroup, _ := errgroup.WithContext(ctx) for _, field := range fields { @@ -989,6 +1063,51 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS return segment.LoadIndex(ctx, indexInfo, fieldType) } +func (loader *segmentLoader) loadBm25Stats(ctx context.Context, segmentID int64, stats map[int64]*storage.BM25Stats, binlogPaths map[int64][]string) error { + log := log.Ctx(ctx).With( + zap.Int64("segmentID", segmentID), + ) + if len(binlogPaths) == 0 { + log.Info("there are no bm25 stats logs saved with segment") + return nil + } + + pathList := []string{} + fieldList := []int64{} + fieldOffset := []int{} + for fieldId, logpaths := range binlogPaths { + pathList = append(pathList, logpaths...) + fieldList = append(fieldList, fieldId) + fieldOffset = append(fieldOffset, len(logpaths)) + } + + startTs := time.Now() + values, err := loader.cm.MultiRead(ctx, pathList) + if err != nil { + return err + } + + cnt := 0 + for i, fieldID := range fieldList { + newStats, ok := stats[fieldID] + if !ok { + newStats = storage.NewBM25Stats() + stats[fieldID] = newStats + } + + for j := 0; j < fieldOffset[i]; j++ { + err := newStats.Deserialize(values[cnt+j]) + if err != nil { + return err + } + } + cnt += fieldOffset[i] + log.Info("Successfully load bm25 stats", zap.Duration("time", time.Since(startTs)), zap.Int64("numRow", newStats.NumRow()), zap.Int64("fieldID", fieldID)) + } + + return nil +} + func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet, binlogPaths []string, logType storage.StatsLogType, ) error { diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index ebb282d50e..bf36e0d012 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -70,14 +70,16 @@ func (suite *SegmentLoaderSuite) SetupSuite() { } func (suite *SegmentLoaderSuite) SetupTest() { - // Dependencies - suite.manager = NewManager() ctx := context.Background() + // TODO:: cpp chunk manager not support local chunk manager // suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath( // fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63()))) chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) + + // Dependencies + suite.manager = NewManager() suite.loader = NewLoader(suite.manager, suite.chunkManager) initcore.InitRemoteChunkManager(paramtable.Get()) @@ -92,6 +94,22 @@ func (suite *SegmentLoaderSuite) SetupTest() { suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta) } +func (suite *SegmentLoaderSuite) SetupBM25() { + // Dependencies + suite.manager = NewManager() + suite.loader = NewLoader(suite.manager, suite.chunkManager) + initcore.InitRemoteChunkManager(paramtable.Get()) + + suite.schema = GenTestBM25CollectionSchema("test") + indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema) + loadMeta := &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + } + suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta) +} + func (suite *SegmentLoaderSuite) TearDownTest() { ctx := context.Background() for i := 0; i < suite.segmentNum; i++ { @@ -407,6 +425,41 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() { } } +func (suite *SegmentLoaderSuite) TestLoadBm25Stats() { + suite.SetupBM25() + msgLength := 1 + sparseFieldID := simpleSparseFloatVectorField.id + loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum) + + for i := 0; i < suite.segmentNum; i++ { + segmentID := suite.segmentID + int64(i) + + bm25logs, err := SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager) + suite.NoError(err) + + loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{ + SegmentID: segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + Bm25Logs: []*datapb.FieldBinlog{bm25logs}, + NumOfRows: int64(msgLength), + InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), + }) + } + + statsMap, err := suite.loader.LoadBM25Stats(context.Background(), suite.collectionID, loadInfos...) + suite.NoError(err) + + for i := 0; i < suite.segmentNum; i++ { + segmentID := suite.segmentID + int64(i) + stats, ok := statsMap.Get(segmentID) + suite.True(ok) + fieldStats, ok := stats[sparseFieldID] + suite.True(ok) + suite.Equal(int64(msgLength), fieldStats.NumRow()) + } +} + func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() { ctx := context.Background() loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum) diff --git a/internal/storage/stats.go b/internal/storage/stats.go index f71930b935..9f4a2ea4db 100644 --- a/internal/storage/stats.go +++ b/internal/storage/stats.go @@ -21,10 +21,10 @@ import ( "encoding/binary" "encoding/json" "fmt" + "maps" "math" "go.uber.org/zap" - "golang.org/x/exp/maps" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/bloomfilter" @@ -347,7 +347,7 @@ func (m *BM25Stats) AppendFieldData(datas ...*SparseFloatVectorFieldData) { // Update BM25Stats by sparse vector bytes func (m *BM25Stats) AppendBytes(datas ...[]byte) { for _, data := range datas { - dim := len(data) / 8 + dim := typeutil.SparseFloatRowElementCount(data) for i := 0; i < dim; i++ { index := typeutil.SparseFloatRowIndexAt(data, i) value := typeutil.SparseFloatRowValueAt(data, i) @@ -454,17 +454,19 @@ func (m *BM25Stats) Deserialize(bs []byte) error { m.rowsWithToken[keys[i]] += values[i] } - log.Info("test-- deserialize", zap.Int64("numrow", m.numRow), zap.Int64("tokenNum", m.numToken)) return nil } -func (m *BM25Stats) BuildIDF(tf map[uint32]float32) map[uint32]float32 { - vector := make(map[uint32]float32) - for key, value := range tf { +func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) { + dim := typeutil.SparseFloatRowElementCount(tf) + idf = make([]byte, len(tf)) + for idx := 0; idx < dim; idx++ { + key := typeutil.SparseFloatRowIndexAt(tf, idx) + value := typeutil.SparseFloatRowValueAt(tf, idx) nq := m.rowsWithToken[key] - vector[key] = value * float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5))) + typeutil.SparseFloatRowSetAt(idf, idx, key, value*float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5)))) } - return vector + return } func (m *BM25Stats) GetAvgdl() float64 { diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 5021043182..5e444027dc 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -358,7 +358,7 @@ func readDoubleArray(blobReaders []io.Reader) []float64 { return ret } -func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema) (idata *InsertData, err error) { +func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema, skipFunction bool) (idata *InsertData, err error) { blobReaders := make([]io.Reader, 0) for _, blob := range msg.RowData { blobReaders = append(blobReaders, bytes.NewReader(blob.GetValue())) @@ -371,7 +371,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap } for _, field := range collSchema.Fields { - if field.GetIsFunctionOutput() { + if skipFunction && field.GetIsFunctionOutput() { continue } @@ -696,7 +696,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche func InsertMsgToInsertData(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) (idata *InsertData, err error) { if msg.IsRowBased() { - return RowBasedInsertMsgToInsertData(msg, schema) + return RowBasedInsertMsgToInsertData(msg, schema, true) } return ColumnBasedInsertMsgToInsertData(msg, schema) } @@ -1272,7 +1272,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msgstream.InsertMsg) (*segcorepb.InsertRecord, error) { if msg.IsRowBased() { - insertData, err := RowBasedInsertMsgToInsertData(msg, schema) + insertData, err := RowBasedInsertMsgToInsertData(msg, schema, false) if err != nil { return nil, err } @@ -1281,7 +1281,8 @@ func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msg // column base insert msg insertRecord := &segcorepb.InsertRecord{ - NumRows: int64(msg.NumRows), + NumRows: int64(msg.NumRows), + FieldsData: make([]*schemapb.FieldData, 0), } insertRecord.FieldsData = append(insertRecord.FieldsData, msg.FieldsData...) diff --git a/internal/storage/utils_test.go b/internal/storage/utils_test.go index 57c7cebd6c..03d6847754 100644 --- a/internal/storage/utils_test.go +++ b/internal/storage/utils_test.go @@ -1035,7 +1035,7 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) { fieldIDs = fieldIDs[:len(fieldIDs)-2] msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim) - idata, err := RowBasedInsertMsgToInsertData(msg, schema) + idata, err := RowBasedInsertMsgToInsertData(msg, schema, false) assert.NoError(t, err) for idx, fID := range fieldIDs { column := columns[idx] @@ -1096,7 +1096,7 @@ func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) { }, }, } - _, err := RowBasedInsertMsgToInsertData(msg, schema) + _, err := RowBasedInsertMsgToInsertData(msg, schema, false) assert.Error(t, err) } @@ -1139,7 +1139,7 @@ func TestRowBasedInsertMsgToInsertBFloat16VectorDataError(t *testing.T) { }, }, } - _, err := RowBasedInsertMsgToInsertData(msg, schema) + _, err := RowBasedInsertMsgToInsertData(msg, schema, false) assert.Error(t, err) } diff --git a/internal/util/function/mock_function.go b/internal/util/function/mock_function.go new file mode 100644 index 0000000000..0ccf8848d2 --- /dev/null +++ b/internal/util/function/mock_function.go @@ -0,0 +1,184 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package function + +import ( + schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + mock "github.com/stretchr/testify/mock" +) + +// MockFunctionRunner is an autogenerated mock type for the FunctionRunner type +type MockFunctionRunner struct { + mock.Mock +} + +type MockFunctionRunner_Expecter struct { + mock *mock.Mock +} + +func (_m *MockFunctionRunner) EXPECT() *MockFunctionRunner_Expecter { + return &MockFunctionRunner_Expecter{mock: &_m.Mock} +} + +// BatchRun provides a mock function with given fields: inputs +func (_m *MockFunctionRunner) BatchRun(inputs ...interface{}) ([]interface{}, error) { + var _ca []interface{} + _ca = append(_ca, inputs...) + ret := _m.Called(_ca...) + + var r0 []interface{} + var r1 error + if rf, ok := ret.Get(0).(func(...interface{}) ([]interface{}, error)); ok { + return rf(inputs...) + } + if rf, ok := ret.Get(0).(func(...interface{}) []interface{}); ok { + r0 = rf(inputs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]interface{}) + } + } + + if rf, ok := ret.Get(1).(func(...interface{}) error); ok { + r1 = rf(inputs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockFunctionRunner_BatchRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BatchRun' +type MockFunctionRunner_BatchRun_Call struct { + *mock.Call +} + +// BatchRun is a helper method to define mock.On call +// - inputs ...interface{} +func (_e *MockFunctionRunner_Expecter) BatchRun(inputs ...interface{}) *MockFunctionRunner_BatchRun_Call { + return &MockFunctionRunner_BatchRun_Call{Call: _e.mock.On("BatchRun", + append([]interface{}{}, inputs...)...)} +} + +func (_c *MockFunctionRunner_BatchRun_Call) Run(run func(inputs ...interface{})) *MockFunctionRunner_BatchRun_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-0) + for i, a := range args[0:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(variadicArgs...) + }) + return _c +} + +func (_c *MockFunctionRunner_BatchRun_Call) Return(_a0 []interface{}, _a1 error) *MockFunctionRunner_BatchRun_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockFunctionRunner_BatchRun_Call) RunAndReturn(run func(...interface{}) ([]interface{}, error)) *MockFunctionRunner_BatchRun_Call { + _c.Call.Return(run) + return _c +} + +// GetOutputFields provides a mock function with given fields: +func (_m *MockFunctionRunner) GetOutputFields() []*schemapb.FieldSchema { + ret := _m.Called() + + var r0 []*schemapb.FieldSchema + if rf, ok := ret.Get(0).(func() []*schemapb.FieldSchema); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*schemapb.FieldSchema) + } + } + + return r0 +} + +// MockFunctionRunner_GetOutputFields_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOutputFields' +type MockFunctionRunner_GetOutputFields_Call struct { + *mock.Call +} + +// GetOutputFields is a helper method to define mock.On call +func (_e *MockFunctionRunner_Expecter) GetOutputFields() *MockFunctionRunner_GetOutputFields_Call { + return &MockFunctionRunner_GetOutputFields_Call{Call: _e.mock.On("GetOutputFields")} +} + +func (_c *MockFunctionRunner_GetOutputFields_Call) Run(run func()) *MockFunctionRunner_GetOutputFields_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockFunctionRunner_GetOutputFields_Call) Return(_a0 []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockFunctionRunner_GetOutputFields_Call) RunAndReturn(run func() []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call { + _c.Call.Return(run) + return _c +} + +// GetSchema provides a mock function with given fields: +func (_m *MockFunctionRunner) GetSchema() *schemapb.FunctionSchema { + ret := _m.Called() + + var r0 *schemapb.FunctionSchema + if rf, ok := ret.Get(0).(func() *schemapb.FunctionSchema); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*schemapb.FunctionSchema) + } + } + + return r0 +} + +// MockFunctionRunner_GetSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSchema' +type MockFunctionRunner_GetSchema_Call struct { + *mock.Call +} + +// GetSchema is a helper method to define mock.On call +func (_e *MockFunctionRunner_Expecter) GetSchema() *MockFunctionRunner_GetSchema_Call { + return &MockFunctionRunner_GetSchema_Call{Call: _e.mock.On("GetSchema")} +} + +func (_c *MockFunctionRunner_GetSchema_Call) Run(run func()) *MockFunctionRunner_GetSchema_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockFunctionRunner_GetSchema_Call) Return(_a0 *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockFunctionRunner_GetSchema_Call) RunAndReturn(run func() *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call { + _c.Call.Return(run) + return _c +} + +// NewMockFunctionRunner creates a new instance of MockFunctionRunner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockFunctionRunner(t interface { + mock.TestingT + Cleanup(func()) +}) *MockFunctionRunner { + mock := &MockFunctionRunner{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/util/funcutil/placeholdergroup.go b/pkg/util/funcutil/placeholdergroup.go index 2fa66bdaea..910ee84909 100644 --- a/pkg/util/funcutil/placeholdergroup.go +++ b/pkg/util/funcutil/placeholdergroup.go @@ -6,12 +6,26 @@ import ( "math" "github.com/cockroachdb/errors" + "github.com/samber/lo" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) +func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte { + placeholderGroup := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{{ + Tag: "$0", + Type: commonpb.PlaceholderType_SparseFloatVector, + Values: contents, + }}, + } + + bytes, _ := proto.Marshal(placeholderGroup) + return bytes +} + func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) { placeholderValue, err := fieldDataToPlaceholderValue(fieldData) if err != nil { @@ -93,6 +107,14 @@ func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.Place Values: [][]byte{bytes}, } return placeholderValue, nil + case schemapb.DataType_VarChar: + strs := fieldData.GetScalars().GetStringData().GetData() + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_VarChar, + Values: lo.Map(strs, func(str string, _ int) []byte { return []byte(str) }), + } + return placeholderValue, nil default: return nil, errors.New("field is not a vector field") } @@ -157,3 +179,7 @@ func flattenedBFloat16VectorsToByteVectors(flattenedVectors []byte, dimension in return result } + +func GetVarCharFromPlaceholder(holder *commonpb.PlaceholderValue) []string { + return lo.Map(holder.Values, func(bytes []byte, _ int) string { return string(bytes) }) +} diff --git a/pkg/util/metric/metric_type.go b/pkg/util/metric/metric_type.go index 107e42cf09..37c885af68 100644 --- a/pkg/util/metric/metric_type.go +++ b/pkg/util/metric/metric_type.go @@ -36,4 +36,6 @@ const ( // SUPERSTRUCTURE represents superstructure distance SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE" + + BM25 MetricType = "BM25" ) diff --git a/pkg/util/metric/similarity_corelation.go b/pkg/util/metric/similarity_corelation.go index 3506a8fcb1..f34a5e3141 100644 --- a/pkg/util/metric/similarity_corelation.go +++ b/pkg/util/metric/similarity_corelation.go @@ -21,5 +21,5 @@ import "strings" // PositivelyRelated return if metricType are "ip" or "IP" func PositivelyRelated(metricType string) bool { mUpper := strings.ToUpper(metricType) - return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE) + return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE) || mUpper == strings.ToUpper(BM25) } diff --git a/tests/go_client/testcases/index_test.go b/tests/go_client/testcases/index_test.go index 309bc22578..891af44b54 100644 --- a/tests/go_client/testcases/index_test.go +++ b/tests/go_client/testcases/index_test.go @@ -212,7 +212,7 @@ func TestIndexAutoSparseVector(t *testing.T) { for _, unsupportedMt := range hp.UnsupportedSparseVecMetricsType { idx := index.NewAutoIndex(unsupportedMt) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idx)) - common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index") + common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index") } // auto index with different metric type on sparse vec @@ -829,11 +829,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) { for _, mt := range hp.UnsupportedSparseVecMetricsType { idxInverted := index.NewSparseInvertedIndex(mt, 0.2) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted)) - common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index") + common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index") idxWand := index.NewSparseWANDIndex(mt, 0.2) _, err = mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand)) - common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index") + common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index") } // create index with invalid drop_ratio_build