fix: bm25 search failed when avgdl == nan (#41502)

relate: https://github.com/milvus-io/milvus/issues/41490

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2025-04-27 17:34:38 +08:00 committed by GitHub
parent 12cde913b5
commit 3892451880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 37 deletions

View File

@ -954,11 +954,6 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
pkoracle.WithSegmentType(commonpb.SegmentState_Sealed), pkoracle.WithSegmentType(commonpb.SegmentState_Sealed),
pkoracle.WithWorkerID(targetNodeID), pkoracle.WithWorkerID(targetNodeID),
) )
if sd.idfOracle != nil {
for _, segment := range sealed {
sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Sealed)
}
}
} }
if len(growing) > 0 { if len(growing) > 0 {
sd.pkOracle.Remove( sd.pkOracle.Remove(
@ -967,7 +962,7 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
) )
if sd.idfOracle != nil { if sd.idfOracle != nil {
for _, segment := range growing { for _, segment := range growing {
sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Growing) sd.idfOracle.RemoveGrowing(segment.SegmentID)
} }
} }
} }
@ -1126,6 +1121,10 @@ func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64,
return 0, err return 0, err
} }
if avgdl <= 0 {
return 0, nil
}
for _, idf := range idfSparseVector { for _, idf := range idfSparseVector {
metrics.QueryNodeSearchFTSNumTokens.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(sd.collectionID), fmt.Sprint(req.GetFieldId())).Observe(float64(typeutil.SparseFloatRowElementCount(idf))) metrics.QueryNodeSearchFTSNumTokens.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(sd.collectionID), fmt.Sprint(req.GetFieldId())).Observe(float64(typeutil.SparseFloatRowElementCount(idf)))
} }

View File

@ -40,7 +40,7 @@ type IDFOracle interface {
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
Remove(segmentID int64, state commonpb.SegmentState) RemoveGrowing(segmentID int64)
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
} }
@ -49,6 +49,7 @@ type bm25Stats struct {
stats map[int64]*storage.BM25Stats stats map[int64]*storage.BM25Stats
activate bool activate bool
targetVersion int64 targetVersion int64
snapVersion int64
} }
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) { func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
@ -133,7 +134,7 @@ func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats
o.sealed[segmentID] = &bm25Stats{ o.sealed[segmentID] = &bm25Stats{
stats: stats, stats: stats,
activate: false, activate: false,
targetVersion: initialTargetVersion, targetVersion: unreadableTargetVersion,
} }
default: default:
log.Warn("register segment with unknown state", zap.String("stats", state.String())) log.Warn("register segment with unknown state", zap.String("stats", state.String()))
@ -160,27 +161,15 @@ func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25
} }
} }
func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) { func (o *idfOracle) RemoveGrowing(segmentID int64) {
o.Lock() o.Lock()
defer o.Unlock() defer o.Unlock()
switch state { if stats, ok := o.growing[segmentID]; ok {
case segments.SegmentTypeGrowing: if stats.activate {
if stats, ok := o.growing[segmentID]; ok { o.current.Minus(stats.stats)
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.growing, segmentID)
} }
case segments.SegmentTypeSealed: delete(o.growing, segmentID)
if stats, ok := o.sealed[segmentID]; ok {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.sealed, segmentID)
}
default:
return
} }
} }
@ -207,7 +196,10 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
} }
if stats, ok := o.sealed[segment.SegmentID]; ok { if stats, ok := o.sealed[segment.SegmentID]; ok {
stats.targetVersion = segment.TargetVersion if stats.targetVersion < segment.TargetVersion {
stats.targetVersion = segment.TargetVersion
}
stats.snapVersion = snapshot.version
} else { } else {
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID)) log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
} }
@ -224,11 +216,14 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
o.targetVersion = snapshot.targetVersion o.targetVersion = snapshot.targetVersion
for _, stats := range o.sealed { for segmentID, stats := range o.sealed {
if !stats.activate && stats.targetVersion == o.targetVersion { if !stats.activate && stats.targetVersion == o.targetVersion {
o.activate(stats) o.activate(stats)
} else if stats.activate && stats.targetVersion != o.targetVersion { } else if (stats.targetVersion < o.targetVersion && stats.targetVersion != unreadableTargetVersion) || stats.snapVersion != snapshot.version {
o.deactivate(stats) if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.sealed, segmentID)
} }
} }

View File

@ -132,10 +132,6 @@ func (suite *IDFOracleSuite) TestSealed() {
suite.idfOracle.SyncDistribution(suite.snapshot) suite.idfOracle.SyncDistribution(suite.snapshot)
suite.Equal(int64(1), suite.idfOracle.current.NumRow()) suite.Equal(int64(1), suite.idfOracle.current.NumRow())
for _, segID := range releasedSeg {
suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed)
}
sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1}) sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1})
bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1}) bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1})
suite.NoError(err) suite.NoError(err)
@ -165,10 +161,6 @@ func (suite *IDFOracleSuite) TestGrow() {
suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6)) suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6))
suite.Equal(int64(2), suite.idfOracle.current.NumRow()) suite.Equal(int64(2), suite.idfOracle.current.NumRow())
for _, segID := range releasedSeg {
suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing)
}
} }
func (suite *IDFOracleSuite) TestStats() { func (suite *IDFOracleSuite) TestStats() {

View File

@ -470,6 +470,9 @@ func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) {
} }
func (m *BM25Stats) GetAvgdl() float64 { func (m *BM25Stats) GetAvgdl() float64 {
if m.numRow == 0 || m.numToken == 0 {
return 0
}
return float64(m.numToken) / float64(m.numRow) return float64(m.numToken) / float64(m.numRow)
} }