fix: more stable recovery graceful closing and stable unittest (#42013)

issue: #41544

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-05-23 17:52:26 +08:00 committed by GitHub
parent 252d49d01e
commit 38c804fb01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 97 additions and 56 deletions

View File

@ -1027,7 +1027,7 @@ func (sd *shardDelegator) SyncTargetVersion(
}
func (sd *shardDelegator) GetQueryView() *channelQueryView {
return sd.distribution.queryView
return sd.distribution.GetQueryView()
}
func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {

View File

@ -363,6 +363,14 @@ func (d *distribution) SyncTargetVersion(newVersion int64, partitions []int64, g
)
}
// GetQueryView returns the current query view.
func (d *distribution) GetQueryView() *channelQueryView {
d.mut.RLock()
defer d.mut.RUnlock()
return d.queryView
}
// RemoveDistributions remove segments distributions and returns the clear signal channel.
func (d *distribution) RemoveDistributions(sealedSegments []SegmentEntry, growingSegments []SegmentEntry) chan struct{} {
d.mut.Lock()

View File

@ -20,6 +20,17 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
// isDirty checks if the recovery storage mem state is not consistent with the persisted recovery storage.
func (rs *recoveryStorageImpl) isDirty() bool {
if rs.pendingPersistSnapshot != nil {
return true
}
rs.mu.Lock()
defer rs.mu.Unlock()
return rs.dirtyCounter > 0
}
// TODO: !!! all recovery persist operation should be a compare-and-swap operation to
// promise there's only one consumer of wal.
// But currently, we don't implement the CAS operation of meta interface.
@ -28,6 +39,10 @@ func (rs *recoveryStorageImpl) backgroundTask() {
ticker := time.NewTicker(rs.cfg.persistInterval)
defer func() {
ticker.Stop()
rs.Logger().Info("recovery storage background task, perform a graceful exit...")
if err := rs.persistDritySnapshotWhenClosing(); err != nil {
rs.Logger().Warn("failed to persist dirty snapshot when closing", zap.Error(err))
}
rs.backgroundTaskNotifier.Finish(struct{}{})
rs.Logger().Info("recovery storage background task exit")
}()
@ -35,20 +50,11 @@ func (rs *recoveryStorageImpl) backgroundTask() {
for {
select {
case <-rs.backgroundTaskNotifier.Context().Done():
// If the background task is exiting when on-operating persist operation,
// We can try to do a graceful exit.
rs.Logger().Info("recovery storage background task, perform a graceful exit...")
if err := rs.persistDritySnapshotWhenClosing(); err != nil {
rs.Logger().Warn("failed to persist dirty snapshot when closing", zap.Error(err))
return
}
rs.gracefulClosed = true
return // exit the background task
return
case <-rs.persistNotifier:
case <-ticker.C:
}
snapshot := rs.consumeDirtySnapshot()
if err := rs.persistDirtySnapshot(rs.backgroundTaskNotifier.Context(), snapshot, zap.DebugLevel); err != nil {
if err := rs.persistDirtySnapshot(rs.backgroundTaskNotifier.Context(), zap.DebugLevel); err != nil {
return
}
}
@ -59,12 +65,26 @@ func (rs *recoveryStorageImpl) persistDritySnapshotWhenClosing() error {
ctx, cancel := context.WithTimeout(context.Background(), rs.cfg.gracefulTimeout)
defer cancel()
snapshot := rs.consumeDirtySnapshot()
return rs.persistDirtySnapshot(ctx, snapshot, zap.InfoLevel)
for rs.isDirty() {
if err := rs.persistDirtySnapshot(ctx, zap.InfoLevel); err != nil {
return err
}
}
rs.gracefulClosed = true
return nil
}
// persistDirtySnapshot persists the dirty snapshot to the catalog.
func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, snapshot *RecoverySnapshot, lvl zapcore.Level) (err error) {
func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, lvl zapcore.Level) (err error) {
if rs.pendingPersistSnapshot == nil {
// if there's no dirty snapshot, generate a new one.
rs.pendingPersistSnapshot = rs.consumeDirtySnapshot()
}
if rs.pendingPersistSnapshot == nil {
return nil
}
snapshot := rs.pendingPersistSnapshot
rs.metrics.ObserveIsOnPersisting(true)
logger := rs.Logger().With(
zap.String("checkpoint", snapshot.Checkpoint.MessageID.String()),
@ -77,8 +97,9 @@ func (rs *recoveryStorageImpl) persistDirtySnapshot(ctx context.Context, snapsho
logger.Warn("failed to persist dirty snapshot", zap.Error(err))
return
}
rs.pendingPersistSnapshot = nil
logger.Log(lvl, "persist dirty snapshot")
defer rs.metrics.ObserveIsOnPersisting(false)
rs.metrics.ObserveIsOnPersisting(false)
}()
if err := rs.dropAllVirtualChannel(ctx, snapshot.VChannels); err != nil {

View File

@ -85,10 +85,11 @@ type recoveryStorageImpl struct {
flusherCheckpoint *WALCheckpoint
dirtyCounter int // records the message count since last persist snapshot.
// used to trigger the recovery persist operation.
persistNotifier chan struct{}
gracefulClosed bool
truncator *samplingTruncator
metrics *recoveryMetrics
persistNotifier chan struct{}
gracefulClosed bool
truncator *samplingTruncator
metrics *recoveryMetrics
pendingPersistSnapshot *RecoverySnapshot
}
// UpdateFlusherCheckpoint updates the checkpoint of flusher.
@ -134,6 +135,9 @@ func (r *recoveryStorageImpl) notifyPersist() {
func (r *recoveryStorageImpl) consumeDirtySnapshot() *RecoverySnapshot {
r.mu.Lock()
defer r.mu.Unlock()
if r.dirtyCounter == 0 {
return nil
}
segments := make(map[int64]*streamingpb.SegmentAssignmentMeta)
vchannels := make(map[string]*streamingpb.VChannelMeta)
@ -178,16 +182,11 @@ func (r *recoveryStorageImpl) observeMessage(msg message.ImmutableMessage) {
}
r.handleMessage(msg)
checkpointUpdates := !r.checkpoint.MessageID.EQ(msg.LastConfirmedMessageID())
r.checkpoint.TimeTick = msg.TimeTick()
r.checkpoint.MessageID = msg.LastConfirmedMessageID()
r.metrics.ObServeInMemMetrics(r.checkpoint.TimeTick)
if checkpointUpdates {
// only count the dirty if last confirmed message id is updated.
// we always recover from that point, the writeaheadtimetick is just a redundant information.
r.dirtyCounter++
}
r.dirtyCounter++
if r.dirtyCounter > r.cfg.maxDirtyMessages {
r.notifyPersist()
}

View File

@ -10,6 +10,7 @@ import (
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/mocks"
@ -45,15 +46,17 @@ func TestRecoveryStorage(t *testing.T) {
snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t)
snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.SegmentAssignmentMeta, error) {
return lo.Values(segmentMetas), nil
return lo.MapToSlice(segmentMetas, func(_ int64, v *streamingpb.SegmentAssignmentMeta) *streamingpb.SegmentAssignmentMeta {
return proto.Clone(v).(*streamingpb.SegmentAssignmentMeta)
}), nil
})
snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.VChannelMeta, error) {
return lo.Values(vchannelMetas), nil
return lo.MapToSlice(vchannelMetas, func(_ string, v *streamingpb.VChannelMeta) *streamingpb.VChannelMeta {
return proto.Clone(v).(*streamingpb.VChannelMeta)
}), nil
})
segmentSaveFailure := true
snCatalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[int64]*streamingpb.SegmentAssignmentMeta) error {
if segmentSaveFailure {
segmentSaveFailure = false
if rand.Int31n(3) == 0 {
return errors.New("save failed")
}
for _, v := range m {
@ -65,10 +68,8 @@ func TestRecoveryStorage(t *testing.T) {
}
return nil
})
vchannelSaveFailure := true
snCatalog.EXPECT().SaveVChannels(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[string]*streamingpb.VChannelMeta) error {
if vchannelSaveFailure {
vchannelSaveFailure = false
if rand.Int31n(3) == 0 {
return errors.New("save failed")
}
for _, v := range m {
@ -81,10 +82,8 @@ func TestRecoveryStorage(t *testing.T) {
return nil
})
snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(cp, nil)
checkpointSaveFailure := true
snCatalog.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pchannelName string, checkpoint *streamingpb.WALCheckpoint) error {
if checkpointSaveFailure {
checkpointSaveFailure = false
if rand.Int31n(3) == 0 {
return errors.New("save failed")
}
cp = checkpoint
@ -151,10 +150,21 @@ func TestRecoveryStorage(t *testing.T) {
assert.Equal(t, partitionNum, b.partitionNum())
assert.Equal(t, collectionNum, b.collectionNum())
assert.Equal(t, segmentNum, b.segmentNum())
if b.segmentNum() != segmentNum {
t.Logf("segmentNum: %d, b.segmentNum: %d", segmentNum, b.segmentNum())
}
if rs.gracefulClosed {
// only available when graceful closing
assert.Equal(t, b.collectionNum(), len(vchannelMetas))
if b.collectionNum() != len(vchannelMetas) {
for _, v := range vchannelMetas {
t.Logf("vchannel: %s, state: %s", v.Vchannel, v.State)
}
for id := range b.collectionIDs {
t.Logf("collectionID: %d, %s", id, b.vchannels[id])
}
}
partitionNum := 0
for _, v := range vchannelMetas {
partitionNum += len(v.CollectionInfo.Partitions)
@ -245,25 +255,28 @@ func (b *streamBuilder) Build(param BuildRecoveryStreamParam) RecoveryStream {
}
func (b *streamBuilder) generateStreamMessage() []message.ImmutableMessage {
ops := []func() message.ImmutableMessage{
b.createCollection,
b.createPartition,
b.createSegment,
b.createSegment,
b.dropCollection,
b.dropPartition,
b.flushSegment,
b.flushSegment,
b.createInsert,
b.createInsert,
b.createInsert,
b.createDelete,
b.createDelete,
b.createDelete,
b.createTxn,
b.createTxn,
b.createManualFlush,
b.createSchemaChange,
type opRate struct {
op func() message.ImmutableMessage
rate int
}
opRates := []opRate{
{op: b.createCollection, rate: 1},
{op: b.createPartition, rate: 1},
{op: b.dropCollection, rate: 1},
{op: b.dropPartition, rate: 1},
{op: b.createSegment, rate: 2},
{op: b.flushSegment, rate: 2},
{op: b.createInsert, rate: 5},
{op: b.createDelete, rate: 5},
{op: b.createTxn, rate: 5},
{op: b.createManualFlush, rate: 2},
{op: b.createSchemaChange, rate: 1},
}
ops := make([]func() message.ImmutableMessage, 0)
for _, opRate := range opRates {
for i := 0; i < opRate.rate; i++ {
ops = append(ops, opRate.op)
}
}
msgs := make([]message.ImmutableMessage, 0)
for i := 0; i < int(rand.Int63n(1000)+1000); i++ {