From 389245188079f5ccac67861c64e7e041c4c26aa9 Mon Sep 17 00:00:00 2001 From: aoiasd <45024769+aoiasd@users.noreply.github.com> Date: Sun, 27 Apr 2025 17:34:38 +0800 Subject: [PATCH] fix: bm25 search failed when avgdl == nan (#41502) relate: https://github.com/milvus-io/milvus/issues/41490 --------- Signed-off-by: aoiasd --- .../querynodev2/delegator/delegator_data.go | 11 +++-- internal/querynodev2/delegator/idf_oracle.go | 41 ++++++++----------- .../querynodev2/delegator/idf_oracle_test.go | 8 ---- internal/storage/stats.go | 3 ++ 4 files changed, 26 insertions(+), 37 deletions(-) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index dd4a856038..b0e75284af 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -954,11 +954,6 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele pkoracle.WithSegmentType(commonpb.SegmentState_Sealed), pkoracle.WithWorkerID(targetNodeID), ) - if sd.idfOracle != nil { - for _, segment := range sealed { - sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Sealed) - } - } } if len(growing) > 0 { sd.pkOracle.Remove( @@ -967,7 +962,7 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele ) if sd.idfOracle != nil { for _, segment := range growing { - sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Growing) + sd.idfOracle.RemoveGrowing(segment.SegmentID) } } } @@ -1126,6 +1121,10 @@ func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64, return 0, err } + if avgdl <= 0 { + return 0, nil + } + for _, idf := range idfSparseVector { metrics.QueryNodeSearchFTSNumTokens.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(sd.collectionID), fmt.Sprint(req.GetFieldId())).Observe(float64(typeutil.SparseFloatRowElementCount(idf))) } diff --git a/internal/querynodev2/delegator/idf_oracle.go b/internal/querynodev2/delegator/idf_oracle.go index f4a6008c0f..7a92c5b555 100644 --- a/internal/querynodev2/delegator/idf_oracle.go +++ b/internal/querynodev2/delegator/idf_oracle.go @@ -40,7 +40,7 @@ type IDFOracle interface { UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) - Remove(segmentID int64, state commonpb.SegmentState) + RemoveGrowing(segmentID int64) BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) } @@ -49,6 +49,7 @@ type bm25Stats struct { stats map[int64]*storage.BM25Stats activate bool targetVersion int64 + snapVersion int64 } func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) { @@ -133,7 +134,7 @@ func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats o.sealed[segmentID] = &bm25Stats{ stats: stats, activate: false, - targetVersion: initialTargetVersion, + targetVersion: unreadableTargetVersion, } default: log.Warn("register segment with unknown state", zap.String("stats", state.String())) @@ -160,27 +161,15 @@ func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25 } } -func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) { +func (o *idfOracle) RemoveGrowing(segmentID int64) { o.Lock() defer o.Unlock() - switch state { - case segments.SegmentTypeGrowing: - if stats, ok := o.growing[segmentID]; ok { - if stats.activate { - o.current.Minus(stats.stats) - } - delete(o.growing, segmentID) + if stats, ok := o.growing[segmentID]; ok { + if stats.activate { + o.current.Minus(stats.stats) } - case segments.SegmentTypeSealed: - if stats, ok := o.sealed[segmentID]; ok { - if stats.activate { - o.current.Minus(stats.stats) - } - delete(o.sealed, segmentID) - } - default: - return + delete(o.growing, segmentID) } } @@ -207,7 +196,10 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) { } if stats, ok := o.sealed[segment.SegmentID]; ok { - stats.targetVersion = segment.TargetVersion + if stats.targetVersion < segment.TargetVersion { + stats.targetVersion = segment.TargetVersion + } + stats.snapVersion = snapshot.version } else { log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID)) } @@ -224,11 +216,14 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) { o.targetVersion = snapshot.targetVersion - for _, stats := range o.sealed { + for segmentID, stats := range o.sealed { if !stats.activate && stats.targetVersion == o.targetVersion { o.activate(stats) - } else if stats.activate && stats.targetVersion != o.targetVersion { - o.deactivate(stats) + } else if (stats.targetVersion < o.targetVersion && stats.targetVersion != unreadableTargetVersion) || stats.snapVersion != snapshot.version { + if stats.activate { + o.current.Minus(stats.stats) + } + delete(o.sealed, segmentID) } } diff --git a/internal/querynodev2/delegator/idf_oracle_test.go b/internal/querynodev2/delegator/idf_oracle_test.go index 52881147f2..390d3af9a9 100644 --- a/internal/querynodev2/delegator/idf_oracle_test.go +++ b/internal/querynodev2/delegator/idf_oracle_test.go @@ -132,10 +132,6 @@ func (suite *IDFOracleSuite) TestSealed() { suite.idfOracle.SyncDistribution(suite.snapshot) suite.Equal(int64(1), suite.idfOracle.current.NumRow()) - for _, segID := range releasedSeg { - suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed) - } - sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1}) bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1}) suite.NoError(err) @@ -165,10 +161,6 @@ func (suite *IDFOracleSuite) TestGrow() { suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6)) suite.Equal(int64(2), suite.idfOracle.current.NumRow()) - - for _, segID := range releasedSeg { - suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing) - } } func (suite *IDFOracleSuite) TestStats() { diff --git a/internal/storage/stats.go b/internal/storage/stats.go index 0a87c153b3..42c365105f 100644 --- a/internal/storage/stats.go +++ b/internal/storage/stats.go @@ -470,6 +470,9 @@ func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) { } func (m *BM25Stats) GetAvgdl() float64 { + if m.numRow == 0 || m.numToken == 0 { + return 0 + } return float64(m.numToken) / float64(m.numRow) }