diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go index e654a521aa..def09310e4 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go @@ -12,8 +12,8 @@ func newTimeTickMsg(ts uint64, sourceID int64) (message.MutableMessage, error) { // Common message's time tick is set on interceptor. // TimeTickMsg's time tick should be set here. msg, err := message.NewTimeTickMessageBuilderV1(). - WithMessageHeader(&message.TimeTickMessageHeader{}). - WithPayload(&msgpb.TimeTickMsg{ + WithHeader(&message.TimeTickMessageHeader{}). + WithBody(&msgpb.TimeTickMsg{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_TimeTick), commonpbutil.WithMsgID(0), diff --git a/pkg/streaming/util/message/adaptor/message.go b/pkg/streaming/util/message/adaptor/message.go index e9d012acd1..22f571daf1 100644 --- a/pkg/streaming/util/message/adaptor/message.go +++ b/pkg/streaming/util/message/adaptor/message.go @@ -83,13 +83,13 @@ func fromMessageToTsMsgV1(msg message.ImmutableMessage) (msgstream.TsMsg, error) func recoverMessageFromHeader(tsMsg msgstream.TsMsg, msg message.ImmutableMessage) (msgstream.TsMsg, error) { switch msg.MessageType() { case message.MessageTypeInsert: - insertMessage, err := message.AsImmutableInsertMessage(msg) + insertMessage, err := message.AsImmutableInsertMessageV1(msg) if err != nil { return nil, errors.Wrap(err, "Failed to convert message to insert message") } // insertMsg has multiple partition and segment assignment is done by insert message header. // so recover insert message from header before send it. - return recoverInsertMsgFromHeader(tsMsg.(*msgstream.InsertMsg), insertMessage.MessageHeader(), msg.TimeTick()) + return recoverInsertMsgFromHeader(tsMsg.(*msgstream.InsertMsg), insertMessage.Header(), msg.TimeTick()) default: return tsMsg, nil } diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index 44abba0604..672dc122e8 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -46,47 +46,47 @@ var ( ) // createNewMessageBuilderV1 creates a new message builder with v1 marker. -func createNewMessageBuilderV1[H proto.Message, P proto.Message]() func() *mutableMesasgeBuilder[H, P] { - return func() *mutableMesasgeBuilder[H, P] { - return newMutableMessageBuilder[H, P](VersionV1) +func createNewMessageBuilderV1[H proto.Message, B proto.Message]() func() *mutableMesasgeBuilder[H, B] { + return func() *mutableMesasgeBuilder[H, B] { + return newMutableMessageBuilder[H, B](VersionV1) } } // newMutableMessageBuilder creates a new builder. // Should only used at client side. -func newMutableMessageBuilder[H proto.Message, P proto.Message](v Version) *mutableMesasgeBuilder[H, P] { +func newMutableMessageBuilder[H proto.Message, B proto.Message](v Version) *mutableMesasgeBuilder[H, B] { var h H - messageType := mustGetMessageTypeFromMessageHeader(h) + messageType := mustGetMessageTypeFromHeader(h) properties := make(propertiesImpl) properties.Set(messageTypeKey, messageType.marshal()) properties.Set(messageVersion, v.String()) - return &mutableMesasgeBuilder[H, P]{ + return &mutableMesasgeBuilder[H, B]{ properties: properties, } } // mutableMesasgeBuilder is the builder for message. -type mutableMesasgeBuilder[H proto.Message, P proto.Message] struct { +type mutableMesasgeBuilder[H proto.Message, B proto.Message] struct { header H - payload P + body B properties propertiesImpl } // WithMessageHeader creates a new builder with determined message type. -func (b *mutableMesasgeBuilder[H, P]) WithMessageHeader(h H) *mutableMesasgeBuilder[H, P] { +func (b *mutableMesasgeBuilder[H, B]) WithHeader(h H) *mutableMesasgeBuilder[H, B] { b.header = h return b } -// WithPayload creates a new builder with message payload. -func (b *mutableMesasgeBuilder[H, P]) WithPayload(p P) *mutableMesasgeBuilder[H, P] { - b.payload = p +// WithBody creates a new builder with message body. +func (b *mutableMesasgeBuilder[H, B]) WithBody(body B) *mutableMesasgeBuilder[H, B] { + b.body = body return b } // WithProperty creates a new builder with message property. // A key started with '_' is reserved for streaming system, should never used at user of client. -func (b *mutableMesasgeBuilder[H, P]) WithProperty(key string, val string) *mutableMesasgeBuilder[H, P] { +func (b *mutableMesasgeBuilder[H, B]) WithProperty(key string, val string) *mutableMesasgeBuilder[H, B] { if b.properties.Exist(key) { panic(fmt.Sprintf("message builder already set property field, key = %s", key)) } @@ -96,7 +96,7 @@ func (b *mutableMesasgeBuilder[H, P]) WithProperty(key string, val string) *muta // WithProperties creates a new builder with message properties. // A key started with '_' is reserved for streaming system, should never used at user of client. -func (b *mutableMesasgeBuilder[H, P]) WithProperties(kvs map[string]string) *mutableMesasgeBuilder[H, P] { +func (b *mutableMesasgeBuilder[H, B]) WithProperties(kvs map[string]string) *mutableMesasgeBuilder[H, B] { for key, val := range kvs { b.properties.Set(key, val) } @@ -106,13 +106,13 @@ func (b *mutableMesasgeBuilder[H, P]) WithProperties(kvs map[string]string) *mut // BuildMutable builds a mutable message. // Panic if not set payload and message type. // should only used at client side. -func (b *mutableMesasgeBuilder[H, P]) BuildMutable() (MutableMessage, error) { +func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) { // payload and header must be a pointer if reflect.ValueOf(b.header).IsNil() { panic("message builder not ready for header field") } - if reflect.ValueOf(b.payload).IsNil() { - panic("message builder not ready for payload field") + if reflect.ValueOf(b.body).IsNil() { + panic("message builder not ready for body field") } // setup header. @@ -122,9 +122,9 @@ func (b *mutableMesasgeBuilder[H, P]) BuildMutable() (MutableMessage, error) { } b.properties.Set(messageSpecialiedHeader, sp) - payload, err := proto.Marshal(b.payload) + payload, err := proto.Marshal(b.body) if err != nil { - return nil, errors.Wrap(err, "failed to marshal payload") + return nil, errors.Wrap(err, "failed to marshal body") } return &messageImpl{ payload: payload, diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index c4c5f1e043..71d369c4fa 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -27,6 +27,16 @@ type BasicMessage interface { // Properties returns the message properties. // Should be used with read-only promise. Properties() RProperties + + // VChannel returns the virtual channel of current message. + // Available only when the message's version greater than 0. + // Otherwise, it will panic. + VChannel() string + + // TimeTick returns the time tick of current message. + // Available only when the message's version greater than 0. + // Otherwise, it will panic. + TimeTick() uint64 } // MutableMessage is the mutable message interface. @@ -58,15 +68,8 @@ type ImmutableMessage interface { // WALName returns the name of message related wal. WALName() string - // VChannel returns the virtual channel of current message. - // Available only when the message's version greater than 0. - // Otherwise, it will panic. - VChannel() string - - // TimeTick returns the time tick of current message. - // Available only when the message's version greater than 0. - // Otherwise, it will panic. - TimeTick() uint64 + // MessageID returns the message id of current message. + MessageID() MessageID // LastConfirmedMessageID returns the last confirmed message id of current message. // last confirmed message is always a timetick message. @@ -74,13 +77,10 @@ type ImmutableMessage interface { // Available only when the message's version greater than 0. // Otherwise, it will panic. LastConfirmedMessageID() MessageID - - // MessageID returns the message id of current message. - MessageID() MessageID } // specializedMutableMessage is the specialized mutable message interface. -type specializedMutableMessage[H proto.Message] interface { +type specializedMutableMessage[H proto.Message, B proto.Message] interface { BasicMessage // VChannel returns the vchannel of the message. @@ -89,19 +89,27 @@ type specializedMutableMessage[H proto.Message] interface { // TimeTick returns the time tick of the message. TimeTick() uint64 - // MessageHeader returns the message header. + // Header returns the message header. // Modifications to the returned header will be reflected in the message. - MessageHeader() H + Header() H - // OverwriteMessageHeader overwrites the message header. - OverwriteMessageHeader(header H) + // Body returns the message body. + // !!! Do these will trigger a unmarshal operation, so it should be used with caution. + Body() (B, error) + + // OverwriteHeader overwrites the message header. + OverwriteHeader(header H) } // specializedImmutableMessage is the specialized immutable message interface. -type specializedImmutableMessage[H proto.Message] interface { +type specializedImmutableMessage[H proto.Message, B proto.Message] interface { ImmutableMessage - // MessageHeader returns the message header. + // Header returns the message header. // Modifications to the returned header will be reflected in the message. - MessageHeader() H + Header() H + + // Body returns the message body. + // !!! Do these will trigger a unmarshal operation, so it should be used with caution. + Body() (B, error) } diff --git a/pkg/streaming/util/message/message_builder_test.go b/pkg/streaming/util/message/message_builder_test.go index 00f26fd6b3..f425a18391 100644 --- a/pkg/streaming/util/message/message_builder_test.go +++ b/pkg/streaming/util/message/message_builder_test.go @@ -15,9 +15,10 @@ import ( func TestMessage(t *testing.T) { b := message.NewTimeTickMessageBuilderV1() - mutableMessage, err := b.WithMessageHeader(&message.TimeTickMessageHeader{}). + mutableMessage, err := b.WithHeader(&message.TimeTickMessageHeader{}). WithProperties(map[string]string{"key": "value"}). - WithPayload(&msgpb.TimeTickMsg{}).BuildMutable() + WithProperty("key2", "value2"). + WithBody(&msgpb.TimeTickMsg{}).BuildMutable() assert.NoError(t, err) payload, err := proto.Marshal(&message.TimeTickMessageHeader{}) @@ -26,16 +27,18 @@ func TestMessage(t *testing.T) { assert.True(t, bytes.Equal(payload, mutableMessage.Payload())) assert.True(t, mutableMessage.Properties().Exist("key")) v, ok := mutableMessage.Properties().Get("key") + assert.True(t, mutableMessage.Properties().Exist("key2")) assert.Equal(t, "value", v) assert.True(t, ok) assert.Equal(t, message.MessageTypeTimeTick, mutableMessage.MessageType()) - assert.Equal(t, 20, mutableMessage.EstimateSize()) + assert.Equal(t, 30, mutableMessage.EstimateSize()) mutableMessage.WithTimeTick(123) v, ok = mutableMessage.Properties().Get("_tt") assert.True(t, ok) tt, err := message.DecodeUint64(v) assert.Equal(t, uint64(123), tt) assert.NoError(t, err) + assert.Equal(t, uint64(123), mutableMessage.TimeTick()) lcMsgID := mock_message.NewMockMessageID(t) lcMsgID.EXPECT().Marshal().Return("lcMsgID") @@ -44,6 +47,12 @@ func TestMessage(t *testing.T) { assert.True(t, ok) assert.Equal(t, v, "lcMsgID") + mutableMessage.WithVChannel("v1") + v, ok = mutableMessage.Properties().Get("_vc") + assert.True(t, ok) + assert.Equal(t, "v1", v) + assert.Equal(t, "v1", mutableMessage.VChannel()) + msgID := mock_message.NewMockMessageID(t) msgID.EXPECT().EQ(msgID).Return(true) msgID.EXPECT().WALName().Return("testMsgID") diff --git a/pkg/streaming/util/message/message_impl.go b/pkg/streaming/util/message/message_impl.go index 0dc0c7da48..b974455781 100644 --- a/pkg/streaming/util/message/message_impl.go +++ b/pkg/streaming/util/message/message_impl.go @@ -101,25 +101,9 @@ func (m *immutableMessageImpl) WALName() string { return m.id.WALName() } -// TimeTick returns the time tick of current message. -func (m *immutableMessageImpl) TimeTick() uint64 { - value, ok := m.properties.Get(messageTimeTick) - if !ok { - panic(fmt.Sprintf("there's a bug in the message codes, timetick lost in properties of message, id: %+v", m.id)) - } - tt, err := DecodeUint64(value) - if err != nil { - panic(fmt.Sprintf("there's a bug in the message codes, dirty timetick %s in properties of message, id: %+v", value, m.id)) - } - return tt -} - -func (m *immutableMessageImpl) VChannel() string { - value, ok := m.properties.Get(messageVChannel) - if !ok { - panic(fmt.Sprintf("there's a bug in the message codes, vchannel lost in properties of message, id: %+v", m.id)) - } - return value +// MessageID returns the message id. +func (m *immutableMessageImpl) MessageID() MessageID { + return m.id } func (m *immutableMessageImpl) LastConfirmedMessageID() MessageID { @@ -133,8 +117,3 @@ func (m *immutableMessageImpl) LastConfirmedMessageID() MessageID { } return id } - -// MessageID returns the message id. -func (m *immutableMessageImpl) MessageID() MessageID { - return m.id -} diff --git a/pkg/streaming/util/message/specialized_message.go b/pkg/streaming/util/message/specialized_message.go index 2864dca985..f3f2faa4a3 100644 --- a/pkg/streaming/util/message/specialized_message.go +++ b/pkg/streaming/util/message/specialized_message.go @@ -7,6 +7,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/pkg/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/streaming/util/message/messagepb" ) @@ -35,51 +36,51 @@ var messageTypeMap = map[reflect.Type]MessageType{ // List all specialized message types. type ( - MutableTimeTickMessage = specializedMutableMessage[*TimeTickMessageHeader] - MutableInsertMessage = specializedMutableMessage[*InsertMessageHeader] - MutableDeleteMessage = specializedMutableMessage[*DeleteMessageHeader] - MutableCreateCollection = specializedMutableMessage[*CreateCollectionMessageHeader] - MutableDropCollection = specializedMutableMessage[*DropCollectionMessageHeader] - MutableCreatePartition = specializedMutableMessage[*CreatePartitionMessageHeader] - MutableDropPartition = specializedMutableMessage[*DropPartitionMessageHeader] + MutableTimeTickMessageV1 = specializedMutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg] + MutableInsertMessageV1 = specializedMutableMessage[*InsertMessageHeader, *msgpb.InsertRequest] + MutableDeleteMessageV1 = specializedMutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest] + MutableCreateCollectionMessageV1 = specializedMutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest] + MutableDropCollectionMessageV1 = specializedMutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest] + MutableCreatePartitionMessageV1 = specializedMutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] + MutableDropPartitionMessageV1 = specializedMutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] - ImmutableTimeTickMessage = specializedImmutableMessage[*TimeTickMessageHeader] - ImmutableInsertMessage = specializedImmutableMessage[*InsertMessageHeader] - ImmutableDeleteMessage = specializedImmutableMessage[*DeleteMessageHeader] - ImmutableCreateCollection = specializedImmutableMessage[*CreateCollectionMessageHeader] - ImmutableDropCollection = specializedImmutableMessage[*DropCollectionMessageHeader] - ImmutableCreatePartition = specializedImmutableMessage[*CreatePartitionMessageHeader] - ImmutableDropPartition = specializedImmutableMessage[*DropPartitionMessageHeader] + ImmutableTimeTickMessageV1 = specializedImmutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg] + ImmutableInsertMessageV1 = specializedImmutableMessage[*InsertMessageHeader, *msgpb.InsertRequest] + ImmutableDeleteMessageV1 = specializedImmutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest] + ImmutableCreateCollectionMessageV1 = specializedImmutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest] + ImmutableDropCollectionMessageV1 = specializedImmutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest] + ImmutableCreatePartitionMessageV1 = specializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] + ImmutableDropPartitionMessageV1 = specializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] ) // List all as functions for specialized messages. var ( - AsMutableTimeTickMessage = asSpecializedMutableMessage[*TimeTickMessageHeader] - AsMutableInsertMessage = asSpecializedMutableMessage[*InsertMessageHeader] - AsMutableDeleteMessage = asSpecializedMutableMessage[*DeleteMessageHeader] - AsMutableCreateCollection = asSpecializedMutableMessage[*CreateCollectionMessageHeader] - AsMutableDropCollection = asSpecializedMutableMessage[*DropCollectionMessageHeader] - AsMutableCreatePartition = asSpecializedMutableMessage[*CreatePartitionMessageHeader] - AsMutableDropPartition = asSpecializedMutableMessage[*DropPartitionMessageHeader] + AsMutableTimeTickMessageV1 = asSpecializedMutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg] + AsMutableInsertMessageV1 = asSpecializedMutableMessage[*InsertMessageHeader, *msgpb.InsertRequest] + AsMutableDeleteMessageV1 = asSpecializedMutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest] + AsMutableCreateCollectionMessageV1 = asSpecializedMutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest] + AsMutableDropCollectionMessageV1 = asSpecializedMutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest] + AsMutableCreatePartitionMessageV1 = asSpecializedMutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] + AsMutableDropPartitionMessageV1 = asSpecializedMutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] - AsImmutableTimeTickMessage = asSpecializedImmutableMessage[*TimeTickMessageHeader] - AsImmutableInsertMessage = asSpecializedImmutableMessage[*InsertMessageHeader] - AsImmutableDeleteMessage = asSpecializedImmutableMessage[*DeleteMessageHeader] - AsImmutableCreateCollection = asSpecializedImmutableMessage[*CreateCollectionMessageHeader] - AsImmutableDropCollection = asSpecializedImmutableMessage[*DropCollectionMessageHeader] - AsImmutableCreatePartition = asSpecializedImmutableMessage[*CreatePartitionMessageHeader] - AsImmutableDropPartition = asSpecializedImmutableMessage[*DropPartitionMessageHeader] + AsImmutableTimeTickMessageV1 = asSpecializedImmutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg] + AsImmutableInsertMessageV1 = asSpecializedImmutableMessage[*InsertMessageHeader, *msgpb.InsertRequest] + AsImmutableDeleteMessageV1 = asSpecializedImmutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest] + AsImmutableCreateCollectionMessageV1 = asSpecializedImmutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest] + AsImmutableDropCollectionMessageV1 = asSpecializedImmutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest] + AsImmutableCreatePartitionMessageV1 = asSpecializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] + AsImmutableDropPartitionMessageV1 = asSpecializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] ) // asSpecializedMutableMessage converts a MutableMessage to a specialized MutableMessage. // Return nil, nil if the message is not the target specialized message. // Return nil, error if the message is the target specialized message but failed to decode the specialized header. // Return specializedMutableMessage, nil if the message is the target specialized message and successfully decoded the specialized header. -func asSpecializedMutableMessage[H proto.Message](msg MutableMessage) (specializedMutableMessage[H], error) { +func asSpecializedMutableMessage[H proto.Message, B proto.Message](msg MutableMessage) (specializedMutableMessage[H, B], error) { underlying := msg.(*messageImpl) var header H - msgType := mustGetMessageTypeFromMessageHeader(header) + msgType := mustGetMessageTypeFromHeader(header) if underlying.MessageType() != msgType { // The message type do not match the specialized header. return nil, nil @@ -101,7 +102,7 @@ func asSpecializedMutableMessage[H proto.Message](msg MutableMessage) (specializ if err := DecodeProto(val, header); err != nil { return nil, errors.Wrap(err, "failed to decode specialized header") } - return &specializedMutableMessageImpl[H]{ + return &specializedMutableMessageImpl[H, B]{ header: header, messageImpl: underlying, }, nil @@ -111,11 +112,11 @@ func asSpecializedMutableMessage[H proto.Message](msg MutableMessage) (specializ // Return nil, nil if the message is not the target specialized message. // Return nil, error if the message is the target specialized message but failed to decode the specialized header. // Return asSpecializedImmutableMessage, nil if the message is the target specialized message and successfully decoded the specialized header. -func asSpecializedImmutableMessage[H proto.Message](msg ImmutableMessage) (specializedImmutableMessage[H], error) { +func asSpecializedImmutableMessage[H proto.Message, B proto.Message](msg ImmutableMessage) (specializedImmutableMessage[H, B], error) { underlying := msg.(*immutableMessageImpl) var header H - msgType := mustGetMessageTypeFromMessageHeader(header) + msgType := mustGetMessageTypeFromHeader(header) if underlying.MessageType() != msgType { // The message type do not match the specialized header. return nil, nil @@ -137,14 +138,14 @@ func asSpecializedImmutableMessage[H proto.Message](msg ImmutableMessage) (speci if err := DecodeProto(val, header); err != nil { return nil, errors.Wrap(err, "failed to decode specialized header") } - return &specializedImmutableMessageImpl[H]{ + return &specializedImmutableMessageImpl[H, B]{ header: header, immutableMessageImpl: underlying, }, nil } // mustGetMessageTypeFromMessageHeader returns the message type of the given message header. -func mustGetMessageTypeFromMessageHeader(msg proto.Message) MessageType { +func mustGetMessageTypeFromHeader(msg proto.Message) MessageType { t := reflect.TypeOf(msg) mt, ok := messageTypeMap[t] if !ok { @@ -154,18 +155,23 @@ func mustGetMessageTypeFromMessageHeader(msg proto.Message) MessageType { } // specializedMutableMessageImpl is the specialized mutable message implementation. -type specializedMutableMessageImpl[H proto.Message] struct { +type specializedMutableMessageImpl[H proto.Message, B proto.Message] struct { header H *messageImpl } // MessageHeader returns the message header. -func (m *specializedMutableMessageImpl[H]) MessageHeader() H { +func (m *specializedMutableMessageImpl[H, B]) Header() H { return m.header } +// Body returns the message body. +func (m *specializedMutableMessageImpl[H, B]) Body() (B, error) { + return unmarshalProtoB[B](m.payload) +} + // OverwriteMessageHeader overwrites the message header. -func (m *specializedMutableMessageImpl[H]) OverwriteMessageHeader(header H) { +func (m *specializedMutableMessageImpl[H, B]) OverwriteHeader(header H) { m.header = header newHeader, err := EncodeProto(m.header) if err != nil { @@ -175,12 +181,32 @@ func (m *specializedMutableMessageImpl[H]) OverwriteMessageHeader(header H) { } // specializedImmutableMessageImpl is the specialized immmutable message implementation. -type specializedImmutableMessageImpl[H proto.Message] struct { +type specializedImmutableMessageImpl[H proto.Message, B proto.Message] struct { header H *immutableMessageImpl } -// MessageHeader returns the message header. -func (m *specializedImmutableMessageImpl[H]) MessageHeader() H { +// Header returns the message header. +func (m *specializedImmutableMessageImpl[H, B]) Header() H { return m.header } + +// Body returns the message body. +func (m *specializedImmutableMessageImpl[H, B]) Body() (B, error) { + return unmarshalProtoB[B](m.payload) +} + +func unmarshalProtoB[B proto.Message](data []byte) (B, error) { + var nilBody B + // Decode the specialized header. + // Must be pointer type. + t := reflect.TypeOf(nilBody) + t.Elem() + body := reflect.New(t.Elem()).Interface().(B) + + err := proto.Unmarshal(data, body) + if err != nil { + return nilBody, err + } + return body, nil +} diff --git a/pkg/streaming/util/message/specialized_message_test.go b/pkg/streaming/util/message/specialized_message_test.go index ac7f017b57..c76ea333e8 100644 --- a/pkg/streaming/util/message/specialized_message_test.go +++ b/pkg/streaming/util/message/specialized_message_test.go @@ -12,7 +12,7 @@ import ( func TestAsSpecializedMessage(t *testing.T) { m, err := message.NewInsertMessageBuilderV1(). - WithMessageHeader(&message.InsertMessageHeader{ + WithHeader(&message.InsertMessageHeader{ CollectionId: 1, Partitions: []*message.PartitionSegmentAssignment{ { @@ -22,33 +22,41 @@ func TestAsSpecializedMessage(t *testing.T) { }, }, }). - WithPayload(&msgpb.InsertRequest{}).BuildMutable() + WithBody(&msgpb.InsertRequest{ + CollectionID: 1, + }).BuildMutable() assert.NoError(t, err) - insertMsg, err := message.AsMutableInsertMessage(m) + insertMsg, err := message.AsMutableInsertMessageV1(m) assert.NoError(t, err) assert.NotNil(t, insertMsg) - assert.Equal(t, int64(1), insertMsg.MessageHeader().CollectionId) + assert.Equal(t, int64(1), insertMsg.Header().CollectionId) + body, err := insertMsg.Body() + assert.NoError(t, err) + assert.Equal(t, int64(1), body.CollectionID) - h := insertMsg.MessageHeader() + h := insertMsg.Header() h.Partitions[0].SegmentAssignment = &message.SegmentAssignment{ SegmentId: 1, } - insertMsg.OverwriteMessageHeader(h) + insertMsg.OverwriteHeader(h) - createColMsg, err := message.AsMutableCreateCollection(m) + createColMsg, err := message.AsMutableCreateCollectionMessageV1(m) assert.NoError(t, err) assert.Nil(t, createColMsg) m2 := m.IntoImmutableMessage(mock_message.NewMockMessageID(t)) - insertMsg2, err := message.AsImmutableInsertMessage(m2) + insertMsg2, err := message.AsImmutableInsertMessageV1(m2) assert.NoError(t, err) assert.NotNil(t, insertMsg2) - assert.Equal(t, int64(1), insertMsg2.MessageHeader().CollectionId) - assert.Equal(t, insertMsg2.MessageHeader().Partitions[0].SegmentAssignment.SegmentId, int64(1)) + assert.Equal(t, int64(1), insertMsg2.Header().CollectionId) + assert.Equal(t, insertMsg2.Header().Partitions[0].SegmentAssignment.SegmentId, int64(1)) + body, err = insertMsg2.Body() + assert.NoError(t, err) + assert.Equal(t, int64(1), body.CollectionID) - createColMsg2, err := message.AsMutableCreateCollection(m) + createColMsg2, err := message.AsMutableCreateCollectionMessageV1(m) assert.NoError(t, err) assert.Nil(t, createColMsg2) } diff --git a/pkg/streaming/util/message/test_case.go b/pkg/streaming/util/message/test_case.go index 3ffb7707b4..25ce304e58 100644 --- a/pkg/streaming/util/message/test_case.go +++ b/pkg/streaming/util/message/test_case.go @@ -61,7 +61,7 @@ func CreateTestInsertMessage(t *testing.T, segmentID int64, totalRows int, timet }, } msg, err := NewInsertMessageBuilderV1(). - WithMessageHeader(&InsertMessageHeader{ + WithHeader(&InsertMessageHeader{ CollectionId: 1, Partitions: []*PartitionSegmentAssignment{ { @@ -72,7 +72,7 @@ func CreateTestInsertMessage(t *testing.T, segmentID int64, totalRows int, timet }, }, }). - WithPayload(&msgpb.InsertRequest{ + WithBody(&msgpb.InsertRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Insert, Timestamp: 100, @@ -119,8 +119,8 @@ func CreateTestCreateCollectionMessage(t *testing.T, collectionID int64, timetic } msg, err := NewCreateCollectionMessageBuilderV1(). - WithMessageHeader(header). - WithPayload(payload). + WithHeader(header). + WithBody(payload). BuildMutable() assert.NoError(t, err) msg.WithVChannel("v1") @@ -132,8 +132,8 @@ func CreateTestCreateCollectionMessage(t *testing.T, collectionID int64, timetic // CreateTestEmptyInsertMesage creates an empty insert message for testing func CreateTestEmptyInsertMesage(msgID int64, extraProperties map[string]string) MutableMessage { msg, err := NewInsertMessageBuilderV1(). - WithMessageHeader(&InsertMessageHeader{}). - WithPayload(&msgpb.InsertRequest{ + WithHeader(&InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Insert, MsgID: msgID, diff --git a/pkg/streaming/walimpls/test_framework.go b/pkg/streaming/walimpls/test_framework.go index 3ba8c95ebd..f1faa70049 100644 --- a/pkg/streaming/walimpls/test_framework.go +++ b/pkg/streaming/walimpls/test_framework.go @@ -248,7 +248,7 @@ func (f *testOneWALImplsFramework) testAppend(ctx context.Context, w WALImpls) ( "const": "t", "term": strconv.FormatInt(int64(f.term), 10), } - msg, err := message.NewTimeTickMessageBuilderV1().WithMessageHeader(&message.TimeTickMessageHeader{}).WithPayload(&msgpb.TimeTickMsg{ + msg, err := message.NewTimeTickMessageBuilderV1().WithHeader(&message.TimeTickMessageHeader{}).WithBody(&msgpb.TimeTickMsg{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_TimeTick, MsgID: int64(f.messageCount - 1),