From 0e83a08ffec4c66ecbffe31e1a4bd2abeda1e61a Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Wed, 26 Mar 2025 16:30:20 +0800 Subject: [PATCH] enhance: enable cipher for streaming message (#40659) issue: #40321 - add cipher plugin for streaming message package. - add more unittest for streaming message. - remove redundant code for streaming message. Signed-off-by: chyezh --- pkg/.mockery_pkg.yaml | 5 + .../go-api/v2/mock_hook/mock_Cipher.go | 255 ++++++++++++++++++ .../go-api/v2/mock_hook/mock_Decryptor.go | 90 +++++++ .../go-api/v2/mock_hook/mock_Encryptor.go | 90 +++++++ pkg/proto/messages.proto | 7 + pkg/proto/messagespb/messages.pb.go | 173 ++++++++---- pkg/streaming/util/message/broadcast.go | 7 - pkg/streaming/util/message/builder.go | 41 ++- pkg/streaming/util/message/cipher.go | 34 +++ pkg/streaming/util/message/message_id_test.go | 11 + pkg/streaming/util/message/message_impl.go | 29 ++ pkg/streaming/util/message/message_test.go | 80 ++++++ pkg/streaming/util/message/message_type.go | 6 + pkg/streaming/util/message/properties.go | 1 + .../util/message/specialized_message.go | 9 +- pkg/streaming/util/message/txn_test.go | 2 + 16 files changed, 782 insertions(+), 58 deletions(-) create mode 100644 pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Cipher.go create mode 100644 pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Decryptor.go create mode 100644 pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Encryptor.go create mode 100644 pkg/streaming/util/message/cipher.go diff --git a/pkg/.mockery_pkg.yaml b/pkg/.mockery_pkg.yaml index d5dfa7532f..e559350c2e 100644 --- a/pkg/.mockery_pkg.yaml +++ b/pkg/.mockery_pkg.yaml @@ -42,3 +42,8 @@ packages: StreamingCoordBroadcastService_WatchServer: StreamingCoordBroadcastServiceClient: StreamingCoordBroadcastService_WatchClient: + github.com/milvus-io/milvus-proto/go-api/v2/hook: + interfaces: + Cipher: + Encryptor: + Decryptor: diff --git a/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Cipher.go b/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Cipher.go new file mode 100644 index 0000000000..435fe4b28b --- /dev/null +++ b/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Cipher.go @@ -0,0 +1,255 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_hook + +import ( + hook "github.com/milvus-io/milvus-proto/go-api/v2/hook" + mock "github.com/stretchr/testify/mock" +) + +// MockCipher is an autogenerated mock type for the Cipher type +type MockCipher struct { + mock.Mock +} + +type MockCipher_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCipher) EXPECT() *MockCipher_Expecter { + return &MockCipher_Expecter{mock: &_m.Mock} +} + +// GetDecryptor provides a mock function with given fields: ezID, safeKey +func (_m *MockCipher) GetDecryptor(ezID int64, safeKey []byte) (hook.Decryptor, error) { + ret := _m.Called(ezID, safeKey) + + if len(ret) == 0 { + panic("no return value specified for GetDecryptor") + } + + var r0 hook.Decryptor + var r1 error + if rf, ok := ret.Get(0).(func(int64, []byte) (hook.Decryptor, error)); ok { + return rf(ezID, safeKey) + } + if rf, ok := ret.Get(0).(func(int64, []byte) hook.Decryptor); ok { + r0 = rf(ezID, safeKey) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(hook.Decryptor) + } + } + + if rf, ok := ret.Get(1).(func(int64, []byte) error); ok { + r1 = rf(ezID, safeKey) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCipher_GetDecryptor_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDecryptor' +type MockCipher_GetDecryptor_Call struct { + *mock.Call +} + +// GetDecryptor is a helper method to define mock.On call +// - ezID int64 +// - safeKey []byte +func (_e *MockCipher_Expecter) GetDecryptor(ezID interface{}, safeKey interface{}) *MockCipher_GetDecryptor_Call { + return &MockCipher_GetDecryptor_Call{Call: _e.mock.On("GetDecryptor", ezID, safeKey)} +} + +func (_c *MockCipher_GetDecryptor_Call) Run(run func(ezID int64, safeKey []byte)) *MockCipher_GetDecryptor_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].([]byte)) + }) + return _c +} + +func (_c *MockCipher_GetDecryptor_Call) Return(_a0 hook.Decryptor, _a1 error) *MockCipher_GetDecryptor_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCipher_GetDecryptor_Call) RunAndReturn(run func(int64, []byte) (hook.Decryptor, error)) *MockCipher_GetDecryptor_Call { + _c.Call.Return(run) + return _c +} + +// GetEncryptor provides a mock function with given fields: ezID +func (_m *MockCipher) GetEncryptor(ezID int64) (hook.Encryptor, []byte, error) { + ret := _m.Called(ezID) + + if len(ret) == 0 { + panic("no return value specified for GetEncryptor") + } + + var r0 hook.Encryptor + var r1 []byte + var r2 error + if rf, ok := ret.Get(0).(func(int64) (hook.Encryptor, []byte, error)); ok { + return rf(ezID) + } + if rf, ok := ret.Get(0).(func(int64) hook.Encryptor); ok { + r0 = rf(ezID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(hook.Encryptor) + } + } + + if rf, ok := ret.Get(1).(func(int64) []byte); ok { + r1 = rf(ezID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } + } + + if rf, ok := ret.Get(2).(func(int64) error); ok { + r2 = rf(ezID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockCipher_GetEncryptor_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEncryptor' +type MockCipher_GetEncryptor_Call struct { + *mock.Call +} + +// GetEncryptor is a helper method to define mock.On call +// - ezID int64 +func (_e *MockCipher_Expecter) GetEncryptor(ezID interface{}) *MockCipher_GetEncryptor_Call { + return &MockCipher_GetEncryptor_Call{Call: _e.mock.On("GetEncryptor", ezID)} +} + +func (_c *MockCipher_GetEncryptor_Call) Run(run func(ezID int64)) *MockCipher_GetEncryptor_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockCipher_GetEncryptor_Call) Return(encryptor hook.Encryptor, safeKey []byte, err error) *MockCipher_GetEncryptor_Call { + _c.Call.Return(encryptor, safeKey, err) + return _c +} + +func (_c *MockCipher_GetEncryptor_Call) RunAndReturn(run func(int64) (hook.Encryptor, []byte, error)) *MockCipher_GetEncryptor_Call { + _c.Call.Return(run) + return _c +} + +// GetUnsafeKey provides a mock function with given fields: ezID +func (_m *MockCipher) GetUnsafeKey(ezID int64) []byte { + ret := _m.Called(ezID) + + if len(ret) == 0 { + panic("no return value specified for GetUnsafeKey") + } + + var r0 []byte + if rf, ok := ret.Get(0).(func(int64) []byte); ok { + r0 = rf(ezID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// MockCipher_GetUnsafeKey_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUnsafeKey' +type MockCipher_GetUnsafeKey_Call struct { + *mock.Call +} + +// GetUnsafeKey is a helper method to define mock.On call +// - ezID int64 +func (_e *MockCipher_Expecter) GetUnsafeKey(ezID interface{}) *MockCipher_GetUnsafeKey_Call { + return &MockCipher_GetUnsafeKey_Call{Call: _e.mock.On("GetUnsafeKey", ezID)} +} + +func (_c *MockCipher_GetUnsafeKey_Call) Run(run func(ezID int64)) *MockCipher_GetUnsafeKey_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockCipher_GetUnsafeKey_Call) Return(_a0 []byte) *MockCipher_GetUnsafeKey_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCipher_GetUnsafeKey_Call) RunAndReturn(run func(int64) []byte) *MockCipher_GetUnsafeKey_Call { + _c.Call.Return(run) + return _c +} + +// Init provides a mock function with given fields: params +func (_m *MockCipher) Init(params map[string]string) error { + ret := _m.Called(params) + + if len(ret) == 0 { + panic("no return value specified for Init") + } + + var r0 error + if rf, ok := ret.Get(0).(func(map[string]string) error); ok { + r0 = rf(params) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCipher_Init_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Init' +type MockCipher_Init_Call struct { + *mock.Call +} + +// Init is a helper method to define mock.On call +// - params map[string]string +func (_e *MockCipher_Expecter) Init(params interface{}) *MockCipher_Init_Call { + return &MockCipher_Init_Call{Call: _e.mock.On("Init", params)} +} + +func (_c *MockCipher_Init_Call) Run(run func(params map[string]string)) *MockCipher_Init_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(map[string]string)) + }) + return _c +} + +func (_c *MockCipher_Init_Call) Return(_a0 error) *MockCipher_Init_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCipher_Init_Call) RunAndReturn(run func(map[string]string) error) *MockCipher_Init_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCipher creates a new instance of MockCipher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCipher(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCipher { + mock := &MockCipher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Decryptor.go b/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Decryptor.go new file mode 100644 index 0000000000..a48afdb28e --- /dev/null +++ b/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Decryptor.go @@ -0,0 +1,90 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_hook + +import mock "github.com/stretchr/testify/mock" + +// MockDecryptor is an autogenerated mock type for the Decryptor type +type MockDecryptor struct { + mock.Mock +} + +type MockDecryptor_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDecryptor) EXPECT() *MockDecryptor_Expecter { + return &MockDecryptor_Expecter{mock: &_m.Mock} +} + +// Decrypt provides a mock function with given fields: cipherText +func (_m *MockDecryptor) Decrypt(cipherText []byte) ([]byte, error) { + ret := _m.Called(cipherText) + + if len(ret) == 0 { + panic("no return value specified for Decrypt") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func([]byte) ([]byte, error)); ok { + return rf(cipherText) + } + if rf, ok := ret.Get(0).(func([]byte) []byte); ok { + r0 = rf(cipherText) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(cipherText) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDecryptor_Decrypt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Decrypt' +type MockDecryptor_Decrypt_Call struct { + *mock.Call +} + +// Decrypt is a helper method to define mock.On call +// - cipherText []byte +func (_e *MockDecryptor_Expecter) Decrypt(cipherText interface{}) *MockDecryptor_Decrypt_Call { + return &MockDecryptor_Decrypt_Call{Call: _e.mock.On("Decrypt", cipherText)} +} + +func (_c *MockDecryptor_Decrypt_Call) Run(run func(cipherText []byte)) *MockDecryptor_Decrypt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockDecryptor_Decrypt_Call) Return(plainText []byte, err error) *MockDecryptor_Decrypt_Call { + _c.Call.Return(plainText, err) + return _c +} + +func (_c *MockDecryptor_Decrypt_Call) RunAndReturn(run func([]byte) ([]byte, error)) *MockDecryptor_Decrypt_Call { + _c.Call.Return(run) + return _c +} + +// NewMockDecryptor creates a new instance of MockDecryptor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDecryptor(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDecryptor { + mock := &MockDecryptor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Encryptor.go b/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Encryptor.go new file mode 100644 index 0000000000..0350ccc666 --- /dev/null +++ b/pkg/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook/mock_Encryptor.go @@ -0,0 +1,90 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_hook + +import mock "github.com/stretchr/testify/mock" + +// MockEncryptor is an autogenerated mock type for the Encryptor type +type MockEncryptor struct { + mock.Mock +} + +type MockEncryptor_Expecter struct { + mock *mock.Mock +} + +func (_m *MockEncryptor) EXPECT() *MockEncryptor_Expecter { + return &MockEncryptor_Expecter{mock: &_m.Mock} +} + +// Encrypt provides a mock function with given fields: plainText +func (_m *MockEncryptor) Encrypt(plainText []byte) ([]byte, error) { + ret := _m.Called(plainText) + + if len(ret) == 0 { + panic("no return value specified for Encrypt") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func([]byte) ([]byte, error)); ok { + return rf(plainText) + } + if rf, ok := ret.Get(0).(func([]byte) []byte); ok { + r0 = rf(plainText) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(plainText) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockEncryptor_Encrypt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Encrypt' +type MockEncryptor_Encrypt_Call struct { + *mock.Call +} + +// Encrypt is a helper method to define mock.On call +// - plainText []byte +func (_e *MockEncryptor_Expecter) Encrypt(plainText interface{}) *MockEncryptor_Encrypt_Call { + return &MockEncryptor_Encrypt_Call{Call: _e.mock.On("Encrypt", plainText)} +} + +func (_c *MockEncryptor_Encrypt_Call) Run(run func(plainText []byte)) *MockEncryptor_Encrypt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockEncryptor_Encrypt_Call) Return(cipherText []byte, err error) *MockEncryptor_Encrypt_Call { + _c.Call.Return(cipherText, err) + return _c +} + +func (_c *MockEncryptor_Encrypt_Call) RunAndReturn(run func([]byte) ([]byte, error)) *MockEncryptor_Encrypt_Call { + _c.Call.Return(run) + return _c +} + +// NewMockEncryptor creates a new instance of MockEncryptor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockEncryptor(t interface { + mock.TestingT + Cleanup(func()) +}) *MockEncryptor { + mock := &MockEncryptor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/proto/messages.proto b/pkg/proto/messages.proto index 79f82b0fa5..5a40768c22 100644 --- a/pkg/proto/messages.proto +++ b/pkg/proto/messages.proto @@ -277,3 +277,10 @@ message ResourceKey { ResourceDomain domain = 1; string key = 2; } + +// CipherHeader is the header of a message that is encrypted. +message CipherHeader { + int64 ez_id = 1; // related to the encryption zone id + bytes safe_key = 2; // the safe key + int64 payload_bytes = 3; // the size of the payload before encryption +} diff --git a/pkg/proto/messagespb/messages.pb.go b/pkg/proto/messagespb/messages.pb.go index 9101533990..34e529af5d 100644 --- a/pkg/proto/messagespb/messages.pb.go +++ b/pkg/proto/messagespb/messages.pb.go @@ -1910,6 +1910,70 @@ func (x *ResourceKey) GetKey() string { return "" } +// CipherHeader is the header of a message that is encrypted. +type CipherHeader struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EzId int64 `protobuf:"varint,1,opt,name=ez_id,json=ezId,proto3" json:"ez_id,omitempty"` // related to the encryption zone id + SafeKey []byte `protobuf:"bytes,2,opt,name=safe_key,json=safeKey,proto3" json:"safe_key,omitempty"` // the safe key + PayloadBytes int64 `protobuf:"varint,3,opt,name=payload_bytes,json=payloadBytes,proto3" json:"payload_bytes,omitempty"` // the size of the payload before encryption +} + +func (x *CipherHeader) Reset() { + *x = CipherHeader{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CipherHeader) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CipherHeader) ProtoMessage() {} + +func (x *CipherHeader) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_msgTypes[33] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CipherHeader.ProtoReflect.Descriptor instead. +func (*CipherHeader) Descriptor() ([]byte, []int) { + return file_messages_proto_rawDescGZIP(), []int{33} +} + +func (x *CipherHeader) GetEzId() int64 { + if x != nil { + return x.EzId + } + return 0 +} + +func (x *CipherHeader) GetSafeKey() []byte { + if x != nil { + return x.SafeKey + } + return nil +} + +func (x *CipherHeader) GetPayloadBytes() int64 { + if x != nil { + return x.PayloadBytes + } + return 0 +} + var File_messages_proto protoreflect.FileDescriptor var file_messages_proto_rawDesc = []byte{ @@ -2089,43 +2153,49 @@ var file_messages_proto_rawDesc = []byte{ 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x2a, 0x88, 0x02, 0x0a, 0x0b, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x6e, 0x6b, 0x6e, - 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x69, 0x6d, 0x65, 0x54, 0x69, 0x63, - 0x6b, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x49, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x10, 0x02, 0x12, - 0x0a, 0x0a, 0x06, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x10, 0x03, 0x12, 0x09, 0x0a, 0x05, 0x46, - 0x6c, 0x75, 0x73, 0x68, 0x10, 0x04, 0x12, 0x14, 0x0a, 0x10, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x10, 0x05, 0x12, 0x12, 0x0a, 0x0e, - 0x44, 0x72, 0x6f, 0x70, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x10, 0x06, - 0x12, 0x13, 0x0a, 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x61, 0x72, 0x74, 0x69, 0x74, - 0x69, 0x6f, 0x6e, 0x10, 0x07, 0x12, 0x11, 0x0a, 0x0d, 0x44, 0x72, 0x6f, 0x70, 0x50, 0x61, 0x72, - 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x10, 0x08, 0x12, 0x0f, 0x0a, 0x0b, 0x4d, 0x61, 0x6e, 0x75, - 0x61, 0x6c, 0x46, 0x6c, 0x75, 0x73, 0x68, 0x10, 0x09, 0x12, 0x11, 0x0a, 0x0d, 0x43, 0x72, 0x65, - 0x61, 0x74, 0x65, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x10, 0x0a, 0x12, 0x0a, 0x0a, 0x06, - 0x49, 0x6d, 0x70, 0x6f, 0x72, 0x74, 0x10, 0x0b, 0x12, 0x0d, 0x0a, 0x08, 0x42, 0x65, 0x67, 0x69, - 0x6e, 0x54, 0x78, 0x6e, 0x10, 0x84, 0x07, 0x12, 0x0e, 0x0a, 0x09, 0x43, 0x6f, 0x6d, 0x6d, 0x69, - 0x74, 0x54, 0x78, 0x6e, 0x10, 0x85, 0x07, 0x12, 0x10, 0x0a, 0x0b, 0x52, 0x6f, 0x6c, 0x6c, 0x62, - 0x61, 0x63, 0x6b, 0x54, 0x78, 0x6e, 0x10, 0x86, 0x07, 0x12, 0x08, 0x0a, 0x03, 0x54, 0x78, 0x6e, - 0x10, 0xe7, 0x07, 0x2a, 0x82, 0x01, 0x0a, 0x08, 0x54, 0x78, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x78, 0x6e, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, - 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x78, 0x6e, 0x42, 0x65, 0x67, 0x69, 0x6e, 0x10, 0x01, 0x12, 0x0f, - 0x0a, 0x0b, 0x54, 0x78, 0x6e, 0x49, 0x6e, 0x46, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x10, 0x02, 0x12, - 0x0f, 0x0a, 0x0b, 0x54, 0x78, 0x6e, 0x4f, 0x6e, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x10, 0x03, - 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x78, 0x6e, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x74, 0x65, 0x64, - 0x10, 0x04, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x78, 0x6e, 0x4f, 0x6e, 0x52, 0x6f, 0x6c, 0x6c, 0x62, - 0x61, 0x63, 0x6b, 0x10, 0x05, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x78, 0x6e, 0x52, 0x6f, 0x6c, 0x6c, - 0x62, 0x61, 0x63, 0x6b, 0x65, 0x64, 0x10, 0x06, 0x2a, 0x6c, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x19, 0x0a, 0x15, 0x52, 0x65, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x55, 0x6e, 0x6b, 0x6e, - 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x1d, 0x0a, 0x19, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x49, 0x6d, 0x70, 0x6f, 0x72, 0x74, 0x4a, 0x6f, 0x62, - 0x49, 0x44, 0x10, 0x01, 0x12, 0x20, 0x0a, 0x1c, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x4e, 0x61, 0x6d, 0x65, 0x10, 0x02, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, - 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x70, 0x62, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x22, 0x63, 0x0a, 0x0c, 0x43, 0x69, 0x70, 0x68, + 0x65, 0x72, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x13, 0x0a, 0x05, 0x65, 0x7a, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x65, 0x7a, 0x49, 0x64, 0x12, 0x19, 0x0a, + 0x08, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x07, 0x73, 0x61, 0x66, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x61, 0x79, 0x6c, + 0x6f, 0x61, 0x64, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x0c, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x42, 0x79, 0x74, 0x65, 0x73, 0x2a, 0x88, 0x02, + 0x0a, 0x0b, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, + 0x07, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x69, + 0x6d, 0x65, 0x54, 0x69, 0x63, 0x6b, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x49, 0x6e, 0x73, 0x65, + 0x72, 0x74, 0x10, 0x02, 0x12, 0x0a, 0x0a, 0x06, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x10, 0x03, + 0x12, 0x09, 0x0a, 0x05, 0x46, 0x6c, 0x75, 0x73, 0x68, 0x10, 0x04, 0x12, 0x14, 0x0a, 0x10, 0x43, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x10, + 0x05, 0x12, 0x12, 0x0a, 0x0e, 0x44, 0x72, 0x6f, 0x70, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x10, 0x06, 0x12, 0x13, 0x0a, 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, + 0x61, 0x72, 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x10, 0x07, 0x12, 0x11, 0x0a, 0x0d, 0x44, 0x72, + 0x6f, 0x70, 0x50, 0x61, 0x72, 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x10, 0x08, 0x12, 0x0f, 0x0a, + 0x0b, 0x4d, 0x61, 0x6e, 0x75, 0x61, 0x6c, 0x46, 0x6c, 0x75, 0x73, 0x68, 0x10, 0x09, 0x12, 0x11, + 0x0a, 0x0d, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x10, + 0x0a, 0x12, 0x0a, 0x0a, 0x06, 0x49, 0x6d, 0x70, 0x6f, 0x72, 0x74, 0x10, 0x0b, 0x12, 0x0d, 0x0a, + 0x08, 0x42, 0x65, 0x67, 0x69, 0x6e, 0x54, 0x78, 0x6e, 0x10, 0x84, 0x07, 0x12, 0x0e, 0x0a, 0x09, + 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x54, 0x78, 0x6e, 0x10, 0x85, 0x07, 0x12, 0x10, 0x0a, 0x0b, + 0x52, 0x6f, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x54, 0x78, 0x6e, 0x10, 0x86, 0x07, 0x12, 0x08, + 0x0a, 0x03, 0x54, 0x78, 0x6e, 0x10, 0xe7, 0x07, 0x2a, 0x82, 0x01, 0x0a, 0x08, 0x54, 0x78, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x78, 0x6e, 0x55, 0x6e, 0x6b, 0x6e, + 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x78, 0x6e, 0x42, 0x65, 0x67, 0x69, + 0x6e, 0x10, 0x01, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x78, 0x6e, 0x49, 0x6e, 0x46, 0x6c, 0x69, 0x67, + 0x68, 0x74, 0x10, 0x02, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x78, 0x6e, 0x4f, 0x6e, 0x43, 0x6f, 0x6d, + 0x6d, 0x69, 0x74, 0x10, 0x03, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x78, 0x6e, 0x43, 0x6f, 0x6d, 0x6d, + 0x69, 0x74, 0x74, 0x65, 0x64, 0x10, 0x04, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x78, 0x6e, 0x4f, 0x6e, + 0x52, 0x6f, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x10, 0x05, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x78, + 0x6e, 0x52, 0x6f, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x64, 0x10, 0x06, 0x2a, 0x6c, 0x0a, + 0x0e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, + 0x19, 0x0a, 0x15, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x1d, 0x0a, 0x19, 0x52, 0x65, + 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x49, 0x6d, 0x70, 0x6f, + 0x72, 0x74, 0x4a, 0x6f, 0x62, 0x49, 0x44, 0x10, 0x01, 0x12, 0x20, 0x0a, 0x1c, 0x52, 0x65, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6f, 0x6c, 0x6c, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4e, 0x61, 0x6d, 0x65, 0x10, 0x02, 0x42, 0x35, 0x5a, 0x33, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, + 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, + 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, + 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2141,7 +2211,7 @@ func file_messages_proto_rawDescGZIP() []byte { } var file_messages_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 36) +var file_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 37) var file_messages_proto_goTypes = []interface{}{ (MessageType)(0), // 0: milvus.proto.messages.MessageType (TxnState)(0), // 1: milvus.proto.messages.TxnState @@ -2179,19 +2249,20 @@ var file_messages_proto_goTypes = []interface{}{ (*RMQMessageLayout)(nil), // 33: milvus.proto.messages.RMQMessageLayout (*BroadcastHeader)(nil), // 34: milvus.proto.messages.BroadcastHeader (*ResourceKey)(nil), // 35: milvus.proto.messages.ResourceKey - nil, // 36: milvus.proto.messages.Message.PropertiesEntry - nil, // 37: milvus.proto.messages.ImmutableMessage.PropertiesEntry - nil, // 38: milvus.proto.messages.RMQMessageLayout.PropertiesEntry + (*CipherHeader)(nil), // 36: milvus.proto.messages.CipherHeader + nil, // 37: milvus.proto.messages.Message.PropertiesEntry + nil, // 38: milvus.proto.messages.ImmutableMessage.PropertiesEntry + nil, // 39: milvus.proto.messages.RMQMessageLayout.PropertiesEntry } var file_messages_proto_depIdxs = []int32{ - 36, // 0: milvus.proto.messages.Message.properties:type_name -> milvus.proto.messages.Message.PropertiesEntry + 37, // 0: milvus.proto.messages.Message.properties:type_name -> milvus.proto.messages.Message.PropertiesEntry 3, // 1: milvus.proto.messages.ImmutableMessage.id:type_name -> milvus.proto.messages.MessageID - 37, // 2: milvus.proto.messages.ImmutableMessage.properties:type_name -> milvus.proto.messages.ImmutableMessage.PropertiesEntry + 38, // 2: milvus.proto.messages.ImmutableMessage.properties:type_name -> milvus.proto.messages.ImmutableMessage.PropertiesEntry 9, // 3: milvus.proto.messages.CreateSegmentMessageBody.segments:type_name -> milvus.proto.messages.CreateSegmentInfo 4, // 4: milvus.proto.messages.TxnMessageBody.messages:type_name -> milvus.proto.messages.Message 16, // 5: milvus.proto.messages.InsertMessageHeader.partitions:type_name -> milvus.proto.messages.PartitionSegmentAssignment 17, // 6: milvus.proto.messages.PartitionSegmentAssignment.segment_assignment:type_name -> milvus.proto.messages.SegmentAssignment - 38, // 7: milvus.proto.messages.RMQMessageLayout.properties:type_name -> milvus.proto.messages.RMQMessageLayout.PropertiesEntry + 39, // 7: milvus.proto.messages.RMQMessageLayout.properties:type_name -> milvus.proto.messages.RMQMessageLayout.PropertiesEntry 35, // 8: milvus.proto.messages.BroadcastHeader.Resource_keys:type_name -> milvus.proto.messages.ResourceKey 2, // 9: milvus.proto.messages.ResourceKey.domain:type_name -> milvus.proto.messages.ResourceDomain 10, // [10:10] is the sub-list for method output_type @@ -2603,6 +2674,18 @@ func file_messages_proto_init() { return nil } } + file_messages_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CipherHeader); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -2610,7 +2693,7 @@ func file_messages_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_messages_proto_rawDesc, NumEnums: 3, - NumMessages: 36, + NumMessages: 37, NumExtensions: 0, NumServices: 0, }, diff --git a/pkg/streaming/util/message/broadcast.go b/pkg/streaming/util/message/broadcast.go index b8d19bee8d..0eabcf839e 100644 --- a/pkg/streaming/util/message/broadcast.go +++ b/pkg/streaming/util/message/broadcast.go @@ -52,13 +52,6 @@ type ResourceKey struct { Key string } -func (rk *ResourceKey) IntoResourceKey() *messagespb.ResourceKey { - return &messagespb.ResourceKey{ - Domain: rk.Domain, - Key: rk.Key, - } -} - // NewImportJobIDResourceKey creates a key for import job resource. func NewImportJobIDResourceKey(importJobID int64) ResourceKey { return ResourceKey{ diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index c27e41df6c..45bda75218 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -1,6 +1,7 @@ package message import ( + "fmt" "reflect" "github.com/cockroachdb/errors" @@ -99,10 +100,11 @@ func newMutableMessageBuilder[H proto.Message, B proto.Message](v Version) *muta // mutableMesasgeBuilder is the builder for message. type mutableMesasgeBuilder[H proto.Message, B proto.Message] struct { - header H - body B - properties propertiesImpl - allVChannel bool + header H + body B + properties propertiesImpl + cipherConfig *CipherConfig + allVChannel bool } // WithMessageHeader creates a new builder with determined message type. @@ -175,6 +177,12 @@ func (b *mutableMesasgeBuilder[H, B]) WithProperties(kvs map[string]string) *mut return b } +// WithCipher creates a new builder with cipher property. +func (b *mutableMesasgeBuilder[H, B]) WithCipher(cipherConfig *CipherConfig) *mutableMesasgeBuilder[H, B] { + b.cipherConfig = cipherConfig + return b +} + // BuildMutable builds a mutable message. // Panic if not set payload and message type. // should only used at client side. @@ -226,6 +234,31 @@ func (b *mutableMesasgeBuilder[H, B]) build() (*messageImpl, error) { if err != nil { return nil, errors.Wrap(err, "failed to marshal body") } + if b.cipherConfig != nil { + messageType := mustGetMessageTypeFromHeader(b.header) + if !messageType.CanEnableCipher() { + panic(fmt.Sprintf("the message type cannot enable cipher, %s", messageType)) + } + + cipher := mustGetCipher() + encryptor, safeKey, err := cipher.GetEncryptor(b.cipherConfig.EzID) + if err != nil { + return nil, errors.Wrap(err, "failed to get encryptor") + } + payloadBytes := len(payload) + if payload, err = encryptor.Encrypt(payload); err != nil { + return nil, errors.Wrap(err, "failed to encrypt payload") + } + ch, err := EncodeProto(&messagespb.CipherHeader{ + EzId: b.cipherConfig.EzID, + SafeKey: safeKey, + PayloadBytes: int64(payloadBytes), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to encode cipher header") + } + b.properties.Set(messageCipherHeader, ch) + } return &messageImpl{ payload: payload, properties: b.properties, diff --git a/pkg/streaming/util/message/cipher.go b/pkg/streaming/util/message/cipher.go new file mode 100644 index 0000000000..7fb5e025f8 --- /dev/null +++ b/pkg/streaming/util/message/cipher.go @@ -0,0 +1,34 @@ +package message + +import ( + "github.com/milvus-io/milvus-proto/go-api/v2/hook" +) + +// cipher is a global variable that is used to encrypt and decrypt messages. +// It should be initialized at initialization stage. +var ( + cipher hook.Cipher +) + +// RegisterCipher registers a cipher to be used for encrypting and decrypting messages. +// It should be called only once when the program starts and initialization stage. +func RegisterCipher(c hook.Cipher) { + if cipher != nil { + panic("cipher already registered") + } + cipher = c +} + +// mustGetCipher returns the registered cipher. +func mustGetCipher() hook.Cipher { + if cipher == nil { + panic("cipher not registered") + } + return cipher +} + +// CipherConfig is the configuration for cipher that is used to encrypt and decrypt messages. +type CipherConfig struct { + // EzID is the encryption zone ID. + EzID int64 +} diff --git a/pkg/streaming/util/message/message_id_test.go b/pkg/streaming/util/message/message_id_test.go index 2e964f96b9..038f2fffc4 100644 --- a/pkg/streaming/util/message/message_id_test.go +++ b/pkg/streaming/util/message/message_id_test.go @@ -41,3 +41,14 @@ func TestRegisterMessageIDUnmarshaler(t *testing.T) { }) }) } + +func TestCases(t *testing.T) { + msgID := mock_message.NewMockMessageID(t) + msgID.EXPECT().Marshal().Return("123").Maybe() + message.CreateTestInsertMessage(t, 1, 100, 100, msgID) + message.CreateTestCreateCollectionMessage(t, 1, 100, msgID) + message.CreateTestEmptyInsertMesage(1, nil) + message.CreateTestDropCollectionMessage(t, 1, 100, msgID) + message.CreateTestTimeTickSyncMessage(t, 1, 100, msgID) + message.CreateTestCreateSegmentMessage(t, 1, 100, msgID) +} diff --git a/pkg/streaming/util/message/message_impl.go b/pkg/streaming/util/message/message_impl.go index c6243ecc36..10a214adaf 100644 --- a/pkg/streaming/util/message/message_impl.go +++ b/pkg/streaming/util/message/message_impl.go @@ -31,6 +31,18 @@ func (m *messageImpl) Version() Version { // Payload returns payload of current message. func (m *messageImpl) Payload() []byte { + if ch := m.cipherHeader(); ch != nil { + cipher := mustGetCipher() + decryptor, err := cipher.GetDecryptor(ch.EzId, ch.SafeKey) + if err != nil { + panic(fmt.Sprintf("can not get decryptor for message: %s", err)) + } + payload, err := decryptor.Decrypt(m.payload) + if err != nil { + panic(fmt.Sprintf("can not decrypt message: %s", err)) + } + return payload + } return m.payload } @@ -41,6 +53,10 @@ func (m *messageImpl) Properties() RProperties { // EstimateSize returns the estimated size of current message. func (m *messageImpl) EstimateSize() int { + if ch := m.cipherHeader(); ch != nil { + // if it's a cipher message, we need to estimate the size of payload before encryption. + return int(ch.PayloadBytes) + m.properties.EstimateSize() + } // TODO: more accurate size estimation. return len(m.payload) + m.properties.EstimateSize() } @@ -197,6 +213,19 @@ func (m *messageImpl) broadcastHeader() *messagespb.BroadcastHeader { return header } +// cipherHeader returns the cipher header of current message. +func (m *messageImpl) cipherHeader() *messagespb.CipherHeader { + value, ok := m.properties.Get(messageCipherHeader) + if !ok { + return nil + } + header := &messagespb.CipherHeader{} + if err := DecodeProto(value, header); err != nil { + panic("can not decode cipher header") + } + return header +} + // SplitIntoMutableMessage splits the current broadcast message into multiple messages. func (m *messageImpl) SplitIntoMutableMessage() []MutableMessage { bh := m.broadcastHeader() diff --git a/pkg/streaming/util/message/message_test.go b/pkg/streaming/util/message/message_test.go index 5276d8f830..8629bf3c25 100644 --- a/pkg/streaming/util/message/message_test.go +++ b/pkg/streaming/util/message/message_test.go @@ -4,6 +4,10 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/v2/mocks/github.com/milvus-io/milvus-proto/go-api/v2/mock_hook" ) func TestMessageType(t *testing.T) { @@ -20,6 +24,22 @@ func TestMessageType(t *testing.T) { typ = unmarshalMessageType(s) assert.Equal(t, MessageTypeTimeTick, typ) assert.True(t, MessageTypeTimeTick.Valid()) + + assert.True(t, MessageTypeTimeTick.IsSystem()) + assert.True(t, MessageTypeTxn.IsSystem()) + assert.True(t, MessageTypeBeginTxn.IsSystem()) + assert.True(t, MessageTypeCommitTxn.IsSystem()) + assert.True(t, MessageTypeRollbackTxn.IsSystem()) + assert.False(t, MessageTypeImport.IsSystem()) + assert.False(t, MessageTypeInsert.IsSystem()) + assert.False(t, MessageTypeDelete.IsSystem()) + assert.False(t, MessageTypeCreateSegment.IsSystem()) + assert.False(t, MessageTypeFlush.IsSystem()) + assert.False(t, MessageTypeManualFlush.IsSystem()) + assert.False(t, MessageTypeCreateCollection.IsSystem()) + assert.False(t, MessageTypeDropCollection.IsSystem()) + assert.False(t, MessageTypeCreatePartition.IsSystem()) + assert.False(t, MessageTypeDropPartition.IsSystem()) } func TestVersion(t *testing.T) { @@ -33,4 +53,64 @@ func TestVersion(t *testing.T) { assert.True(t, VersionV1.GT(VersionOld)) assert.True(t, VersionV2.GT(VersionV1)) + assert.True(t, VersionV1.EQ(VersionV1)) + assert.True(t, VersionV2.EQ(VersionV2)) + assert.True(t, VersionOld.EQ(VersionOld)) +} + +func TestBroadcast(t *testing.T) { + msg, err := NewCreateCollectionMessageBuilderV1(). + WithHeader(&CreateCollectionMessageHeader{}). + WithBody(&msgpb.CreateCollectionRequest{}). + WithBroadcast([]string{"v1", "v2"}, NewCollectionNameResourceKey("1"), NewImportJobIDResourceKey(1)). + BuildBroadcast() + assert.NoError(t, err) + assert.NotNil(t, msg) + msg.WithBroadcastID(1) + msgs := msg.SplitIntoMutableMessage() + assert.NotNil(t, msgs) + assert.Len(t, msgs, 2) + assert.Equal(t, *msgs[1].BroadcastHeader(), *msgs[0].BroadcastHeader()) + assert.Equal(t, uint64(1), msgs[1].BroadcastHeader().BroadcastID) + assert.Len(t, msgs[0].BroadcastHeader().ResourceKeys, 2) + assert.ElementsMatch(t, []string{"v1", "v2"}, []string{msgs[0].VChannel(), msgs[1].VChannel()}) +} + +func TestCiper(t *testing.T) { + // Not broadcast. + builder := NewInsertMessageBuilderV1(). + WithHeader(&InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{ + ShardName: "123123", + }). + WithVChannel("v1"). + WithCipher(&CipherConfig{ + EzID: 1, + }) + assert.Panics(t, func() { + builder.BuildMutable() + }) + c := mock_hook.NewMockCipher(t) + e := mock_hook.NewMockEncryptor(t) + e.EXPECT().Encrypt(mock.Anything).RunAndReturn(func(b []byte) ([]byte, error) { + return []byte("123" + string(b)), nil + }) + d := mock_hook.NewMockDecryptor(t) + d.EXPECT().Decrypt(mock.Anything).RunAndReturn(func(b []byte) ([]byte, error) { + return b[3:], nil + }) + c.EXPECT().GetEncryptor(mock.Anything).Return(e, []byte("123"), nil) + c.EXPECT().GetDecryptor(mock.Anything, mock.Anything).Return(d, nil) + RegisterCipher(c) + + msg, _ := builder.WithCipher(&CipherConfig{ + EzID: 1, + }).BuildMutable() + + msg2, err := AsMutableInsertMessageV1(msg) + assert.NoError(t, err) + body, err := msg2.Body() + assert.NoError(t, err) + assert.Equal(t, body.ShardName, "123123") + assert.Equal(t, msg2.EstimateSize(), 36) } diff --git a/pkg/streaming/util/message/message_type.go b/pkg/streaming/util/message/message_type.go index 2e2c09997d..19dbda3f4e 100644 --- a/pkg/streaming/util/message/message_type.go +++ b/pkg/streaming/util/message/message_type.go @@ -62,6 +62,12 @@ func (t MessageType) Valid() bool { return t != MessageTypeUnknown && ok } +// CanEnableCipher checks if the MessageType can enable cipher. +func (t MessageType) CanEnableCipher() bool { + _, ok := cipherMessageType[t] + return ok +} + // IsSysmtem checks if the MessageType is a system type. func (t MessageType) IsSystem() bool { _, ok := systemMessageType[t] diff --git a/pkg/streaming/util/message/properties.go b/pkg/streaming/util/message/properties.go index 193f18eb0a..62f2747e47 100644 --- a/pkg/streaming/util/message/properties.go +++ b/pkg/streaming/util/message/properties.go @@ -13,6 +13,7 @@ const ( messageBroadcastHeader = "_bh" // message broadcast header. messageHeader = "_h" // specialized message header. messageTxnContext = "_tx" // transaction context. + messageCipherHeader = "_ch" // message cipher header. ) var ( diff --git a/pkg/streaming/util/message/specialized_message.go b/pkg/streaming/util/message/specialized_message.go index 4b3d44e5f7..b8c25a6e96 100644 --- a/pkg/streaming/util/message/specialized_message.go +++ b/pkg/streaming/util/message/specialized_message.go @@ -73,6 +73,11 @@ var systemMessageType = map[MessageType]struct{}{ MessageTypeTxn: {}, } +var cipherMessageType = map[MessageType]struct{}{ + MessageTypeInsert: {}, + MessageTypeDelete: {}, +} + // List all specialized message types. type ( MutableTimeTickMessageV1 = specializedMutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg] @@ -244,7 +249,7 @@ func (m *specializedMutableMessageImpl[H, B]) Header() H { // Body returns the message body. func (m *specializedMutableMessageImpl[H, B]) Body() (B, error) { - return unmarshalProtoB[B](m.payload) + return unmarshalProtoB[B](m.Payload()) } // OverwriteMessageHeader overwrites the message header. @@ -270,7 +275,7 @@ func (m *specializedImmutableMessageImpl[H, B]) Header() H { // Body returns the message body. func (m *specializedImmutableMessageImpl[H, B]) Body() (B, error) { - return unmarshalProtoB[B](m.payload) + return unmarshalProtoB[B](m.Payload()) } func unmarshalProtoB[B proto.Message](data []byte) (B, error) { diff --git a/pkg/streaming/util/message/txn_test.go b/pkg/streaming/util/message/txn_test.go index a28a4eaec9..b07afb5b9a 100644 --- a/pkg/streaming/util/message/txn_test.go +++ b/pkg/streaming/util/message/txn_test.go @@ -87,4 +87,6 @@ func TestAsImmutableTxnMessage(t *testing.T) { assert.NotNil(t, txnMsg.Commit()) assert.Equal(t, 1, txnMsg.Size()) assert.NotNil(t, txnMsg.Begin()) + assert.NotNil(t, message.AsImmutableTxnMessage(txnMsg)) + assert.Nil(t, message.AsImmutableTxnMessage(beginMsg)) }