enhance: specialized immutable and mutable message (#34951)

issue: #33285

- add specialized mutable and immutable message, make type safe.
- add version based constructor and type.

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
chyezh 2024-07-25 11:57:45 +08:00 committed by GitHub
parent b843c91bad
commit 4f6cbfd520
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 159 additions and 129 deletions

View File

@ -12,8 +12,8 @@ func newTimeTickMsg(ts uint64, sourceID int64) (message.MutableMessage, error) {
// Common message's time tick is set on interceptor. // Common message's time tick is set on interceptor.
// TimeTickMsg's time tick should be set here. // TimeTickMsg's time tick should be set here.
msg, err := message.NewTimeTickMessageBuilderV1(). msg, err := message.NewTimeTickMessageBuilderV1().
WithMessageHeader(&message.TimeTickMessageHeader{}). WithHeader(&message.TimeTickMessageHeader{}).
WithPayload(&msgpb.TimeTickMsg{ WithBody(&msgpb.TimeTickMsg{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_TimeTick), commonpbutil.WithMsgType(commonpb.MsgType_TimeTick),
commonpbutil.WithMsgID(0), commonpbutil.WithMsgID(0),

View File

@ -83,13 +83,13 @@ func fromMessageToTsMsgV1(msg message.ImmutableMessage) (msgstream.TsMsg, error)
func recoverMessageFromHeader(tsMsg msgstream.TsMsg, msg message.ImmutableMessage) (msgstream.TsMsg, error) { func recoverMessageFromHeader(tsMsg msgstream.TsMsg, msg message.ImmutableMessage) (msgstream.TsMsg, error) {
switch msg.MessageType() { switch msg.MessageType() {
case message.MessageTypeInsert: case message.MessageTypeInsert:
insertMessage, err := message.AsImmutableInsertMessage(msg) insertMessage, err := message.AsImmutableInsertMessageV1(msg)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Failed to convert message to insert message") return nil, errors.Wrap(err, "Failed to convert message to insert message")
} }
// insertMsg has multiple partition and segment assignment is done by insert message header. // insertMsg has multiple partition and segment assignment is done by insert message header.
// so recover insert message from header before send it. // so recover insert message from header before send it.
return recoverInsertMsgFromHeader(tsMsg.(*msgstream.InsertMsg), insertMessage.MessageHeader(), msg.TimeTick()) return recoverInsertMsgFromHeader(tsMsg.(*msgstream.InsertMsg), insertMessage.Header(), msg.TimeTick())
default: default:
return tsMsg, nil return tsMsg, nil
} }

View File

@ -46,47 +46,47 @@ var (
) )
// createNewMessageBuilderV1 creates a new message builder with v1 marker. // createNewMessageBuilderV1 creates a new message builder with v1 marker.
func createNewMessageBuilderV1[H proto.Message, P proto.Message]() func() *mutableMesasgeBuilder[H, P] { func createNewMessageBuilderV1[H proto.Message, B proto.Message]() func() *mutableMesasgeBuilder[H, B] {
return func() *mutableMesasgeBuilder[H, P] { return func() *mutableMesasgeBuilder[H, B] {
return newMutableMessageBuilder[H, P](VersionV1) return newMutableMessageBuilder[H, B](VersionV1)
} }
} }
// newMutableMessageBuilder creates a new builder. // newMutableMessageBuilder creates a new builder.
// Should only used at client side. // Should only used at client side.
func newMutableMessageBuilder[H proto.Message, P proto.Message](v Version) *mutableMesasgeBuilder[H, P] { func newMutableMessageBuilder[H proto.Message, B proto.Message](v Version) *mutableMesasgeBuilder[H, B] {
var h H var h H
messageType := mustGetMessageTypeFromMessageHeader(h) messageType := mustGetMessageTypeFromHeader(h)
properties := make(propertiesImpl) properties := make(propertiesImpl)
properties.Set(messageTypeKey, messageType.marshal()) properties.Set(messageTypeKey, messageType.marshal())
properties.Set(messageVersion, v.String()) properties.Set(messageVersion, v.String())
return &mutableMesasgeBuilder[H, P]{ return &mutableMesasgeBuilder[H, B]{
properties: properties, properties: properties,
} }
} }
// mutableMesasgeBuilder is the builder for message. // mutableMesasgeBuilder is the builder for message.
type mutableMesasgeBuilder[H proto.Message, P proto.Message] struct { type mutableMesasgeBuilder[H proto.Message, B proto.Message] struct {
header H header H
payload P body B
properties propertiesImpl properties propertiesImpl
} }
// WithMessageHeader creates a new builder with determined message type. // WithMessageHeader creates a new builder with determined message type.
func (b *mutableMesasgeBuilder[H, P]) WithMessageHeader(h H) *mutableMesasgeBuilder[H, P] { func (b *mutableMesasgeBuilder[H, B]) WithHeader(h H) *mutableMesasgeBuilder[H, B] {
b.header = h b.header = h
return b return b
} }
// WithPayload creates a new builder with message payload. // WithBody creates a new builder with message body.
func (b *mutableMesasgeBuilder[H, P]) WithPayload(p P) *mutableMesasgeBuilder[H, P] { func (b *mutableMesasgeBuilder[H, B]) WithBody(body B) *mutableMesasgeBuilder[H, B] {
b.payload = p b.body = body
return b return b
} }
// WithProperty creates a new builder with message property. // WithProperty creates a new builder with message property.
// A key started with '_' is reserved for streaming system, should never used at user of client. // A key started with '_' is reserved for streaming system, should never used at user of client.
func (b *mutableMesasgeBuilder[H, P]) WithProperty(key string, val string) *mutableMesasgeBuilder[H, P] { func (b *mutableMesasgeBuilder[H, B]) WithProperty(key string, val string) *mutableMesasgeBuilder[H, B] {
if b.properties.Exist(key) { if b.properties.Exist(key) {
panic(fmt.Sprintf("message builder already set property field, key = %s", key)) panic(fmt.Sprintf("message builder already set property field, key = %s", key))
} }
@ -96,7 +96,7 @@ func (b *mutableMesasgeBuilder[H, P]) WithProperty(key string, val string) *muta
// WithProperties creates a new builder with message properties. // WithProperties creates a new builder with message properties.
// A key started with '_' is reserved for streaming system, should never used at user of client. // A key started with '_' is reserved for streaming system, should never used at user of client.
func (b *mutableMesasgeBuilder[H, P]) WithProperties(kvs map[string]string) *mutableMesasgeBuilder[H, P] { func (b *mutableMesasgeBuilder[H, B]) WithProperties(kvs map[string]string) *mutableMesasgeBuilder[H, B] {
for key, val := range kvs { for key, val := range kvs {
b.properties.Set(key, val) b.properties.Set(key, val)
} }
@ -106,13 +106,13 @@ func (b *mutableMesasgeBuilder[H, P]) WithProperties(kvs map[string]string) *mut
// BuildMutable builds a mutable message. // BuildMutable builds a mutable message.
// Panic if not set payload and message type. // Panic if not set payload and message type.
// should only used at client side. // should only used at client side.
func (b *mutableMesasgeBuilder[H, P]) BuildMutable() (MutableMessage, error) { func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) {
// payload and header must be a pointer // payload and header must be a pointer
if reflect.ValueOf(b.header).IsNil() { if reflect.ValueOf(b.header).IsNil() {
panic("message builder not ready for header field") panic("message builder not ready for header field")
} }
if reflect.ValueOf(b.payload).IsNil() { if reflect.ValueOf(b.body).IsNil() {
panic("message builder not ready for payload field") panic("message builder not ready for body field")
} }
// setup header. // setup header.
@ -122,9 +122,9 @@ func (b *mutableMesasgeBuilder[H, P]) BuildMutable() (MutableMessage, error) {
} }
b.properties.Set(messageSpecialiedHeader, sp) b.properties.Set(messageSpecialiedHeader, sp)
payload, err := proto.Marshal(b.payload) payload, err := proto.Marshal(b.body)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to marshal payload") return nil, errors.Wrap(err, "failed to marshal body")
} }
return &messageImpl{ return &messageImpl{
payload: payload, payload: payload,

View File

@ -27,6 +27,16 @@ type BasicMessage interface {
// Properties returns the message properties. // Properties returns the message properties.
// Should be used with read-only promise. // Should be used with read-only promise.
Properties() RProperties Properties() RProperties
// VChannel returns the virtual channel of current message.
// Available only when the message's version greater than 0.
// Otherwise, it will panic.
VChannel() string
// TimeTick returns the time tick of current message.
// Available only when the message's version greater than 0.
// Otherwise, it will panic.
TimeTick() uint64
} }
// MutableMessage is the mutable message interface. // MutableMessage is the mutable message interface.
@ -58,15 +68,8 @@ type ImmutableMessage interface {
// WALName returns the name of message related wal. // WALName returns the name of message related wal.
WALName() string WALName() string
// VChannel returns the virtual channel of current message. // MessageID returns the message id of current message.
// Available only when the message's version greater than 0. MessageID() MessageID
// Otherwise, it will panic.
VChannel() string
// TimeTick returns the time tick of current message.
// Available only when the message's version greater than 0.
// Otherwise, it will panic.
TimeTick() uint64
// LastConfirmedMessageID returns the last confirmed message id of current message. // LastConfirmedMessageID returns the last confirmed message id of current message.
// last confirmed message is always a timetick message. // last confirmed message is always a timetick message.
@ -74,13 +77,10 @@ type ImmutableMessage interface {
// Available only when the message's version greater than 0. // Available only when the message's version greater than 0.
// Otherwise, it will panic. // Otherwise, it will panic.
LastConfirmedMessageID() MessageID LastConfirmedMessageID() MessageID
// MessageID returns the message id of current message.
MessageID() MessageID
} }
// specializedMutableMessage is the specialized mutable message interface. // specializedMutableMessage is the specialized mutable message interface.
type specializedMutableMessage[H proto.Message] interface { type specializedMutableMessage[H proto.Message, B proto.Message] interface {
BasicMessage BasicMessage
// VChannel returns the vchannel of the message. // VChannel returns the vchannel of the message.
@ -89,19 +89,27 @@ type specializedMutableMessage[H proto.Message] interface {
// TimeTick returns the time tick of the message. // TimeTick returns the time tick of the message.
TimeTick() uint64 TimeTick() uint64
// MessageHeader returns the message header. // Header returns the message header.
// Modifications to the returned header will be reflected in the message. // Modifications to the returned header will be reflected in the message.
MessageHeader() H Header() H
// OverwriteMessageHeader overwrites the message header. // Body returns the message body.
OverwriteMessageHeader(header H) // !!! Do these will trigger a unmarshal operation, so it should be used with caution.
Body() (B, error)
// OverwriteHeader overwrites the message header.
OverwriteHeader(header H)
} }
// specializedImmutableMessage is the specialized immutable message interface. // specializedImmutableMessage is the specialized immutable message interface.
type specializedImmutableMessage[H proto.Message] interface { type specializedImmutableMessage[H proto.Message, B proto.Message] interface {
ImmutableMessage ImmutableMessage
// MessageHeader returns the message header. // Header returns the message header.
// Modifications to the returned header will be reflected in the message. // Modifications to the returned header will be reflected in the message.
MessageHeader() H Header() H
// Body returns the message body.
// !!! Do these will trigger a unmarshal operation, so it should be used with caution.
Body() (B, error)
} }

View File

@ -15,9 +15,10 @@ import (
func TestMessage(t *testing.T) { func TestMessage(t *testing.T) {
b := message.NewTimeTickMessageBuilderV1() b := message.NewTimeTickMessageBuilderV1()
mutableMessage, err := b.WithMessageHeader(&message.TimeTickMessageHeader{}). mutableMessage, err := b.WithHeader(&message.TimeTickMessageHeader{}).
WithProperties(map[string]string{"key": "value"}). WithProperties(map[string]string{"key": "value"}).
WithPayload(&msgpb.TimeTickMsg{}).BuildMutable() WithProperty("key2", "value2").
WithBody(&msgpb.TimeTickMsg{}).BuildMutable()
assert.NoError(t, err) assert.NoError(t, err)
payload, err := proto.Marshal(&message.TimeTickMessageHeader{}) payload, err := proto.Marshal(&message.TimeTickMessageHeader{})
@ -26,16 +27,18 @@ func TestMessage(t *testing.T) {
assert.True(t, bytes.Equal(payload, mutableMessage.Payload())) assert.True(t, bytes.Equal(payload, mutableMessage.Payload()))
assert.True(t, mutableMessage.Properties().Exist("key")) assert.True(t, mutableMessage.Properties().Exist("key"))
v, ok := mutableMessage.Properties().Get("key") v, ok := mutableMessage.Properties().Get("key")
assert.True(t, mutableMessage.Properties().Exist("key2"))
assert.Equal(t, "value", v) assert.Equal(t, "value", v)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, message.MessageTypeTimeTick, mutableMessage.MessageType()) assert.Equal(t, message.MessageTypeTimeTick, mutableMessage.MessageType())
assert.Equal(t, 20, mutableMessage.EstimateSize()) assert.Equal(t, 30, mutableMessage.EstimateSize())
mutableMessage.WithTimeTick(123) mutableMessage.WithTimeTick(123)
v, ok = mutableMessage.Properties().Get("_tt") v, ok = mutableMessage.Properties().Get("_tt")
assert.True(t, ok) assert.True(t, ok)
tt, err := message.DecodeUint64(v) tt, err := message.DecodeUint64(v)
assert.Equal(t, uint64(123), tt) assert.Equal(t, uint64(123), tt)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, uint64(123), mutableMessage.TimeTick())
lcMsgID := mock_message.NewMockMessageID(t) lcMsgID := mock_message.NewMockMessageID(t)
lcMsgID.EXPECT().Marshal().Return("lcMsgID") lcMsgID.EXPECT().Marshal().Return("lcMsgID")
@ -44,6 +47,12 @@ func TestMessage(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, v, "lcMsgID") assert.Equal(t, v, "lcMsgID")
mutableMessage.WithVChannel("v1")
v, ok = mutableMessage.Properties().Get("_vc")
assert.True(t, ok)
assert.Equal(t, "v1", v)
assert.Equal(t, "v1", mutableMessage.VChannel())
msgID := mock_message.NewMockMessageID(t) msgID := mock_message.NewMockMessageID(t)
msgID.EXPECT().EQ(msgID).Return(true) msgID.EXPECT().EQ(msgID).Return(true)
msgID.EXPECT().WALName().Return("testMsgID") msgID.EXPECT().WALName().Return("testMsgID")

View File

@ -101,25 +101,9 @@ func (m *immutableMessageImpl) WALName() string {
return m.id.WALName() return m.id.WALName()
} }
// TimeTick returns the time tick of current message. // MessageID returns the message id.
func (m *immutableMessageImpl) TimeTick() uint64 { func (m *immutableMessageImpl) MessageID() MessageID {
value, ok := m.properties.Get(messageTimeTick) return m.id
if !ok {
panic(fmt.Sprintf("there's a bug in the message codes, timetick lost in properties of message, id: %+v", m.id))
}
tt, err := DecodeUint64(value)
if err != nil {
panic(fmt.Sprintf("there's a bug in the message codes, dirty timetick %s in properties of message, id: %+v", value, m.id))
}
return tt
}
func (m *immutableMessageImpl) VChannel() string {
value, ok := m.properties.Get(messageVChannel)
if !ok {
panic(fmt.Sprintf("there's a bug in the message codes, vchannel lost in properties of message, id: %+v", m.id))
}
return value
} }
func (m *immutableMessageImpl) LastConfirmedMessageID() MessageID { func (m *immutableMessageImpl) LastConfirmedMessageID() MessageID {
@ -133,8 +117,3 @@ func (m *immutableMessageImpl) LastConfirmedMessageID() MessageID {
} }
return id return id
} }
// MessageID returns the message id.
func (m *immutableMessageImpl) MessageID() MessageID {
return m.id
}

View File

@ -7,6 +7,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/streaming/util/message/messagepb" "github.com/milvus-io/milvus/pkg/streaming/util/message/messagepb"
) )
@ -35,51 +36,51 @@ var messageTypeMap = map[reflect.Type]MessageType{
// List all specialized message types. // List all specialized message types.
type ( type (
MutableTimeTickMessage = specializedMutableMessage[*TimeTickMessageHeader] MutableTimeTickMessageV1 = specializedMutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg]
MutableInsertMessage = specializedMutableMessage[*InsertMessageHeader] MutableInsertMessageV1 = specializedMutableMessage[*InsertMessageHeader, *msgpb.InsertRequest]
MutableDeleteMessage = specializedMutableMessage[*DeleteMessageHeader] MutableDeleteMessageV1 = specializedMutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest]
MutableCreateCollection = specializedMutableMessage[*CreateCollectionMessageHeader] MutableCreateCollectionMessageV1 = specializedMutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest]
MutableDropCollection = specializedMutableMessage[*DropCollectionMessageHeader] MutableDropCollectionMessageV1 = specializedMutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest]
MutableCreatePartition = specializedMutableMessage[*CreatePartitionMessageHeader] MutableCreatePartitionMessageV1 = specializedMutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest]
MutableDropPartition = specializedMutableMessage[*DropPartitionMessageHeader] MutableDropPartitionMessageV1 = specializedMutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest]
ImmutableTimeTickMessage = specializedImmutableMessage[*TimeTickMessageHeader] ImmutableTimeTickMessageV1 = specializedImmutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg]
ImmutableInsertMessage = specializedImmutableMessage[*InsertMessageHeader] ImmutableInsertMessageV1 = specializedImmutableMessage[*InsertMessageHeader, *msgpb.InsertRequest]
ImmutableDeleteMessage = specializedImmutableMessage[*DeleteMessageHeader] ImmutableDeleteMessageV1 = specializedImmutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest]
ImmutableCreateCollection = specializedImmutableMessage[*CreateCollectionMessageHeader] ImmutableCreateCollectionMessageV1 = specializedImmutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest]
ImmutableDropCollection = specializedImmutableMessage[*DropCollectionMessageHeader] ImmutableDropCollectionMessageV1 = specializedImmutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest]
ImmutableCreatePartition = specializedImmutableMessage[*CreatePartitionMessageHeader] ImmutableCreatePartitionMessageV1 = specializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest]
ImmutableDropPartition = specializedImmutableMessage[*DropPartitionMessageHeader] ImmutableDropPartitionMessageV1 = specializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest]
) )
// List all as functions for specialized messages. // List all as functions for specialized messages.
var ( var (
AsMutableTimeTickMessage = asSpecializedMutableMessage[*TimeTickMessageHeader] AsMutableTimeTickMessageV1 = asSpecializedMutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg]
AsMutableInsertMessage = asSpecializedMutableMessage[*InsertMessageHeader] AsMutableInsertMessageV1 = asSpecializedMutableMessage[*InsertMessageHeader, *msgpb.InsertRequest]
AsMutableDeleteMessage = asSpecializedMutableMessage[*DeleteMessageHeader] AsMutableDeleteMessageV1 = asSpecializedMutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest]
AsMutableCreateCollection = asSpecializedMutableMessage[*CreateCollectionMessageHeader] AsMutableCreateCollectionMessageV1 = asSpecializedMutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest]
AsMutableDropCollection = asSpecializedMutableMessage[*DropCollectionMessageHeader] AsMutableDropCollectionMessageV1 = asSpecializedMutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest]
AsMutableCreatePartition = asSpecializedMutableMessage[*CreatePartitionMessageHeader] AsMutableCreatePartitionMessageV1 = asSpecializedMutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest]
AsMutableDropPartition = asSpecializedMutableMessage[*DropPartitionMessageHeader] AsMutableDropPartitionMessageV1 = asSpecializedMutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest]
AsImmutableTimeTickMessage = asSpecializedImmutableMessage[*TimeTickMessageHeader] AsImmutableTimeTickMessageV1 = asSpecializedImmutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg]
AsImmutableInsertMessage = asSpecializedImmutableMessage[*InsertMessageHeader] AsImmutableInsertMessageV1 = asSpecializedImmutableMessage[*InsertMessageHeader, *msgpb.InsertRequest]
AsImmutableDeleteMessage = asSpecializedImmutableMessage[*DeleteMessageHeader] AsImmutableDeleteMessageV1 = asSpecializedImmutableMessage[*DeleteMessageHeader, *msgpb.DeleteRequest]
AsImmutableCreateCollection = asSpecializedImmutableMessage[*CreateCollectionMessageHeader] AsImmutableCreateCollectionMessageV1 = asSpecializedImmutableMessage[*CreateCollectionMessageHeader, *msgpb.CreateCollectionRequest]
AsImmutableDropCollection = asSpecializedImmutableMessage[*DropCollectionMessageHeader] AsImmutableDropCollectionMessageV1 = asSpecializedImmutableMessage[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest]
AsImmutableCreatePartition = asSpecializedImmutableMessage[*CreatePartitionMessageHeader] AsImmutableCreatePartitionMessageV1 = asSpecializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest]
AsImmutableDropPartition = asSpecializedImmutableMessage[*DropPartitionMessageHeader] AsImmutableDropPartitionMessageV1 = asSpecializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest]
) )
// asSpecializedMutableMessage converts a MutableMessage to a specialized MutableMessage. // asSpecializedMutableMessage converts a MutableMessage to a specialized MutableMessage.
// Return nil, nil if the message is not the target specialized message. // Return nil, nil if the message is not the target specialized message.
// Return nil, error if the message is the target specialized message but failed to decode the specialized header. // Return nil, error if the message is the target specialized message but failed to decode the specialized header.
// Return specializedMutableMessage, nil if the message is the target specialized message and successfully decoded the specialized header. // Return specializedMutableMessage, nil if the message is the target specialized message and successfully decoded the specialized header.
func asSpecializedMutableMessage[H proto.Message](msg MutableMessage) (specializedMutableMessage[H], error) { func asSpecializedMutableMessage[H proto.Message, B proto.Message](msg MutableMessage) (specializedMutableMessage[H, B], error) {
underlying := msg.(*messageImpl) underlying := msg.(*messageImpl)
var header H var header H
msgType := mustGetMessageTypeFromMessageHeader(header) msgType := mustGetMessageTypeFromHeader(header)
if underlying.MessageType() != msgType { if underlying.MessageType() != msgType {
// The message type do not match the specialized header. // The message type do not match the specialized header.
return nil, nil return nil, nil
@ -101,7 +102,7 @@ func asSpecializedMutableMessage[H proto.Message](msg MutableMessage) (specializ
if err := DecodeProto(val, header); err != nil { if err := DecodeProto(val, header); err != nil {
return nil, errors.Wrap(err, "failed to decode specialized header") return nil, errors.Wrap(err, "failed to decode specialized header")
} }
return &specializedMutableMessageImpl[H]{ return &specializedMutableMessageImpl[H, B]{
header: header, header: header,
messageImpl: underlying, messageImpl: underlying,
}, nil }, nil
@ -111,11 +112,11 @@ func asSpecializedMutableMessage[H proto.Message](msg MutableMessage) (specializ
// Return nil, nil if the message is not the target specialized message. // Return nil, nil if the message is not the target specialized message.
// Return nil, error if the message is the target specialized message but failed to decode the specialized header. // Return nil, error if the message is the target specialized message but failed to decode the specialized header.
// Return asSpecializedImmutableMessage, nil if the message is the target specialized message and successfully decoded the specialized header. // Return asSpecializedImmutableMessage, nil if the message is the target specialized message and successfully decoded the specialized header.
func asSpecializedImmutableMessage[H proto.Message](msg ImmutableMessage) (specializedImmutableMessage[H], error) { func asSpecializedImmutableMessage[H proto.Message, B proto.Message](msg ImmutableMessage) (specializedImmutableMessage[H, B], error) {
underlying := msg.(*immutableMessageImpl) underlying := msg.(*immutableMessageImpl)
var header H var header H
msgType := mustGetMessageTypeFromMessageHeader(header) msgType := mustGetMessageTypeFromHeader(header)
if underlying.MessageType() != msgType { if underlying.MessageType() != msgType {
// The message type do not match the specialized header. // The message type do not match the specialized header.
return nil, nil return nil, nil
@ -137,14 +138,14 @@ func asSpecializedImmutableMessage[H proto.Message](msg ImmutableMessage) (speci
if err := DecodeProto(val, header); err != nil { if err := DecodeProto(val, header); err != nil {
return nil, errors.Wrap(err, "failed to decode specialized header") return nil, errors.Wrap(err, "failed to decode specialized header")
} }
return &specializedImmutableMessageImpl[H]{ return &specializedImmutableMessageImpl[H, B]{
header: header, header: header,
immutableMessageImpl: underlying, immutableMessageImpl: underlying,
}, nil }, nil
} }
// mustGetMessageTypeFromMessageHeader returns the message type of the given message header. // mustGetMessageTypeFromMessageHeader returns the message type of the given message header.
func mustGetMessageTypeFromMessageHeader(msg proto.Message) MessageType { func mustGetMessageTypeFromHeader(msg proto.Message) MessageType {
t := reflect.TypeOf(msg) t := reflect.TypeOf(msg)
mt, ok := messageTypeMap[t] mt, ok := messageTypeMap[t]
if !ok { if !ok {
@ -154,18 +155,23 @@ func mustGetMessageTypeFromMessageHeader(msg proto.Message) MessageType {
} }
// specializedMutableMessageImpl is the specialized mutable message implementation. // specializedMutableMessageImpl is the specialized mutable message implementation.
type specializedMutableMessageImpl[H proto.Message] struct { type specializedMutableMessageImpl[H proto.Message, B proto.Message] struct {
header H header H
*messageImpl *messageImpl
} }
// MessageHeader returns the message header. // MessageHeader returns the message header.
func (m *specializedMutableMessageImpl[H]) MessageHeader() H { func (m *specializedMutableMessageImpl[H, B]) Header() H {
return m.header return m.header
} }
// Body returns the message body.
func (m *specializedMutableMessageImpl[H, B]) Body() (B, error) {
return unmarshalProtoB[B](m.payload)
}
// OverwriteMessageHeader overwrites the message header. // OverwriteMessageHeader overwrites the message header.
func (m *specializedMutableMessageImpl[H]) OverwriteMessageHeader(header H) { func (m *specializedMutableMessageImpl[H, B]) OverwriteHeader(header H) {
m.header = header m.header = header
newHeader, err := EncodeProto(m.header) newHeader, err := EncodeProto(m.header)
if err != nil { if err != nil {
@ -175,12 +181,32 @@ func (m *specializedMutableMessageImpl[H]) OverwriteMessageHeader(header H) {
} }
// specializedImmutableMessageImpl is the specialized immmutable message implementation. // specializedImmutableMessageImpl is the specialized immmutable message implementation.
type specializedImmutableMessageImpl[H proto.Message] struct { type specializedImmutableMessageImpl[H proto.Message, B proto.Message] struct {
header H header H
*immutableMessageImpl *immutableMessageImpl
} }
// MessageHeader returns the message header. // Header returns the message header.
func (m *specializedImmutableMessageImpl[H]) MessageHeader() H { func (m *specializedImmutableMessageImpl[H, B]) Header() H {
return m.header return m.header
} }
// Body returns the message body.
func (m *specializedImmutableMessageImpl[H, B]) Body() (B, error) {
return unmarshalProtoB[B](m.payload)
}
func unmarshalProtoB[B proto.Message](data []byte) (B, error) {
var nilBody B
// Decode the specialized header.
// Must be pointer type.
t := reflect.TypeOf(nilBody)
t.Elem()
body := reflect.New(t.Elem()).Interface().(B)
err := proto.Unmarshal(data, body)
if err != nil {
return nilBody, err
}
return body, nil
}

View File

@ -12,7 +12,7 @@ import (
func TestAsSpecializedMessage(t *testing.T) { func TestAsSpecializedMessage(t *testing.T) {
m, err := message.NewInsertMessageBuilderV1(). m, err := message.NewInsertMessageBuilderV1().
WithMessageHeader(&message.InsertMessageHeader{ WithHeader(&message.InsertMessageHeader{
CollectionId: 1, CollectionId: 1,
Partitions: []*message.PartitionSegmentAssignment{ Partitions: []*message.PartitionSegmentAssignment{
{ {
@ -22,33 +22,41 @@ func TestAsSpecializedMessage(t *testing.T) {
}, },
}, },
}). }).
WithPayload(&msgpb.InsertRequest{}).BuildMutable() WithBody(&msgpb.InsertRequest{
CollectionID: 1,
}).BuildMutable()
assert.NoError(t, err) assert.NoError(t, err)
insertMsg, err := message.AsMutableInsertMessage(m) insertMsg, err := message.AsMutableInsertMessageV1(m)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, insertMsg) assert.NotNil(t, insertMsg)
assert.Equal(t, int64(1), insertMsg.MessageHeader().CollectionId) assert.Equal(t, int64(1), insertMsg.Header().CollectionId)
body, err := insertMsg.Body()
assert.NoError(t, err)
assert.Equal(t, int64(1), body.CollectionID)
h := insertMsg.MessageHeader() h := insertMsg.Header()
h.Partitions[0].SegmentAssignment = &message.SegmentAssignment{ h.Partitions[0].SegmentAssignment = &message.SegmentAssignment{
SegmentId: 1, SegmentId: 1,
} }
insertMsg.OverwriteMessageHeader(h) insertMsg.OverwriteHeader(h)
createColMsg, err := message.AsMutableCreateCollection(m) createColMsg, err := message.AsMutableCreateCollectionMessageV1(m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, createColMsg) assert.Nil(t, createColMsg)
m2 := m.IntoImmutableMessage(mock_message.NewMockMessageID(t)) m2 := m.IntoImmutableMessage(mock_message.NewMockMessageID(t))
insertMsg2, err := message.AsImmutableInsertMessage(m2) insertMsg2, err := message.AsImmutableInsertMessageV1(m2)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, insertMsg2) assert.NotNil(t, insertMsg2)
assert.Equal(t, int64(1), insertMsg2.MessageHeader().CollectionId) assert.Equal(t, int64(1), insertMsg2.Header().CollectionId)
assert.Equal(t, insertMsg2.MessageHeader().Partitions[0].SegmentAssignment.SegmentId, int64(1)) assert.Equal(t, insertMsg2.Header().Partitions[0].SegmentAssignment.SegmentId, int64(1))
body, err = insertMsg2.Body()
assert.NoError(t, err)
assert.Equal(t, int64(1), body.CollectionID)
createColMsg2, err := message.AsMutableCreateCollection(m) createColMsg2, err := message.AsMutableCreateCollectionMessageV1(m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, createColMsg2) assert.Nil(t, createColMsg2)
} }

View File

@ -61,7 +61,7 @@ func CreateTestInsertMessage(t *testing.T, segmentID int64, totalRows int, timet
}, },
} }
msg, err := NewInsertMessageBuilderV1(). msg, err := NewInsertMessageBuilderV1().
WithMessageHeader(&InsertMessageHeader{ WithHeader(&InsertMessageHeader{
CollectionId: 1, CollectionId: 1,
Partitions: []*PartitionSegmentAssignment{ Partitions: []*PartitionSegmentAssignment{
{ {
@ -72,7 +72,7 @@ func CreateTestInsertMessage(t *testing.T, segmentID int64, totalRows int, timet
}, },
}, },
}). }).
WithPayload(&msgpb.InsertRequest{ WithBody(&msgpb.InsertRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert, MsgType: commonpb.MsgType_Insert,
Timestamp: 100, Timestamp: 100,
@ -119,8 +119,8 @@ func CreateTestCreateCollectionMessage(t *testing.T, collectionID int64, timetic
} }
msg, err := NewCreateCollectionMessageBuilderV1(). msg, err := NewCreateCollectionMessageBuilderV1().
WithMessageHeader(header). WithHeader(header).
WithPayload(payload). WithBody(payload).
BuildMutable() BuildMutable()
assert.NoError(t, err) assert.NoError(t, err)
msg.WithVChannel("v1") msg.WithVChannel("v1")
@ -132,8 +132,8 @@ func CreateTestCreateCollectionMessage(t *testing.T, collectionID int64, timetic
// CreateTestEmptyInsertMesage creates an empty insert message for testing // CreateTestEmptyInsertMesage creates an empty insert message for testing
func CreateTestEmptyInsertMesage(msgID int64, extraProperties map[string]string) MutableMessage { func CreateTestEmptyInsertMesage(msgID int64, extraProperties map[string]string) MutableMessage {
msg, err := NewInsertMessageBuilderV1(). msg, err := NewInsertMessageBuilderV1().
WithMessageHeader(&InsertMessageHeader{}). WithHeader(&InsertMessageHeader{}).
WithPayload(&msgpb.InsertRequest{ WithBody(&msgpb.InsertRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert, MsgType: commonpb.MsgType_Insert,
MsgID: msgID, MsgID: msgID,

View File

@ -248,7 +248,7 @@ func (f *testOneWALImplsFramework) testAppend(ctx context.Context, w WALImpls) (
"const": "t", "const": "t",
"term": strconv.FormatInt(int64(f.term), 10), "term": strconv.FormatInt(int64(f.term), 10),
} }
msg, err := message.NewTimeTickMessageBuilderV1().WithMessageHeader(&message.TimeTickMessageHeader{}).WithPayload(&msgpb.TimeTickMsg{ msg, err := message.NewTimeTickMessageBuilderV1().WithHeader(&message.TimeTickMessageHeader{}).WithBody(&msgpb.TimeTickMsg{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_TimeTick, MsgType: commonpb.MsgType_TimeTick,
MsgID: int64(f.messageCount - 1), MsgID: int64(f.messageCount - 1),