milvus/internal/streamingnode/server/wal/recovery/recovery_storage_test.go
Zhen Ye 38c804fb01
fix: more stable recovery graceful closing and stable unittest (#42013)
issue: #41544

Signed-off-by: chyezh <chyezh@outlook.com>
2025-05-23 17:52:26 +08:00

641 lines
19 KiB
Go

package recovery
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
internaltypes "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func TestRecoveryStorage(t *testing.T) {
paramtable.Get().Save(paramtable.Get().StreamingCfg.WALRecoveryPersistInterval.Key, "1ms")
paramtable.Get().Save(paramtable.Get().StreamingCfg.WALRecoveryGracefulCloseTimeout.Key, "10ms")
vchannelMetas := make(map[string]*streamingpb.VChannelMeta)
segmentMetas := make(map[int64]*streamingpb.SegmentAssignmentMeta)
cp := &streamingpb.WALCheckpoint{
MessageId: &messagespb.MessageID{
Id: rmq.NewRmqID(1).Marshal(),
},
TimeTick: 1,
RecoveryMagic: 0,
}
snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t)
snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.SegmentAssignmentMeta, error) {
return lo.MapToSlice(segmentMetas, func(_ int64, v *streamingpb.SegmentAssignmentMeta) *streamingpb.SegmentAssignmentMeta {
return proto.Clone(v).(*streamingpb.SegmentAssignmentMeta)
}), nil
})
snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.VChannelMeta, error) {
return lo.MapToSlice(vchannelMetas, func(_ string, v *streamingpb.VChannelMeta) *streamingpb.VChannelMeta {
return proto.Clone(v).(*streamingpb.VChannelMeta)
}), nil
})
snCatalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[int64]*streamingpb.SegmentAssignmentMeta) error {
if rand.Int31n(3) == 0 {
return errors.New("save failed")
}
for _, v := range m {
if v.State != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED {
segmentMetas[v.SegmentId] = v
} else {
delete(segmentMetas, v.SegmentId)
}
}
return nil
})
snCatalog.EXPECT().SaveVChannels(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[string]*streamingpb.VChannelMeta) error {
if rand.Int31n(3) == 0 {
return errors.New("save failed")
}
for _, v := range m {
if v.State != streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
vchannelMetas[v.Vchannel] = v
} else {
delete(vchannelMetas, v.Vchannel)
}
}
return nil
})
snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(cp, nil)
snCatalog.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pchannelName string, checkpoint *streamingpb.WALCheckpoint) error {
if rand.Int31n(3) == 0 {
return errors.New("save failed")
}
cp = checkpoint
return nil
})
mixCoord := mocks.NewMockMixCoordClient(t)
mixCoord.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(&datapb.DropVirtualChannelResponse{
Status: merr.Success(),
}, nil)
f := syncutil.NewFuture[internaltypes.MixCoordClient]()
f.Set(mixCoord)
resource.InitForTest(t, resource.OptStreamingNodeCatalog(snCatalog), resource.OptMixCoordClient(f))
b := &streamBuilder{
channel: types.PChannelInfo{Name: "test_channel"},
lastConfirmedMessageID: 1,
messageID: 1,
timetick: 1,
collectionIDs: make(map[int64]map[int64]map[int64]struct{}),
vchannels: make(map[int64]string),
idAlloc: 1,
}
msg := message.NewTimeTickMessageBuilderV1().
WithAllVChannel().
WithHeader(&message.TimeTickMessageHeader{}).
WithBody(&msgpb.TimeTickMsg{}).
MustBuildMutable().
WithTimeTick(1).
WithLastConfirmed(rmq.NewRmqID(1)).
IntoImmutableMessage(rmq.NewRmqID(1))
b.generateStreamMessage()
for i := 0; i < 3; i++ {
if i == 2 {
// make sure the checkpoint is saved.
paramtable.Get().Save(paramtable.Get().StreamingCfg.WALRecoveryGracefulCloseTimeout.Key, "1000s")
}
rsInterface, snapshot, err := RecoverRecoveryStorage(context.Background(), b, msg)
rs := rsInterface.(*recoveryStorageImpl)
assert.NoError(t, err)
assert.NotNil(t, rs)
assert.NotNil(t, snapshot)
msgs := b.generateStreamMessage()
for _, msg := range msgs {
rs.ObserveMessage(msg)
}
rs.Close()
var partitionNum int
var collectionNum int
var segmentNum int
for _, v := range rs.vchannels {
if v.meta.State != streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
collectionNum += 1
partitionNum += len(v.meta.CollectionInfo.Partitions)
}
}
for _, v := range rs.segments {
if v.meta.State != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED {
segmentNum += 1
}
}
assert.Equal(t, partitionNum, b.partitionNum())
assert.Equal(t, collectionNum, b.collectionNum())
assert.Equal(t, segmentNum, b.segmentNum())
if b.segmentNum() != segmentNum {
t.Logf("segmentNum: %d, b.segmentNum: %d", segmentNum, b.segmentNum())
}
if rs.gracefulClosed {
// only available when graceful closing
assert.Equal(t, b.collectionNum(), len(vchannelMetas))
if b.collectionNum() != len(vchannelMetas) {
for _, v := range vchannelMetas {
t.Logf("vchannel: %s, state: %s", v.Vchannel, v.State)
}
for id := range b.collectionIDs {
t.Logf("collectionID: %d, %s", id, b.vchannels[id])
}
}
partitionNum := 0
for _, v := range vchannelMetas {
partitionNum += len(v.CollectionInfo.Partitions)
}
assert.Equal(t, b.partitionNum(), partitionNum)
assert.Equal(t, b.segmentNum(), len(segmentMetas))
}
}
}
type streamBuilder struct {
channel types.PChannelInfo
lastConfirmedMessageID int64
messageID int64
timetick uint64
collectionIDs map[int64]map[int64]map[int64]struct{}
vchannels map[int64]string
idAlloc int64
histories []message.ImmutableMessage
}
func (b *streamBuilder) collectionNum() int {
return len(b.collectionIDs)
}
func (b *streamBuilder) partitionNum() int {
partitionNum := 0
for _, partitions := range b.collectionIDs {
partitionNum += len(partitions)
}
return partitionNum
}
func (b *streamBuilder) segmentNum() int {
segmentNum := 0
for _, partitions := range b.collectionIDs {
for _, segments := range partitions {
segmentNum += len(segments)
}
}
return segmentNum
}
func (b *streamBuilder) RWWALImpls() walimpls.WALImpls {
return nil
}
type testRecoveryStream struct {
ch chan message.ImmutableMessage
}
func (ts *testRecoveryStream) Chan() <-chan message.ImmutableMessage {
return ts.ch
}
func (ts *testRecoveryStream) Error() error {
return nil
}
func (ts *testRecoveryStream) TxnBuffer() *utility.TxnBuffer {
return nil
}
func (ts *testRecoveryStream) Close() error {
return nil
}
func (b *streamBuilder) WALName() string {
return "rocksmq"
}
func (b *streamBuilder) Channel() types.PChannelInfo {
return b.channel
}
func (b *streamBuilder) Build(param BuildRecoveryStreamParam) RecoveryStream {
rs := &testRecoveryStream{
ch: make(chan message.ImmutableMessage, len(b.histories)),
}
cp := param.StartCheckpoint
for _, msg := range b.histories {
if cp.LTE(msg.MessageID()) {
rs.ch <- msg
}
}
close(rs.ch)
return rs
}
func (b *streamBuilder) generateStreamMessage() []message.ImmutableMessage {
type opRate struct {
op func() message.ImmutableMessage
rate int
}
opRates := []opRate{
{op: b.createCollection, rate: 1},
{op: b.createPartition, rate: 1},
{op: b.dropCollection, rate: 1},
{op: b.dropPartition, rate: 1},
{op: b.createSegment, rate: 2},
{op: b.flushSegment, rate: 2},
{op: b.createInsert, rate: 5},
{op: b.createDelete, rate: 5},
{op: b.createTxn, rate: 5},
{op: b.createManualFlush, rate: 2},
{op: b.createSchemaChange, rate: 1},
}
ops := make([]func() message.ImmutableMessage, 0)
for _, opRate := range opRates {
for i := 0; i < opRate.rate; i++ {
ops = append(ops, opRate.op)
}
}
msgs := make([]message.ImmutableMessage, 0)
for i := 0; i < int(rand.Int63n(1000)+1000); i++ {
op := rand.Int31n(int32(len(ops)))
if msg := ops[op](); msg != nil {
msgs = append(msgs, msg)
}
}
b.histories = append(b.histories, msgs...)
return msgs
}
// createCollection creates a collection message with the given vchannel.
func (b *streamBuilder) createCollection() message.ImmutableMessage {
vchannel := fmt.Sprintf("vchannel_%d", b.allocID())
collectionID := b.allocID()
partitions := rand.Int31n(1023) + 1
partitionIDs := make(map[int64]map[int64]struct{}, partitions)
for i := int32(0); i < partitions; i++ {
partitionIDs[b.allocID()] = make(map[int64]struct{})
}
b.nextMessage()
b.collectionIDs[collectionID] = partitionIDs
b.vchannels[collectionID] = vchannel
return message.NewCreateCollectionMessageBuilderV1().
WithVChannel(vchannel).
WithHeader(&message.CreateCollectionMessageHeader{
CollectionId: collectionID,
PartitionIds: lo.Keys(partitionIDs),
}).
WithBody(&msgpb.CreateCollectionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
func (b *streamBuilder) createPartition() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
partitionID := b.allocID()
collection[partitionID] = make(map[int64]struct{})
b.nextMessage()
return message.NewCreatePartitionMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.CreatePartitionMessageHeader{
CollectionId: collectionID,
PartitionId: partitionID,
}).
WithBody(&msgpb.CreatePartitionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) createSegment() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID, partition := range collection {
if rand.Int31n(3) < 1 {
continue
}
segmentID := b.allocID()
partition[segmentID] = struct{}{}
b.nextMessage()
return message.NewCreateSegmentMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.CreateSegmentMessageHeader{
CollectionId: collectionID,
SegmentId: segmentID,
PartitionId: partitionID,
StorageVersion: 1,
MaxSegmentSize: 1024,
}).
WithBody(&message.CreateSegmentMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
return nil
}
func (b *streamBuilder) dropCollection() message.ImmutableMessage {
for collectionID := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
delete(b.collectionIDs, collectionID)
return message.NewDropCollectionMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.DropCollectionMessageHeader{
CollectionId: collectionID,
}).
WithBody(&msgpb.DropCollectionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) dropPartition() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID := range collection {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
delete(collection, partitionID)
return message.NewDropPartitionMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: collectionID,
PartitionId: partitionID,
}).
WithBody(&msgpb.DropPartitionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
return nil
}
func (b *streamBuilder) flushSegment() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID := range collection {
if rand.Int31n(3) < 1 {
continue
}
for segmentID := range collection[partitionID] {
if rand.Int31n(4) < 1 {
continue
}
delete(collection[partitionID], segmentID)
b.nextMessage()
return message.NewFlushMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.FlushMessageHeader{
CollectionId: collectionID,
SegmentId: segmentID,
}).
WithBody(&message.FlushMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
}
return nil
}
func (b *streamBuilder) createTxn() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
txnSession := &message.TxnContext{
TxnID: message.TxnID(b.allocID()),
}
begin := message.NewBeginTxnMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.BeginTxnMessageHeader{}).
WithBody(&message.BeginTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithTxnContext(*txnSession).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
builder := message.NewImmutableTxnMessageBuilder(message.MustAsImmutableBeginTxnMessageV2(begin))
for partitionID := range collection {
for segmentID := range collection[partitionID] {
b.nextMessage()
builder.Add(message.NewInsertMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.InsertMessageHeader{
CollectionId: collectionID,
Partitions: []*messagespb.PartitionSegmentAssignment{
{
PartitionId: partitionID,
Rows: uint64(rand.Int31n(100)),
BinarySize: uint64(rand.Int31n(100)),
SegmentAssignment: &messagespb.SegmentAssignment{SegmentId: segmentID},
},
},
}).
WithBody(&msgpb.InsertRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithTxnContext(*txnSession).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID)))
}
}
b.nextMessage()
commit := message.NewCommitTxnMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.CommitTxnMessageHeader{}).
WithBody(&message.CommitTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithTxnContext(*txnSession).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
txnMsg, _ := builder.Build(message.MustAsImmutableCommitTxnMessageV2(commit))
return txnMsg
}
return nil
}
func (b *streamBuilder) createDelete() message.ImmutableMessage {
for collectionID := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
return message.NewDeleteMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.DeleteMessageHeader{
CollectionId: collectionID,
}).
WithBody(&msgpb.DeleteRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) createManualFlush() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
segmentIDs := make([]int64, 0)
for partitionID := range collection {
for segmentID := range collection[partitionID] {
segmentIDs = append(segmentIDs, segmentID)
delete(collection[partitionID], segmentID)
}
}
if len(segmentIDs) == 0 {
continue
}
b.nextMessage()
return message.NewManualFlushMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.ManualFlushMessageHeader{
CollectionId: collectionID,
SegmentIds: segmentIDs,
}).
WithBody(&message.ManualFlushMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) createSchemaChange() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
segmentIDs := make([]int64, 0)
for partitionID := range collection {
for segmentID := range collection[partitionID] {
segmentIDs = append(segmentIDs, segmentID)
delete(collection[partitionID], segmentID)
}
}
b.nextMessage()
return message.NewSchemaChangeMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.SchemaChangeMessageHeader{
CollectionId: collectionID,
FlushedSegmentIds: segmentIDs,
}).
WithBody(&message.SchemaChangeMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) createInsert() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID := range collection {
if rand.Int31n(3) < 1 {
continue
}
for segmentID := range collection[partitionID] {
if rand.Int31n(4) < 2 {
continue
}
b.nextMessage()
return message.NewInsertMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.InsertMessageHeader{
CollectionId: collectionID,
Partitions: []*messagespb.PartitionSegmentAssignment{
{
PartitionId: partitionID,
Rows: uint64(rand.Int31n(100)),
BinarySize: uint64(rand.Int31n(100)),
SegmentAssignment: &messagespb.SegmentAssignment{SegmentId: segmentID},
},
},
}).
WithBody(&msgpb.InsertRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
}
return nil
}
func (b *streamBuilder) nextMessage() {
b.messageID++
if rand.Int31n(3) < 2 {
b.lastConfirmedMessageID = b.messageID + 1
}
b.timetick++
}
func (b *streamBuilder) allocID() int64 {
b.idAlloc++
return b.idAlloc
}