diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 6b2ddb0597..8af2266060 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -612,6 +612,13 @@ func (s *Server) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtual // validate nodeID := req.GetBase().GetSourceID() if !s.channelManager.Match(nodeID, channel) { + if streamingutil.IsStreamingServiceEnabled() { + // If streaming service is enabled, the channel manager will always return true if channel exist. + // once the channel is not exist, the drop virtual channel has been done. + return &datapb.DropVirtualChannelResponse{ + Status: merr.Success(), + }, nil + } err := merr.WrapErrChannelNotFound(channel, fmt.Sprintf("for node %d", nodeID)) resp.Status = merr.Status(err) log.Warn("node is not matched with channel", zap.String("channel", channel), zap.Int64("nodeID", nodeID)) diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index 4d8d99be7e..41f7703573 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/vchantempstore" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/log" @@ -43,6 +44,7 @@ func OptRootCoordClient(rootCoordClient *syncutil.Future[types.RootCoordClient]) r.rootCoordClient = rootCoordClient r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) + r.vchannelTempStorage = vchantempstore.NewVChannelTempStorage(r.rootCoordClient) } } @@ -99,6 +101,7 @@ type resourceImpl struct { streamingNodeCatalog metastore.StreamingNodeCataLog segmentAssignStatsManager *stats.StatsManager timeTickInspector tinspector.TimeTickSyncInspector + vchannelTempStorage *vchantempstore.VChannelTempStorage } // TSOAllocator returns the timestamp allocator to allocate timestamp. @@ -145,6 +148,11 @@ func (r *resourceImpl) TimeTickInspector() tinspector.TimeTickSyncInspector { return r.timeTickInspector } +// VChannelTempStorage returns the vchannel temp storage. +func (r *resourceImpl) VChannelTempStorage() *vchantempstore.VChannelTempStorage { + return r.vchannelTempStorage +} + func (r *resourceImpl) Logger() *log.MLogger { return r.logger } diff --git a/internal/streamingnode/server/resource/resource_test.go b/internal/streamingnode/server/resource/resource_test.go index 8d5feb6817..9d0af63890 100644 --- a/internal/streamingnode/server/resource/resource_test.go +++ b/internal/streamingnode/server/resource/resource_test.go @@ -1,6 +1,7 @@ package resource import ( + "os" "testing" "github.com/stretchr/testify/assert" @@ -13,9 +14,12 @@ import ( "github.com/milvus-io/milvus/pkg/util/syncutil" ) -func TestApply(t *testing.T) { +func TestMain(m *testing.M) { paramtable.Init() + os.Exit(m.Run()) +} +func TestApply(t *testing.T) { Apply() Apply(OptETCD(&clientv3.Client{})) Apply(OptRootCoordClient(syncutil.NewFuture[types.RootCoordClient]())) diff --git a/internal/streamingnode/server/service/handler/producer/produce_server.go b/internal/streamingnode/server/service/handler/producer/produce_server.go index 5e0b862d27..0889e00b79 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server.go @@ -204,9 +204,6 @@ func (p *ProduceServer) handleProduce(req *streamingpb.ProduceMessageRequest) { // validateMessage validates the message. func (p *ProduceServer) validateMessage(msg message.MutableMessage) error { // validate the msg. - if !msg.Version().GT(message.VersionOld) { - return status.NewInvaildArgument("unsupported message version") - } if !msg.MessageType().Valid() { return status.NewInvaildArgument("unsupported message type") } diff --git a/internal/streamingnode/server/wal/adaptor/old_version_message.go b/internal/streamingnode/server/wal/adaptor/old_version_message.go new file mode 100644 index 0000000000..c30ce5179a --- /dev/null +++ b/internal/streamingnode/server/wal/adaptor/old_version_message.go @@ -0,0 +1,217 @@ +package adaptor + +import ( + "context" + "fmt" + + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/pkg/mq/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" +) + +// newOldVersionImmutableMessage creates a new immutable message from the old version message. +// Because some old version message didn't have vchannel, so we need to recognize it from the pchnnel and some data field. +func newOldVersionImmutableMessage( + ctx context.Context, + pchannel string, + lastConfirmedMessageID message.MessageID, + msg message.ImmutableMessage, +) (message.ImmutableMessage, error) { + if msg.Version() != message.VersionOld { + panic("invalid message version") + } + msgType, err := common.GetMsgTypeFromRaw(msg.Payload(), msg.Properties().ToRawMap()) + if err != nil { + panic(fmt.Sprintf("failed to get message type: %v", err)) + } + tsMsg, err := adaptor.UnmashalerDispatcher.Unmarshal(msg.Payload(), msgType) + if err != nil { + panic(fmt.Sprintf("failed to unmarshal message: %v", err)) + } + + // We will transfer it from v0 into v1 here to make it can be consumed by streaming service. + // It will lose some performance, but there should always a little amount of old version message, so it should be ok. + var mutableMessage message.MutableMessage + switch underlyingMsg := tsMsg.(type) { + case *msgstream.CreateCollectionMsg: + mutableMessage = newV1CreateCollectionMsgFromV0(pchannel, underlyingMsg) + case *msgstream.DropCollectionMsg: + mutableMessage, err = newV1DropCollectionMsgFromV0(ctx, pchannel, underlyingMsg) + case *msgstream.InsertMsg: + mutableMessage = newV1InsertMsgFromV0(underlyingMsg, uint64(len(msg.Payload()))) + case *msgstream.DeleteMsg: + mutableMessage = newV1DeleteMsgFromV0(underlyingMsg) + case *msgstream.TimeTickMsg: + mutableMessage = newV1TimeTickMsgFromV0(underlyingMsg) + case *msgstream.CreatePartitionMsg: + mutableMessage, err = newV1CreatePartitionMessageV0(ctx, pchannel, underlyingMsg) + case *msgstream.DropPartitionMsg: + mutableMessage, err = newV1DropPartitionMessageV0(ctx, pchannel, underlyingMsg) + case *msgstream.ImportMsg: + mutableMessage, err = newV1ImportMsgFromV0(ctx, pchannel, underlyingMsg) + default: + panic("unsupported message type") + } + if err != nil { + return nil, err + } + return mutableMessage.WithLastConfirmed(lastConfirmedMessageID).IntoImmutableMessage(msg.MessageID()), nil +} + +// newV1CreateCollectionMsgFromV0 creates a new create collection message from the old version create collection message. +func newV1CreateCollectionMsgFromV0(pchannel string, msg *msgstream.CreateCollectionMsg) message.MutableMessage { + var vchannel string + for idx, v := range msg.PhysicalChannelNames { + if v == pchannel { + vchannel = msg.VirtualChannelNames[idx] + break + } + } + if vchannel == "" { + panic(fmt.Sprintf("vchannel not found at create collection message, collection id: %d, pchannel: %s", msg.CollectionID, pchannel)) + } + + mutableMessage, err := message.NewCreateCollectionMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.CreateCollectionMessageHeader{ + CollectionId: msg.CollectionID, + PartitionIds: msg.PartitionIDs, + }). + WithBody(msg.CreateCollectionRequest). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()) +} + +// newV1DropCollectionMsgFromV0 creates a new drop collection message from the old version drop collection message. +func newV1DropCollectionMsgFromV0(ctx context.Context, pchannel string, msg *msgstream.DropCollectionMsg) (message.MutableMessage, error) { + vchannel, err := resource.Resource().VChannelTempStorage().GetVChannelByPChannelOfCollection(ctx, msg.CollectionID, pchannel) + if err != nil { + return nil, err + } + + mutableMessage, err := message.NewDropCollectionMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.DropCollectionMessageHeader{ + CollectionId: msg.CollectionID, + }). + WithBody(msg.DropCollectionRequest). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()), nil +} + +// newV1InsertMsgFromV0 creates a new insert message from the old version insert message. +func newV1InsertMsgFromV0(msg *msgstream.InsertMsg, binarySize uint64) message.MutableMessage { + mutableMessage, err := message.NewInsertMessageBuilderV1(). + WithVChannel(msg.ShardName). + WithHeader(&message.InsertMessageHeader{ + CollectionId: msg.CollectionID, + Partitions: []*message.PartitionSegmentAssignment{{ + PartitionId: msg.PartitionID, + Rows: msg.NumRows, + BinarySize: binarySize, + SegmentAssignment: &message.SegmentAssignment{ + SegmentId: msg.SegmentID, + }, + }}, + }). + WithBody(msg.InsertRequest). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()) +} + +// newV1DeleteMsgFromV0 creates a new delete message from the old version delete message. +func newV1DeleteMsgFromV0(msg *msgstream.DeleteMsg) message.MutableMessage { + mutableMessage, err := message.NewDeleteMessageBuilderV1(). + WithVChannel(msg.ShardName). + WithHeader(&message.DeleteMessageHeader{ + CollectionId: msg.CollectionID, + }). + WithBody(msg.DeleteRequest). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()) +} + +// newV1TimeTickMsgFromV0 creates a new time tick message from the old version time tick message. +func newV1TimeTickMsgFromV0(msg *msgstream.TimeTickMsg) message.MutableMessage { + mutableMessage, err := message.NewTimeTickMessageBuilderV1(). + WithAllVChannel(). + WithHeader(&message.TimeTickMessageHeader{}). + WithBody(msg.TimeTickMsg). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()) +} + +// newV1CreatePartitionMessageV0 creates a new create partition message from the old version create partition message. +func newV1CreatePartitionMessageV0(ctx context.Context, pchannel string, msg *msgstream.CreatePartitionMsg) (message.MutableMessage, error) { + vchannel, err := resource.Resource().VChannelTempStorage().GetVChannelByPChannelOfCollection(ctx, msg.CollectionID, pchannel) + if err != nil { + return nil, err + } + + mutableMessage, err := message.NewCreatePartitionMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.CreatePartitionMessageHeader{ + CollectionId: msg.CollectionID, + PartitionId: msg.PartitionID, + }). + WithBody(msg.CreatePartitionRequest). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()), nil +} + +// newV1DropPartitionMessageV0 creates a new drop partition message from the old version drop partition message. +func newV1DropPartitionMessageV0(ctx context.Context, pchannel string, msg *msgstream.DropPartitionMsg) (message.MutableMessage, error) { + vchannel, err := resource.Resource().VChannelTempStorage().GetVChannelByPChannelOfCollection(ctx, msg.CollectionID, pchannel) + if err != nil { + return nil, err + } + mutableMessage, err := message.NewDropPartitionMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.DropPartitionMessageHeader{ + CollectionId: msg.CollectionID, + PartitionId: msg.PartitionID, + }). + WithBody(msg.DropPartitionRequest). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()), nil +} + +// newV1ImportMsgFromV0 creates a new import message from the old version import message. +func newV1ImportMsgFromV0(ctx context.Context, pchannel string, msg *msgstream.ImportMsg) (message.MutableMessage, error) { + vchannel, err := resource.Resource().VChannelTempStorage().GetVChannelByPChannelOfCollection(ctx, msg.CollectionID, pchannel) + if err != nil { + return nil, err + } + mutableMessage, err := message.NewImportMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.ImportMessageHeader{}). + WithBody(msg.ImportMsg). + BuildMutable() + if err != nil { + panic(err) + } + return mutableMessage.WithTimeTick(msg.BeginTs()), nil +} diff --git a/internal/streamingnode/server/wal/adaptor/old_version_message_test.go b/internal/streamingnode/server/wal/adaptor/old_version_message_test.go new file mode 100644 index 0000000000..66c247e94c --- /dev/null +++ b/internal/streamingnode/server/wal/adaptor/old_version_message_test.go @@ -0,0 +1,217 @@ +package adaptor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +func TestNewOldVersionImmutableMessage(t *testing.T) { + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + PhysicalChannelNames: []string{"test1", "test2"}, + VirtualChannelNames: []string{"test1-v0", "test2-v0"}, + }, nil) + rcf := syncutil.NewFuture[types.RootCoordClient]() + rcf.Set(rc) + resource.InitForTest(t, resource.OptRootCoordClient(rcf)) + + ctx := context.Background() + pchannel := "test1" + lastConfirmedMessageID := walimplstest.NewTestMessageID(1) + messageID := walimplstest.NewTestMessageID(2) + tt := uint64(10086) + + // createCollectionMsg + createCollectionMsgV0 := msgpb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateCollection, + Timestamp: tt, + }, + CollectionID: 1, + PhysicalChannelNames: []string{"test1", "test2"}, + VirtualChannelNames: []string{"test1-v0", "test2-v0"}, + PartitionIDs: []int64{1}, + } + payload, _ := proto.Marshal(&createCollectionMsgV0) + + msg, err := newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.NotNil(t, msg.LastConfirmedMessageID()) + assert.Equal(t, msg.VChannel(), "test1-v0") + assert.Equal(t, msg.TimeTick(), tt) + createCollectionMsgV1, err := message.AsImmutableCreateCollectionMessageV1(msg) + assert.NoError(t, err) + assert.Equal(t, createCollectionMsgV1.Header().CollectionId, int64(1)) + assert.Equal(t, createCollectionMsgV1.Header().PartitionIds, []int64{1}) + + // dropCollectionMsg + dropCollectionMsgV0 := msgpb.DropCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropCollection, + Timestamp: tt, + }, + CollectionID: 1, + } + payload, _ = proto.Marshal(&dropCollectionMsgV0) + msg, err = newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.True(t, msg.MessageID().EQ(messageID)) + assert.True(t, msg.LastConfirmedMessageID().EQ(lastConfirmedMessageID)) + assert.Equal(t, msg.VChannel(), "test1-v0") + assert.Equal(t, msg.TimeTick(), tt) + dropCollectionMsgV1, err := message.AsImmutableDropCollectionMessageV1(msg) + assert.NoError(t, err) + assert.Equal(t, dropCollectionMsgV1.Header().CollectionId, int64(1)) + + // insertMsg + insertMsgV0 := msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + Timestamp: tt, + }, + Timestamps: []uint64{10086}, + CollectionID: 1, + PartitionID: 2, + NumRows: 102, + SegmentID: 100, + ShardName: "test1-v0", + } + payload, _ = proto.Marshal(&insertMsgV0) + msg, err = newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.True(t, msg.MessageID().EQ(messageID)) + assert.True(t, msg.LastConfirmedMessageID().EQ(lastConfirmedMessageID)) + assert.Equal(t, msg.VChannel(), "test1-v0") + assert.Equal(t, msg.TimeTick(), tt) + insertMsgV1, err := message.AsImmutableInsertMessageV1(msg) + assert.NoError(t, err) + assert.Equal(t, insertMsgV1.Header().CollectionId, int64(1)) + assert.Equal(t, insertMsgV1.Header().Partitions[0].PartitionId, int64(2)) + assert.Equal(t, insertMsgV1.Header().Partitions[0].SegmentAssignment.SegmentId, int64(100)) + assert.NotZero(t, insertMsgV1.Header().Partitions[0].BinarySize) + assert.Equal(t, insertMsgV1.Header().Partitions[0].Rows, uint64(102)) + + // deleteMsg + deleteMsgV0 := msgpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + Timestamp: tt, + }, + Timestamps: []uint64{10086}, + CollectionID: 1, + PartitionID: 2, + NumRows: 102, + ShardName: "test1-v0", + } + payload, _ = proto.Marshal(&deleteMsgV0) + msg, err = newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.True(t, msg.MessageID().EQ(messageID)) + assert.True(t, msg.LastConfirmedMessageID().EQ(lastConfirmedMessageID)) + assert.Equal(t, msg.VChannel(), "test1-v0") + assert.Equal(t, msg.TimeTick(), tt) + deleteMsgV1, err := message.AsImmutableDeleteMessageV1(msg) + assert.NoError(t, err) + assert.Equal(t, deleteMsgV1.Header().CollectionId, int64(1)) + + // timetickSyncMsg + timetickSyncMsgV0 := msgpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_TimeTick, + Timestamp: tt, + }, + Timestamps: []uint64{10086}, + CollectionID: 1, + PartitionID: 2, + NumRows: 102, + ShardName: "test1-v0", + } + payload, _ = proto.Marshal(&timetickSyncMsgV0) + msg, err = newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.True(t, msg.MessageID().EQ(messageID)) + assert.True(t, msg.LastConfirmedMessageID().EQ(lastConfirmedMessageID)) + assert.Equal(t, msg.VChannel(), "") + assert.Equal(t, msg.TimeTick(), tt) + _, err = message.AsImmutableTimeTickMessageV1(msg) + assert.NoError(t, err) + + // createPartitionMsg + createPartitionMsgV0 := msgpb.CreatePartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreatePartition, + Timestamp: tt, + }, + CollectionID: 1, + PartitionID: 2, + } + payload, _ = proto.Marshal(&createPartitionMsgV0) + + msg, err = newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.True(t, msg.MessageID().EQ(messageID)) + assert.True(t, msg.LastConfirmedMessageID().EQ(lastConfirmedMessageID)) + assert.Equal(t, msg.VChannel(), "test1-v0") + assert.Equal(t, msg.TimeTick(), tt) + createPartitionMsgV1, err := message.AsImmutableCreatePartitionMessageV1(msg) + assert.NoError(t, err) + assert.Equal(t, createPartitionMsgV1.Header().CollectionId, int64(1)) + assert.Equal(t, createPartitionMsgV1.Header().PartitionId, int64(2)) + + // dropPartitionMsg + dropPartitionMsgV0 := msgpb.DropPartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropPartition, + Timestamp: tt, + }, + CollectionID: 1, + PartitionID: 2, + } + payload, _ = proto.Marshal(&dropPartitionMsgV0) + msg, err = newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.True(t, msg.MessageID().EQ(messageID)) + assert.True(t, msg.LastConfirmedMessageID().EQ(lastConfirmedMessageID)) + assert.Equal(t, msg.VChannel(), "test1-v0") + assert.Equal(t, msg.TimeTick(), tt) + dropPartitionMsgV1, err := message.AsImmutableDropPartitionMessageV1(msg) + assert.NoError(t, err) + assert.Equal(t, createPartitionMsgV1.Header().CollectionId, int64(1)) + assert.Equal(t, dropPartitionMsgV1.Header().PartitionId, int64(2)) + + // ImportMsg + ImportMsgV0 := msgpb.ImportMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Import, + Timestamp: tt, + }, + CollectionID: 1, + } + payload, _ = proto.Marshal(&ImportMsgV0) + msg, err = newOldVersionImmutableMessage(ctx, pchannel, lastConfirmedMessageID, message.NewImmutableMesasge(messageID, payload, map[string]string{})) + assert.NoError(t, err) + assert.True(t, msg.MessageID().EQ(messageID)) + assert.True(t, msg.LastConfirmedMessageID().EQ(lastConfirmedMessageID)) + assert.Equal(t, msg.VChannel(), "test1-v0") + assert.Equal(t, msg.TimeTick(), tt) + ImportMsgV1, err := message.AsImmutableImportMessageV1(msg) + assert.NoError(t, err) + assert.NotNil(t, ImportMsgV1) +} diff --git a/internal/streamingnode/server/wal/adaptor/scanner_switchable.go b/internal/streamingnode/server/wal/adaptor/scanner_switchable.go index 701bad275a..2a1ddd1789 100644 --- a/internal/streamingnode/server/wal/adaptor/scanner_switchable.go +++ b/internal/streamingnode/server/wal/adaptor/scanner_switchable.go @@ -8,6 +8,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/wab" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/vchantempstore" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/options" @@ -73,8 +74,9 @@ func (s *switchableScannerImpl) HandleMessage(ctx context.Context, msg message.I // catchupScanner is a scanner that make a read at underlying wal, and try to catchup the writeahead buffer then switch to tailing mode. type catchupScanner struct { switchableScannerImpl - deliverPolicy options.DeliverPolicy - exclusiveStartTimeTick uint64 // scanner should filter out the message that less than or equal to this time tick. + deliverPolicy options.DeliverPolicy + exclusiveStartTimeTick uint64 // scanner should filter out the message that less than or equal to this time tick. + lastConfirmedMessageIDForOldVersion message.MessageID } func (s *catchupScanner) Mode() string { @@ -110,6 +112,36 @@ func (s *catchupScanner) consumeWithScanner(ctx context.Context, scanner walimpl if !ok { return nil, scanner.Error() } + + if msg.Version() == message.VersionOld { + if s.lastConfirmedMessageIDForOldVersion == nil { + s.logger.Info( + "scanner find a old version message, set it as the last confirmed message id for all old version message", + zap.Stringer("messageID", msg.MessageID()), + ) + s.lastConfirmedMessageIDForOldVersion = msg.MessageID() + } + // We always use first consumed message as the last confirmed message id for old version message. + // After upgrading from old milvus: + // The wal will be read at consuming side as following: + // msgv0, msgv0 ..., msgv0, msgv1, msgv1, msgv1, ... + // the msgv1 will be read after all msgv0 is consumed as soon as possible. + // so the last confirm is set to the first msgv0 message for all old version message is ok. + var err error + msg, err = newOldVersionImmutableMessage(ctx, s.innerWAL.Channel().Name, s.lastConfirmedMessageIDForOldVersion, msg) + if errors.Is(err, vchantempstore.ErrNotFound) { + // Skip the message's vchannel is not found in the vchannel temp store. + s.logger.Info("skip the old version message because vchannel not found", zap.Stringer("messageID", msg.MessageID())) + continue + } + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return nil, err + } + if err != nil { + panic("unrechable: unexpected error found: " + err.Error()) + } + } + if msg.TimeTick() <= s.exclusiveStartTimeTick { // we should filter out the message that less than or equal to this time tick to remove duplicate message // when we switch from tailing mode to catchup mode. diff --git a/internal/streamingnode/server/wal/interceptors/wab/pending_queue.go b/internal/streamingnode/server/wal/interceptors/wab/pending_queue.go index 26a98ff544..6d4f5d656a 100644 --- a/internal/streamingnode/server/wal/interceptors/wab/pending_queue.go +++ b/internal/streamingnode/server/wal/interceptors/wab/pending_queue.go @@ -65,9 +65,6 @@ func (q *pendingQueue) CurrentOffset() int { // push adds a message to the buffer. func (q *pendingQueue) pushOne(msg message.ImmutableMessage, now time.Time) { - if msg.Version().EQ(message.VersionOld) { - panic("old message version is not supported") - } if (msg.MessageType() == message.MessageTypeTimeTick && msg.TimeTick() < q.lastTimeTick) || (msg.MessageType() != message.MessageTypeTimeTick && msg.TimeTick() <= q.lastTimeTick) { // only timetick message can be repeated with the last time tick. diff --git a/internal/streamingnode/server/wal/vchantempstore/vchannel_temp_storage.go b/internal/streamingnode/server/wal/vchantempstore/vchannel_temp_storage.go new file mode 100644 index 0000000000..9b0c5a68ff --- /dev/null +++ b/internal/streamingnode/server/wal/vchantempstore/vchannel_temp_storage.go @@ -0,0 +1,96 @@ +package vchantempstore + +import ( + "context" + "fmt" + "sync" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +// ErrNotFound is returned when the vchannel is not found. +var ErrNotFound = errors.New("not found") + +// NewVChannelTempStorage creates a new VChannelTempStorage. +func NewVChannelTempStorage(rc *syncutil.Future[types.RootCoordClient]) *VChannelTempStorage { + return &VChannelTempStorage{ + rc: rc, + vchannels: make(map[int64]map[string]string), + } +} + +// VChannelTempStorage is a temporary storage for vchannel messages. +// It's used to make compatibility between old version and new version message. +// TODO: removed in 3.0. +type VChannelTempStorage struct { + rc *syncutil.Future[types.RootCoordClient] + + mu sync.Mutex + vchannels map[int64]map[string]string +} + +func (ts *VChannelTempStorage) GetVChannelByPChannelOfCollection(ctx context.Context, collectionID int64, pchannel string) (string, error) { + if err := ts.updateVChannelByPChannelOfCollectionIfNotExist(ctx, collectionID); err != nil { + return "", err + } + + ts.mu.Lock() + defer ts.mu.Unlock() + item, ok := ts.vchannels[collectionID] + if !ok { + return "", errors.Wrapf(ErrNotFound, "collection %d at pchannel %s", collectionID, pchannel) + } + v, ok := item[pchannel] + if !ok { + panic(fmt.Sprintf("pchannel not found for collection %d at pchannel %s", collectionID, pchannel)) + } + return v, nil +} + +func (ts *VChannelTempStorage) updateVChannelByPChannelOfCollectionIfNotExist(ctx context.Context, collectionID int64) error { + ts.mu.Lock() + if _, ok := ts.vchannels[collectionID]; ok { + ts.mu.Unlock() + return nil + } + ts.mu.Unlock() + + rc, err := ts.rc.GetWithContext(ctx) + if err != nil { + return err + } + + return retry.Do(ctx, func() error { + resp, err := rc.DescribeCollectionInternal(ctx, &milvuspb.DescribeCollectionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + CollectionID: collectionID, + }) + err = merr.CheckRPCCall(resp, err) + if errors.Is(err, merr.ErrCollectionNotFound) { + return nil + } + if err == nil { + ts.mu.Lock() + if _, ok := ts.vchannels[collectionID]; !ok { + ts.vchannels[collectionID] = make(map[string]string, len(resp.PhysicalChannelNames)) + } + for idx, pchannel := range resp.PhysicalChannelNames { + ts.vchannels[collectionID][pchannel] = resp.VirtualChannelNames[idx] + } + ts.mu.Unlock() + } + return err + }) +} diff --git a/internal/streamingnode/server/wal/vchantempstore/vchannel_temp_storage_test.go b/internal/streamingnode/server/wal/vchantempstore/vchannel_temp_storage_test.go new file mode 100644 index 0000000000..20c406420d --- /dev/null +++ b/internal/streamingnode/server/wal/vchantempstore/vchannel_temp_storage_test.go @@ -0,0 +1,63 @@ +package vchantempstore + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +func TestVChannelTempStorage(t *testing.T) { + rcf := syncutil.NewFuture[types.RootCoordClient]() + ts := NewVChannelTempStorage(rcf) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + _, err := ts.GetVChannelByPChannelOfCollection(ctx, 1, "test") + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + ctx = context.Background() + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + PhysicalChannelNames: []string{"test1", "test2"}, + VirtualChannelNames: []string{"test1-v0", "test2-v0"}, + }, nil) + rcf.Set(rc) + + v, err := ts.GetVChannelByPChannelOfCollection(ctx, 1, "test1") + assert.NoError(t, err) + assert.Equal(t, "test1-v0", v) + + v, err = ts.GetVChannelByPChannelOfCollection(ctx, 1, "test2") + assert.NoError(t, err) + assert.Equal(t, "test2-v0", v) + + assert.Panics(t, func() { + ts.GetVChannelByPChannelOfCollection(ctx, 1, "test3") + }) + + rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Unset() + rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound) + + v, err = ts.GetVChannelByPChannelOfCollection(ctx, 1, "test1") + assert.NoError(t, err) + assert.Equal(t, "test1-v0", v) + v, err = ts.GetVChannelByPChannelOfCollection(ctx, 1, "test2") + assert.NoError(t, err) + assert.Equal(t, "test2-v0", v) + + v, err = ts.GetVChannelByPChannelOfCollection(ctx, 2, "test2") + assert.ErrorIs(t, err, ErrNotFound) + assert.Equal(t, "", v) +} diff --git a/pkg/mq/common/message.go b/pkg/mq/common/message.go index 4658fb8f53..3625516d09 100644 --- a/pkg/mq/common/message.go +++ b/pkg/mq/common/message.go @@ -83,9 +83,14 @@ const ( ReplicateIDTypeKey = "replicate_id" ) +// GetMsgType gets the message type from message. func GetMsgType(msg Message) (commonpb.MsgType, error) { + return GetMsgTypeFromRaw(msg.Payload(), msg.Properties()) +} + +// GetMsgTypeFromRaw gets the message type from payload and properties. +func GetMsgTypeFromRaw(payload []byte, properties map[string]string) (commonpb.MsgType, error) { msgType := commonpb.MsgType_Undefined - properties := msg.Properties() if properties != nil { if val, ok := properties[MsgTypeKey]; ok { msgType = commonpb.MsgType(commonpb.MsgType_value[val]) @@ -93,10 +98,10 @@ func GetMsgType(msg Message) (commonpb.MsgType, error) { } if msgType == commonpb.MsgType_Undefined { header := commonpb.MsgHeader{} - if msg.Payload() == nil { + if payload == nil { return msgType, fmt.Errorf("failed to unmarshal message header, payload is empty") } - err := proto.Unmarshal(msg.Payload(), &header) + err := proto.Unmarshal(payload, &header) if err != nil { return msgType, fmt.Errorf("failed to unmarshal message header, err %s", err.Error()) } diff --git a/pkg/mq/common/message_test.go b/pkg/mq/common/message_test.go new file mode 100644 index 0000000000..7fbe521ee7 --- /dev/null +++ b/pkg/mq/common/message_test.go @@ -0,0 +1,71 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +type mockMessage struct { + topic string + properties map[string]string + payload []byte + id MessageID +} + +func (m *mockMessage) Topic() string { + return m.topic +} + +func (m *mockMessage) Properties() map[string]string { + return m.properties +} + +func (m *mockMessage) Payload() []byte { + return m.payload +} + +func (m *mockMessage) ID() MessageID { + return m.id +} + +func TestGetMsgType(t *testing.T) { + t.Run("Test with properties", func(t *testing.T) { + properties := map[string]string{ + MsgTypeKey: "Insert", + } + msg := &mockMessage{ + properties: properties, + } + msgType, err := GetMsgType(msg) + assert.NoError(t, err) + assert.Equal(t, commonpb.MsgType_Insert, msgType) + }) + + t.Run("Test with payload", func(t *testing.T) { + header := &commonpb.MsgHeader{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + } + payload, err := proto.Marshal(header) + assert.NoError(t, err) + + msg := &mockMessage{ + payload: payload, + } + msgType, err := GetMsgType(msg) + assert.NoError(t, err) + assert.Equal(t, commonpb.MsgType_Insert, msgType) + }) + + t.Run("Test with empty payload and properties", func(t *testing.T) { + msg := &mockMessage{} + msgType, err := GetMsgType(msg) + assert.Error(t, err) + assert.Equal(t, commonpb.MsgType_Undefined, msgType) + }) +} diff --git a/pkg/streaming/util/message/adaptor/broadcast_message.go b/pkg/streaming/util/message/adaptor/broadcast_message.go index 69bcde5dad..338da8ac00 100644 --- a/pkg/streaming/util/message/adaptor/broadcast_message.go +++ b/pkg/streaming/util/message/adaptor/broadcast_message.go @@ -12,7 +12,7 @@ func NewMsgPackFromMutableMessageV1(msg message.MutableMessage) (msgstream.TsMsg return nil, errors.New("Invalid message version") } - tsMsg, err := unmashalerDispatcher.Unmarshal(msg.Payload(), MustGetCommonpbMsgTypeFromMessageType(msg.MessageType())) + tsMsg, err := UnmashalerDispatcher.Unmarshal(msg.Payload(), MustGetCommonpbMsgTypeFromMessageType(msg.MessageType())) if err != nil { return nil, errors.Wrap(err, "Failed to unmarshal message") } diff --git a/pkg/streaming/util/message/adaptor/handler.go b/pkg/streaming/util/message/adaptor/handler.go index a85faf7a9a..57bcb56edb 100644 --- a/pkg/streaming/util/message/adaptor/handler.go +++ b/pkg/streaming/util/message/adaptor/handler.go @@ -116,6 +116,8 @@ func (m *BaseMsgPackAdaptorHandler) GenerateMsgPack(msg message.ImmutableMessage switch msg.Version() { case message.VersionOld: if len(m.Pendings) != 0 { + // multiple message from old version may share the same time tick. + // should be packed into one msgPack. if msg.TimeTick() > m.Pendings[0].TimeTick() { m.addMsgPackIntoPending(m.Pendings...) m.Pendings = nil diff --git a/pkg/streaming/util/message/adaptor/message.go b/pkg/streaming/util/message/adaptor/message.go index 12a958bbe8..70064f7640 100644 --- a/pkg/streaming/util/message/adaptor/message.go +++ b/pkg/streaming/util/message/adaptor/message.go @@ -9,7 +9,7 @@ import ( "github.com/milvus-io/milvus/pkg/streaming/util/message" ) -var unmashalerDispatcher = (&msgstream.ProtoUDFactory{}).NewUnmarshalDispatcher() +var UnmashalerDispatcher = (&msgstream.ProtoUDFactory{}).NewUnmarshalDispatcher() // FromMessageToMsgPack converts message to msgpack. // Same TimeTick must be sent with same msgpack. @@ -97,8 +97,6 @@ func parseTxnMsg(msg message.ImmutableMessage) ([]msgstream.TsMsg, error) { // parseSingleMsg converts message to ts message. func parseSingleMsg(msg message.ImmutableMessage) (msgstream.TsMsg, error) { switch msg.Version() { - case message.VersionOld: - return fromMessageToTsMsgVOld(msg) case message.VersionV1: return fromMessageToTsMsgV1(msg) case message.VersionV2: @@ -108,13 +106,9 @@ func parseSingleMsg(msg message.ImmutableMessage) (msgstream.TsMsg, error) { } } -func fromMessageToTsMsgVOld(msg message.ImmutableMessage) (msgstream.TsMsg, error) { - panic("Not implemented") -} - // fromMessageToTsMsgV1 converts message to ts message. func fromMessageToTsMsgV1(msg message.ImmutableMessage) (msgstream.TsMsg, error) { - tsMsg, err := unmashalerDispatcher.Unmarshal(msg.Payload(), MustGetCommonpbMsgTypeFromMessageType(msg.MessageType())) + tsMsg, err := UnmashalerDispatcher.Unmarshal(msg.Payload(), MustGetCommonpbMsgTypeFromMessageType(msg.MessageType())) if err != nil { return nil, errors.Wrap(err, "Failed to unmarshal message") } @@ -227,7 +221,7 @@ func recoverDeleteMsgFromHeader(deleteMsg *msgstream.DeleteMsg, header *message. return deleteMsg, nil } -func recoverImportMsgFromHeader(importMsg *msgstream.ImportMsg, header *message.ImportMessageHeader, timetick uint64) (msgstream.TsMsg, error) { +func recoverImportMsgFromHeader(importMsg *msgstream.ImportMsg, _ *message.ImportMessageHeader, timetick uint64) (msgstream.TsMsg, error) { importMsg.Base.Timestamp = timetick return importMsg, nil } diff --git a/pkg/streaming/util/message/version.go b/pkg/streaming/util/message/version.go index bed4625966..6b260ea12b 100644 --- a/pkg/streaming/util/message/version.go +++ b/pkg/streaming/util/message/version.go @@ -3,7 +3,7 @@ package message import "strconv" var ( - VersionOld Version = 0 // old version before streamingnode. + VersionOld Version = 0 // old version before streamingnode, keep in 2.6 and will be removed from 3.0. VersionV1 Version = 1 // The message marshal unmarshal still use msgstream. VersionV2 Version = 2 // The message marshal unmarshal never rely on msgstream. )