enhance: support preload sealed segment bm25 stats and optimize bm25 stats serialize (#44279)

relate: https://github.com/milvus-io/milvus/issues/41424

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2025-09-29 16:35:05 +08:00 committed by GitHub
parent ca1cc7c9f3
commit 78ee76f018
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 157 additions and 65 deletions

View File

@ -17,15 +17,17 @@
package delegator package delegator
import ( import (
"bufio"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io/fs"
"os" "os"
"path" "path"
"sync" "sync"
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
@ -106,40 +108,47 @@ type sealedBm25Stats struct {
removed bool removed bool
segmentID int64 segmentID int64
ts time.Time // Time of segemnt register, all segment resgister after target generate will don't remove ts time.Time // Time of segemnt register, all segment resgister after target generate will don't remove
localPath string localDir string
fieldList []int64 // bm25 field list
} }
func (s *sealedBm25Stats) writeFile(localPath string) (error, bool) { func (s *sealedBm25Stats) writeFile(localDir string) (error, bool) {
s.RLock() s.RLock()
if s.removed { if s.removed || !s.inmemory {
return nil, true return nil, true
} }
m := make(map[int64][]byte, len(s.bm25Stats))
stats := s.bm25Stats stats := s.bm25Stats
s.RUnlock() s.RUnlock()
err := os.MkdirAll(localDir, fs.ModePerm)
if err != nil {
return err, false
}
// RUnlock when stats serialize and write to file // RUnlock when stats serialize and write to file
// to avoid block remove stats too long when sync distribution // to avoid block remove stats too long when sync distribution
for fieldID, stats := range stats { for fieldID, stats := range stats {
bytes, err := stats.Serialize() file, err := os.Create(path.Join(localDir, fmt.Sprintf("%d.data", fieldID)))
if err != nil { if err != nil {
return err, false return err, false
} }
m[fieldID] = bytes defer file.Close()
} writer := bufio.NewWriter(file)
b, err := json.Marshal(m) err = stats.SerializeToWriter(writer)
if err != nil { if err != nil {
return err, false return err, false
} }
err = os.WriteFile(localPath, b, 0o600) err = writer.Flush()
if err != nil { if err != nil {
return err, false return err, false
} }
}
return nil, false return nil, false
} }
@ -153,22 +162,25 @@ func (s *sealedBm25Stats) ShouldOffLoadToDisk() bool {
// so that later when the segment is removed from target, we can Minus its stats. To reduce memory usage, // so that later when the segment is removed from target, we can Minus its stats. To reduce memory usage,
// idfOracle store such per segment stats to disk, and load them when removing the segment. // idfOracle store such per segment stats to disk, and load them when removing the segment.
func (s *sealedBm25Stats) ToLocal(dirPath string) error { func (s *sealedBm25Stats) ToLocal(dirPath string) error {
localpath := path.Join(dirPath, fmt.Sprintf("%d.data", s.segmentID)) dir := path.Join(dirPath, fmt.Sprint(s.segmentID))
if err, skip := s.writeFile(dir); err != nil {
if err, skip := s.writeFile(localpath); err != nil || skip { os.RemoveAll(dir)
return err return err
} else if skip {
return nil
} }
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
s.fieldList = lo.Keys(s.bm25Stats)
s.inmemory = false s.inmemory = false
s.bm25Stats = nil s.bm25Stats = nil
s.localPath = localpath s.localDir = dir
if s.removed { if s.removed {
err := os.Remove(s.localPath) err := os.RemoveAll(s.localDir)
if err != nil { if err != nil {
log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localPath)) log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localDir))
} }
} }
return nil return nil
@ -180,9 +192,9 @@ func (s *sealedBm25Stats) Remove() {
s.removed = true s.removed = true
if !s.inmemory { if !s.inmemory {
err := os.Remove(s.localPath) err := os.RemoveAll(s.localDir)
if err != nil { if err != nil {
log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localPath)) log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localDir))
} }
} }
} }
@ -197,25 +209,21 @@ func (s *sealedBm25Stats) FetchStats() (map[int64]*storage.BM25Stats, error) {
return s.bm25Stats, nil return s.bm25Stats, nil
} }
b, err := os.ReadFile(s.localPath)
if err != nil {
return nil, err
}
m := make(map[int64][]byte)
err = json.Unmarshal(b, &m)
if err != nil {
return nil, err
}
stats := make(map[int64]*storage.BM25Stats) stats := make(map[int64]*storage.BM25Stats)
for fieldID, bytes := range m { for _, fieldID := range s.fieldList {
stats[fieldID] = storage.NewBM25Stats() path := path.Join(s.localDir, fmt.Sprintf("%d.data", fieldID))
err = stats[fieldID].Deserialize(bytes) b, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, errors.Newf("read local file %s: failed: %v", path, err)
}
stats[fieldID] = storage.NewBM25Stats()
err = stats[fieldID].Deserialize(b)
if err != nil {
return nil, errors.Newf("deserialize local file : %s failed: %v", path, err)
} }
} }
return stats, nil return stats, nil
} }
@ -278,10 +286,29 @@ type idfOracle struct {
wg sync.WaitGroup wg sync.WaitGroup
} }
// now only used for test
func (o *idfOracle) TargetVersion() int64 { func (o *idfOracle) TargetVersion() int64 {
return o.targetVersion.Load() return o.targetVersion.Load()
} }
func (o *idfOracle) preloadSealed(segmentID int64, stats bm25Stats) {
o.Lock()
defer o.Unlock()
// skip preload if first target was loaded.
if o.targetVersion.Load() != 0 {
return
}
o.sealed.Insert(segmentID, &sealedBm25Stats{
bm25Stats: stats,
ts: time.Now(),
activate: atomic.NewBool(true),
inmemory: true,
segmentID: segmentID,
})
o.current.Merge(stats)
}
func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.SegmentState) { func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.SegmentState) {
switch state { switch state {
case segments.SegmentTypeGrowing: case segments.SegmentTypeGrowing:
@ -300,6 +327,11 @@ func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.Se
if ok := o.sealed.Contain(segmentID); ok { if ok := o.sealed.Contain(segmentID); ok {
return return
} }
// preload sealed segment to channel before first target
if o.targetVersion.Load() == 0 {
o.preloadSealed(segmentID, stats)
} else {
o.sealed.Insert(segmentID, &sealedBm25Stats{ o.sealed.Insert(segmentID, &sealedBm25Stats{
bm25Stats: stats, bm25Stats: stats,
ts: time.Now(), ts: time.Now(),
@ -307,6 +339,7 @@ func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.Se
inmemory: true, inmemory: true,
segmentID: segmentID, segmentID: segmentID,
}) })
}
default: default:
log.Warn("register segment with unknown state", zap.String("stats", state.String())) log.Warn("register segment with unknown state", zap.String("stats", state.String()))
return return
@ -447,6 +480,7 @@ func (o *idfOracle) localloop() {
} }
// WARN: SyncDistribution not concurrent safe. // WARN: SyncDistribution not concurrent safe.
// SyncDistribution sync current target to idf oracle.
func (o *idfOracle) SyncDistribution() error { func (o *idfOracle) SyncDistribution() error {
snapshot, snapshotTs := o.next.GetSnapshot() snapshot, snapshotTs := o.next.GetSnapshot()
if snapshot.targetVersion <= o.targetVersion.Load() { if snapshot.targetVersion <= o.targetVersion.Load() {
@ -457,18 +491,20 @@ func (o *idfOracle) SyncDistribution() error {
sealedMap := map[int64]bool{} // sealed diff map, activate segment stats if true, and remove if not in map sealedMap := map[int64]bool{} // sealed diff map, activate segment stats if true, and remove if not in map
// only remain current target segment and unknown version segment in snapshot.
for _, item := range sealed { for _, item := range sealed {
for _, segment := range item.Segments { for _, segment := range item.Segments {
if segment.Level == datapb.SegmentLevel_L0 { if segment.Level == datapb.SegmentLevel_L0 {
continue continue
} }
if segment.TargetVersion == snapshot.targetVersion { switch segment.TargetVersion {
case snapshot.targetVersion:
sealedMap[segment.SegmentID] = true sealedMap[segment.SegmentID] = true
if !o.sealed.Contain(segment.SegmentID) { if !o.sealed.Contain(segment.SegmentID) {
log.Warn("idf oracle lack some sealed segment", zap.Int64("segment", segment.SegmentID)) log.Warn("idf oracle lack some sealed segment", zap.Int64("segment", segment.SegmentID))
} }
} else if segment.TargetVersion == unreadableTargetVersion { case unreadableTargetVersion:
sealedMap[segment.SegmentID] = false sealedMap[segment.SegmentID] = false
} }
} }
@ -478,19 +514,24 @@ func (o *idfOracle) SyncDistribution() error {
var rangeErr error var rangeErr error
o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool { o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool {
activate, ok := sealedMap[segmentID] // segment was unreadable if in snapshot but not in target.
statsActivate := stats.activate.Load() intarget, insnap := sealedMap[segmentID]
if ok && activate && !statsActivate { activate := stats.activate.Load()
// activate segment if segment in target
if insnap && intarget && !activate {
stats, err := stats.FetchStats() stats, err := stats.FetchStats()
if err != nil { if err != nil {
rangeErr = err rangeErr = fmt.Errorf("fetch stats failed with error: %v", err)
return false return false
} }
diff.Merge(stats) diff.Merge(stats)
} else if !ok && statsActivate { } else
// deactivate segment if segment not in snapshot
// or deactivate segment if segment unreadable (only exist at preload segment)
if (!insnap || (insnap && !intarget)) && activate {
stats, err := stats.FetchStats() stats, err := stats.FetchStats()
if err != nil { if err != nil {
rangeErr = err rangeErr = fmt.Errorf("fetch stats failed with error: %v", err)
return false return false
} }
diff.Minus(stats) diff.Minus(stats)
@ -518,22 +559,30 @@ func (o *idfOracle) SyncDistribution() error {
// remove sealed segment not in target // remove sealed segment not in target
o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool { o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool {
activate, ok := sealedMap[segmentID] intarget, insnap := sealedMap[segmentID]
statsActivate := stats.activate.Load() activate := stats.activate.Load()
if !ok && stats.ts.Before(snapshotTs) { // remove if segment not in snapshot
// and add before snapshot
if !insnap && stats.ts.Before(snapshotTs) {
stats.Remove() stats.Remove()
o.sealed.Remove(segmentID) o.sealed.Remove(segmentID)
} }
if ok && activate && !statsActivate { // save activate if segment in target.
if insnap && intarget && !activate {
stats.activate.Store(true) stats.activate.Store(true)
} }
// deactivate if segment unreadable.
if insnap && !intarget && activate {
stats.activate.Store(false)
}
return true return true
}) })
o.targetVersion.Store(snapshot.targetVersion) o.targetVersion.Store(snapshot.targetVersion)
o.NotifyLocal() o.NotifyLocal()
log.Ctx(context.TODO()).Info("sync distribution finished", zap.Int64("version", snapshot.targetVersion), zap.Int64("numrow", o.current.NumRow()), zap.Int("growing", len(o.growing)), zap.Int("sealed", o.sealed.Len())) log.Ctx(context.TODO()).Info("sync idf distribution finished", zap.Int64("version", snapshot.targetVersion), zap.Int64("numrow", o.current.NumRow()), zap.Int("growing", len(o.growing)), zap.Int("sealed", o.sealed.Len()))
return nil return nil
} }

View File

@ -136,14 +136,21 @@ func (suite *IDFOracleSuite) TestSealed() {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
} }
// register sealed segment but all deactvate // some sealed not in target
suite.Zero(suite.idfOracle.current.NumRow()) invalidSealedSegs := []int64{5, 6}
for _, segID := range invalidSealedSegs {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
}
// update and sync snapshot make all sealed activate // register sealed segment and all preload to current
suite.Equal(int64(len(sealedSegs)+len(invalidSealedSegs)), suite.idfOracle.current.NumRow())
// update and sync snapshot make all sealed in target activate
// and invalid sealed segemnt deactivate
suite.updateSnapshot(sealedSegs, []int64{}, []int64{}) suite.updateSnapshot(sealedSegs, []int64{}, []int64{})
suite.idfOracle.SetNext(suite.snapshot) suite.idfOracle.SetNext(suite.snapshot)
suite.waitTargetVersion(suite.targetVersion) suite.waitTargetVersion(suite.targetVersion)
suite.Equal(int64(4), suite.idfOracle.current.NumRow()) suite.Equal(int64(len(sealedSegs)), suite.idfOracle.current.NumRow())
releasedSeg := []int64{1, 2, 3} releasedSeg := []int64{1, 2, 3}
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg) suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
@ -213,11 +220,20 @@ func (suite *IDFOracleSuite) TestLocalCache() {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
} }
// update and sync snapshot make all sealed activate // some sealed not in target
invalidSealedSegs := []int64{5, 6}
for _, segID := range invalidSealedSegs {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
}
// register sealed segment and all preload to current
suite.Equal(int64(len(sealedSegs)+len(invalidSealedSegs)), suite.idfOracle.current.NumRow())
// update and sync snapshot make all sealed in target activate
suite.updateSnapshot(sealedSegs, []int64{}, []int64{}) suite.updateSnapshot(sealedSegs, []int64{}, []int64{})
suite.idfOracle.SetNext(suite.snapshot) suite.idfOracle.SetNext(suite.snapshot)
suite.waitTargetVersion(suite.targetVersion) suite.waitTargetVersion(suite.targetVersion)
suite.Equal(int64(4), suite.idfOracle.current.NumRow()) suite.Equal(int64(len(sealedSegs)), suite.idfOracle.current.NumRow())
suite.Require().Eventually(func() bool { suite.Require().Eventually(func() bool {
allInLocal := true allInLocal := true

View File

@ -19,6 +19,7 @@ package storage
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"io"
"maps" "maps"
"math" "math"
@ -419,6 +420,32 @@ func (m *BM25Stats) Serialize() ([]byte, error) {
return buffer.Bytes(), nil return buffer.Bytes(), nil
} }
func (m *BM25Stats) SerializeToWriter(w io.Writer) error {
if err := binary.Write(w, common.Endian, BM25VERSION); err != nil {
return err
}
if err := binary.Write(w, common.Endian, m.numRow); err != nil {
return err
}
if err := binary.Write(w, common.Endian, m.numToken); err != nil {
return err
}
for key, value := range m.rowsWithToken {
if err := binary.Write(w, common.Endian, key); err != nil {
return err
}
if err := binary.Write(w, common.Endian, value); err != nil {
return err
}
}
return nil
}
func (m *BM25Stats) Deserialize(bs []byte) error { func (m *BM25Stats) Deserialize(bs []byte) error {
buffer := bytes.NewBuffer(bs) buffer := bytes.NewBuffer(bs)
dim := (len(bs) - 20) / 8 dim := (len(bs) - 20) / 8