diff --git a/internal/querynodev2/delegator/idf_oracle.go b/internal/querynodev2/delegator/idf_oracle.go index 17a028bfac..367d6f640c 100644 --- a/internal/querynodev2/delegator/idf_oracle.go +++ b/internal/querynodev2/delegator/idf_oracle.go @@ -17,15 +17,17 @@ package delegator import ( + "bufio" "context" - "encoding/json" "fmt" + "io/fs" "os" "path" "sync" "time" "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" @@ -106,40 +108,47 @@ type sealedBm25Stats struct { removed bool segmentID int64 ts time.Time // Time of segemnt register, all segment resgister after target generate will don't remove - localPath string + localDir string + fieldList []int64 // bm25 field list } -func (s *sealedBm25Stats) writeFile(localPath string) (error, bool) { +func (s *sealedBm25Stats) writeFile(localDir string) (error, bool) { s.RLock() - if s.removed { + if s.removed || !s.inmemory { return nil, true } - m := make(map[int64][]byte, len(s.bm25Stats)) stats := s.bm25Stats s.RUnlock() + err := os.MkdirAll(localDir, fs.ModePerm) + if err != nil { + return err, false + } + // RUnlock when stats serialize and write to file // to avoid block remove stats too long when sync distribution for fieldID, stats := range stats { - bytes, err := stats.Serialize() + file, err := os.Create(path.Join(localDir, fmt.Sprintf("%d.data", fieldID))) if err != nil { return err, false } - m[fieldID] = bytes + defer file.Close() + writer := bufio.NewWriter(file) + + err = stats.SerializeToWriter(writer) + if err != nil { + return err, false + } + + err = writer.Flush() + if err != nil { + return err, false + } } - b, err := json.Marshal(m) - if err != nil { - return err, false - } - - err = os.WriteFile(localPath, b, 0o600) - if err != nil { - return err, false - } return nil, false } @@ -153,22 +162,25 @@ func (s *sealedBm25Stats) ShouldOffLoadToDisk() bool { // so that later when the segment is removed from target, we can Minus its stats. To reduce memory usage, // idfOracle store such per segment stats to disk, and load them when removing the segment. func (s *sealedBm25Stats) ToLocal(dirPath string) error { - localpath := path.Join(dirPath, fmt.Sprintf("%d.data", s.segmentID)) - - if err, skip := s.writeFile(localpath); err != nil || skip { + dir := path.Join(dirPath, fmt.Sprint(s.segmentID)) + if err, skip := s.writeFile(dir); err != nil { + os.RemoveAll(dir) return err + } else if skip { + return nil } s.Lock() defer s.Unlock() + s.fieldList = lo.Keys(s.bm25Stats) s.inmemory = false s.bm25Stats = nil - s.localPath = localpath + s.localDir = dir if s.removed { - err := os.Remove(s.localPath) + err := os.RemoveAll(s.localDir) if err != nil { - log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localPath)) + log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localDir)) } } return nil @@ -180,9 +192,9 @@ func (s *sealedBm25Stats) Remove() { s.removed = true if !s.inmemory { - err := os.Remove(s.localPath) + err := os.RemoveAll(s.localDir) if err != nil { - log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localPath)) + log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localDir)) } } } @@ -197,25 +209,21 @@ func (s *sealedBm25Stats) FetchStats() (map[int64]*storage.BM25Stats, error) { return s.bm25Stats, nil } - b, err := os.ReadFile(s.localPath) - if err != nil { - return nil, err - } - - m := make(map[int64][]byte) - err = json.Unmarshal(b, &m) - if err != nil { - return nil, err - } - stats := make(map[int64]*storage.BM25Stats) - for fieldID, bytes := range m { - stats[fieldID] = storage.NewBM25Stats() - err = stats[fieldID].Deserialize(bytes) + for _, fieldID := range s.fieldList { + path := path.Join(s.localDir, fmt.Sprintf("%d.data", fieldID)) + b, err := os.ReadFile(path) if err != nil { - return nil, err + return nil, errors.Newf("read local file %s: failed: %v", path, err) + } + + stats[fieldID] = storage.NewBM25Stats() + err = stats[fieldID].Deserialize(b) + if err != nil { + return nil, errors.Newf("deserialize local file : %s failed: %v", path, err) } } + return stats, nil } @@ -278,10 +286,29 @@ type idfOracle struct { wg sync.WaitGroup } +// now only used for test func (o *idfOracle) TargetVersion() int64 { return o.targetVersion.Load() } +func (o *idfOracle) preloadSealed(segmentID int64, stats bm25Stats) { + o.Lock() + defer o.Unlock() + + // skip preload if first target was loaded. + if o.targetVersion.Load() != 0 { + return + } + o.sealed.Insert(segmentID, &sealedBm25Stats{ + bm25Stats: stats, + ts: time.Now(), + activate: atomic.NewBool(true), + inmemory: true, + segmentID: segmentID, + }) + o.current.Merge(stats) +} + func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.SegmentState) { switch state { case segments.SegmentTypeGrowing: @@ -300,13 +327,19 @@ func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.Se if ok := o.sealed.Contain(segmentID); ok { return } - o.sealed.Insert(segmentID, &sealedBm25Stats{ - bm25Stats: stats, - ts: time.Now(), - activate: atomic.NewBool(false), - inmemory: true, - segmentID: segmentID, - }) + + // preload sealed segment to channel before first target + if o.targetVersion.Load() == 0 { + o.preloadSealed(segmentID, stats) + } else { + o.sealed.Insert(segmentID, &sealedBm25Stats{ + bm25Stats: stats, + ts: time.Now(), + activate: atomic.NewBool(false), + inmemory: true, + segmentID: segmentID, + }) + } default: log.Warn("register segment with unknown state", zap.String("stats", state.String())) return @@ -447,6 +480,7 @@ func (o *idfOracle) localloop() { } // WARN: SyncDistribution not concurrent safe. +// SyncDistribution sync current target to idf oracle. func (o *idfOracle) SyncDistribution() error { snapshot, snapshotTs := o.next.GetSnapshot() if snapshot.targetVersion <= o.targetVersion.Load() { @@ -457,18 +491,20 @@ func (o *idfOracle) SyncDistribution() error { sealedMap := map[int64]bool{} // sealed diff map, activate segment stats if true, and remove if not in map + // only remain current target segment and unknown version segment in snapshot. for _, item := range sealed { for _, segment := range item.Segments { if segment.Level == datapb.SegmentLevel_L0 { continue } - if segment.TargetVersion == snapshot.targetVersion { + switch segment.TargetVersion { + case snapshot.targetVersion: sealedMap[segment.SegmentID] = true if !o.sealed.Contain(segment.SegmentID) { log.Warn("idf oracle lack some sealed segment", zap.Int64("segment", segment.SegmentID)) } - } else if segment.TargetVersion == unreadableTargetVersion { + case unreadableTargetVersion: sealedMap[segment.SegmentID] = false } } @@ -478,19 +514,24 @@ func (o *idfOracle) SyncDistribution() error { var rangeErr error o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool { - activate, ok := sealedMap[segmentID] - statsActivate := stats.activate.Load() - if ok && activate && !statsActivate { + // segment was unreadable if in snapshot but not in target. + intarget, insnap := sealedMap[segmentID] + activate := stats.activate.Load() + // activate segment if segment in target + if insnap && intarget && !activate { stats, err := stats.FetchStats() if err != nil { - rangeErr = err + rangeErr = fmt.Errorf("fetch stats failed with error: %v", err) return false } diff.Merge(stats) - } else if !ok && statsActivate { + } else + // deactivate segment if segment not in snapshot + // or deactivate segment if segment unreadable (only exist at preload segment) + if (!insnap || (insnap && !intarget)) && activate { stats, err := stats.FetchStats() if err != nil { - rangeErr = err + rangeErr = fmt.Errorf("fetch stats failed with error: %v", err) return false } diff.Minus(stats) @@ -518,22 +559,30 @@ func (o *idfOracle) SyncDistribution() error { // remove sealed segment not in target o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool { - activate, ok := sealedMap[segmentID] - statsActivate := stats.activate.Load() - if !ok && stats.ts.Before(snapshotTs) { + intarget, insnap := sealedMap[segmentID] + activate := stats.activate.Load() + // remove if segment not in snapshot + // and add before snapshot + if !insnap && stats.ts.Before(snapshotTs) { stats.Remove() o.sealed.Remove(segmentID) } - if ok && activate && !statsActivate { + // save activate if segment in target. + if insnap && intarget && !activate { stats.activate.Store(true) } + + // deactivate if segment unreadable. + if insnap && !intarget && activate { + stats.activate.Store(false) + } return true }) o.targetVersion.Store(snapshot.targetVersion) o.NotifyLocal() - log.Ctx(context.TODO()).Info("sync distribution finished", zap.Int64("version", snapshot.targetVersion), zap.Int64("numrow", o.current.NumRow()), zap.Int("growing", len(o.growing)), zap.Int("sealed", o.sealed.Len())) + log.Ctx(context.TODO()).Info("sync idf distribution finished", zap.Int64("version", snapshot.targetVersion), zap.Int64("numrow", o.current.NumRow()), zap.Int("growing", len(o.growing)), zap.Int("sealed", o.sealed.Len())) return nil } diff --git a/internal/querynodev2/delegator/idf_oracle_test.go b/internal/querynodev2/delegator/idf_oracle_test.go index b5e6bd7b75..ee19d99973 100644 --- a/internal/querynodev2/delegator/idf_oracle_test.go +++ b/internal/querynodev2/delegator/idf_oracle_test.go @@ -136,14 +136,21 @@ func (suite *IDFOracleSuite) TestSealed() { suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) } - // register sealed segment but all deactvate - suite.Zero(suite.idfOracle.current.NumRow()) + // some sealed not in target + invalidSealedSegs := []int64{5, 6} + for _, segID := range invalidSealedSegs { + suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) + } - // update and sync snapshot make all sealed activate + // register sealed segment and all preload to current + suite.Equal(int64(len(sealedSegs)+len(invalidSealedSegs)), suite.idfOracle.current.NumRow()) + + // update and sync snapshot make all sealed in target activate + // and invalid sealed segemnt deactivate suite.updateSnapshot(sealedSegs, []int64{}, []int64{}) suite.idfOracle.SetNext(suite.snapshot) suite.waitTargetVersion(suite.targetVersion) - suite.Equal(int64(4), suite.idfOracle.current.NumRow()) + suite.Equal(int64(len(sealedSegs)), suite.idfOracle.current.NumRow()) releasedSeg := []int64{1, 2, 3} suite.updateSnapshot([]int64{}, []int64{}, releasedSeg) @@ -213,11 +220,20 @@ func (suite *IDFOracleSuite) TestLocalCache() { suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) } - // update and sync snapshot make all sealed activate + // some sealed not in target + invalidSealedSegs := []int64{5, 6} + for _, segID := range invalidSealedSegs { + suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed) + } + + // register sealed segment and all preload to current + suite.Equal(int64(len(sealedSegs)+len(invalidSealedSegs)), suite.idfOracle.current.NumRow()) + + // update and sync snapshot make all sealed in target activate suite.updateSnapshot(sealedSegs, []int64{}, []int64{}) suite.idfOracle.SetNext(suite.snapshot) suite.waitTargetVersion(suite.targetVersion) - suite.Equal(int64(4), suite.idfOracle.current.NumRow()) + suite.Equal(int64(len(sealedSegs)), suite.idfOracle.current.NumRow()) suite.Require().Eventually(func() bool { allInLocal := true diff --git a/internal/storage/stats.go b/internal/storage/stats.go index 42c365105f..d02a8f0c09 100644 --- a/internal/storage/stats.go +++ b/internal/storage/stats.go @@ -19,6 +19,7 @@ package storage import ( "bytes" "encoding/binary" + "io" "maps" "math" @@ -419,6 +420,32 @@ func (m *BM25Stats) Serialize() ([]byte, error) { return buffer.Bytes(), nil } +func (m *BM25Stats) SerializeToWriter(w io.Writer) error { + if err := binary.Write(w, common.Endian, BM25VERSION); err != nil { + return err + } + + if err := binary.Write(w, common.Endian, m.numRow); err != nil { + return err + } + + if err := binary.Write(w, common.Endian, m.numToken); err != nil { + return err + } + + for key, value := range m.rowsWithToken { + if err := binary.Write(w, common.Endian, key); err != nil { + return err + } + + if err := binary.Write(w, common.Endian, value); err != nil { + return err + } + } + + return nil +} + func (m *BM25Stats) Deserialize(bs []byte) error { buffer := bytes.NewBuffer(bs) dim := (len(bs) - 20) / 8