From 38c804fb017b397f1e033b4683f0c07447dc5dfe Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Fri, 23 May 2025 17:52:26 +0800 Subject: [PATCH] fix: more stable recovery graceful closing and stable unittest (#42013) issue: #41544 Signed-off-by: chyezh --- .../querynodev2/delegator/delegator_data.go | 2 +- .../querynodev2/delegator/distribution.go | 8 ++ .../wal/recovery/recovery_background_task.go | 51 +++++++++---- .../wal/recovery/recovery_storage_impl.go | 19 +++-- .../wal/recovery/recovery_storage_test.go | 73 +++++++++++-------- 5 files changed, 97 insertions(+), 56 deletions(-) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index a612abec28..6d875acbeb 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -1027,7 +1027,7 @@ func (sd *shardDelegator) SyncTargetVersion( } func (sd *shardDelegator) GetQueryView() *channelQueryView { - return sd.distribution.queryView + return sd.distribution.GetQueryView() } func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) { diff --git a/internal/querynodev2/delegator/distribution.go b/internal/querynodev2/delegator/distribution.go index 1d5bb23181..dc4fd16b90 100644 --- a/internal/querynodev2/delegator/distribution.go +++ b/internal/querynodev2/delegator/distribution.go @@ -363,6 +363,14 @@ func (d *distribution) SyncTargetVersion(newVersion int64, partitions []int64, g ) } +// GetQueryView returns the current query view. +func (d *distribution) GetQueryView() *channelQueryView { + d.mut.RLock() + defer d.mut.RUnlock() + + return d.queryView +} + // RemoveDistributions remove segments distributions and returns the clear signal channel. func (d *distribution) RemoveDistributions(sealedSegments []SegmentEntry, growingSegments []SegmentEntry) chan struct{} { d.mut.Lock() diff --git a/internal/streamingnode/server/wal/recovery/recovery_background_task.go b/internal/streamingnode/server/wal/recovery/recovery_background_task.go index e912bdd81f..9e491320d1 100644 --- a/internal/streamingnode/server/wal/recovery/recovery_background_task.go +++ b/internal/streamingnode/server/wal/recovery/recovery_background_task.go @@ -20,6 +20,17 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) +// isDirty checks if the recovery storage mem state is not consistent with the persisted recovery storage. +func (rs *recoveryStorageImpl) isDirty() bool { + if rs.pendingPersistSnapshot != nil { + return true + } + + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.dirtyCounter > 0 +} + // TODO: !!! all recovery persist operation should be a compare-and-swap operation to // promise there's only one consumer of wal. // But currently, we don't implement the CAS operation of meta interface. @@ -28,6 +39,10 @@ func (rs *recoveryStorageImpl) backgroundTask() { ticker := time.NewTicker(rs.cfg.persistInterval) defer func() { ticker.Stop() + rs.Logger().Info("recovery storage background task, perform a graceful exit...") + if err := rs.persistDritySnapshotWhenClosing(); err != nil { + rs.Logger().Warn("failed to persist dirty snapshot when closing", zap.Error(err)) + } rs.backgroundTaskNotifier.Finish(struct{}{}) rs.Logger().Info("recovery storage background task exit") }() @@ -35,20 +50,11 @@ func (rs *recoveryStorageImpl) backgroundTask() { for { select { case <-rs.backgroundTaskNotifier.Context().Done(): - // If the background task is exiting when on-operating persist operation, - // We can try to do a graceful exit. - rs.Logger().Info("recovery storage background task, perform a graceful exit...") - if err := rs.persistDritySnapshotWhenClosing(); err != nil { - rs.Logger().Warn("failed to persist dirty snapshot when closing", zap.Error(err)) - return - } - rs.gracefulClosed = true - return // exit the background task + return case <-rs.persistNotifier: case <-ticker.C: } - snapshot := rs.consumeDirtySnapshot() - if err := rs.persistDirtySnapshot(rs.backgroundTaskNotifier.Context(), snapshot, zap.DebugLevel); err != nil { + if err := rs.persistDirtySnapshot(rs.backgroundTaskNotifier.Context(), zap.DebugLevel); err != nil { return } } @@ -59,12 +65,26 @@ func (rs *recoveryStorageImpl) persistDritySnapshotWhenClosing() error { ctx, cancel := context.WithTimeout(context.Background(), rs.cfg.gracefulTimeout) defer cancel() - snapshot := rs.consumeDirtySnapshot() - return rs.persistDirtySnapshot(ctx, snapshot, zap.InfoLevel) + for rs.isDirty() { + if err := rs.persistDirtySnapshot(ctx, zap.InfoLevel); err != nil { + return err + } + } + rs.gracefulClosed = true + return nil } // persistDirtySnapshot persists the dirty snapshot to the catalog. -func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, snapshot *RecoverySnapshot, lvl zapcore.Level) (err error) { +func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, lvl zapcore.Level) (err error) { + if rs.pendingPersistSnapshot == nil { + // if there's no dirty snapshot, generate a new one. + rs.pendingPersistSnapshot = rs.consumeDirtySnapshot() + } + if rs.pendingPersistSnapshot == nil { + return nil + } + + snapshot := rs.pendingPersistSnapshot rs.metrics.ObserveIsOnPersisting(true) logger := rs.Logger().With( zap.String("checkpoint", snapshot.Checkpoint.MessageID.String()), @@ -77,8 +97,9 @@ func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, snapsho logger.Warn("failed to persist dirty snapshot", zap.Error(err)) return } + rs.pendingPersistSnapshot = nil logger.Log(lvl, "persist dirty snapshot") - defer rs.metrics.ObserveIsOnPersisting(false) + rs.metrics.ObserveIsOnPersisting(false) }() if err := rs.dropAllVirtualChannel(ctx, snapshot.VChannels); err != nil { diff --git a/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go b/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go index dbc6cdcfab..3587e15747 100644 --- a/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go +++ b/internal/streamingnode/server/wal/recovery/recovery_storage_impl.go @@ -85,10 +85,11 @@ type recoveryStorageImpl struct { flusherCheckpoint *WALCheckpoint dirtyCounter int // records the message count since last persist snapshot. // used to trigger the recovery persist operation. - persistNotifier chan struct{} - gracefulClosed bool - truncator *samplingTruncator - metrics *recoveryMetrics + persistNotifier chan struct{} + gracefulClosed bool + truncator *samplingTruncator + metrics *recoveryMetrics + pendingPersistSnapshot *RecoverySnapshot } // UpdateFlusherCheckpoint updates the checkpoint of flusher. @@ -134,6 +135,9 @@ func (r *recoveryStorageImpl) notifyPersist() { func (r *recoveryStorageImpl) consumeDirtySnapshot() *RecoverySnapshot { r.mu.Lock() defer r.mu.Unlock() + if r.dirtyCounter == 0 { + return nil + } segments := make(map[int64]*streamingpb.SegmentAssignmentMeta) vchannels := make(map[string]*streamingpb.VChannelMeta) @@ -178,16 +182,11 @@ func (r *recoveryStorageImpl) observeMessage(msg message.ImmutableMessage) { } r.handleMessage(msg) - checkpointUpdates := !r.checkpoint.MessageID.EQ(msg.LastConfirmedMessageID()) r.checkpoint.TimeTick = msg.TimeTick() r.checkpoint.MessageID = msg.LastConfirmedMessageID() r.metrics.ObServeInMemMetrics(r.checkpoint.TimeTick) - if checkpointUpdates { - // only count the dirty if last confirmed message id is updated. - // we always recover from that point, the writeaheadtimetick is just a redundant information. - r.dirtyCounter++ - } + r.dirtyCounter++ if r.dirtyCounter > r.cfg.maxDirtyMessages { r.notifyPersist() } diff --git a/internal/streamingnode/server/wal/recovery/recovery_storage_test.go b/internal/streamingnode/server/wal/recovery/recovery_storage_test.go index 7b7856d47c..75737eacc8 100644 --- a/internal/streamingnode/server/wal/recovery/recovery_storage_test.go +++ b/internal/streamingnode/server/wal/recovery/recovery_storage_test.go @@ -10,6 +10,7 @@ import ( "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/mocks" @@ -45,15 +46,17 @@ func TestRecoveryStorage(t *testing.T) { snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t) snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.SegmentAssignmentMeta, error) { - return lo.Values(segmentMetas), nil + return lo.MapToSlice(segmentMetas, func(_ int64, v *streamingpb.SegmentAssignmentMeta) *streamingpb.SegmentAssignmentMeta { + return proto.Clone(v).(*streamingpb.SegmentAssignmentMeta) + }), nil }) snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.VChannelMeta, error) { - return lo.Values(vchannelMetas), nil + return lo.MapToSlice(vchannelMetas, func(_ string, v *streamingpb.VChannelMeta) *streamingpb.VChannelMeta { + return proto.Clone(v).(*streamingpb.VChannelMeta) + }), nil }) - segmentSaveFailure := true snCatalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[int64]*streamingpb.SegmentAssignmentMeta) error { - if segmentSaveFailure { - segmentSaveFailure = false + if rand.Int31n(3) == 0 { return errors.New("save failed") } for _, v := range m { @@ -65,10 +68,8 @@ func TestRecoveryStorage(t *testing.T) { } return nil }) - vchannelSaveFailure := true snCatalog.EXPECT().SaveVChannels(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[string]*streamingpb.VChannelMeta) error { - if vchannelSaveFailure { - vchannelSaveFailure = false + if rand.Int31n(3) == 0 { return errors.New("save failed") } for _, v := range m { @@ -81,10 +82,8 @@ func TestRecoveryStorage(t *testing.T) { return nil }) snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(cp, nil) - checkpointSaveFailure := true snCatalog.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pchannelName string, checkpoint *streamingpb.WALCheckpoint) error { - if checkpointSaveFailure { - checkpointSaveFailure = false + if rand.Int31n(3) == 0 { return errors.New("save failed") } cp = checkpoint @@ -151,10 +150,21 @@ func TestRecoveryStorage(t *testing.T) { assert.Equal(t, partitionNum, b.partitionNum()) assert.Equal(t, collectionNum, b.collectionNum()) assert.Equal(t, segmentNum, b.segmentNum()) + if b.segmentNum() != segmentNum { + t.Logf("segmentNum: %d, b.segmentNum: %d", segmentNum, b.segmentNum()) + } if rs.gracefulClosed { // only available when graceful closing assert.Equal(t, b.collectionNum(), len(vchannelMetas)) + if b.collectionNum() != len(vchannelMetas) { + for _, v := range vchannelMetas { + t.Logf("vchannel: %s, state: %s", v.Vchannel, v.State) + } + for id := range b.collectionIDs { + t.Logf("collectionID: %d, %s", id, b.vchannels[id]) + } + } partitionNum := 0 for _, v := range vchannelMetas { partitionNum += len(v.CollectionInfo.Partitions) @@ -245,25 +255,28 @@ func (b *streamBuilder) Build(param BuildRecoveryStreamParam) RecoveryStream { } func (b *streamBuilder) generateStreamMessage() []message.ImmutableMessage { - ops := []func() message.ImmutableMessage{ - b.createCollection, - b.createPartition, - b.createSegment, - b.createSegment, - b.dropCollection, - b.dropPartition, - b.flushSegment, - b.flushSegment, - b.createInsert, - b.createInsert, - b.createInsert, - b.createDelete, - b.createDelete, - b.createDelete, - b.createTxn, - b.createTxn, - b.createManualFlush, - b.createSchemaChange, + type opRate struct { + op func() message.ImmutableMessage + rate int + } + opRates := []opRate{ + {op: b.createCollection, rate: 1}, + {op: b.createPartition, rate: 1}, + {op: b.dropCollection, rate: 1}, + {op: b.dropPartition, rate: 1}, + {op: b.createSegment, rate: 2}, + {op: b.flushSegment, rate: 2}, + {op: b.createInsert, rate: 5}, + {op: b.createDelete, rate: 5}, + {op: b.createTxn, rate: 5}, + {op: b.createManualFlush, rate: 2}, + {op: b.createSchemaChange, rate: 1}, + } + ops := make([]func() message.ImmutableMessage, 0) + for _, opRate := range opRates { + for i := 0; i < opRate.rate; i++ { + ops = append(ops, opRate.op) + } } msgs := make([]message.ImmutableMessage, 0) for i := 0; i < int(rand.Int63n(1000)+1000); i++ {