fix: bm25 search failed when avgdl == nan (#41502)

relate: https://github.com/milvus-io/milvus/issues/41490

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2025-04-27 17:34:38 +08:00 committed by GitHub
parent 12cde913b5
commit 3892451880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 37 deletions

View File

@ -954,11 +954,6 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
pkoracle.WithSegmentType(commonpb.SegmentState_Sealed),
pkoracle.WithWorkerID(targetNodeID),
)
if sd.idfOracle != nil {
for _, segment := range sealed {
sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Sealed)
}
}
}
if len(growing) > 0 {
sd.pkOracle.Remove(
@ -967,7 +962,7 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
)
if sd.idfOracle != nil {
for _, segment := range growing {
sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Growing)
sd.idfOracle.RemoveGrowing(segment.SegmentID)
}
}
}
@ -1126,6 +1121,10 @@ func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64,
return 0, err
}
if avgdl <= 0 {
return 0, nil
}
for _, idf := range idfSparseVector {
metrics.QueryNodeSearchFTSNumTokens.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(sd.collectionID), fmt.Sprint(req.GetFieldId())).Observe(float64(typeutil.SparseFloatRowElementCount(idf)))
}

View File

@ -40,7 +40,7 @@ type IDFOracle interface {
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
Remove(segmentID int64, state commonpb.SegmentState)
RemoveGrowing(segmentID int64)
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
}
@ -49,6 +49,7 @@ type bm25Stats struct {
stats map[int64]*storage.BM25Stats
activate bool
targetVersion int64
snapVersion int64
}
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
@ -133,7 +134,7 @@ func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats
o.sealed[segmentID] = &bm25Stats{
stats: stats,
activate: false,
targetVersion: initialTargetVersion,
targetVersion: unreadableTargetVersion,
}
default:
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
@ -160,28 +161,16 @@ func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25
}
}
func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) {
func (o *idfOracle) RemoveGrowing(segmentID int64) {
o.Lock()
defer o.Unlock()
switch state {
case segments.SegmentTypeGrowing:
if stats, ok := o.growing[segmentID]; ok {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.growing, segmentID)
}
case segments.SegmentTypeSealed:
if stats, ok := o.sealed[segmentID]; ok {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.sealed, segmentID)
}
default:
return
}
}
func (o *idfOracle) activate(stats *bm25Stats) {
@ -207,7 +196,10 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
}
if stats, ok := o.sealed[segment.SegmentID]; ok {
if stats.targetVersion < segment.TargetVersion {
stats.targetVersion = segment.TargetVersion
}
stats.snapVersion = snapshot.version
} else {
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
}
@ -224,11 +216,14 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
o.targetVersion = snapshot.targetVersion
for _, stats := range o.sealed {
for segmentID, stats := range o.sealed {
if !stats.activate && stats.targetVersion == o.targetVersion {
o.activate(stats)
} else if stats.activate && stats.targetVersion != o.targetVersion {
o.deactivate(stats)
} else if (stats.targetVersion < o.targetVersion && stats.targetVersion != unreadableTargetVersion) || stats.snapVersion != snapshot.version {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.sealed, segmentID)
}
}

View File

@ -132,10 +132,6 @@ func (suite *IDFOracleSuite) TestSealed() {
suite.idfOracle.SyncDistribution(suite.snapshot)
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
for _, segID := range releasedSeg {
suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed)
}
sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1})
bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1})
suite.NoError(err)
@ -165,10 +161,6 @@ func (suite *IDFOracleSuite) TestGrow() {
suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6))
suite.Equal(int64(2), suite.idfOracle.current.NumRow())
for _, segID := range releasedSeg {
suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing)
}
}
func (suite *IDFOracleSuite) TestStats() {

View File

@ -470,6 +470,9 @@ func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) {
}
func (m *BM25Stats) GetAvgdl() float64 {
if m.numRow == 0 || m.numToken == 0 {
return 0
}
return float64(m.numToken) / float64(m.numRow)
}