mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
fix: bm25 search failed when avgdl == nan (#41502)
relate: https://github.com/milvus-io/milvus/issues/41490 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
12cde913b5
commit
3892451880
@ -954,11 +954,6 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
|
|||||||
pkoracle.WithSegmentType(commonpb.SegmentState_Sealed),
|
pkoracle.WithSegmentType(commonpb.SegmentState_Sealed),
|
||||||
pkoracle.WithWorkerID(targetNodeID),
|
pkoracle.WithWorkerID(targetNodeID),
|
||||||
)
|
)
|
||||||
if sd.idfOracle != nil {
|
|
||||||
for _, segment := range sealed {
|
|
||||||
sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Sealed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if len(growing) > 0 {
|
if len(growing) > 0 {
|
||||||
sd.pkOracle.Remove(
|
sd.pkOracle.Remove(
|
||||||
@ -967,7 +962,7 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
|
|||||||
)
|
)
|
||||||
if sd.idfOracle != nil {
|
if sd.idfOracle != nil {
|
||||||
for _, segment := range growing {
|
for _, segment := range growing {
|
||||||
sd.idfOracle.Remove(segment.SegmentID, commonpb.SegmentState_Growing)
|
sd.idfOracle.RemoveGrowing(segment.SegmentID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1126,6 +1121,10 @@ func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64,
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if avgdl <= 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, idf := range idfSparseVector {
|
for _, idf := range idfSparseVector {
|
||||||
metrics.QueryNodeSearchFTSNumTokens.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(sd.collectionID), fmt.Sprint(req.GetFieldId())).Observe(float64(typeutil.SparseFloatRowElementCount(idf)))
|
metrics.QueryNodeSearchFTSNumTokens.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(sd.collectionID), fmt.Sprint(req.GetFieldId())).Observe(float64(typeutil.SparseFloatRowElementCount(idf)))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,7 +40,7 @@ type IDFOracle interface {
|
|||||||
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
|
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
|
||||||
|
|
||||||
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
|
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
|
||||||
Remove(segmentID int64, state commonpb.SegmentState)
|
RemoveGrowing(segmentID int64)
|
||||||
|
|
||||||
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
|
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
|
||||||
}
|
}
|
||||||
@ -49,6 +49,7 @@ type bm25Stats struct {
|
|||||||
stats map[int64]*storage.BM25Stats
|
stats map[int64]*storage.BM25Stats
|
||||||
activate bool
|
activate bool
|
||||||
targetVersion int64
|
targetVersion int64
|
||||||
|
snapVersion int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
|
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
|
||||||
@ -133,7 +134,7 @@ func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats
|
|||||||
o.sealed[segmentID] = &bm25Stats{
|
o.sealed[segmentID] = &bm25Stats{
|
||||||
stats: stats,
|
stats: stats,
|
||||||
activate: false,
|
activate: false,
|
||||||
targetVersion: initialTargetVersion,
|
targetVersion: unreadableTargetVersion,
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
|
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
|
||||||
@ -160,27 +161,15 @@ func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) {
|
func (o *idfOracle) RemoveGrowing(segmentID int64) {
|
||||||
o.Lock()
|
o.Lock()
|
||||||
defer o.Unlock()
|
defer o.Unlock()
|
||||||
|
|
||||||
switch state {
|
if stats, ok := o.growing[segmentID]; ok {
|
||||||
case segments.SegmentTypeGrowing:
|
if stats.activate {
|
||||||
if stats, ok := o.growing[segmentID]; ok {
|
o.current.Minus(stats.stats)
|
||||||
if stats.activate {
|
|
||||||
o.current.Minus(stats.stats)
|
|
||||||
}
|
|
||||||
delete(o.growing, segmentID)
|
|
||||||
}
|
}
|
||||||
case segments.SegmentTypeSealed:
|
delete(o.growing, segmentID)
|
||||||
if stats, ok := o.sealed[segmentID]; ok {
|
|
||||||
if stats.activate {
|
|
||||||
o.current.Minus(stats.stats)
|
|
||||||
}
|
|
||||||
delete(o.sealed, segmentID)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,7 +196,10 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if stats, ok := o.sealed[segment.SegmentID]; ok {
|
if stats, ok := o.sealed[segment.SegmentID]; ok {
|
||||||
stats.targetVersion = segment.TargetVersion
|
if stats.targetVersion < segment.TargetVersion {
|
||||||
|
stats.targetVersion = segment.TargetVersion
|
||||||
|
}
|
||||||
|
stats.snapVersion = snapshot.version
|
||||||
} else {
|
} else {
|
||||||
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
|
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
|
||||||
}
|
}
|
||||||
@ -224,11 +216,14 @@ func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
|
|||||||
|
|
||||||
o.targetVersion = snapshot.targetVersion
|
o.targetVersion = snapshot.targetVersion
|
||||||
|
|
||||||
for _, stats := range o.sealed {
|
for segmentID, stats := range o.sealed {
|
||||||
if !stats.activate && stats.targetVersion == o.targetVersion {
|
if !stats.activate && stats.targetVersion == o.targetVersion {
|
||||||
o.activate(stats)
|
o.activate(stats)
|
||||||
} else if stats.activate && stats.targetVersion != o.targetVersion {
|
} else if (stats.targetVersion < o.targetVersion && stats.targetVersion != unreadableTargetVersion) || stats.snapVersion != snapshot.version {
|
||||||
o.deactivate(stats)
|
if stats.activate {
|
||||||
|
o.current.Minus(stats.stats)
|
||||||
|
}
|
||||||
|
delete(o.sealed, segmentID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -132,10 +132,6 @@ func (suite *IDFOracleSuite) TestSealed() {
|
|||||||
suite.idfOracle.SyncDistribution(suite.snapshot)
|
suite.idfOracle.SyncDistribution(suite.snapshot)
|
||||||
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
|
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
|
||||||
|
|
||||||
for _, segID := range releasedSeg {
|
|
||||||
suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed)
|
|
||||||
}
|
|
||||||
|
|
||||||
sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1})
|
sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1})
|
||||||
bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1})
|
bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1})
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
@ -165,10 +161,6 @@ func (suite *IDFOracleSuite) TestGrow() {
|
|||||||
|
|
||||||
suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6))
|
suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6))
|
||||||
suite.Equal(int64(2), suite.idfOracle.current.NumRow())
|
suite.Equal(int64(2), suite.idfOracle.current.NumRow())
|
||||||
|
|
||||||
for _, segID := range releasedSeg {
|
|
||||||
suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *IDFOracleSuite) TestStats() {
|
func (suite *IDFOracleSuite) TestStats() {
|
||||||
|
|||||||
@ -470,6 +470,9 @@ func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *BM25Stats) GetAvgdl() float64 {
|
func (m *BM25Stats) GetAvgdl() float64 {
|
||||||
|
if m.numRow == 0 || m.numToken == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
return float64(m.numToken) / float64(m.numRow)
|
return float64(m.numToken) / float64(m.numRow)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user