fix: avoid panic when load segment with pkoracle and idforacle already exist (#36959)

relate: https://github.com/milvus-io/milvus/issues/36949

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2024-10-18 11:57:24 +08:00 committed by GitHub
parent 50da48a30d
commit fbe177d6e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 13 deletions

View File

@ -913,10 +913,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
excludedSegments := NewExcludedSegments(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.GetAsDuration(time.Second)) excludedSegments := NewExcludedSegments(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.GetAsDuration(time.Second))
var idfOracle IDFOracle
if len(collection.Schema().GetFunctions()) > 0 {
idfOracle = NewIDFOracle(collection.Schema().GetFunctions())
}
sd := &shardDelegator{ sd := &shardDelegator{
collectionID: collectionID, collectionID: collectionID,
replicaID: replicaID, replicaID: replicaID,
@ -926,7 +922,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
segmentManager: manager.Segment, segmentManager: manager.Segment,
workerManager: workerManager, workerManager: workerManager,
lifetime: lifetime.NewLifetime(lifetime.Initializing), lifetime: lifetime.NewLifetime(lifetime.Initializing),
distribution: NewDistribution(idfOracle), distribution: NewDistribution(),
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock), deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock),
pkOracle: pkoracle.NewPkOracle(), pkOracle: pkoracle.NewPkOracle(),
tsafeManager: tsafeManager, tsafeManager: tsafeManager,
@ -935,7 +931,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
factory: factory, factory: factory,
queryHook: queryHook, queryHook: queryHook,
chunkManager: chunkManager, chunkManager: chunkManager,
idfOracle: idfOracle,
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot), partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
excludedSegments: excludedSegments, excludedSegments: excludedSegments,
functionRunners: make(map[int64]function.FunctionRunner), functionRunners: make(map[int64]function.FunctionRunner),
@ -955,6 +950,11 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
} }
} }
if len(sd.isBM25Field) > 0 {
sd.idfOracle = NewIDFOracle(collection.Schema().GetFunctions())
sd.distribution.SetIDFOracle(sd.idfOracle)
}
m := sync.Mutex{} m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m) sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(lifetime.NotStopped) == nil { if sd.lifetime.Add(lifetime.NotStopped) == nil {

View File

@ -487,7 +487,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
}) })
var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats] var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
if len(sd.isBM25Field) > 0 { if sd.idfOracle != nil {
bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...) bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...)
if err != nil { if err != nil {
log.Warn("failed to load bm25 stats for segment", zap.Error(err)) log.Warn("failed to load bm25 stats for segment", zap.Error(err))
@ -690,8 +690,11 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
sd.pkOracle.Register(candidate, targetNodeID) sd.pkOracle.Register(candidate, targetNodeID)
} }
if sd.idfOracle != nil { if sd.idfOracle != nil && bm25Stats != nil {
bm25Stats.Range(func(segmentID int64, stats map[int64]*storage.BM25Stats) bool { bm25Stats.Range(func(segmentID int64, stats map[int64]*storage.BM25Stats) bool {
log.Info("register sealed segment bm25 stats into idforacle",
zap.Int64("segmentID", segmentID),
)
sd.idfOracle.Register(segmentID, stats, segments.SegmentTypeSealed) sd.idfOracle.Register(segmentID, stats, segments.SegmentTypeSealed)
return false return false
}) })

View File

@ -91,7 +91,7 @@ type SegmentEntry struct {
} }
// NewDistribution creates a new distribution instance with all field initialized. // NewDistribution creates a new distribution instance with all field initialized.
func NewDistribution(idfOracle IDFOracle) *distribution { func NewDistribution() *distribution {
dist := &distribution{ dist := &distribution{
serviceable: atomic.NewBool(false), serviceable: atomic.NewBool(false),
growingSegments: make(map[UniqueID]SegmentEntry), growingSegments: make(map[UniqueID]SegmentEntry),
@ -100,13 +100,18 @@ func NewDistribution(idfOracle IDFOracle) *distribution {
current: atomic.NewPointer[snapshot](nil), current: atomic.NewPointer[snapshot](nil),
offlines: typeutil.NewSet[int64](), offlines: typeutil.NewSet[int64](),
targetVersion: atomic.NewInt64(initialTargetVersion), targetVersion: atomic.NewInt64(initialTargetVersion),
idfOracle: idfOracle,
} }
dist.genSnapshot() dist.genSnapshot()
return dist return dist
} }
func (d *distribution) SetIDFOracle(idfOracle IDFOracle) {
d.mut.Lock()
defer d.mut.Unlock()
d.idfOracle = idfOracle
}
func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) { func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) {
d.mut.RLock() d.mut.RLock()
defer d.mut.RUnlock() defer d.mut.RUnlock()

View File

@ -21,8 +21,6 @@ import (
"time" "time"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
) )
type DistributionSuite struct { type DistributionSuite struct {
@ -31,7 +29,7 @@ type DistributionSuite struct {
} }
func (s *DistributionSuite) SetupTest() { func (s *DistributionSuite) SetupTest() {
s.dist = NewDistribution(NewIDFOracle([]*schemapb.FunctionSchema{})) s.dist = NewDistribution()
s.Equal(initialTargetVersion, s.dist.getTargetVersion()) s.Equal(initialTargetVersion, s.dist.getTargetVersion())
} }