fix: more stable recovery graceful closing and stable unittest (#42013)

issue: #41544

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-05-23 17:52:26 +08:00 committed by GitHub
parent 252d49d01e
commit 38c804fb01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 97 additions and 56 deletions

View File

@ -1027,7 +1027,7 @@ func (sd *shardDelegator) SyncTargetVersion(
} }
func (sd *shardDelegator) GetQueryView() *channelQueryView { func (sd *shardDelegator) GetQueryView() *channelQueryView {
return sd.distribution.queryView return sd.distribution.GetQueryView()
} }
func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) { func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {

View File

@ -363,6 +363,14 @@ func (d *distribution) SyncTargetVersion(newVersion int64, partitions []int64, g
) )
} }
// GetQueryView returns the current query view.
func (d *distribution) GetQueryView() *channelQueryView {
d.mut.RLock()
defer d.mut.RUnlock()
return d.queryView
}
// RemoveDistributions remove segments distributions and returns the clear signal channel. // RemoveDistributions remove segments distributions and returns the clear signal channel.
func (d *distribution) RemoveDistributions(sealedSegments []SegmentEntry, growingSegments []SegmentEntry) chan struct{} { func (d *distribution) RemoveDistributions(sealedSegments []SegmentEntry, growingSegments []SegmentEntry) chan struct{} {
d.mut.Lock() d.mut.Lock()

View File

@ -20,6 +20,17 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
) )
// isDirty checks if the recovery storage mem state is not consistent with the persisted recovery storage.
func (rs *recoveryStorageImpl) isDirty() bool {
if rs.pendingPersistSnapshot != nil {
return true
}
rs.mu.Lock()
defer rs.mu.Unlock()
return rs.dirtyCounter > 0
}
// TODO: !!! all recovery persist operation should be a compare-and-swap operation to // TODO: !!! all recovery persist operation should be a compare-and-swap operation to
// promise there's only one consumer of wal. // promise there's only one consumer of wal.
// But currently, we don't implement the CAS operation of meta interface. // But currently, we don't implement the CAS operation of meta interface.
@ -28,6 +39,10 @@ func (rs *recoveryStorageImpl) backgroundTask() {
ticker := time.NewTicker(rs.cfg.persistInterval) ticker := time.NewTicker(rs.cfg.persistInterval)
defer func() { defer func() {
ticker.Stop() ticker.Stop()
rs.Logger().Info("recovery storage background task, perform a graceful exit...")
if err := rs.persistDritySnapshotWhenClosing(); err != nil {
rs.Logger().Warn("failed to persist dirty snapshot when closing", zap.Error(err))
}
rs.backgroundTaskNotifier.Finish(struct{}{}) rs.backgroundTaskNotifier.Finish(struct{}{})
rs.Logger().Info("recovery storage background task exit") rs.Logger().Info("recovery storage background task exit")
}() }()
@ -35,20 +50,11 @@ func (rs *recoveryStorageImpl) backgroundTask() {
for { for {
select { select {
case <-rs.backgroundTaskNotifier.Context().Done(): case <-rs.backgroundTaskNotifier.Context().Done():
// If the background task is exiting when on-operating persist operation,
// We can try to do a graceful exit.
rs.Logger().Info("recovery storage background task, perform a graceful exit...")
if err := rs.persistDritySnapshotWhenClosing(); err != nil {
rs.Logger().Warn("failed to persist dirty snapshot when closing", zap.Error(err))
return return
}
rs.gracefulClosed = true
return // exit the background task
case <-rs.persistNotifier: case <-rs.persistNotifier:
case <-ticker.C: case <-ticker.C:
} }
snapshot := rs.consumeDirtySnapshot() if err := rs.persistDirtySnapshot(rs.backgroundTaskNotifier.Context(), zap.DebugLevel); err != nil {
if err := rs.persistDirtySnapshot(rs.backgroundTaskNotifier.Context(), snapshot, zap.DebugLevel); err != nil {
return return
} }
} }
@ -59,12 +65,26 @@ func (rs *recoveryStorageImpl) persistDritySnapshotWhenClosing() error {
ctx, cancel := context.WithTimeout(context.Background(), rs.cfg.gracefulTimeout) ctx, cancel := context.WithTimeout(context.Background(), rs.cfg.gracefulTimeout)
defer cancel() defer cancel()
snapshot := rs.consumeDirtySnapshot() for rs.isDirty() {
return rs.persistDirtySnapshot(ctx, snapshot, zap.InfoLevel) if err := rs.persistDirtySnapshot(ctx, zap.InfoLevel); err != nil {
return err
}
}
rs.gracefulClosed = true
return nil
} }
// persistDirtySnapshot persists the dirty snapshot to the catalog. // persistDirtySnapshot persists the dirty snapshot to the catalog.
func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, snapshot *RecoverySnapshot, lvl zapcore.Level) (err error) { func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, lvl zapcore.Level) (err error) {
if rs.pendingPersistSnapshot == nil {
// if there's no dirty snapshot, generate a new one.
rs.pendingPersistSnapshot = rs.consumeDirtySnapshot()
}
if rs.pendingPersistSnapshot == nil {
return nil
}
snapshot := rs.pendingPersistSnapshot
rs.metrics.ObserveIsOnPersisting(true) rs.metrics.ObserveIsOnPersisting(true)
logger := rs.Logger().With( logger := rs.Logger().With(
zap.String("checkpoint", snapshot.Checkpoint.MessageID.String()), zap.String("checkpoint", snapshot.Checkpoint.MessageID.String()),
@ -77,8 +97,9 @@ func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, snapsho
logger.Warn("failed to persist dirty snapshot", zap.Error(err)) logger.Warn("failed to persist dirty snapshot", zap.Error(err))
return return
} }
rs.pendingPersistSnapshot = nil
logger.Log(lvl, "persist dirty snapshot") logger.Log(lvl, "persist dirty snapshot")
defer rs.metrics.ObserveIsOnPersisting(false) rs.metrics.ObserveIsOnPersisting(false)
}() }()
if err := rs.dropAllVirtualChannel(ctx, snapshot.VChannels); err != nil { if err := rs.dropAllVirtualChannel(ctx, snapshot.VChannels); err != nil {

View File

@ -89,6 +89,7 @@ type recoveryStorageImpl struct {
gracefulClosed bool gracefulClosed bool
truncator *samplingTruncator truncator *samplingTruncator
metrics *recoveryMetrics metrics *recoveryMetrics
pendingPersistSnapshot *RecoverySnapshot
} }
// UpdateFlusherCheckpoint updates the checkpoint of flusher. // UpdateFlusherCheckpoint updates the checkpoint of flusher.
@ -134,6 +135,9 @@ func (r *recoveryStorageImpl) notifyPersist() {
func (r *recoveryStorageImpl) consumeDirtySnapshot() *RecoverySnapshot { func (r *recoveryStorageImpl) consumeDirtySnapshot() *RecoverySnapshot {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if r.dirtyCounter == 0 {
return nil
}
segments := make(map[int64]*streamingpb.SegmentAssignmentMeta) segments := make(map[int64]*streamingpb.SegmentAssignmentMeta)
vchannels := make(map[string]*streamingpb.VChannelMeta) vchannels := make(map[string]*streamingpb.VChannelMeta)
@ -178,16 +182,11 @@ func (r *recoveryStorageImpl) observeMessage(msg message.ImmutableMessage) {
} }
r.handleMessage(msg) r.handleMessage(msg)
checkpointUpdates := !r.checkpoint.MessageID.EQ(msg.LastConfirmedMessageID())
r.checkpoint.TimeTick = msg.TimeTick() r.checkpoint.TimeTick = msg.TimeTick()
r.checkpoint.MessageID = msg.LastConfirmedMessageID() r.checkpoint.MessageID = msg.LastConfirmedMessageID()
r.metrics.ObServeInMemMetrics(r.checkpoint.TimeTick) r.metrics.ObServeInMemMetrics(r.checkpoint.TimeTick)
if checkpointUpdates {
// only count the dirty if last confirmed message id is updated.
// we always recover from that point, the writeaheadtimetick is just a redundant information.
r.dirtyCounter++ r.dirtyCounter++
}
if r.dirtyCounter > r.cfg.maxDirtyMessages { if r.dirtyCounter > r.cfg.maxDirtyMessages {
r.notifyPersist() r.notifyPersist()
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
@ -45,15 +46,17 @@ func TestRecoveryStorage(t *testing.T) {
snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t) snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t)
snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.SegmentAssignmentMeta, error) { snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.SegmentAssignmentMeta, error) {
return lo.Values(segmentMetas), nil return lo.MapToSlice(segmentMetas, func(_ int64, v *streamingpb.SegmentAssignmentMeta) *streamingpb.SegmentAssignmentMeta {
return proto.Clone(v).(*streamingpb.SegmentAssignmentMeta)
}), nil
}) })
snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.VChannelMeta, error) { snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.VChannelMeta, error) {
return lo.Values(vchannelMetas), nil return lo.MapToSlice(vchannelMetas, func(_ string, v *streamingpb.VChannelMeta) *streamingpb.VChannelMeta {
return proto.Clone(v).(*streamingpb.VChannelMeta)
}), nil
}) })
segmentSaveFailure := true
snCatalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[int64]*streamingpb.SegmentAssignmentMeta) error { snCatalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[int64]*streamingpb.SegmentAssignmentMeta) error {
if segmentSaveFailure { if rand.Int31n(3) == 0 {
segmentSaveFailure = false
return errors.New("save failed") return errors.New("save failed")
} }
for _, v := range m { for _, v := range m {
@ -65,10 +68,8 @@ func TestRecoveryStorage(t *testing.T) {
} }
return nil return nil
}) })
vchannelSaveFailure := true
snCatalog.EXPECT().SaveVChannels(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[string]*streamingpb.VChannelMeta) error { snCatalog.EXPECT().SaveVChannels(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[string]*streamingpb.VChannelMeta) error {
if vchannelSaveFailure { if rand.Int31n(3) == 0 {
vchannelSaveFailure = false
return errors.New("save failed") return errors.New("save failed")
} }
for _, v := range m { for _, v := range m {
@ -81,10 +82,8 @@ func TestRecoveryStorage(t *testing.T) {
return nil return nil
}) })
snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(cp, nil) snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(cp, nil)
checkpointSaveFailure := true
snCatalog.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pchannelName string, checkpoint *streamingpb.WALCheckpoint) error { snCatalog.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pchannelName string, checkpoint *streamingpb.WALCheckpoint) error {
if checkpointSaveFailure { if rand.Int31n(3) == 0 {
checkpointSaveFailure = false
return errors.New("save failed") return errors.New("save failed")
} }
cp = checkpoint cp = checkpoint
@ -151,10 +150,21 @@ func TestRecoveryStorage(t *testing.T) {
assert.Equal(t, partitionNum, b.partitionNum()) assert.Equal(t, partitionNum, b.partitionNum())
assert.Equal(t, collectionNum, b.collectionNum()) assert.Equal(t, collectionNum, b.collectionNum())
assert.Equal(t, segmentNum, b.segmentNum()) assert.Equal(t, segmentNum, b.segmentNum())
if b.segmentNum() != segmentNum {
t.Logf("segmentNum: %d, b.segmentNum: %d", segmentNum, b.segmentNum())
}
if rs.gracefulClosed { if rs.gracefulClosed {
// only available when graceful closing // only available when graceful closing
assert.Equal(t, b.collectionNum(), len(vchannelMetas)) assert.Equal(t, b.collectionNum(), len(vchannelMetas))
if b.collectionNum() != len(vchannelMetas) {
for _, v := range vchannelMetas {
t.Logf("vchannel: %s, state: %s", v.Vchannel, v.State)
}
for id := range b.collectionIDs {
t.Logf("collectionID: %d, %s", id, b.vchannels[id])
}
}
partitionNum := 0 partitionNum := 0
for _, v := range vchannelMetas { for _, v := range vchannelMetas {
partitionNum += len(v.CollectionInfo.Partitions) partitionNum += len(v.CollectionInfo.Partitions)
@ -245,25 +255,28 @@ func (b *streamBuilder) Build(param BuildRecoveryStreamParam) RecoveryStream {
} }
func (b *streamBuilder) generateStreamMessage() []message.ImmutableMessage { func (b *streamBuilder) generateStreamMessage() []message.ImmutableMessage {
ops := []func() message.ImmutableMessage{ type opRate struct {
b.createCollection, op func() message.ImmutableMessage
b.createPartition, rate int
b.createSegment, }
b.createSegment, opRates := []opRate{
b.dropCollection, {op: b.createCollection, rate: 1},
b.dropPartition, {op: b.createPartition, rate: 1},
b.flushSegment, {op: b.dropCollection, rate: 1},
b.flushSegment, {op: b.dropPartition, rate: 1},
b.createInsert, {op: b.createSegment, rate: 2},
b.createInsert, {op: b.flushSegment, rate: 2},
b.createInsert, {op: b.createInsert, rate: 5},
b.createDelete, {op: b.createDelete, rate: 5},
b.createDelete, {op: b.createTxn, rate: 5},
b.createDelete, {op: b.createManualFlush, rate: 2},
b.createTxn, {op: b.createSchemaChange, rate: 1},
b.createTxn, }
b.createManualFlush, ops := make([]func() message.ImmutableMessage, 0)
b.createSchemaChange, for _, opRate := range opRates {
for i := 0; i < opRate.rate; i++ {
ops = append(ops, opRate.op)
}
} }
msgs := make([]message.ImmutableMessage, 0) msgs := make([]message.ImmutableMessage, 0)
for i := 0; i < int(rand.Int63n(1000)+1000); i++ { for i := 0; i < int(rand.Int63n(1000)+1000); i++ {