mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: support preload sealed segment bm25 stats and optimize bm25 stats serialize (#44279)
relate: https://github.com/milvus-io/milvus/issues/41424 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
ca1cc7c9f3
commit
78ee76f018
@ -17,15 +17,17 @@
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@ -106,40 +108,47 @@ type sealedBm25Stats struct {
|
||||
removed bool
|
||||
segmentID int64
|
||||
ts time.Time // Time of segemnt register, all segment resgister after target generate will don't remove
|
||||
localPath string
|
||||
localDir string
|
||||
fieldList []int64 // bm25 field list
|
||||
}
|
||||
|
||||
func (s *sealedBm25Stats) writeFile(localPath string) (error, bool) {
|
||||
func (s *sealedBm25Stats) writeFile(localDir string) (error, bool) {
|
||||
s.RLock()
|
||||
|
||||
if s.removed {
|
||||
if s.removed || !s.inmemory {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
m := make(map[int64][]byte, len(s.bm25Stats))
|
||||
stats := s.bm25Stats
|
||||
s.RUnlock()
|
||||
|
||||
err := os.MkdirAll(localDir, fs.ModePerm)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
// RUnlock when stats serialize and write to file
|
||||
// to avoid block remove stats too long when sync distribution
|
||||
for fieldID, stats := range stats {
|
||||
bytes, err := stats.Serialize()
|
||||
file, err := os.Create(path.Join(localDir, fmt.Sprintf("%d.data", fieldID)))
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
m[fieldID] = bytes
|
||||
defer file.Close()
|
||||
writer := bufio.NewWriter(file)
|
||||
|
||||
err = stats.SerializeToWriter(writer)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
err = writer.Flush()
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
err = os.WriteFile(localPath, b, 0o600)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@ -153,22 +162,25 @@ func (s *sealedBm25Stats) ShouldOffLoadToDisk() bool {
|
||||
// so that later when the segment is removed from target, we can Minus its stats. To reduce memory usage,
|
||||
// idfOracle store such per segment stats to disk, and load them when removing the segment.
|
||||
func (s *sealedBm25Stats) ToLocal(dirPath string) error {
|
||||
localpath := path.Join(dirPath, fmt.Sprintf("%d.data", s.segmentID))
|
||||
|
||||
if err, skip := s.writeFile(localpath); err != nil || skip {
|
||||
dir := path.Join(dirPath, fmt.Sprint(s.segmentID))
|
||||
if err, skip := s.writeFile(dir); err != nil {
|
||||
os.RemoveAll(dir)
|
||||
return err
|
||||
} else if skip {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.fieldList = lo.Keys(s.bm25Stats)
|
||||
s.inmemory = false
|
||||
s.bm25Stats = nil
|
||||
s.localPath = localpath
|
||||
s.localDir = dir
|
||||
|
||||
if s.removed {
|
||||
err := os.Remove(s.localPath)
|
||||
err := os.RemoveAll(s.localDir)
|
||||
if err != nil {
|
||||
log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localPath))
|
||||
log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localDir))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -180,9 +192,9 @@ func (s *sealedBm25Stats) Remove() {
|
||||
s.removed = true
|
||||
|
||||
if !s.inmemory {
|
||||
err := os.Remove(s.localPath)
|
||||
err := os.RemoveAll(s.localDir)
|
||||
if err != nil {
|
||||
log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localPath))
|
||||
log.Warn("remove local bm25 stats failed", zap.Error(err), zap.String("path", s.localDir))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -197,25 +209,21 @@ func (s *sealedBm25Stats) FetchStats() (map[int64]*storage.BM25Stats, error) {
|
||||
return s.bm25Stats, nil
|
||||
}
|
||||
|
||||
b, err := os.ReadFile(s.localPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := make(map[int64][]byte)
|
||||
err = json.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := make(map[int64]*storage.BM25Stats)
|
||||
for fieldID, bytes := range m {
|
||||
stats[fieldID] = storage.NewBM25Stats()
|
||||
err = stats[fieldID].Deserialize(bytes)
|
||||
for _, fieldID := range s.fieldList {
|
||||
path := path.Join(s.localDir, fmt.Sprintf("%d.data", fieldID))
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Newf("read local file %s: failed: %v", path, err)
|
||||
}
|
||||
|
||||
stats[fieldID] = storage.NewBM25Stats()
|
||||
err = stats[fieldID].Deserialize(b)
|
||||
if err != nil {
|
||||
return nil, errors.Newf("deserialize local file : %s failed: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@ -278,10 +286,29 @@ type idfOracle struct {
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// now only used for test
|
||||
func (o *idfOracle) TargetVersion() int64 {
|
||||
return o.targetVersion.Load()
|
||||
}
|
||||
|
||||
func (o *idfOracle) preloadSealed(segmentID int64, stats bm25Stats) {
|
||||
o.Lock()
|
||||
defer o.Unlock()
|
||||
|
||||
// skip preload if first target was loaded.
|
||||
if o.targetVersion.Load() != 0 {
|
||||
return
|
||||
}
|
||||
o.sealed.Insert(segmentID, &sealedBm25Stats{
|
||||
bm25Stats: stats,
|
||||
ts: time.Now(),
|
||||
activate: atomic.NewBool(true),
|
||||
inmemory: true,
|
||||
segmentID: segmentID,
|
||||
})
|
||||
o.current.Merge(stats)
|
||||
}
|
||||
|
||||
func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.SegmentState) {
|
||||
switch state {
|
||||
case segments.SegmentTypeGrowing:
|
||||
@ -300,13 +327,19 @@ func (o *idfOracle) Register(segmentID int64, stats bm25Stats, state commonpb.Se
|
||||
if ok := o.sealed.Contain(segmentID); ok {
|
||||
return
|
||||
}
|
||||
o.sealed.Insert(segmentID, &sealedBm25Stats{
|
||||
bm25Stats: stats,
|
||||
ts: time.Now(),
|
||||
activate: atomic.NewBool(false),
|
||||
inmemory: true,
|
||||
segmentID: segmentID,
|
||||
})
|
||||
|
||||
// preload sealed segment to channel before first target
|
||||
if o.targetVersion.Load() == 0 {
|
||||
o.preloadSealed(segmentID, stats)
|
||||
} else {
|
||||
o.sealed.Insert(segmentID, &sealedBm25Stats{
|
||||
bm25Stats: stats,
|
||||
ts: time.Now(),
|
||||
activate: atomic.NewBool(false),
|
||||
inmemory: true,
|
||||
segmentID: segmentID,
|
||||
})
|
||||
}
|
||||
default:
|
||||
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
|
||||
return
|
||||
@ -447,6 +480,7 @@ func (o *idfOracle) localloop() {
|
||||
}
|
||||
|
||||
// WARN: SyncDistribution not concurrent safe.
|
||||
// SyncDistribution sync current target to idf oracle.
|
||||
func (o *idfOracle) SyncDistribution() error {
|
||||
snapshot, snapshotTs := o.next.GetSnapshot()
|
||||
if snapshot.targetVersion <= o.targetVersion.Load() {
|
||||
@ -457,18 +491,20 @@ func (o *idfOracle) SyncDistribution() error {
|
||||
|
||||
sealedMap := map[int64]bool{} // sealed diff map, activate segment stats if true, and remove if not in map
|
||||
|
||||
// only remain current target segment and unknown version segment in snapshot.
|
||||
for _, item := range sealed {
|
||||
for _, segment := range item.Segments {
|
||||
if segment.Level == datapb.SegmentLevel_L0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if segment.TargetVersion == snapshot.targetVersion {
|
||||
switch segment.TargetVersion {
|
||||
case snapshot.targetVersion:
|
||||
sealedMap[segment.SegmentID] = true
|
||||
if !o.sealed.Contain(segment.SegmentID) {
|
||||
log.Warn("idf oracle lack some sealed segment", zap.Int64("segment", segment.SegmentID))
|
||||
}
|
||||
} else if segment.TargetVersion == unreadableTargetVersion {
|
||||
case unreadableTargetVersion:
|
||||
sealedMap[segment.SegmentID] = false
|
||||
}
|
||||
}
|
||||
@ -478,19 +514,24 @@ func (o *idfOracle) SyncDistribution() error {
|
||||
|
||||
var rangeErr error
|
||||
o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool {
|
||||
activate, ok := sealedMap[segmentID]
|
||||
statsActivate := stats.activate.Load()
|
||||
if ok && activate && !statsActivate {
|
||||
// segment was unreadable if in snapshot but not in target.
|
||||
intarget, insnap := sealedMap[segmentID]
|
||||
activate := stats.activate.Load()
|
||||
// activate segment if segment in target
|
||||
if insnap && intarget && !activate {
|
||||
stats, err := stats.FetchStats()
|
||||
if err != nil {
|
||||
rangeErr = err
|
||||
rangeErr = fmt.Errorf("fetch stats failed with error: %v", err)
|
||||
return false
|
||||
}
|
||||
diff.Merge(stats)
|
||||
} else if !ok && statsActivate {
|
||||
} else
|
||||
// deactivate segment if segment not in snapshot
|
||||
// or deactivate segment if segment unreadable (only exist at preload segment)
|
||||
if (!insnap || (insnap && !intarget)) && activate {
|
||||
stats, err := stats.FetchStats()
|
||||
if err != nil {
|
||||
rangeErr = err
|
||||
rangeErr = fmt.Errorf("fetch stats failed with error: %v", err)
|
||||
return false
|
||||
}
|
||||
diff.Minus(stats)
|
||||
@ -518,22 +559,30 @@ func (o *idfOracle) SyncDistribution() error {
|
||||
|
||||
// remove sealed segment not in target
|
||||
o.sealed.Range(func(segmentID int64, stats *sealedBm25Stats) bool {
|
||||
activate, ok := sealedMap[segmentID]
|
||||
statsActivate := stats.activate.Load()
|
||||
if !ok && stats.ts.Before(snapshotTs) {
|
||||
intarget, insnap := sealedMap[segmentID]
|
||||
activate := stats.activate.Load()
|
||||
// remove if segment not in snapshot
|
||||
// and add before snapshot
|
||||
if !insnap && stats.ts.Before(snapshotTs) {
|
||||
stats.Remove()
|
||||
o.sealed.Remove(segmentID)
|
||||
}
|
||||
|
||||
if ok && activate && !statsActivate {
|
||||
// save activate if segment in target.
|
||||
if insnap && intarget && !activate {
|
||||
stats.activate.Store(true)
|
||||
}
|
||||
|
||||
// deactivate if segment unreadable.
|
||||
if insnap && !intarget && activate {
|
||||
stats.activate.Store(false)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
o.targetVersion.Store(snapshot.targetVersion)
|
||||
o.NotifyLocal()
|
||||
log.Ctx(context.TODO()).Info("sync distribution finished", zap.Int64("version", snapshot.targetVersion), zap.Int64("numrow", o.current.NumRow()), zap.Int("growing", len(o.growing)), zap.Int("sealed", o.sealed.Len()))
|
||||
log.Ctx(context.TODO()).Info("sync idf distribution finished", zap.Int64("version", snapshot.targetVersion), zap.Int64("numrow", o.current.NumRow()), zap.Int("growing", len(o.growing)), zap.Int("sealed", o.sealed.Len()))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -136,14 +136,21 @@ func (suite *IDFOracleSuite) TestSealed() {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
|
||||
// register sealed segment but all deactvate
|
||||
suite.Zero(suite.idfOracle.current.NumRow())
|
||||
// some sealed not in target
|
||||
invalidSealedSegs := []int64{5, 6}
|
||||
for _, segID := range invalidSealedSegs {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
|
||||
// update and sync snapshot make all sealed activate
|
||||
// register sealed segment and all preload to current
|
||||
suite.Equal(int64(len(sealedSegs)+len(invalidSealedSegs)), suite.idfOracle.current.NumRow())
|
||||
|
||||
// update and sync snapshot make all sealed in target activate
|
||||
// and invalid sealed segemnt deactivate
|
||||
suite.updateSnapshot(sealedSegs, []int64{}, []int64{})
|
||||
suite.idfOracle.SetNext(suite.snapshot)
|
||||
suite.waitTargetVersion(suite.targetVersion)
|
||||
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
|
||||
suite.Equal(int64(len(sealedSegs)), suite.idfOracle.current.NumRow())
|
||||
|
||||
releasedSeg := []int64{1, 2, 3}
|
||||
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
|
||||
@ -213,11 +220,20 @@ func (suite *IDFOracleSuite) TestLocalCache() {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
|
||||
// update and sync snapshot make all sealed activate
|
||||
// some sealed not in target
|
||||
invalidSealedSegs := []int64{5, 6}
|
||||
for _, segID := range invalidSealedSegs {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
|
||||
// register sealed segment and all preload to current
|
||||
suite.Equal(int64(len(sealedSegs)+len(invalidSealedSegs)), suite.idfOracle.current.NumRow())
|
||||
|
||||
// update and sync snapshot make all sealed in target activate
|
||||
suite.updateSnapshot(sealedSegs, []int64{}, []int64{})
|
||||
suite.idfOracle.SetNext(suite.snapshot)
|
||||
suite.waitTargetVersion(suite.targetVersion)
|
||||
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
|
||||
suite.Equal(int64(len(sealedSegs)), suite.idfOracle.current.NumRow())
|
||||
|
||||
suite.Require().Eventually(func() bool {
|
||||
allInLocal := true
|
||||
|
||||
@ -19,6 +19,7 @@ package storage
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"maps"
|
||||
"math"
|
||||
|
||||
@ -419,6 +420,32 @@ func (m *BM25Stats) Serialize() ([]byte, error) {
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *BM25Stats) SerializeToWriter(w io.Writer) error {
|
||||
if err := binary.Write(w, common.Endian, BM25VERSION); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, common.Endian, m.numRow); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, common.Endian, m.numToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for key, value := range m.rowsWithToken {
|
||||
if err := binary.Write(w, common.Endian, key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, common.Endian, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BM25Stats) Deserialize(bs []byte) error {
|
||||
buffer := bytes.NewBuffer(bs)
|
||||
dim := (len(bs) - 20) / 8
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user