enhance: simplify the proto message, make segment assignment code more clean (#41671)

issue: #41544

- simplify the proto message for flush and create segment.
- simplify the msg handler for flowgraph.

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-05-11 20:49:00 +08:00 committed by GitHub
parent 452d6fb709
commit e675da76e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 4080 additions and 1692 deletions

View File

@ -1215,6 +1215,19 @@ streaming:
# The threshold of slow log, 1s by default.
# If the wal implementation is woodpecker, the minimum threshold is 3s
appendSlowThreshold: 1s
walRecovery:
# The interval of persist recovery info, 10s by default.
# Every the interval, the recovery info of wal will try to persist, and the checkpoint of wal can be advanced.
# Currently it only affect the recovery of wal, but not affect the recovery of data flush into object storage
persistInterval: 10s
# The max dirty message count of wal recovery, 100 by default.
# If there are more than this count of dirty message in wal recovery info, it will be persisted immediately,
# but not wait for the persist interval.
maxDirtyMessage: 100
# The graceful close timeout for wal recovery, 3s by default.
# When the wal is on-closing, the recovery module will try to persist the recovery info for wal to make next recovery operation more fast.
# If that persist operation exceeds this timeout, the wal recovery module will close right now.
gracefulCloseTimeout: 3s
# Any configuration related to the knowhere vector search engine
knowhere:

View File

@ -32,7 +32,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -43,15 +42,15 @@ type msgHandlerImpl struct {
broker broker.Broker
}
func (m *msgHandlerImpl) HandleCreateSegment(ctx context.Context, vchannel string, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
func (m *msgHandlerImpl) HandleCreateSegment(ctx context.Context, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
panic("unreachable code")
}
func (m *msgHandlerImpl) HandleFlush(vchannel string, flushMsg message.ImmutableFlushMessageV2) error {
func (m *msgHandlerImpl) HandleFlush(flushMsg message.ImmutableFlushMessageV2) error {
panic("unreachable code")
}
func (m *msgHandlerImpl) HandleManualFlush(vchannel string, flushMsg message.ImmutableManualFlushMessageV2) error {
func (m *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error {
panic("unreachable code")
}
@ -90,11 +89,8 @@ func (m *msgHandlerImpl) HandleImport(ctx context.Context, vchannel string, impo
}, retry.AttemptAlways())
}
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, vchannel string, msg *adaptor.SchemaChangeMessageBody) error {
return streaming.WAL().Broadcast().Ack(ctx, types.BroadcastAckRequest{
BroadcastID: msg.BroadcastID,
VChannel: vchannel,
})
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error {
panic("unreachable code")
}
func NewMsgHandlerImpl(broker broker.Broker) *msgHandlerImpl {

View File

@ -37,13 +37,13 @@ func TestMsgHandlerImpl(t *testing.T) {
b := broker.NewMockBroker(t)
m := NewMsgHandlerImpl(b)
assert.Panics(t, func() {
m.HandleCreateSegment(nil, "", nil)
m.HandleCreateSegment(nil, nil)
})
assert.Panics(t, func() {
m.HandleFlush("", nil)
m.HandleFlush(nil)
})
assert.Panics(t, func() {
m.HandleManualFlush("", nil)
m.HandleManualFlush(nil)
})
t.Run("HandleImport success", func(t *testing.T) {
wal := mock_streaming.NewMockWALAccesser(t)

View File

@ -249,7 +249,7 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
zap.Uint64("timetick", createSegment.CreateSegmentMessage.TimeTick()),
)
logger.Info("receive create segment message")
if err := ddn.msgHandler.HandleCreateSegment(context.Background(), ddn.vChannelName, createSegment.CreateSegmentMessage); err != nil {
if err := ddn.msgHandler.HandleCreateSegment(ddn.ctx, createSegment.CreateSegmentMessage); err != nil {
logger.Warn("handle create segment message failed", zap.Error(err))
} else {
logger.Info("handle create segment message success")
@ -262,7 +262,7 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
zap.Uint64("timetick", flushMsg.FlushMessage.TimeTick()),
)
logger.Info("receive flush message")
if err := ddn.msgHandler.HandleFlush(ddn.vChannelName, flushMsg.FlushMessage); err != nil {
if err := ddn.msgHandler.HandleFlush(flushMsg.FlushMessage); err != nil {
logger.Warn("handle flush message failed", zap.Error(err))
} else {
logger.Info("handle flush message success")
@ -276,7 +276,7 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
zap.Uint64("flushTs", manualFlushMsg.ManualFlushMessage.Header().FlushTs),
)
logger.Info("receive manual flush message")
if err := ddn.msgHandler.HandleManualFlush(ddn.vChannelName, manualFlushMsg.ManualFlushMessage); err != nil {
if err := ddn.msgHandler.HandleManualFlush(manualFlushMsg.ManualFlushMessage); err != nil {
logger.Warn("handle manual flush message failed", zap.Error(err))
} else {
logger.Info("handle manual flush message success")
@ -311,7 +311,7 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
}
fgMsg.updatedSchema = body.GetSchema()
fgMsg.schemaVersion = schemaMsg.BeginTs()
ddn.msgHandler.HandleSchemaChange(ddn.ctx, ddn.vChannelName, schemaMsg)
ddn.msgHandler.HandleSchemaChange(ddn.ctx, schemaMsg.SchemaChangeMessage)
}
}

View File

@ -98,9 +98,9 @@ func TestFlowGraph_DDNode_newDDNode(t *testing.T) {
func TestFlowGraph_DDNode_OperateFlush(t *testing.T) {
h := mock_util.NewMockMsgHandler(t)
h.EXPECT().HandleCreateSegment(mock.Anything, mock.Anything, mock.Anything).Return(nil)
h.EXPECT().HandleFlush(mock.Anything, mock.Anything).Return(nil)
h.EXPECT().HandleManualFlush(mock.Anything, mock.Anything).Return(nil)
h.EXPECT().HandleCreateSegment(mock.Anything, mock.Anything).Return(nil)
h.EXPECT().HandleFlush(mock.Anything).Return(nil)
h.EXPECT().HandleManualFlush(mock.Anything).Return(nil)
ddn := ddNode{
ctx: context.Background(),

View File

@ -22,19 +22,18 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
)
type MsgHandler interface {
HandleCreateSegment(ctx context.Context, vchannel string, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error
HandleCreateSegment(ctx context.Context, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error
HandleFlush(vchannel string, flushMsg message.ImmutableFlushMessageV2) error
HandleFlush(flushMsg message.ImmutableFlushMessageV2) error
HandleManualFlush(vchannel string, flushMsg message.ImmutableManualFlushMessageV2) error
HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error
HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error
HandleSchemaChange(ctx context.Context, vchannel string, msg *adaptor.SchemaChangeMessageBody) error
HandleSchemaChange(ctx context.Context, schemaChangeMsg message.ImmutableSchemaChangeMessageV2) error
}
func ConvertInternalImportFile(file *msgpb.ImportFile, _ int) *internalpb.ImportFile {

View File

@ -226,11 +226,17 @@ type StreamingNodeCataLog interface {
// WAL select the wal related recovery infos.
// Which must give the pchannel name.
// ListVChannel list all vchannels on current pchannel.
ListVChannel(ctx context.Context, pchannelName string) ([]*streamingpb.VChannelMeta, error)
// SaveVChannels save vchannel on current pchannel.
SaveVChannels(ctx context.Context, pchannelName string, vchannels map[string]*streamingpb.VChannelMeta) error
// ListSegmentAssignment list all segment assignments for the wal.
ListSegmentAssignment(ctx context.Context, pChannelName string) ([]*streamingpb.SegmentAssignmentMeta, error)
// SaveSegmentAssignments save the segment assignments for the wal.
SaveSegmentAssignments(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta) error
SaveSegmentAssignments(ctx context.Context, pChannelName string, infos map[int64]*streamingpb.SegmentAssignmentMeta) error
// GetConsumeCheckpoint gets the consuming checkpoint of the wal.
// Return nil, nil if the checkpoint is not exist.

View File

@ -5,6 +5,7 @@ const (
DirectoryWAL = "wal"
DirectorySegmentAssign = "segment-assign"
DirectoryVChannel = "vchannel"
KeyConsumeCheckpoint = "consume-checkpoint"
)

View File

@ -24,12 +24,18 @@ import (
//
// ├── pchannel-1
// │   ├── checkpoint
// │   ├── vchannels
// │   │   ├── vchannel-1
// │   │   └── vchannel-2
// │   └── segment-assign
// │   ├── 456398247934
// │   ├── 456398247936
// │   └── 456398247939
// └── pchannel-2
// ├── checkpoint
//   ├── vchannels
//    │   ├── vchannel-1
//    │   └── vchannel-2
// └── segment-assign
// ├── 456398247934
// ├── 456398247935
@ -45,6 +51,58 @@ type catalog struct {
metaKV kv.MetaKv
}
// ListVChannel lists the vchannel info of the pchannel.
func (c *catalog) ListVChannel(ctx context.Context, pchannelName string) ([]*streamingpb.VChannelMeta, error) {
prefix := buildVChannelMetaPath(pchannelName)
keys, values, err := c.metaKV.LoadWithPrefix(ctx, prefix)
if err != nil {
return nil, err
}
infos := make([]*streamingpb.VChannelMeta, 0, len(values))
for k, value := range values {
info := &streamingpb.VChannelMeta{}
if err = proto.Unmarshal([]byte(value), info); err != nil {
return nil, errors.Wrapf(err, "unmarshal pchannel %s failed", keys[k])
}
infos = append(infos, info)
}
return infos, nil
}
// SaveVChannels save vchannel on current pchannel.
func (c *catalog) SaveVChannels(ctx context.Context, pchannelName string, vchannels map[string]*streamingpb.VChannelMeta) error {
kvs := make(map[string]string, len(vchannels))
removes := make([]string, 0)
for _, info := range vchannels {
key := buildVChannelMetaPathOfVChannel(pchannelName, info.GetVchannel())
if info.GetState() == streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
// Flushed segment should be removed from meta
removes = append(removes, key)
continue
}
data, err := proto.Marshal(info)
if err != nil {
return errors.Wrapf(err, "marshal vchannel %d at pchannel %s failed", info.GetVchannel(), pchannelName)
}
kvs[key] = string(data)
}
if len(removes) > 0 {
if err := etcd.RemoveByBatchWithLimit(removes, util.MaxEtcdTxnNum, func(partialRemoves []string) error {
return c.metaKV.MultiRemove(ctx, partialRemoves)
}); err != nil {
return err
}
}
if len(kvs) > 0 {
return etcd.SaveByBatchWithLimit(kvs, util.MaxEtcdTxnNum, func(partialKvs map[string]string) error {
return c.metaKV.MultiSave(ctx, partialKvs)
})
}
return nil
}
// ListSegmentAssignment lists the segment assignment info of the pchannel.
func (c *catalog) ListSegmentAssignment(ctx context.Context, pChannelName string) ([]*streamingpb.SegmentAssignmentMeta, error) {
prefix := buildSegmentAssignmentMetaPath(pChannelName)
@ -65,7 +123,7 @@ func (c *catalog) ListSegmentAssignment(ctx context.Context, pChannelName string
}
// SaveSegmentAssignments saves the segment assignment info to meta storage.
func (c *catalog) SaveSegmentAssignments(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta) error {
func (c *catalog) SaveSegmentAssignments(ctx context.Context, pChannelName string, infos map[int64]*streamingpb.SegmentAssignmentMeta) error {
kvs := make(map[string]string, len(infos))
removes := make([]string, 0)
for _, info := range infos {
@ -126,6 +184,16 @@ func (c *catalog) SaveConsumeCheckpoint(ctx context.Context, pchannelName string
return c.metaKV.Save(ctx, key, string(value))
}
// buildVChannelMetaPath builds the path for vchannel meta
func buildVChannelMetaPath(pChannelName string) string {
return path.Join(buildWALDirectory(pChannelName), DirectoryVChannel) + "/"
}
// buildVChannelMetaPathOfVChannel builds the path for vchannel meta
func buildVChannelMetaPathOfVChannel(pChannelName string, vchannelName string) string {
return path.Join(buildVChannelMetaPath(pChannelName), vchannelName)
}
// buildSegmentAssignmentMetaPath builds the path for segment assignment
func buildSegmentAssignmentMetaPath(pChannelName string) string {
return path.Join(buildWALDirectory(pChannelName), DirectorySegmentAssign) + "/"

View File

@ -66,14 +66,44 @@ func TestCatalogSegmentAssignments(t *testing.T) {
kv.EXPECT().MultiRemove(mock.Anything, mock.Anything).Return(nil)
kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Return(nil)
err = catalog.SaveSegmentAssignments(ctx, "p1", []*streamingpb.SegmentAssignmentMeta{
{
err = catalog.SaveSegmentAssignments(ctx, "p1", map[int64]*streamingpb.SegmentAssignmentMeta{
1: {
SegmentId: 1,
State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED,
},
{
2: {
SegmentId: 2,
State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_PENDING,
State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING,
},
})
assert.NoError(t, err)
}
func TestCatalogVChannel(t *testing.T) {
kv := mocks.NewMetaKv(t)
k := "p1/vchannel-1"
v := streamingpb.VChannelMeta{}
vs, err := proto.Marshal(&v)
assert.NoError(t, err)
kv.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Return([]string{k}, []string{string(vs)}, nil)
catalog := NewCataLog(kv)
ctx := context.Background()
metas, err := catalog.ListVChannel(ctx, "p1")
assert.Len(t, metas, 1)
assert.NoError(t, err)
kv.EXPECT().MultiRemove(mock.Anything, mock.Anything).Return(nil)
kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Return(nil)
err = catalog.SaveVChannels(ctx, "p1", map[string]*streamingpb.VChannelMeta{
"vchannel-1": {
Vchannel: "vchannel-1",
State: streamingpb.VChannelState_VCHANNEL_STATE_DROPPED,
},
"vchannel-2": {
Vchannel: "vchannel-2",
State: streamingpb.VChannelState_VCHANNEL_STATE_NORMAL,
},
})
assert.NoError(t, err)

View File

@ -27,7 +27,7 @@ func (_m *MockWALAccesser) EXPECT() *MockWALAccesser_Expecter {
}
// AppendMessages provides a mock function with given fields: ctx, msgs
func (_m *MockWALAccesser) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
func (_m *MockWALAccesser) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses {
_va := make([]interface{}, len(msgs))
for _i := range msgs {
_va[_i] = msgs[_i]
@ -41,11 +41,11 @@ func (_m *MockWALAccesser) AppendMessages(ctx context.Context, msgs ...message.M
panic("no return value specified for AppendMessages")
}
var r0 types.AppendResponses
if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) types.AppendResponses); ok {
var r0 streaming.AppendResponses
if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) streaming.AppendResponses); ok {
r0 = rf(ctx, msgs...)
} else {
r0 = ret.Get(0).(types.AppendResponses)
r0 = ret.Get(0).(streaming.AppendResponses)
}
return r0
@ -77,18 +77,18 @@ func (_c *MockWALAccesser_AppendMessages_Call) Run(run func(ctx context.Context,
return _c
}
func (_c *MockWALAccesser_AppendMessages_Call) Return(_a0 types.AppendResponses) *MockWALAccesser_AppendMessages_Call {
func (_c *MockWALAccesser_AppendMessages_Call) Return(_a0 streaming.AppendResponses) *MockWALAccesser_AppendMessages_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWALAccesser_AppendMessages_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) types.AppendResponses) *MockWALAccesser_AppendMessages_Call {
func (_c *MockWALAccesser_AppendMessages_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) streaming.AppendResponses) *MockWALAccesser_AppendMessages_Call {
_c.Call.Return(run)
return _c
}
// AppendMessagesWithOption provides a mock function with given fields: ctx, opts, msgs
func (_m *MockWALAccesser) AppendMessagesWithOption(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage) types.AppendResponses {
func (_m *MockWALAccesser) AppendMessagesWithOption(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage) streaming.AppendResponses {
_va := make([]interface{}, len(msgs))
for _i := range msgs {
_va[_i] = msgs[_i]
@ -102,11 +102,11 @@ func (_m *MockWALAccesser) AppendMessagesWithOption(ctx context.Context, opts st
panic("no return value specified for AppendMessagesWithOption")
}
var r0 types.AppendResponses
if rf, ok := ret.Get(0).(func(context.Context, streaming.AppendOption, ...message.MutableMessage) types.AppendResponses); ok {
var r0 streaming.AppendResponses
if rf, ok := ret.Get(0).(func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses); ok {
r0 = rf(ctx, opts, msgs...)
} else {
r0 = ret.Get(0).(types.AppendResponses)
r0 = ret.Get(0).(streaming.AppendResponses)
}
return r0
@ -139,12 +139,12 @@ func (_c *MockWALAccesser_AppendMessagesWithOption_Call) Run(run func(ctx contex
return _c
}
func (_c *MockWALAccesser_AppendMessagesWithOption_Call) Return(_a0 types.AppendResponses) *MockWALAccesser_AppendMessagesWithOption_Call {
func (_c *MockWALAccesser_AppendMessagesWithOption_Call) Return(_a0 streaming.AppendResponses) *MockWALAccesser_AppendMessagesWithOption_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWALAccesser_AppendMessagesWithOption_Call) RunAndReturn(run func(context.Context, streaming.AppendOption, ...message.MutableMessage) types.AppendResponses) *MockWALAccesser_AppendMessagesWithOption_Call {
func (_c *MockWALAccesser_AppendMessagesWithOption_Call) RunAndReturn(run func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses) *MockWALAccesser_AppendMessagesWithOption_Call {
_c.Call.Return(run)
return _c
}

View File

@ -5,10 +5,7 @@ package mock_util
import (
context "context"
adaptor "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
@ -27,17 +24,17 @@ func (_m *MockMsgHandler) EXPECT() *MockMsgHandler_Expecter {
return &MockMsgHandler_Expecter{mock: &_m.Mock}
}
// HandleCreateSegment provides a mock function with given fields: ctx, vchannel, createSegmentMsg
func (_m *MockMsgHandler) HandleCreateSegment(ctx context.Context, vchannel string, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
ret := _m.Called(ctx, vchannel, createSegmentMsg)
// HandleCreateSegment provides a mock function with given fields: ctx, createSegmentMsg
func (_m *MockMsgHandler) HandleCreateSegment(ctx context.Context, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
ret := _m.Called(ctx, createSegmentMsg)
if len(ret) == 0 {
panic("no return value specified for HandleCreateSegment")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, message.ImmutableCreateSegmentMessageV2) error); ok {
r0 = rf(ctx, vchannel, createSegmentMsg)
if rf, ok := ret.Get(0).(func(context.Context, message.ImmutableCreateSegmentMessageV2) error); ok {
r0 = rf(ctx, createSegmentMsg)
} else {
r0 = ret.Error(0)
}
@ -52,15 +49,14 @@ type MockMsgHandler_HandleCreateSegment_Call struct {
// HandleCreateSegment is a helper method to define mock.On call
// - ctx context.Context
// - vchannel string
// - createSegmentMsg message.ImmutableCreateSegmentMessageV2
func (_e *MockMsgHandler_Expecter) HandleCreateSegment(ctx interface{}, vchannel interface{}, createSegmentMsg interface{}) *MockMsgHandler_HandleCreateSegment_Call {
return &MockMsgHandler_HandleCreateSegment_Call{Call: _e.mock.On("HandleCreateSegment", ctx, vchannel, createSegmentMsg)}
func (_e *MockMsgHandler_Expecter) HandleCreateSegment(ctx interface{}, createSegmentMsg interface{}) *MockMsgHandler_HandleCreateSegment_Call {
return &MockMsgHandler_HandleCreateSegment_Call{Call: _e.mock.On("HandleCreateSegment", ctx, createSegmentMsg)}
}
func (_c *MockMsgHandler_HandleCreateSegment_Call) Run(run func(ctx context.Context, vchannel string, createSegmentMsg message.ImmutableCreateSegmentMessageV2)) *MockMsgHandler_HandleCreateSegment_Call {
func (_c *MockMsgHandler_HandleCreateSegment_Call) Run(run func(ctx context.Context, createSegmentMsg message.ImmutableCreateSegmentMessageV2)) *MockMsgHandler_HandleCreateSegment_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(message.ImmutableCreateSegmentMessageV2))
run(args[0].(context.Context), args[1].(message.ImmutableCreateSegmentMessageV2))
})
return _c
}
@ -70,22 +66,22 @@ func (_c *MockMsgHandler_HandleCreateSegment_Call) Return(_a0 error) *MockMsgHan
return _c
}
func (_c *MockMsgHandler_HandleCreateSegment_Call) RunAndReturn(run func(context.Context, string, message.ImmutableCreateSegmentMessageV2) error) *MockMsgHandler_HandleCreateSegment_Call {
func (_c *MockMsgHandler_HandleCreateSegment_Call) RunAndReturn(run func(context.Context, message.ImmutableCreateSegmentMessageV2) error) *MockMsgHandler_HandleCreateSegment_Call {
_c.Call.Return(run)
return _c
}
// HandleFlush provides a mock function with given fields: vchannel, flushMsg
func (_m *MockMsgHandler) HandleFlush(vchannel string, flushMsg message.ImmutableFlushMessageV2) error {
ret := _m.Called(vchannel, flushMsg)
// HandleFlush provides a mock function with given fields: flushMsg
func (_m *MockMsgHandler) HandleFlush(flushMsg message.ImmutableFlushMessageV2) error {
ret := _m.Called(flushMsg)
if len(ret) == 0 {
panic("no return value specified for HandleFlush")
}
var r0 error
if rf, ok := ret.Get(0).(func(string, message.ImmutableFlushMessageV2) error); ok {
r0 = rf(vchannel, flushMsg)
if rf, ok := ret.Get(0).(func(message.ImmutableFlushMessageV2) error); ok {
r0 = rf(flushMsg)
} else {
r0 = ret.Error(0)
}
@ -99,15 +95,14 @@ type MockMsgHandler_HandleFlush_Call struct {
}
// HandleFlush is a helper method to define mock.On call
// - vchannel string
// - flushMsg message.ImmutableFlushMessageV2
func (_e *MockMsgHandler_Expecter) HandleFlush(vchannel interface{}, flushMsg interface{}) *MockMsgHandler_HandleFlush_Call {
return &MockMsgHandler_HandleFlush_Call{Call: _e.mock.On("HandleFlush", vchannel, flushMsg)}
func (_e *MockMsgHandler_Expecter) HandleFlush(flushMsg interface{}) *MockMsgHandler_HandleFlush_Call {
return &MockMsgHandler_HandleFlush_Call{Call: _e.mock.On("HandleFlush", flushMsg)}
}
func (_c *MockMsgHandler_HandleFlush_Call) Run(run func(vchannel string, flushMsg message.ImmutableFlushMessageV2)) *MockMsgHandler_HandleFlush_Call {
func (_c *MockMsgHandler_HandleFlush_Call) Run(run func(flushMsg message.ImmutableFlushMessageV2)) *MockMsgHandler_HandleFlush_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(message.ImmutableFlushMessageV2))
run(args[0].(message.ImmutableFlushMessageV2))
})
return _c
}
@ -117,7 +112,7 @@ func (_c *MockMsgHandler_HandleFlush_Call) Return(_a0 error) *MockMsgHandler_Han
return _c
}
func (_c *MockMsgHandler_HandleFlush_Call) RunAndReturn(run func(string, message.ImmutableFlushMessageV2) error) *MockMsgHandler_HandleFlush_Call {
func (_c *MockMsgHandler_HandleFlush_Call) RunAndReturn(run func(message.ImmutableFlushMessageV2) error) *MockMsgHandler_HandleFlush_Call {
_c.Call.Return(run)
return _c
}
@ -170,17 +165,17 @@ func (_c *MockMsgHandler_HandleImport_Call) RunAndReturn(run func(context.Contex
return _c
}
// HandleManualFlush provides a mock function with given fields: vchannel, flushMsg
func (_m *MockMsgHandler) HandleManualFlush(vchannel string, flushMsg message.ImmutableManualFlushMessageV2) error {
ret := _m.Called(vchannel, flushMsg)
// HandleManualFlush provides a mock function with given fields: flushMsg
func (_m *MockMsgHandler) HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error {
ret := _m.Called(flushMsg)
if len(ret) == 0 {
panic("no return value specified for HandleManualFlush")
}
var r0 error
if rf, ok := ret.Get(0).(func(string, message.ImmutableManualFlushMessageV2) error); ok {
r0 = rf(vchannel, flushMsg)
if rf, ok := ret.Get(0).(func(message.ImmutableManualFlushMessageV2) error); ok {
r0 = rf(flushMsg)
} else {
r0 = ret.Error(0)
}
@ -194,15 +189,14 @@ type MockMsgHandler_HandleManualFlush_Call struct {
}
// HandleManualFlush is a helper method to define mock.On call
// - vchannel string
// - flushMsg message.ImmutableManualFlushMessageV2
func (_e *MockMsgHandler_Expecter) HandleManualFlush(vchannel interface{}, flushMsg interface{}) *MockMsgHandler_HandleManualFlush_Call {
return &MockMsgHandler_HandleManualFlush_Call{Call: _e.mock.On("HandleManualFlush", vchannel, flushMsg)}
func (_e *MockMsgHandler_Expecter) HandleManualFlush(flushMsg interface{}) *MockMsgHandler_HandleManualFlush_Call {
return &MockMsgHandler_HandleManualFlush_Call{Call: _e.mock.On("HandleManualFlush", flushMsg)}
}
func (_c *MockMsgHandler_HandleManualFlush_Call) Run(run func(vchannel string, flushMsg message.ImmutableManualFlushMessageV2)) *MockMsgHandler_HandleManualFlush_Call {
func (_c *MockMsgHandler_HandleManualFlush_Call) Run(run func(flushMsg message.ImmutableManualFlushMessageV2)) *MockMsgHandler_HandleManualFlush_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(message.ImmutableManualFlushMessageV2))
run(args[0].(message.ImmutableManualFlushMessageV2))
})
return _c
}
@ -212,22 +206,22 @@ func (_c *MockMsgHandler_HandleManualFlush_Call) Return(_a0 error) *MockMsgHandl
return _c
}
func (_c *MockMsgHandler_HandleManualFlush_Call) RunAndReturn(run func(string, message.ImmutableManualFlushMessageV2) error) *MockMsgHandler_HandleManualFlush_Call {
func (_c *MockMsgHandler_HandleManualFlush_Call) RunAndReturn(run func(message.ImmutableManualFlushMessageV2) error) *MockMsgHandler_HandleManualFlush_Call {
_c.Call.Return(run)
return _c
}
// HandleSchemaChange provides a mock function with given fields: ctx, vchannel, msg
func (_m *MockMsgHandler) HandleSchemaChange(ctx context.Context, vchannel string, msg *adaptor.SchemaChangeMessageBody) error {
ret := _m.Called(ctx, vchannel, msg)
// HandleSchemaChange provides a mock function with given fields: ctx, schemaChangeMsg
func (_m *MockMsgHandler) HandleSchemaChange(ctx context.Context, schemaChangeMsg message.ImmutableSchemaChangeMessageV2) error {
ret := _m.Called(ctx, schemaChangeMsg)
if len(ret) == 0 {
panic("no return value specified for HandleSchemaChange")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, *adaptor.SchemaChangeMessageBody) error); ok {
r0 = rf(ctx, vchannel, msg)
if rf, ok := ret.Get(0).(func(context.Context, message.ImmutableSchemaChangeMessageV2) error); ok {
r0 = rf(ctx, schemaChangeMsg)
} else {
r0 = ret.Error(0)
}
@ -242,15 +236,14 @@ type MockMsgHandler_HandleSchemaChange_Call struct {
// HandleSchemaChange is a helper method to define mock.On call
// - ctx context.Context
// - vchannel string
// - msg *adaptor.SchemaChangeMessageBody
func (_e *MockMsgHandler_Expecter) HandleSchemaChange(ctx interface{}, vchannel interface{}, msg interface{}) *MockMsgHandler_HandleSchemaChange_Call {
return &MockMsgHandler_HandleSchemaChange_Call{Call: _e.mock.On("HandleSchemaChange", ctx, vchannel, msg)}
// - schemaChangeMsg message.ImmutableSchemaChangeMessageV2
func (_e *MockMsgHandler_Expecter) HandleSchemaChange(ctx interface{}, schemaChangeMsg interface{}) *MockMsgHandler_HandleSchemaChange_Call {
return &MockMsgHandler_HandleSchemaChange_Call{Call: _e.mock.On("HandleSchemaChange", ctx, schemaChangeMsg)}
}
func (_c *MockMsgHandler_HandleSchemaChange_Call) Run(run func(ctx context.Context, vchannel string, msg *adaptor.SchemaChangeMessageBody)) *MockMsgHandler_HandleSchemaChange_Call {
func (_c *MockMsgHandler_HandleSchemaChange_Call) Run(run func(ctx context.Context, schemaChangeMsg message.ImmutableSchemaChangeMessageV2)) *MockMsgHandler_HandleSchemaChange_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(*adaptor.SchemaChangeMessageBody))
run(args[0].(context.Context), args[1].(message.ImmutableSchemaChangeMessageV2))
})
return _c
}
@ -260,7 +253,7 @@ func (_c *MockMsgHandler_HandleSchemaChange_Call) Return(_a0 error) *MockMsgHand
return _c
}
func (_c *MockMsgHandler_HandleSchemaChange_Call) RunAndReturn(run func(context.Context, string, *adaptor.SchemaChangeMessageBody) error) *MockMsgHandler_HandleSchemaChange_Call {
func (_c *MockMsgHandler_HandleSchemaChange_Call) RunAndReturn(run func(context.Context, message.ImmutableSchemaChangeMessageV2) error) *MockMsgHandler_HandleSchemaChange_Call {
_c.Call.Return(run)
return _c
}

View File

@ -173,7 +173,7 @@ func (_c *MockClientStream_Header_Call) RunAndReturn(run func() (metadata.MD, er
}
// RecvMsg provides a mock function with given fields: m
func (_m *MockClientStream) RecvMsg(m interface{}) error {
func (_m *MockClientStream) RecvMsg(m any) error {
ret := _m.Called(m)
if len(ret) == 0 {
@ -181,7 +181,7 @@ func (_m *MockClientStream) RecvMsg(m interface{}) error {
}
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
if rf, ok := ret.Get(0).(func(any) error); ok {
r0 = rf(m)
} else {
r0 = ret.Error(0)
@ -196,14 +196,14 @@ type MockClientStream_RecvMsg_Call struct {
}
// RecvMsg is a helper method to define mock.On call
// - m interface{}
// - m any
func (_e *MockClientStream_Expecter) RecvMsg(m interface{}) *MockClientStream_RecvMsg_Call {
return &MockClientStream_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)}
}
func (_c *MockClientStream_RecvMsg_Call) Run(run func(m interface{})) *MockClientStream_RecvMsg_Call {
func (_c *MockClientStream_RecvMsg_Call) Run(run func(m any)) *MockClientStream_RecvMsg_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(interface{}))
run(args[0].(any))
})
return _c
}
@ -213,13 +213,13 @@ func (_c *MockClientStream_RecvMsg_Call) Return(_a0 error) *MockClientStream_Rec
return _c
}
func (_c *MockClientStream_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockClientStream_RecvMsg_Call {
func (_c *MockClientStream_RecvMsg_Call) RunAndReturn(run func(any) error) *MockClientStream_RecvMsg_Call {
_c.Call.Return(run)
return _c
}
// SendMsg provides a mock function with given fields: m
func (_m *MockClientStream) SendMsg(m interface{}) error {
func (_m *MockClientStream) SendMsg(m any) error {
ret := _m.Called(m)
if len(ret) == 0 {
@ -227,7 +227,7 @@ func (_m *MockClientStream) SendMsg(m interface{}) error {
}
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
if rf, ok := ret.Get(0).(func(any) error); ok {
r0 = rf(m)
} else {
r0 = ret.Error(0)
@ -242,14 +242,14 @@ type MockClientStream_SendMsg_Call struct {
}
// SendMsg is a helper method to define mock.On call
// - m interface{}
// - m any
func (_e *MockClientStream_Expecter) SendMsg(m interface{}) *MockClientStream_SendMsg_Call {
return &MockClientStream_SendMsg_Call{Call: _e.mock.On("SendMsg", m)}
}
func (_c *MockClientStream_SendMsg_Call) Run(run func(m interface{})) *MockClientStream_SendMsg_Call {
func (_c *MockClientStream_SendMsg_Call) Run(run func(m any)) *MockClientStream_SendMsg_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(interface{}))
run(args[0].(any))
})
return _c
}
@ -259,7 +259,7 @@ func (_c *MockClientStream_SendMsg_Call) Return(_a0 error) *MockClientStream_Sen
return _c
}
func (_c *MockClientStream_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockClientStream_SendMsg_Call {
func (_c *MockClientStream_SendMsg_Call) RunAndReturn(run func(any) error) *MockClientStream_SendMsg_Call {
_c.Call.Return(run)
return _c
}

View File

@ -141,6 +141,65 @@ func (_c *MockStreamingNodeCataLog_ListSegmentAssignment_Call) RunAndReturn(run
return _c
}
// ListVChannel provides a mock function with given fields: ctx, pchannelName
func (_m *MockStreamingNodeCataLog) ListVChannel(ctx context.Context, pchannelName string) ([]*streamingpb.VChannelMeta, error) {
ret := _m.Called(ctx, pchannelName)
if len(ret) == 0 {
panic("no return value specified for ListVChannel")
}
var r0 []*streamingpb.VChannelMeta
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) ([]*streamingpb.VChannelMeta, error)); ok {
return rf(ctx, pchannelName)
}
if rf, ok := ret.Get(0).(func(context.Context, string) []*streamingpb.VChannelMeta); ok {
r0 = rf(ctx, pchannelName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*streamingpb.VChannelMeta)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, pchannelName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockStreamingNodeCataLog_ListVChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListVChannel'
type MockStreamingNodeCataLog_ListVChannel_Call struct {
*mock.Call
}
// ListVChannel is a helper method to define mock.On call
// - ctx context.Context
// - pchannelName string
func (_e *MockStreamingNodeCataLog_Expecter) ListVChannel(ctx interface{}, pchannelName interface{}) *MockStreamingNodeCataLog_ListVChannel_Call {
return &MockStreamingNodeCataLog_ListVChannel_Call{Call: _e.mock.On("ListVChannel", ctx, pchannelName)}
}
func (_c *MockStreamingNodeCataLog_ListVChannel_Call) Run(run func(ctx context.Context, pchannelName string)) *MockStreamingNodeCataLog_ListVChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockStreamingNodeCataLog_ListVChannel_Call) Return(_a0 []*streamingpb.VChannelMeta, _a1 error) *MockStreamingNodeCataLog_ListVChannel_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockStreamingNodeCataLog_ListVChannel_Call) RunAndReturn(run func(context.Context, string) ([]*streamingpb.VChannelMeta, error)) *MockStreamingNodeCataLog_ListVChannel_Call {
_c.Call.Return(run)
return _c
}
// SaveConsumeCheckpoint provides a mock function with given fields: ctx, pChannelName, checkpoint
func (_m *MockStreamingNodeCataLog) SaveConsumeCheckpoint(ctx context.Context, pChannelName string, checkpoint *streamingpb.WALCheckpoint) error {
ret := _m.Called(ctx, pChannelName, checkpoint)
@ -190,7 +249,7 @@ func (_c *MockStreamingNodeCataLog_SaveConsumeCheckpoint_Call) RunAndReturn(run
}
// SaveSegmentAssignments provides a mock function with given fields: ctx, pChannelName, infos
func (_m *MockStreamingNodeCataLog) SaveSegmentAssignments(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta) error {
func (_m *MockStreamingNodeCataLog) SaveSegmentAssignments(ctx context.Context, pChannelName string, infos map[int64]*streamingpb.SegmentAssignmentMeta) error {
ret := _m.Called(ctx, pChannelName, infos)
if len(ret) == 0 {
@ -198,7 +257,7 @@ func (_m *MockStreamingNodeCataLog) SaveSegmentAssignments(ctx context.Context,
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, []*streamingpb.SegmentAssignmentMeta) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, string, map[int64]*streamingpb.SegmentAssignmentMeta) error); ok {
r0 = rf(ctx, pChannelName, infos)
} else {
r0 = ret.Error(0)
@ -215,14 +274,14 @@ type MockStreamingNodeCataLog_SaveSegmentAssignments_Call struct {
// SaveSegmentAssignments is a helper method to define mock.On call
// - ctx context.Context
// - pChannelName string
// - infos []*streamingpb.SegmentAssignmentMeta
// - infos map[int64]*streamingpb.SegmentAssignmentMeta
func (_e *MockStreamingNodeCataLog_Expecter) SaveSegmentAssignments(ctx interface{}, pChannelName interface{}, infos interface{}) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call {
return &MockStreamingNodeCataLog_SaveSegmentAssignments_Call{Call: _e.mock.On("SaveSegmentAssignments", ctx, pChannelName, infos)}
}
func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) Run(run func(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta)) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call {
func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) Run(run func(ctx context.Context, pChannelName string, infos map[int64]*streamingpb.SegmentAssignmentMeta)) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].([]*streamingpb.SegmentAssignmentMeta))
run(args[0].(context.Context), args[1].(string), args[2].(map[int64]*streamingpb.SegmentAssignmentMeta))
})
return _c
}
@ -232,7 +291,55 @@ func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) Return(_a0 error
return _c
}
func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) RunAndReturn(run func(context.Context, string, []*streamingpb.SegmentAssignmentMeta) error) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call {
func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) RunAndReturn(run func(context.Context, string, map[int64]*streamingpb.SegmentAssignmentMeta) error) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call {
_c.Call.Return(run)
return _c
}
// SaveVChannels provides a mock function with given fields: ctx, pchannelName, vchannels
func (_m *MockStreamingNodeCataLog) SaveVChannels(ctx context.Context, pchannelName string, vchannels map[string]*streamingpb.VChannelMeta) error {
ret := _m.Called(ctx, pchannelName, vchannels)
if len(ret) == 0 {
panic("no return value specified for SaveVChannels")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, map[string]*streamingpb.VChannelMeta) error); ok {
r0 = rf(ctx, pchannelName, vchannels)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockStreamingNodeCataLog_SaveVChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveVChannels'
type MockStreamingNodeCataLog_SaveVChannels_Call struct {
*mock.Call
}
// SaveVChannels is a helper method to define mock.On call
// - ctx context.Context
// - pchannelName string
// - vchannels map[string]*streamingpb.VChannelMeta
func (_e *MockStreamingNodeCataLog_Expecter) SaveVChannels(ctx interface{}, pchannelName interface{}, vchannels interface{}) *MockStreamingNodeCataLog_SaveVChannels_Call {
return &MockStreamingNodeCataLog_SaveVChannels_Call{Call: _e.mock.On("SaveVChannels", ctx, pchannelName, vchannels)}
}
func (_c *MockStreamingNodeCataLog_SaveVChannels_Call) Run(run func(ctx context.Context, pchannelName string, vchannels map[string]*streamingpb.VChannelMeta)) *MockStreamingNodeCataLog_SaveVChannels_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(map[string]*streamingpb.VChannelMeta))
})
return _c
}
func (_c *MockStreamingNodeCataLog_SaveVChannels_Call) Return(_a0 error) *MockStreamingNodeCataLog_SaveVChannels_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockStreamingNodeCataLog_SaveVChannels_Call) RunAndReturn(run func(context.Context, string, map[string]*streamingpb.VChannelMeta) error) *MockStreamingNodeCataLog_SaveVChannels_Call {
_c.Call.Return(run)
return _c
}

View File

@ -5,13 +5,8 @@ package mock_handler
import (
context "context"
consumer "github.com/milvus-io/milvus/internal/streamingnode/client/handler/consumer"
handler "github.com/milvus-io/milvus/internal/streamingnode/client/handler"
mock "github.com/stretchr/testify/mock"
producer "github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer"
)
// MockHandlerClient is an autogenerated mock type for the HandlerClient type
@ -60,23 +55,23 @@ func (_c *MockHandlerClient_Close_Call) RunAndReturn(run func()) *MockHandlerCli
}
// CreateConsumer provides a mock function with given fields: ctx, opts
func (_m *MockHandlerClient) CreateConsumer(ctx context.Context, opts *handler.ConsumerOptions) (consumer.Consumer, error) {
func (_m *MockHandlerClient) CreateConsumer(ctx context.Context, opts *handler.ConsumerOptions) (handler.Consumer, error) {
ret := _m.Called(ctx, opts)
if len(ret) == 0 {
panic("no return value specified for CreateConsumer")
}
var r0 consumer.Consumer
var r0 handler.Consumer
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *handler.ConsumerOptions) (consumer.Consumer, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, *handler.ConsumerOptions) (handler.Consumer, error)); ok {
return rf(ctx, opts)
}
if rf, ok := ret.Get(0).(func(context.Context, *handler.ConsumerOptions) consumer.Consumer); ok {
if rf, ok := ret.Get(0).(func(context.Context, *handler.ConsumerOptions) handler.Consumer); ok {
r0 = rf(ctx, opts)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(consumer.Consumer)
r0 = ret.Get(0).(handler.Consumer)
}
}
@ -108,34 +103,34 @@ func (_c *MockHandlerClient_CreateConsumer_Call) Run(run func(ctx context.Contex
return _c
}
func (_c *MockHandlerClient_CreateConsumer_Call) Return(_a0 consumer.Consumer, _a1 error) *MockHandlerClient_CreateConsumer_Call {
func (_c *MockHandlerClient_CreateConsumer_Call) Return(_a0 handler.Consumer, _a1 error) *MockHandlerClient_CreateConsumer_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockHandlerClient_CreateConsumer_Call) RunAndReturn(run func(context.Context, *handler.ConsumerOptions) (consumer.Consumer, error)) *MockHandlerClient_CreateConsumer_Call {
func (_c *MockHandlerClient_CreateConsumer_Call) RunAndReturn(run func(context.Context, *handler.ConsumerOptions) (handler.Consumer, error)) *MockHandlerClient_CreateConsumer_Call {
_c.Call.Return(run)
return _c
}
// CreateProducer provides a mock function with given fields: ctx, opts
func (_m *MockHandlerClient) CreateProducer(ctx context.Context, opts *handler.ProducerOptions) (producer.Producer, error) {
func (_m *MockHandlerClient) CreateProducer(ctx context.Context, opts *handler.ProducerOptions) (handler.Producer, error) {
ret := _m.Called(ctx, opts)
if len(ret) == 0 {
panic("no return value specified for CreateProducer")
}
var r0 producer.Producer
var r0 handler.Producer
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *handler.ProducerOptions) (producer.Producer, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, *handler.ProducerOptions) (handler.Producer, error)); ok {
return rf(ctx, opts)
}
if rf, ok := ret.Get(0).(func(context.Context, *handler.ProducerOptions) producer.Producer); ok {
if rf, ok := ret.Get(0).(func(context.Context, *handler.ProducerOptions) handler.Producer); ok {
r0 = rf(ctx, opts)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(producer.Producer)
r0 = ret.Get(0).(handler.Producer)
}
}
@ -167,12 +162,12 @@ func (_c *MockHandlerClient_CreateProducer_Call) Run(run func(ctx context.Contex
return _c
}
func (_c *MockHandlerClient_CreateProducer_Call) Return(_a0 producer.Producer, _a1 error) *MockHandlerClient_CreateProducer_Call {
func (_c *MockHandlerClient_CreateProducer_Call) Return(_a0 handler.Producer, _a1 error) *MockHandlerClient_CreateProducer_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockHandlerClient_CreateProducer_Call) RunAndReturn(run func(context.Context, *handler.ProducerOptions) (producer.Producer, error)) *MockHandlerClient_CreateProducer_Call {
func (_c *MockHandlerClient_CreateProducer_Call) RunAndReturn(run func(context.Context, *handler.ProducerOptions) (handler.Producer, error)) *MockHandlerClient_CreateProducer_Call {
_c.Call.Return(run)
return _c
}

View File

@ -27,23 +27,23 @@ func (_m *MockWAL) EXPECT() *MockWAL_Expecter {
}
// Append provides a mock function with given fields: ctx, msg
func (_m *MockWAL) Append(ctx context.Context, msg message.MutableMessage) (*types.AppendResult, error) {
func (_m *MockWAL) Append(ctx context.Context, msg message.MutableMessage) (*wal.AppendResult, error) {
ret := _m.Called(ctx, msg)
if len(ret) == 0 {
panic("no return value specified for Append")
}
var r0 *types.AppendResult
var r0 *wal.AppendResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) (*types.AppendResult, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) (*wal.AppendResult, error)); ok {
return rf(ctx, msg)
}
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) *types.AppendResult); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) *wal.AppendResult); ok {
r0 = rf(ctx, msg)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.AppendResult)
r0 = ret.Get(0).(*wal.AppendResult)
}
}
@ -75,18 +75,18 @@ func (_c *MockWAL_Append_Call) Run(run func(ctx context.Context, msg message.Mut
return _c
}
func (_c *MockWAL_Append_Call) Return(_a0 *types.AppendResult, _a1 error) *MockWAL_Append_Call {
func (_c *MockWAL_Append_Call) Return(_a0 *wal.AppendResult, _a1 error) *MockWAL_Append_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockWAL_Append_Call) RunAndReturn(run func(context.Context, message.MutableMessage) (*types.AppendResult, error)) *MockWAL_Append_Call {
func (_c *MockWAL_Append_Call) RunAndReturn(run func(context.Context, message.MutableMessage) (*wal.AppendResult, error)) *MockWAL_Append_Call {
_c.Call.Return(run)
return _c
}
// AppendAsync provides a mock function with given fields: ctx, msg, cb
func (_m *MockWAL) AppendAsync(ctx context.Context, msg message.MutableMessage, cb func(*types.AppendResult, error)) {
func (_m *MockWAL) AppendAsync(ctx context.Context, msg message.MutableMessage, cb func(*wal.AppendResult, error)) {
_m.Called(ctx, msg, cb)
}
@ -98,14 +98,14 @@ type MockWAL_AppendAsync_Call struct {
// AppendAsync is a helper method to define mock.On call
// - ctx context.Context
// - msg message.MutableMessage
// - cb func(*types.AppendResult , error)
// - cb func(*wal.AppendResult , error)
func (_e *MockWAL_Expecter) AppendAsync(ctx interface{}, msg interface{}, cb interface{}) *MockWAL_AppendAsync_Call {
return &MockWAL_AppendAsync_Call{Call: _e.mock.On("AppendAsync", ctx, msg, cb)}
}
func (_c *MockWAL_AppendAsync_Call) Run(run func(ctx context.Context, msg message.MutableMessage, cb func(*types.AppendResult, error))) *MockWAL_AppendAsync_Call {
func (_c *MockWAL_AppendAsync_Call) Run(run func(ctx context.Context, msg message.MutableMessage, cb func(*wal.AppendResult, error))) *MockWAL_AppendAsync_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(func(*types.AppendResult, error)))
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(func(*wal.AppendResult, error)))
})
return _c
}
@ -115,7 +115,7 @@ func (_c *MockWAL_AppendAsync_Call) Return() *MockWAL_AppendAsync_Call {
return _c
}
func (_c *MockWAL_AppendAsync_Call) RunAndReturn(run func(context.Context, message.MutableMessage, func(*types.AppendResult, error))) *MockWAL_AppendAsync_Call {
func (_c *MockWAL_AppendAsync_Call) RunAndReturn(run func(context.Context, message.MutableMessage, func(*wal.AppendResult, error))) *MockWAL_AppendAsync_Call {
_c.Run(run)
return _c
}

View File

@ -5,6 +5,7 @@ package mock_interceptors
import (
context "context"
interceptors "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
@ -56,7 +57,7 @@ func (_c *MockInterceptor_Close_Call) RunAndReturn(run func()) *MockInterceptor_
}
// DoAppend provides a mock function with given fields: ctx, msg, append
func (_m *MockInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error) {
func (_m *MockInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (message.MessageID, error) {
ret := _m.Called(ctx, msg, append)
if len(ret) == 0 {
@ -65,10 +66,10 @@ func (_m *MockInterceptor) DoAppend(ctx context.Context, msg message.MutableMess
var r0 message.MessageID
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, interceptors.Append) (message.MessageID, error)); ok {
return rf(ctx, msg, append)
}
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) message.MessageID); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, interceptors.Append) message.MessageID); ok {
r0 = rf(ctx, msg, append)
} else {
if ret.Get(0) != nil {
@ -76,7 +77,7 @@ func (_m *MockInterceptor) DoAppend(ctx context.Context, msg message.MutableMess
}
}
if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage, interceptors.Append) error); ok {
r1 = rf(ctx, msg, append)
} else {
r1 = ret.Error(1)
@ -93,14 +94,14 @@ type MockInterceptor_DoAppend_Call struct {
// DoAppend is a helper method to define mock.On call
// - ctx context.Context
// - msg message.MutableMessage
// - append func(context.Context , message.MutableMessage)(message.MessageID , error)
// - append interceptors.Append
func (_e *MockInterceptor_Expecter) DoAppend(ctx interface{}, msg interface{}, append interface{}) *MockInterceptor_DoAppend_Call {
return &MockInterceptor_DoAppend_Call{Call: _e.mock.On("DoAppend", ctx, msg, append)}
}
func (_c *MockInterceptor_DoAppend_Call) Run(run func(ctx context.Context, msg message.MutableMessage, append func(context.Context, message.MutableMessage) (message.MessageID, error))) *MockInterceptor_DoAppend_Call {
func (_c *MockInterceptor_DoAppend_Call) Run(run func(ctx context.Context, msg message.MutableMessage, append interceptors.Append)) *MockInterceptor_DoAppend_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(func(context.Context, message.MutableMessage) (message.MessageID, error)))
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(interceptors.Append))
})
return _c
}
@ -110,7 +111,7 @@ func (_c *MockInterceptor_DoAppend_Call) Return(_a0 message.MessageID, _a1 error
return _c
}
func (_c *MockInterceptor_DoAppend_Call) RunAndReturn(run func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error)) *MockInterceptor_DoAppend_Call {
func (_c *MockInterceptor_DoAppend_Call) RunAndReturn(run func(context.Context, message.MutableMessage, interceptors.Append) (message.MessageID, error)) *MockInterceptor_DoAppend_Call {
_c.Call.Return(run)
return _c
}

View File

@ -5,6 +5,7 @@ package mock_interceptors
import (
context "context"
interceptors "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
@ -56,7 +57,7 @@ func (_c *MockInterceptorWithMetrics_Close_Call) RunAndReturn(run func()) *MockI
}
// DoAppend provides a mock function with given fields: ctx, msg, append
func (_m *MockInterceptorWithMetrics) DoAppend(ctx context.Context, msg message.MutableMessage, append func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error) {
func (_m *MockInterceptorWithMetrics) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (message.MessageID, error) {
ret := _m.Called(ctx, msg, append)
if len(ret) == 0 {
@ -65,10 +66,10 @@ func (_m *MockInterceptorWithMetrics) DoAppend(ctx context.Context, msg message.
var r0 message.MessageID
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, interceptors.Append) (message.MessageID, error)); ok {
return rf(ctx, msg, append)
}
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) message.MessageID); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, interceptors.Append) message.MessageID); ok {
r0 = rf(ctx, msg, append)
} else {
if ret.Get(0) != nil {
@ -76,7 +77,7 @@ func (_m *MockInterceptorWithMetrics) DoAppend(ctx context.Context, msg message.
}
}
if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage, interceptors.Append) error); ok {
r1 = rf(ctx, msg, append)
} else {
r1 = ret.Error(1)
@ -93,14 +94,14 @@ type MockInterceptorWithMetrics_DoAppend_Call struct {
// DoAppend is a helper method to define mock.On call
// - ctx context.Context
// - msg message.MutableMessage
// - append func(context.Context , message.MutableMessage)(message.MessageID , error)
// - append interceptors.Append
func (_e *MockInterceptorWithMetrics_Expecter) DoAppend(ctx interface{}, msg interface{}, append interface{}) *MockInterceptorWithMetrics_DoAppend_Call {
return &MockInterceptorWithMetrics_DoAppend_Call{Call: _e.mock.On("DoAppend", ctx, msg, append)}
}
func (_c *MockInterceptorWithMetrics_DoAppend_Call) Run(run func(ctx context.Context, msg message.MutableMessage, append func(context.Context, message.MutableMessage) (message.MessageID, error))) *MockInterceptorWithMetrics_DoAppend_Call {
func (_c *MockInterceptorWithMetrics_DoAppend_Call) Run(run func(ctx context.Context, msg message.MutableMessage, append interceptors.Append)) *MockInterceptorWithMetrics_DoAppend_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(func(context.Context, message.MutableMessage) (message.MessageID, error)))
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(interceptors.Append))
})
return _c
}
@ -110,7 +111,7 @@ func (_c *MockInterceptorWithMetrics_DoAppend_Call) Return(_a0 message.MessageID
return _c
}
func (_c *MockInterceptorWithMetrics_DoAppend_Call) RunAndReturn(run func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error)) *MockInterceptorWithMetrics_DoAppend_Call {
func (_c *MockInterceptorWithMetrics_DoAppend_Call) RunAndReturn(run func(context.Context, message.MutableMessage, interceptors.Append) (message.MessageID, error)) *MockInterceptorWithMetrics_DoAppend_Call {
_c.Call.Return(run)
return _c
}

View File

@ -5,6 +5,7 @@ package mock_interceptors
import (
context "context"
interceptors "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
@ -56,7 +57,7 @@ func (_c *MockInterceptorWithReady_Close_Call) RunAndReturn(run func()) *MockInt
}
// DoAppend provides a mock function with given fields: ctx, msg, append
func (_m *MockInterceptorWithReady) DoAppend(ctx context.Context, msg message.MutableMessage, append func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error) {
func (_m *MockInterceptorWithReady) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (message.MessageID, error) {
ret := _m.Called(ctx, msg, append)
if len(ret) == 0 {
@ -65,10 +66,10 @@ func (_m *MockInterceptorWithReady) DoAppend(ctx context.Context, msg message.Mu
var r0 message.MessageID
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, interceptors.Append) (message.MessageID, error)); ok {
return rf(ctx, msg, append)
}
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) message.MessageID); ok {
if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, interceptors.Append) message.MessageID); ok {
r0 = rf(ctx, msg, append)
} else {
if ret.Get(0) != nil {
@ -76,7 +77,7 @@ func (_m *MockInterceptorWithReady) DoAppend(ctx context.Context, msg message.Mu
}
}
if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage, interceptors.Append) error); ok {
r1 = rf(ctx, msg, append)
} else {
r1 = ret.Error(1)
@ -93,14 +94,14 @@ type MockInterceptorWithReady_DoAppend_Call struct {
// DoAppend is a helper method to define mock.On call
// - ctx context.Context
// - msg message.MutableMessage
// - append func(context.Context , message.MutableMessage)(message.MessageID , error)
// - append interceptors.Append
func (_e *MockInterceptorWithReady_Expecter) DoAppend(ctx interface{}, msg interface{}, append interface{}) *MockInterceptorWithReady_DoAppend_Call {
return &MockInterceptorWithReady_DoAppend_Call{Call: _e.mock.On("DoAppend", ctx, msg, append)}
}
func (_c *MockInterceptorWithReady_DoAppend_Call) Run(run func(ctx context.Context, msg message.MutableMessage, append func(context.Context, message.MutableMessage) (message.MessageID, error))) *MockInterceptorWithReady_DoAppend_Call {
func (_c *MockInterceptorWithReady_DoAppend_Call) Run(run func(ctx context.Context, msg message.MutableMessage, append interceptors.Append)) *MockInterceptorWithReady_DoAppend_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(func(context.Context, message.MutableMessage) (message.MessageID, error)))
run(args[0].(context.Context), args[1].(message.MutableMessage), args[2].(interceptors.Append))
})
return _c
}
@ -110,7 +111,7 @@ func (_c *MockInterceptorWithReady_DoAppend_Call) Return(_a0 message.MessageID,
return _c
}
func (_c *MockInterceptorWithReady_DoAppend_Call) RunAndReturn(run func(context.Context, message.MutableMessage, func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error)) *MockInterceptorWithReady_DoAppend_Call {
func (_c *MockInterceptorWithReady_DoAppend_Call) RunAndReturn(run func(context.Context, message.MutableMessage, interceptors.Append) (message.MessageID, error)) *MockInterceptorWithReady_DoAppend_Call {
_c.Call.Return(run)
return _c
}

View File

@ -23,7 +23,7 @@ func (_m *MockCSegment) EXPECT() *MockCSegment_Expecter {
}
// AddFieldDataInfo provides a mock function with given fields: ctx, request
func (_m *MockCSegment) AddFieldDataInfo(ctx context.Context, request *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error) {
func (_m *MockCSegment) AddFieldDataInfo(ctx context.Context, request *segcore.AddFieldDataInfoRequest) (*segcore.AddFieldDataInfoResult, error) {
ret := _m.Called(ctx, request)
if len(ret) == 0 {
@ -32,10 +32,10 @@ func (_m *MockCSegment) AddFieldDataInfo(ctx context.Context, request *segcore.L
var r0 *segcore.AddFieldDataInfoResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, *segcore.AddFieldDataInfoRequest) (*segcore.AddFieldDataInfoResult, error)); ok {
return rf(ctx, request)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) *segcore.AddFieldDataInfoResult); ok {
if rf, ok := ret.Get(0).(func(context.Context, *segcore.AddFieldDataInfoRequest) *segcore.AddFieldDataInfoResult); ok {
r0 = rf(ctx, request)
} else {
if ret.Get(0) != nil {
@ -43,7 +43,7 @@ func (_m *MockCSegment) AddFieldDataInfo(ctx context.Context, request *segcore.L
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.LoadFieldDataRequest) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, *segcore.AddFieldDataInfoRequest) error); ok {
r1 = rf(ctx, request)
} else {
r1 = ret.Error(1)
@ -59,14 +59,14 @@ type MockCSegment_AddFieldDataInfo_Call struct {
// AddFieldDataInfo is a helper method to define mock.On call
// - ctx context.Context
// - request *segcore.LoadFieldDataRequest
// - request *segcore.AddFieldDataInfoRequest
func (_e *MockCSegment_Expecter) AddFieldDataInfo(ctx interface{}, request interface{}) *MockCSegment_AddFieldDataInfo_Call {
return &MockCSegment_AddFieldDataInfo_Call{Call: _e.mock.On("AddFieldDataInfo", ctx, request)}
}
func (_c *MockCSegment_AddFieldDataInfo_Call) Run(run func(ctx context.Context, request *segcore.LoadFieldDataRequest)) *MockCSegment_AddFieldDataInfo_Call {
func (_c *MockCSegment_AddFieldDataInfo_Call) Run(run func(ctx context.Context, request *segcore.AddFieldDataInfoRequest)) *MockCSegment_AddFieldDataInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.LoadFieldDataRequest))
run(args[0].(context.Context), args[1].(*segcore.AddFieldDataInfoRequest))
})
return _c
}
@ -76,7 +76,7 @@ func (_c *MockCSegment_AddFieldDataInfo_Call) Return(_a0 *segcore.AddFieldDataIn
return _c
}
func (_c *MockCSegment_AddFieldDataInfo_Call) RunAndReturn(run func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error)) *MockCSegment_AddFieldDataInfo_Call {
func (_c *MockCSegment_AddFieldDataInfo_Call) RunAndReturn(run func(context.Context, *segcore.AddFieldDataInfoRequest) (*segcore.AddFieldDataInfoResult, error)) *MockCSegment_AddFieldDataInfo_Call {
_c.Call.Return(run)
return _c
}

View File

@ -156,7 +156,7 @@ func (_c *MockQueryHook_InitTuningConfig_Call) RunAndReturn(run func(map[string]
}
// Run provides a mock function with given fields: _a0
func (_m *MockQueryHook) Run(_a0 map[string]interface{}) error {
func (_m *MockQueryHook) Run(_a0 map[string]any) error {
ret := _m.Called(_a0)
if len(ret) == 0 {
@ -164,7 +164,7 @@ func (_m *MockQueryHook) Run(_a0 map[string]interface{}) error {
}
var r0 error
if rf, ok := ret.Get(0).(func(map[string]interface{}) error); ok {
if rf, ok := ret.Get(0).(func(map[string]any) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
@ -179,14 +179,14 @@ type MockQueryHook_Run_Call struct {
}
// Run is a helper method to define mock.On call
// - _a0 map[string]interface{}
// - _a0 map[string]any
func (_e *MockQueryHook_Expecter) Run(_a0 interface{}) *MockQueryHook_Run_Call {
return &MockQueryHook_Run_Call{Call: _e.mock.On("Run", _a0)}
}
func (_c *MockQueryHook_Run_Call) Run(run func(_a0 map[string]interface{})) *MockQueryHook_Run_Call {
func (_c *MockQueryHook_Run_Call) Run(run func(_a0 map[string]any)) *MockQueryHook_Run_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(map[string]interface{}))
run(args[0].(map[string]any))
})
return _c
}
@ -196,7 +196,7 @@ func (_c *MockQueryHook_Run_Call) Return(_a0 error) *MockQueryHook_Run_Call {
return _c
}
func (_c *MockQueryHook_Run_Call) RunAndReturn(run func(map[string]interface{}) error) *MockQueryHook_Run_Call {
func (_c *MockQueryHook_Run_Call) RunAndReturn(run func(map[string]any) error) *MockQueryHook_Run_Call {
_c.Call.Return(run)
return _c
}

View File

@ -11,11 +11,11 @@ import (
)
// MockService is an autogenerated mock type for the Service type
type MockService[T interface{}] struct {
type MockService[T any] struct {
mock.Mock
}
type MockService_Expecter[T interface{}] struct {
type MockService_Expecter[T any] struct {
mock *mock.Mock
}
@ -29,7 +29,7 @@ func (_m *MockService[T]) Close() {
}
// MockService_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockService_Close_Call[T interface{}] struct {
type MockService_Close_Call[T any] struct {
*mock.Call
}
@ -86,7 +86,7 @@ func (_m *MockService[T]) GetConn(ctx context.Context) (*grpc.ClientConn, error)
}
// MockService_GetConn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConn'
type MockService_GetConn_Call[T interface{}] struct {
type MockService_GetConn_Call[T any] struct {
*mock.Call
}
@ -144,7 +144,7 @@ func (_m *MockService[T]) GetService(ctx context.Context) (T, error) {
}
// MockService_GetService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetService'
type MockService_GetService_Call[T interface{}] struct {
type MockService_GetService_Call[T any] struct {
*mock.Call
}
@ -173,7 +173,7 @@ func (_c *MockService_GetService_Call[T]) RunAndReturn(run func(context.Context)
// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockService[T interface{}](t interface {
func NewMockService[T any](t interface {
mock.TestingT
Cleanup(func())
}) *MockService[T] {

View File

@ -5,7 +5,7 @@ package mock_resolver
import (
context "context"
discoverer "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
resolver "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
mock "github.com/stretchr/testify/mock"
)
@ -23,22 +23,22 @@ func (_m *MockResolver) EXPECT() *MockResolver_Expecter {
}
// GetLatestState provides a mock function with given fields: ctx
func (_m *MockResolver) GetLatestState(ctx context.Context) (discoverer.VersionedState, error) {
func (_m *MockResolver) GetLatestState(ctx context.Context) (resolver.VersionedState, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetLatestState")
}
var r0 discoverer.VersionedState
var r0 resolver.VersionedState
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (discoverer.VersionedState, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context) (resolver.VersionedState, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) discoverer.VersionedState); ok {
if rf, ok := ret.Get(0).(func(context.Context) resolver.VersionedState); ok {
r0 = rf(ctx)
} else {
r0 = ret.Get(0).(discoverer.VersionedState)
r0 = ret.Get(0).(resolver.VersionedState)
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
@ -68,18 +68,18 @@ func (_c *MockResolver_GetLatestState_Call) Run(run func(ctx context.Context)) *
return _c
}
func (_c *MockResolver_GetLatestState_Call) Return(_a0 discoverer.VersionedState, _a1 error) *MockResolver_GetLatestState_Call {
func (_c *MockResolver_GetLatestState_Call) Return(_a0 resolver.VersionedState, _a1 error) *MockResolver_GetLatestState_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockResolver_GetLatestState_Call) RunAndReturn(run func(context.Context) (discoverer.VersionedState, error)) *MockResolver_GetLatestState_Call {
func (_c *MockResolver_GetLatestState_Call) RunAndReturn(run func(context.Context) (resolver.VersionedState, error)) *MockResolver_GetLatestState_Call {
_c.Call.Return(run)
return _c
}
// Watch provides a mock function with given fields: ctx, cb
func (_m *MockResolver) Watch(ctx context.Context, cb func(discoverer.VersionedState) error) error {
func (_m *MockResolver) Watch(ctx context.Context, cb func(resolver.VersionedState) error) error {
ret := _m.Called(ctx, cb)
if len(ret) == 0 {
@ -87,7 +87,7 @@ func (_m *MockResolver) Watch(ctx context.Context, cb func(discoverer.VersionedS
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, func(discoverer.VersionedState) error) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, func(resolver.VersionedState) error) error); ok {
r0 = rf(ctx, cb)
} else {
r0 = ret.Error(0)
@ -103,14 +103,14 @@ type MockResolver_Watch_Call struct {
// Watch is a helper method to define mock.On call
// - ctx context.Context
// - cb func(discoverer.VersionedState) error
// - cb func(resolver.VersionedState) error
func (_e *MockResolver_Expecter) Watch(ctx interface{}, cb interface{}) *MockResolver_Watch_Call {
return &MockResolver_Watch_Call{Call: _e.mock.On("Watch", ctx, cb)}
}
func (_c *MockResolver_Watch_Call) Run(run func(ctx context.Context, cb func(discoverer.VersionedState) error)) *MockResolver_Watch_Call {
func (_c *MockResolver_Watch_Call) Run(run func(ctx context.Context, cb func(resolver.VersionedState) error)) *MockResolver_Watch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(func(discoverer.VersionedState) error))
run(args[0].(context.Context), args[1].(func(resolver.VersionedState) error))
})
return _c
}
@ -120,7 +120,7 @@ func (_c *MockResolver_Watch_Call) Return(_a0 error) *MockResolver_Watch_Call {
return _c
}
func (_c *MockResolver_Watch_Call) RunAndReturn(run func(context.Context, func(discoverer.VersionedState) error) error) *MockResolver_Watch_Call {
func (_c *MockResolver_Watch_Call) RunAndReturn(run func(context.Context, func(resolver.VersionedState) error) error) *MockResolver_Watch_Call {
_c.Call.Return(run)
return _c
}

View File

@ -28,18 +28,12 @@ import (
// flusherComponents is the components of the flusher.
type flusherComponents struct {
wal wal.WAL
broker broker.Broker
cpUpdater *util.ChannelCheckpointUpdater
chunkManager storage.ChunkManager
dataServices map[string]*dataSyncServiceWrapper
checkpointManager *pchannelCheckpointManager
logger *log.MLogger
}
// StartMessageID returns the start message id of the flusher after recovering.
func (impl *flusherComponents) StartMessageID() message.MessageID {
return impl.checkpointManager.StartMessageID()
wal wal.WAL
broker broker.Broker
cpUpdater *util.ChannelCheckpointUpdater
chunkManager storage.ChunkManager
dataServices map[string]*dataSyncServiceWrapper
logger *log.MLogger
}
// WhenCreateCollection handles the create collection message.
@ -109,7 +103,6 @@ func (impl *flusherComponents) WhenDropCollection(vchannel string) {
delete(impl.dataServices, vchannel)
impl.logger.Info("drop data sync service", zap.String("vchannel", vchannel))
}
impl.checkpointManager.DropVChannel(vchannel)
}
// HandleMessage handles the plain message.
@ -140,7 +133,6 @@ func (impl *flusherComponents) addNewDataSyncService(
input chan<- *msgstream.MsgPack,
ds *pipeline.DataSyncService,
) {
impl.checkpointManager.AddVChannel(createCollectionMsg.VChannel(), createCollectionMsg.LastConfirmedMessageID())
newDS := newDataSyncServiceWrapper(createCollectionMsg.VChannel(), input, ds)
newDS.Start()
impl.dataServices[createCollectionMsg.VChannel()] = newDS
@ -154,7 +146,6 @@ func (impl *flusherComponents) Close() {
impl.logger.Info("data sync service closed for flusher closing", zap.String("vchannel", vchannel))
}
impl.cpUpdater.Close()
impl.checkpointManager.Close()
}
// recover recover the components of the flusher.
@ -199,16 +190,15 @@ func (impl *flusherComponents) buildDataSyncServiceWithRetry(ctx context.Context
// Flush all the growing segment that is not created by streaming.
segmentIDs := make([]int64, 0, len(recoverInfo.GetInfo().UnflushedSegments))
for _, segment := range recoverInfo.GetInfo().UnflushedSegments {
if !segment.IsCreatedByStreaming {
segmentIDs = append(segmentIDs, segment.ID)
if segment.IsCreatedByStreaming {
continue
}
}
if len(segmentIDs) > 0 {
msg := message.NewFlushMessageBuilderV2().
WithVChannel(recoverInfo.GetInfo().GetChannelName()).
WithHeader(&message.FlushMessageHeader{
CollectionId: recoverInfo.GetInfo().GetCollectionID(),
SegmentIds: segmentIDs,
PartitionId: segment.PartitionID,
SegmentId: segment.ID,
}).
WithBody(&message.FlushMessageBody{}).MustBuildMutable()
if err := retry.Do(ctx, func() error {

View File

@ -29,9 +29,9 @@ import (
"github.com/milvus-io/milvus/internal/flushcommon/writebuffer"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -48,39 +48,64 @@ type msgHandlerImpl struct {
wbMgr writebuffer.BufferManager
}
func (impl *msgHandlerImpl) HandleCreateSegment(ctx context.Context, vchannel string, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
body, err := createSegmentMsg.Body()
if err != nil {
return errors.Wrap(err, "failed to get create segment message body")
func (impl *msgHandlerImpl) HandleCreateSegment(ctx context.Context, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
vchannel := createSegmentMsg.VChannel()
h := createSegmentMsg.Header()
if err := impl.createNewGrowingSegment(ctx, vchannel, h); err != nil {
return err
}
for _, segmentInfo := range body.GetSegments() {
if err := impl.wbMgr.CreateNewGrowingSegment(ctx, vchannel, segmentInfo.GetPartitionId(), segmentInfo.GetSegmentId()); err != nil {
log.Warn("fail to create new growing segment",
zap.String("vchannel", vchannel),
zap.Int64("partition_id", segmentInfo.GetPartitionId()),
zap.Int64("segment_id", segmentInfo.GetSegmentId()))
return err
}
log.Info("create new growing segment",
zap.String("vchannel", vchannel),
zap.Int64("partition_id", segmentInfo.GetPartitionId()),
zap.Int64("segment_id", segmentInfo.GetSegmentId()),
zap.Int64("storage_version", segmentInfo.GetStorageVersion()))
logger := log.With(log.FieldMessage(createSegmentMsg))
if err := impl.wbMgr.CreateNewGrowingSegment(ctx, vchannel, h.PartitionId, h.SegmentId); err != nil {
logger.Warn("fail to create new growing segment")
return err
}
log.Info("create new growing segment")
return nil
}
func (impl *msgHandlerImpl) HandleFlush(vchannel string, flushMsg message.ImmutableFlushMessageV2) error {
if err := impl.wbMgr.SealSegments(context.Background(), vchannel, flushMsg.Header().SegmentIds); err != nil {
func (impl *msgHandlerImpl) createNewGrowingSegment(ctx context.Context, vchannel string, h *message.CreateSegmentMessageHeader) error {
// Transfer the pending segment into growing state.
// Alloc the growing segment at datacoord first.
mix, err := resource.Resource().MixCoordClient().GetWithContext(ctx)
if err != nil {
return err
}
logger := log.With(zap.Int64("collectionID", h.CollectionId), zap.Int64("partitionID", h.PartitionId), zap.Int64("segmentID", h.SegmentId))
return retry.Do(ctx, func() (err error) {
resp, err := mix.AllocSegment(ctx, &datapb.AllocSegmentRequest{
CollectionId: h.CollectionId,
PartitionId: h.PartitionId,
SegmentId: h.SegmentId,
Vchannel: vchannel,
StorageVersion: h.StorageVersion,
IsCreatedByStreaming: true,
})
if err := merr.CheckRPCCall(resp, err); err != nil {
logger.Warn("failed to alloc growing segment at datacoord")
return errors.Wrap(err, "failed to alloc growing segment at datacoord")
}
logger.Info("alloc growing segment at datacoord")
return nil
}, retry.AttemptAlways())
}
func (impl *msgHandlerImpl) HandleFlush(flushMsg message.ImmutableFlushMessageV2) error {
vchannel := flushMsg.VChannel()
if err := impl.wbMgr.SealSegments(context.Background(), vchannel, []int64{flushMsg.Header().SegmentId}); err != nil {
return errors.Wrap(err, "failed to seal segments")
}
return nil
}
func (impl *msgHandlerImpl) HandleManualFlush(vchannel string, flushMsg message.ImmutableManualFlushMessageV2) error {
if err := impl.wbMgr.FlushChannel(context.Background(), vchannel, flushMsg.Header().GetFlushTs()); err != nil {
return errors.Wrap(err, "failed to flush channel")
func (impl *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error {
vchannel := flushMsg.VChannel()
if err := impl.wbMgr.SealSegments(context.Background(), vchannel, flushMsg.Header().SegmentIds); err != nil {
return errors.Wrap(err, "failed to seal segments")
}
if err := impl.wbMgr.FlushChannel(context.Background(), vchannel, flushMsg.Header().FlushTs); err != nil {
return errors.Wrap(err, "failed to flush channel")
} // may be redundant.
broadcastID := flushMsg.BroadcastHeader().BroadcastID
if broadcastID == 0 {
return nil
@ -91,9 +116,11 @@ func (impl *msgHandlerImpl) HandleManualFlush(vchannel string, flushMsg message.
})
}
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, vchannel string, msg *adaptor.SchemaChangeMessageBody) error {
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error {
vchannel := msg.VChannel()
impl.wbMgr.SealSegments(context.Background(), msg.VChannel(), msg.Header().FlushedSegmentIds)
return streaming.WAL().Broadcast().Ack(ctx, types.BroadcastAckRequest{
BroadcastID: msg.BroadcastID,
BroadcastID: msg.BroadcastHeader().BroadcastID,
VChannel: vchannel,
})
}

View File

@ -39,7 +39,7 @@ func TestFlushMsgHandler_HandleFlush(t *testing.T) {
WithVChannel(vchannel).
WithHeader(&message.FlushMessageHeader{
CollectionId: 0,
SegmentIds: []int64{1, 2, 3},
SegmentId: 1,
}).
WithBody(&message.FlushMessageBody{}).
BuildMutable()
@ -49,7 +49,7 @@ func TestFlushMsgHandler_HandleFlush(t *testing.T) {
msgID := mock_message.NewMockMessageID(t)
im, err := message.AsImmutableFlushMessageV2(msg.IntoImmutableMessage(msgID))
assert.NoError(t, err)
err = handler.HandleFlush(vchannel, im)
err = handler.HandleFlush(im)
assert.Error(t, err)
// test normal
@ -57,7 +57,7 @@ func TestFlushMsgHandler_HandleFlush(t *testing.T) {
wbMgr.EXPECT().SealSegments(mock.Anything, mock.Anything, mock.Anything).Return(nil)
handler = newMsgHandler(wbMgr)
err = handler.HandleFlush(vchannel, im)
err = handler.HandleFlush(im)
assert.NoError(t, err)
}
@ -66,6 +66,7 @@ func TestFlushMsgHandler_HandleManualFlush(t *testing.T) {
// test failed
wbMgr := writebuffer.NewMockBufferManager(t)
wbMgr.EXPECT().SealSegments(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock err"))
wbMgr.EXPECT().FlushChannel(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock err"))
msg, err := message.NewManualFlushMessageBuilderV2().
@ -82,14 +83,18 @@ func TestFlushMsgHandler_HandleManualFlush(t *testing.T) {
msgID := mock_message.NewMockMessageID(t)
im, err := message.AsImmutableManualFlushMessageV2(msg.IntoImmutableMessage(msgID))
assert.NoError(t, err)
err = handler.HandleManualFlush(vchannel, im)
err = handler.HandleManualFlush(im)
assert.Error(t, err)
wbMgr.EXPECT().SealSegments(mock.Anything, mock.Anything, mock.Anything).Unset()
wbMgr.EXPECT().SealSegments(mock.Anything, mock.Anything, mock.Anything).Return(nil)
err = handler.HandleManualFlush(im)
assert.Error(t, err)
// test normal
wbMgr = writebuffer.NewMockBufferManager(t)
wbMgr.EXPECT().FlushChannel(mock.Anything, mock.Anything, mock.Anything).Unset()
wbMgr.EXPECT().FlushChannel(mock.Anything, mock.Anything, mock.Anything).Return(nil)
handler = newMsgHandler(wbMgr)
err = handler.HandleManualFlush(vchannel, im)
err = handler.HandleManualFlush(im)
assert.NoError(t, err)
}

View File

@ -1,168 +0,0 @@
package flusherimpl
import (
"context"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// recoverPChannelCheckpointManager recovers the pchannel checkpoint manager from the catalog
func recoverPChannelCheckpointManager(
ctx context.Context,
walName string,
pchannel string,
checkpoints map[string]message.MessageID,
) (*pchannelCheckpointManager, error) {
vchannelManager := newVChannelCheckpointManager(checkpoints)
checkpoint, err := resource.Resource().StreamingNodeCatalog().GetConsumeCheckpoint(ctx, pchannel)
if err != nil {
return nil, err
}
var startMessageID message.MessageID
var previous message.MessageID
if checkpoint != nil {
startMessageID = message.MustUnmarshalMessageID(walName, checkpoint.MessageID.Id)
previous = startMessageID
} else {
startMessageID = vchannelManager.MinimumCheckpoint()
}
u := &pchannelCheckpointManager{
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
cond: syncutil.NewContextCond(&sync.Mutex{}),
pchannel: pchannel,
vchannelManager: vchannelManager,
startMessageID: startMessageID,
logger: resource.Resource().Logger().With(zap.String("pchannel", pchannel), log.FieldComponent("checkpoint-updater")),
}
go u.background(previous)
return u, nil
}
// pchannelCheckpointManager is the struct to update the checkpoint of a pchannel
type pchannelCheckpointManager struct {
notifier *syncutil.AsyncTaskNotifier[struct{}]
cond *syncutil.ContextCond
pchannel string
vchannelManager *vchannelCheckpointManager
startMessageID message.MessageID
logger *log.MLogger
}
// StartMessageID returns the start message checkpoint of current recovery
func (m *pchannelCheckpointManager) StartMessageID() message.MessageID {
return m.startMessageID
}
// Update updates the checkpoint of a vchannel
func (m *pchannelCheckpointManager) Update(vchannel string, checkpoint message.MessageID) {
m.cond.L.Lock()
defer m.cond.L.Unlock()
oldMinimum := m.vchannelManager.MinimumCheckpoint()
err := m.vchannelManager.Update(vchannel, checkpoint)
if err != nil {
m.logger.Warn("failed to update vchannel checkpoint", zap.String("vchannel", vchannel), zap.Error(err))
return
}
if newMinimum := m.vchannelManager.MinimumCheckpoint(); oldMinimum == nil || oldMinimum.LT(newMinimum) {
// if the minimum checkpoint is updated, notify the background goroutine to update the pchannel checkpoint
m.cond.UnsafeBroadcast()
}
}
// AddVChannel adds a vchannel to the pchannel
func (m *pchannelCheckpointManager) AddVChannel(vchannel string, checkpoint message.MessageID) {
m.cond.LockAndBroadcast()
defer m.cond.L.Unlock()
if err := m.vchannelManager.Add(vchannel, checkpoint); err != nil {
m.logger.Warn("failed to add vchannel checkpoint", zap.String("vchannel", vchannel), zap.Error(err))
}
m.logger.Info("add vchannel checkpoint", zap.String("vchannel", vchannel), zap.Stringer("checkpoint", checkpoint))
}
// DropVChannel drops a vchannel from the pchannel
func (m *pchannelCheckpointManager) DropVChannel(vchannel string) {
m.cond.LockAndBroadcast()
defer m.cond.L.Unlock()
if err := m.vchannelManager.Drop(vchannel); err != nil {
m.logger.Warn("failed to drop vchannel checkpoint", zap.String("vchannel", vchannel), zap.Error(err))
return
}
m.logger.Info("drop vchannel checkpoint", zap.String("vchannel", vchannel))
}
func (m *pchannelCheckpointManager) background(previous message.MessageID) {
defer func() {
m.notifier.Finish(struct{}{})
m.logger.Info("pchannel checkpoint updater is closed")
}()
previousStr := "nil"
if previous != nil {
previousStr = previous.String()
}
m.logger.Info("pchannel checkpoint updater started", zap.String("previous", previousStr))
backoff := backoff.NewExponentialBackOff()
backoff.InitialInterval = 100 * time.Millisecond
backoff.MaxInterval = 10 * time.Second
backoff.MaxElapsedTime = 0
for {
current, err := m.blockUntilCheckpointUpdate(previous)
if err != nil {
return
}
if previous == nil || previous.LT(current) {
err := resource.Resource().StreamingNodeCatalog().SaveConsumeCheckpoint(m.notifier.Context(), m.pchannel, &streamingpb.WALCheckpoint{
MessageID: &messagespb.MessageID{Id: current.Marshal()},
})
if err != nil {
nextInterval := backoff.NextBackOff()
m.logger.Warn("failed to update pchannel checkpoint", zap.Stringer("checkpoint", current), zap.Duration("nextRetryInterval", nextInterval), zap.Error(err))
select {
case <-time.After(nextInterval):
continue
case <-m.notifier.Context().Done():
return
}
}
backoff.Reset()
previous = current
m.logger.Debug("update pchannel checkpoint", zap.Stringer("current", current))
}
}
}
// blockUntilCheckpointUpdate blocks until the checkpoint of the pchannel is updated
func (m *pchannelCheckpointManager) blockUntilCheckpointUpdate(previous message.MessageID) (message.MessageID, error) {
m.cond.L.Lock()
// block until following conditions are met:
// there is at least one vchannel, and minimum checkpoint of all vchannels is greater than previous.
// if the previous is nil, block until there is at least one vchannel.
for m.vchannelManager.Len() == 0 || (previous != nil && m.vchannelManager.MinimumCheckpoint().LTE(previous)) {
if err := m.cond.Wait(m.notifier.Context()); err != nil {
return nil, err
}
}
minimum := m.vchannelManager.MinimumCheckpoint()
m.cond.L.Unlock()
return minimum, nil
}
// Close closes the pchannel checkpoint updater
func (m *pchannelCheckpointManager) Close() {
m.notifier.Cancel()
m.notifier.BlockUntilFinish()
}

View File

@ -1,56 +0,0 @@
package flusherimpl
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
func TestPChannelCheckpointManager(t *testing.T) {
snMeta := mock_metastore.NewMockStreamingNodeCataLog(t)
resource.InitForTest(t, resource.OptStreamingNodeCatalog(snMeta))
snMeta.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(&streamingpb.WALCheckpoint{
MessageID: &messagespb.MessageID{Id: rmq.NewRmqID(0).Marshal()},
}, nil)
minimumOne := atomic.NewPointer[message.MessageID](nil)
snMeta.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pchannel string, ckpt *streamingpb.WALCheckpoint) error {
id, _ := message.UnmarshalMessageID("rocksmq", ckpt.MessageID.Id)
minimumOne.Store(&id)
return nil
})
exists, vchannel, minimum := generateRandomExistsMessageID()
p, err := recoverPChannelCheckpointManager(context.Background(), "rocksmq", "test", exists)
assert.True(t, p.StartMessageID().EQ(rmq.NewRmqID(0)))
assert.NoError(t, err)
assert.NotNil(t, p)
assert.Eventually(t, func() bool {
newMinimum := minimumOne.Load()
return newMinimum != nil && (*newMinimum).EQ(minimum)
}, 10*time.Second, 10*time.Millisecond)
p.AddVChannel("vchannel-999", rmq.NewRmqID(1000000))
p.DropVChannel("vchannel-1000")
for _, vchannel := range vchannel {
p.Update(vchannel, rmq.NewRmqID(1000001))
}
assert.Eventually(t, func() bool {
newMinimum := minimumOne.Load()
return !(*newMinimum).EQ(minimum)
}, 10*time.Second, 10*time.Millisecond)
p.Close()
}

View File

@ -47,7 +47,7 @@ func (impl *WALFlusherImpl) getVchannels(ctx context.Context, pchannel string) (
}
// getRecoveryInfos gets the recovery info of the vchannels from datacoord
func (impl *WALFlusherImpl) getRecoveryInfos(ctx context.Context, vchannel []string) (map[string]*datapb.GetChannelRecoveryInfoResponse, map[string]message.MessageID, error) {
func (impl *WALFlusherImpl) getRecoveryInfos(ctx context.Context, vchannel []string) (map[string]*datapb.GetChannelRecoveryInfoResponse, message.MessageID, error) {
futures := make([]*conc.Future[interface{}], 0, len(vchannel))
for _, v := range vchannel {
v := v
@ -69,11 +69,15 @@ func (impl *WALFlusherImpl) getRecoveryInfos(ctx context.Context, vchannel []str
}
return nil, nil, errors.Wrapf(err, "when get recovery info of vchannel %s", vchannel[i])
}
messageIDs := make(map[string]message.MessageID, len(recoveryInfos))
for v, info := range recoveryInfos {
messageIDs[v] = adaptor.MustGetMessageIDFromMQWrapperIDBytes(impl.wal.Get().WALName(), info.GetInfo().GetSeekPosition().GetMsgID())
var checkpoint message.MessageID
for _, info := range recoveryInfos {
messageID := adaptor.MustGetMessageIDFromMQWrapperIDBytes(impl.wal.Get().WALName(), info.GetInfo().GetSeekPosition().GetMsgID())
if checkpoint == nil || messageID.LT(checkpoint) {
checkpoint = messageID
}
}
return recoveryInfos, messageIDs, nil
return recoveryInfos, checkpoint, nil
}
// getRecoveryInfo gets the recovery info of the vchannel.

View File

@ -1,138 +0,0 @@
package flusherimpl
import (
"container/heap"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
var (
errVChannelAlreadyExists = errors.New("vchannel already exists")
errVChannelNotFound = errors.New("vchannel not found")
errRollbackCheckpoint = errors.New("rollback a checkpoint is not allow")
)
// newVChannelCheckpointManager creates a new vchannelCheckpointManager
func newVChannelCheckpointManager(exists map[string]message.MessageID) *vchannelCheckpointManager {
index := make(map[string]*vchannelCheckpoint)
checkpointHeap := make(vchannelCheckpointHeap, 0, len(exists))
for vchannel, checkpoint := range exists {
index[vchannel] = &vchannelCheckpoint{
vchannel: vchannel,
checkpoint: checkpoint,
index: len(checkpointHeap),
}
checkpointHeap = append(checkpointHeap, index[vchannel])
}
heap.Init(&checkpointHeap)
return &vchannelCheckpointManager{
checkpointHeap: checkpointHeap,
index: index,
}
}
// vchannelCheckpointManager is the struct to manage the checkpoints of all vchannels at one pchannel
type vchannelCheckpointManager struct {
checkpointHeap vchannelCheckpointHeap
index map[string]*vchannelCheckpoint
}
// Add adds a vchannel with a checkpoint to the manager
func (m *vchannelCheckpointManager) Add(vchannel string, checkpoint message.MessageID) error {
if _, ok := m.index[vchannel]; ok {
return errVChannelAlreadyExists
}
vc := &vchannelCheckpoint{
vchannel: vchannel,
checkpoint: checkpoint,
}
heap.Push(&m.checkpointHeap, vc)
m.index[vchannel] = vc
return nil
}
// Drop removes a vchannel from the manager
func (m *vchannelCheckpointManager) Drop(vchannel string) error {
vc, ok := m.index[vchannel]
if !ok {
return errVChannelNotFound
}
heap.Remove(&m.checkpointHeap, vc.index)
delete(m.index, vchannel)
return nil
}
// Update updates the checkpoint of a vchannel
func (m *vchannelCheckpointManager) Update(vchannel string, checkpoint message.MessageID) error {
previous, ok := m.index[vchannel]
if !ok {
return errVChannelNotFound
}
if checkpoint.LT(previous.checkpoint) {
return errors.Wrapf(errRollbackCheckpoint, "checkpoint: %s, previous: %s", checkpoint, previous.checkpoint)
}
if checkpoint.EQ(previous.checkpoint) {
return nil
}
m.checkpointHeap.Update(previous, checkpoint)
return nil
}
// Len returns the number of vchannels
func (m *vchannelCheckpointManager) Len() int {
return len(m.checkpointHeap)
}
// MinimumCheckpoint returns the minimum checkpoint of all vchannels
func (m *vchannelCheckpointManager) MinimumCheckpoint() message.MessageID {
if len(m.checkpointHeap) == 0 {
return nil
}
return m.checkpointHeap[0].checkpoint
}
// vchannelCheckpoint is the struct to hold the checkpoint of a vchannel
type vchannelCheckpoint struct {
vchannel string
checkpoint message.MessageID
index int
}
// A vchannelCheckpointHeap implements heap.Interface and holds Items.
type vchannelCheckpointHeap []*vchannelCheckpoint
func (pq vchannelCheckpointHeap) Len() int { return len(pq) }
func (pq vchannelCheckpointHeap) Less(i, j int) bool {
return pq[i].checkpoint.LT(pq[j].checkpoint)
}
func (pq vchannelCheckpointHeap) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq *vchannelCheckpointHeap) Push(x any) {
n := len(*pq)
item := x.(*vchannelCheckpoint)
item.index = n
*pq = append(*pq, item)
}
func (pq *vchannelCheckpointHeap) Pop() any {
old := *pq
n := len(old)
item := old[n-1]
old[n-1] = nil // don't stop the GC from reclaiming the item eventually
item.index = -1 // for safety
*pq = old[0 : n-1]
return item
}
func (pq *vchannelCheckpointHeap) Update(item *vchannelCheckpoint, checkpoint message.MessageID) {
item.checkpoint = checkpoint
heap.Fix(pq, item.index)
}

View File

@ -1,84 +0,0 @@
package flusherimpl
import (
"fmt"
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
func TestVChannelCheckpointManager(t *testing.T) {
exists, vchannels, minimumX := generateRandomExistsMessageID()
m := newVChannelCheckpointManager(exists)
assert.True(t, m.MinimumCheckpoint().EQ(minimumX))
err := m.Add("vchannel-999", rmq.NewRmqID(1000000))
assert.Error(t, err)
assert.True(t, m.MinimumCheckpoint().EQ(minimumX))
err = m.Drop("vchannel-1000")
assert.Error(t, err)
assert.True(t, m.MinimumCheckpoint().EQ(minimumX))
err = m.Update("vchannel-1000", rmq.NewRmqID(1000001))
assert.Error(t, err)
assert.True(t, m.MinimumCheckpoint().EQ(minimumX))
err = m.Add("vchannel-1000", rmq.NewRmqID(1000001))
assert.NoError(t, err)
assert.True(t, m.MinimumCheckpoint().EQ(minimumX))
for _, vchannel := range vchannels {
err = m.Update(vchannel, rmq.NewRmqID(1000001))
assert.NoError(t, err)
}
assert.False(t, m.MinimumCheckpoint().EQ(minimumX))
err = m.Update(vchannels[0], minimumX)
assert.Error(t, err)
err = m.Drop("vchannel-501")
assert.NoError(t, err)
lastMinimum := m.MinimumCheckpoint()
for i := 0; i < 1001; i++ {
m.Update(fmt.Sprintf("vchannel-%d", i), rmq.NewRmqID(rand.Int63n(9999999)+2))
newMinimum := m.MinimumCheckpoint()
assert.True(t, lastMinimum.LTE(newMinimum))
lastMinimum = newMinimum
}
for i := 0; i < 1001; i++ {
m.Drop(fmt.Sprintf("vchannel-%d", i))
newMinimum := m.MinimumCheckpoint()
if newMinimum != nil {
assert.True(t, lastMinimum.LTE(newMinimum))
lastMinimum = newMinimum
}
}
assert.Len(t, m.index, 0)
assert.Len(t, m.checkpointHeap, 0)
assert.Equal(t, m.Len(), 0)
assert.Nil(t, m.MinimumCheckpoint())
}
func generateRandomExistsMessageID() (map[string]message.MessageID, []string, message.MessageID) {
minimumX := int64(10000000)
var vchannel []string
exists := make(map[string]message.MessageID)
for i := 0; i < 1000; i++ {
x := rand.Int63n(999999) + 2
exists[fmt.Sprintf("vchannel-%d", i)] = rmq.NewRmqID(x)
if x < minimumX {
minimumX = x
vchannel = []string{fmt.Sprintf("vchannel-%d", i)}
} else if x == minimumX {
vchannel = append(vchannel, fmt.Sprintf("vchannel-%d", i))
}
}
vchannel = append(vchannel, "vchannel-1")
exists["vchannel-1"] = rmq.NewRmqID(minimumX)
return exists, vchannel, rmq.NewRmqID(minimumX)
}

View File

@ -6,7 +6,6 @@ import (
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/flushcommon/broker"
"github.com/milvus-io/milvus/internal/flushcommon/util"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
@ -66,13 +65,14 @@ func (impl *WALFlusherImpl) Execute() (err error) {
}
impl.logger.Info("wal ready for flusher recovery")
impl.flusherComponents, err = impl.buildFlusherComponents(impl.notifier.Context(), l)
var checkpoint message.MessageID
impl.flusherComponents, checkpoint, err = impl.buildFlusherComponents(impl.notifier.Context(), l)
if err != nil {
return errors.Wrap(err, "when build flusher components")
}
defer impl.flusherComponents.Close()
scanner, err := impl.generateScanner(impl.notifier.Context(), impl.wal.Get())
scanner, err := impl.generateScanner(impl.notifier.Context(), impl.wal.Get(), checkpoint)
if err != nil {
return errors.Wrap(err, "when generate scanner")
}
@ -108,27 +108,27 @@ func (impl *WALFlusherImpl) Close() {
}
// buildFlusherComponents builds the components of the flusher.
func (impl *WALFlusherImpl) buildFlusherComponents(ctx context.Context, l wal.WAL) (*flusherComponents, error) {
func (impl *WALFlusherImpl) buildFlusherComponents(ctx context.Context, l wal.WAL) (*flusherComponents, message.MessageID, error) {
// Get all existed vchannels of the pchannel.
vchannels, err := impl.getVchannels(ctx, l.Channel().Name)
if err != nil {
impl.logger.Warn("get vchannels failed", zap.Error(err))
return nil, err
return nil, nil, err
}
impl.logger.Info("fetch vchannel done", zap.Int("vchannelNum", len(vchannels)))
// Get all the recovery info of the recoverable vchannels.
recoverInfos, checkpoints, err := impl.getRecoveryInfos(ctx, vchannels)
recoverInfos, checkpoint, err := impl.getRecoveryInfos(ctx, vchannels)
if err != nil {
impl.logger.Warn("get recovery info failed", zap.Error(err))
return nil, err
return nil, nil, err
}
impl.logger.Info("fetch recovery info done", zap.Int("recoveryInfoNum", len(recoverInfos)))
mixc, err := resource.Resource().MixCoordClient().GetWithContext(ctx)
if err != nil {
impl.logger.Warn("flusher recovery is canceled before data coord client ready", zap.Error(err))
return nil, err
return nil, nil, err
}
impl.logger.Info("data coord client ready")
@ -136,52 +136,42 @@ func (impl *WALFlusherImpl) buildFlusherComponents(ctx context.Context, l wal.WA
broker := broker.NewCoordBroker(mixc, paramtable.GetNodeID())
chunkManager := resource.Resource().ChunkManager()
pm, err := recoverPChannelCheckpointManager(ctx, l.WALName(), l.Channel().Name, checkpoints)
if err != nil {
impl.logger.Warn("recover pchannel checkpoint manager failure", zap.Error(err))
return nil, err
}
cpUpdater := util.NewChannelCheckpointUpdaterWithCallback(broker, func(mp *msgpb.MsgPosition) {
// After vchannel checkpoint updated, notify the pchannel checkpoint manager to work.
pm.Update(mp.ChannelName, adaptor.MustGetMessageIDFromMQWrapperIDBytes(l.WALName(), mp.MsgID))
})
cpUpdater := util.NewChannelCheckpointUpdater(broker)
go cpUpdater.Start()
fc := &flusherComponents{
wal: l,
broker: broker,
cpUpdater: cpUpdater,
chunkManager: chunkManager,
dataServices: make(map[string]*dataSyncServiceWrapper),
checkpointManager: pm,
logger: impl.logger,
wal: l,
broker: broker,
cpUpdater: cpUpdater,
chunkManager: chunkManager,
dataServices: make(map[string]*dataSyncServiceWrapper),
logger: impl.logger,
}
impl.logger.Info("flusher components intiailizing done")
if err := fc.recover(ctx, recoverInfos); err != nil {
impl.logger.Warn("flusher recovery is canceled before recovery done, recycle the resource", zap.Error(err))
fc.Close()
impl.logger.Info("flusher recycle the resource done")
return nil, err
return nil, nil, err
}
impl.logger.Info("flusher recovery done")
return fc, nil
return fc, checkpoint, nil
}
// generateScanner create a new scanner for the wal.
func (impl *WALFlusherImpl) generateScanner(ctx context.Context, l wal.WAL) (wal.Scanner, error) {
func (impl *WALFlusherImpl) generateScanner(ctx context.Context, l wal.WAL, checkpoint message.MessageID) (wal.Scanner, error) {
handler := make(adaptor.ChanMessageHandler, 64)
readOpt := wal.ReadOption{
VChannel: "", // We need consume all message from wal.
MesasgeHandler: handler,
DeliverPolicy: options.DeliverPolicyAll(),
}
if startMessageID := impl.flusherComponents.StartMessageID(); startMessageID != nil {
impl.logger.Info("wal start to scan from minimum checkpoint", zap.Stringer("startMessageID", startMessageID))
// !!! we always set the deliver policy to start from the last confirmed message id.
// because the catchup scanner at the streamingnode server must see the last confirmed message id if it's the last timetick.
readOpt.DeliverPolicy = options.DeliverPolicyStartFrom(startMessageID)
if checkpoint != nil {
impl.logger.Info("wal start to scan from minimum checkpoint", zap.Stringer("checkpointMessageID", checkpoint))
readOpt.DeliverPolicy = options.DeliverPolicyStartFrom(checkpoint)
} else {
impl.logger.Info("wal start to scan from the earliest checkpoint")
}
impl.logger.Info("wal start to scan from the beginning")
return l.Read(ctx, readOpt)
}

View File

@ -14,7 +14,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/mock_storage"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
@ -60,15 +59,17 @@ func TestWALFlusher(t *testing.T) {
},
},
}, nil)
snMeta := mock_metastore.NewMockStreamingNodeCataLog(t)
snMeta.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(nil, nil)
snMeta.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil)
mixcoord.EXPECT().AllocSegment(mock.Anything, mock.Anything).Return(&datapb.AllocSegmentResponse{
Status: merr.Status(nil),
}, nil)
mixcoord.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(&datapb.DropVirtualChannelResponse{
Status: merr.Status(nil),
}, nil)
fMixcoord := syncutil.NewFuture[internaltypes.MixCoordClient]()
fMixcoord.Set(mixcoord)
resource.InitForTest(
t,
resource.OptMixCoordClient(fMixcoord),
resource.OptStreamingNodeCatalog(snMeta),
resource.OptChunkManager(mock_storage.NewMockChunkManager(t)),
)
l := newMockWAL(t, false)

View File

@ -223,7 +223,7 @@ func (s *scannerAdaptorImpl) handleUpstream(msg message.ImmutableMessage) {
s.metrics.ObserveTimeTickViolation(isTailing, msg.MessageType())
}
s.logger.Warn("failed to push message into reorder buffer",
zap.Object("message", msg),
log.FieldMessage(msg),
zap.Bool("tailing", isTailing),
zap.Error(err))
}

View File

@ -15,7 +15,6 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/metricsutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
@ -239,23 +238,22 @@ func (m *partitionSegmentManager) allocNewGrowingSegment(ctx context.Context) (*
if err := merr.CheckRPCCall(resp, err); err != nil {
return nil, errors.Wrap(err, "failed to alloc growing segment at datacoord")
}
// Getnerate growing segment limitation.
limitation := policy.GetSegmentLimitationPolicy().GenerateLimitation()
msg, err := message.NewCreateSegmentMessageBuilderV2().
WithVChannel(pendingSegment.GetVChannel()).
WithHeader(&message.CreateSegmentMessageHeader{
CollectionId: pendingSegment.GetCollectionID(),
SegmentIds: []int64{pendingSegment.GetSegmentID()},
// We only execute one segment creation operation at a time.
// But in future, we need to modify the segment creation operation to support batch creation.
// Because the partition-key based collection may create huge amount of segments at the same time.
PartitionId: pendingSegment.GetPartitionID(),
SegmentId: pendingSegment.GetSegmentID(),
StorageVersion: pendingSegment.GetStorageVersion(),
MaxSegmentSize: limitation.SegmentSize,
}).
WithBody(&message.CreateSegmentMessageBody{
CollectionId: pendingSegment.GetCollectionID(),
Segments: []*messagespb.CreateSegmentInfo{{
// We only execute one segment creation operation at a time.
// But in future, we need to modify the segment creation operation to support batch creation.
// Because the partition-key based collection may create huge amount of segments at the same time.
PartitionId: pendingSegment.GetPartitionID(),
SegmentId: pendingSegment.GetSegmentID(),
StorageVersion: pendingSegment.GetStorageVersion(),
}},
}).BuildMutable()
WithBody(&message.CreateSegmentMessageBody{}).BuildMutable()
if err != nil {
return nil, errors.Wrapf(err, "failed to create new segment message, segmentID: %d", pendingSegment.GetSegmentID())
}
@ -265,9 +263,6 @@ func (m *partitionSegmentManager) allocNewGrowingSegment(ctx context.Context) (*
return nil, errors.Wrapf(err, "failed to send create segment message into wal, segmentID: %d", pendingSegment.GetSegmentID())
}
// Getnerate growing segment limitation.
limitation := policy.GetSegmentLimitationPolicy().GenerateLimitation()
// Commit it into streaming node meta.
// growing segment can be assigned now.
tx := pendingSegment.BeginModification()

View File

@ -273,7 +273,7 @@ func (m *PChannelSegmentAllocManager) Close(ctx context.Context) {
})
// Try to seal the dirty segment to avoid generate too large segment.
protoSegments := make([]*streamingpb.SegmentAssignmentMeta, 0, len(segments))
protoSegments := make(map[int64]*streamingpb.SegmentAssignmentMeta, len(segments))
growingCnt := 0
for _, segment := range segments {
if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING {
@ -281,7 +281,7 @@ func (m *PChannelSegmentAllocManager) Close(ctx context.Context) {
}
if segment.IsDirtyEnough() {
// Only persist the dirty segment.
protoSegments = append(protoSegments, segment.Snapshot())
protoSegments[segment.GetSegmentID()] = segment.Snapshot()
}
}
m.logger.Info("segment assignment manager save all dirty segment assignments info",

View File

@ -285,11 +285,11 @@ func TestCreateAndDropCollection(t *testing.T) {
func newStat(insertedBinarySize uint64, maxBinarySize uint64) *streamingpb.SegmentAssignmentStat {
return &streamingpb.SegmentAssignmentStat{
MaxBinarySize: maxBinarySize,
InsertedRows: insertedBinarySize,
InsertedBinarySize: insertedBinarySize,
CreateTimestampNanoseconds: time.Now().UnixNano(),
LastModifiedTimestampNanoseconds: time.Now().UnixNano(),
MaxBinarySize: maxBinarySize,
InsertedRows: insertedBinarySize,
InsertedBinarySize: insertedBinarySize,
CreateTimestamp: time.Now().Unix(),
LastModifiedTimestamp: time.Now().Unix(),
}
}

View File

@ -113,12 +113,13 @@ func (q *sealQueue) tryToSealSegments(ctx context.Context, segments ...*segmentA
// send flush message into wal.
for collectionID, vchannelSegments := range sealedSegments {
for vchannel, segments := range vchannelSegments {
if err := q.sendFlushSegmentsMessageIntoWAL(ctx, collectionID, vchannel, segments); err != nil {
q.logger.Warn("fail to send flush message into wal", zap.String("vchannel", vchannel), zap.Int64("collectionID", collectionID), zap.Error(err))
undone = append(undone, segments...)
continue
}
for _, segment := range segments {
if err := q.sendFlushSegmentsMessageIntoWAL(ctx, collectionID, vchannel, segment); err != nil {
q.logger.Warn("fail to send flush message into wal", zap.String("vchannel", vchannel), zap.Int64("collectionID", collectionID), zap.Error(err))
undone = append(undone, segments...)
continue
}
tx := segment.BeginModification()
tx.IntoFlushed()
if err := tx.Commit(ctx); err != nil {
@ -202,16 +203,13 @@ func (q *sealQueue) transferSegmentStateIntoSealed(ctx context.Context, segments
}
// sendFlushSegmentsMessageIntoWAL sends a flush message into wal.
func (m *sealQueue) sendFlushSegmentsMessageIntoWAL(ctx context.Context, collectionID int64, vchannel string, segments []*segmentAllocManager) error {
segmentIDs := make([]int64, 0, len(segments))
for _, segment := range segments {
segmentIDs = append(segmentIDs, segment.GetSegmentID())
}
func (m *sealQueue) sendFlushSegmentsMessageIntoWAL(ctx context.Context, collectionID int64, vchannel string, segment *segmentAllocManager) error {
msg, err := message.NewFlushMessageBuilderV2().
WithVChannel(vchannel).
WithHeader(&message.FlushMessageHeader{
CollectionId: collectionID,
SegmentIds: segmentIDs,
PartitionId: segment.GetPartitionID(),
SegmentId: segment.GetSegmentID(),
}).
WithBody(&message.FlushMessageBody{}).BuildMutable()
if err != nil {
@ -220,9 +218,9 @@ func (m *sealQueue) sendFlushSegmentsMessageIntoWAL(ctx context.Context, collect
msgID, err := m.wal.Get().Append(ctx, msg)
if err != nil {
m.logger.Warn("send flush message into wal failed", zap.Int64("collectionID", collectionID), zap.String("vchannel", vchannel), zap.Int64s("segmentIDs", segmentIDs), zap.Error(err))
m.logger.Warn("send flush message into wal failed", zap.Int64("collectionID", collectionID), zap.String("vchannel", vchannel), zap.Int64("segmentID", segment.GetSegmentID()), zap.Error(err))
return err
}
m.logger.Info("send flush message into wal", zap.Int64("collectionID", collectionID), zap.String("vchannel", vchannel), zap.Int64s("segmentIDs", segmentIDs), zap.Any("msgID", msgID))
m.logger.Info("send flush message into wal", zap.Int64("collectionID", collectionID), zap.String("vchannel", vchannel), zap.Int64("segmentID", segment.GetSegmentID()), zap.Any("msgID", msgID))
return nil
}

View File

@ -228,8 +228,8 @@ func (s *segmentAllocManager) persistStatsIfTooDirty(ctx context.Context) {
if s.dirtyBytes < dirtyThreshold {
return
}
if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, s.pchannel.Name, []*streamingpb.SegmentAssignmentMeta{
s.Snapshot(),
if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, s.pchannel.Name, map[int64]*streamingpb.SegmentAssignmentMeta{
s.GetSegmentID(): s.Snapshot(),
}); err != nil {
log.Warn("failed to persist stats of segment", zap.Int64("segmentID", s.GetSegmentID()), zap.Error(err))
}
@ -267,10 +267,10 @@ func (m *mutableSegmentAssignmentMeta) IntoGrowing(limitation *policy.SegmentLim
m.modifiedCopy.State = streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING
now := time.Now().UnixNano()
m.modifiedCopy.Stat = &streamingpb.SegmentAssignmentStat{
MaxBinarySize: limitation.SegmentSize,
CreateTimestampNanoseconds: now,
LastModifiedTimestampNanoseconds: now,
CreateSegmentTimeTick: createSegmentTimeTick,
MaxBinarySize: limitation.SegmentSize,
CreateTimestamp: now,
LastModifiedTimestamp: now,
CreateSegmentTimeTick: createSegmentTimeTick,
}
}
@ -293,8 +293,8 @@ func (m *mutableSegmentAssignmentMeta) IntoFlushed() {
// Commit commits the modification.
func (m *mutableSegmentAssignmentMeta) Commit(ctx context.Context) error {
if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, m.original.pchannel.Name, []*streamingpb.SegmentAssignmentMeta{
m.modifiedCopy,
if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, m.original.pchannel.Name, map[int64]*streamingpb.SegmentAssignmentMeta{
m.modifiedCopy.SegmentId: m.modifiedCopy,
}); err != nil {
return err
}

View File

@ -29,9 +29,9 @@ func NewSegmentStatFromProto(statProto *streamingpb.SegmentAssignmentStat) *Segm
BinarySize: statProto.InsertedBinarySize,
},
MaxBinarySize: statProto.MaxBinarySize,
CreateTime: time.Unix(0, statProto.CreateTimestampNanoseconds),
CreateTime: time.Unix(statProto.CreateTimestamp, 0),
BinLogCounter: statProto.BinlogCounter,
LastModifiedTime: time.Unix(0, statProto.LastModifiedTimestampNanoseconds),
LastModifiedTime: time.Unix(statProto.LastModifiedTimestamp, 0),
}
}
@ -41,12 +41,12 @@ func NewProtoFromSegmentStat(stat *SegmentStats) *streamingpb.SegmentAssignmentS
return nil
}
return &streamingpb.SegmentAssignmentStat{
MaxBinarySize: stat.MaxBinarySize,
InsertedRows: stat.Insert.Rows,
InsertedBinarySize: stat.Insert.BinarySize,
CreateTimestampNanoseconds: stat.CreateTime.UnixNano(),
BinlogCounter: stat.BinLogCounter,
LastModifiedTimestampNanoseconds: stat.LastModifiedTime.UnixNano(),
MaxBinarySize: stat.MaxBinarySize,
InsertedRows: stat.Insert.Rows,
InsertedBinarySize: stat.Insert.BinarySize,
CreateTimestamp: stat.CreateTime.Unix(),
BinlogCounter: stat.BinLogCounter,
LastModifiedTimestamp: stat.LastModifiedTime.Unix(),
}
}

View File

@ -23,16 +23,16 @@ func TestStatsConvention(t *testing.T) {
assert.Equal(t, stat.MaxBinarySize, pb.MaxBinarySize)
assert.Equal(t, stat.Insert.Rows, pb.InsertedRows)
assert.Equal(t, stat.Insert.BinarySize, pb.InsertedBinarySize)
assert.Equal(t, stat.CreateTime.UnixNano(), pb.CreateTimestampNanoseconds)
assert.Equal(t, stat.LastModifiedTime.UnixNano(), pb.LastModifiedTimestampNanoseconds)
assert.Equal(t, stat.CreateTime.Unix(), pb.CreateTimestamp)
assert.Equal(t, stat.LastModifiedTime.Unix(), pb.LastModifiedTimestamp)
assert.Equal(t, stat.BinLogCounter, pb.BinlogCounter)
stat2 := NewSegmentStatFromProto(pb)
assert.Equal(t, stat.MaxBinarySize, stat2.MaxBinarySize)
assert.Equal(t, stat.Insert.Rows, stat2.Insert.Rows)
assert.Equal(t, stat.Insert.BinarySize, stat2.Insert.BinarySize)
assert.Equal(t, stat.CreateTime.UnixNano(), stat2.CreateTime.UnixNano())
assert.Equal(t, stat.LastModifiedTime.UnixNano(), stat2.LastModifiedTime.UnixNano())
assert.Equal(t, stat.CreateTime.Unix(), stat2.CreateTime.Unix())
assert.Equal(t, stat.LastModifiedTime.Unix(), stat2.LastModifiedTime.Unix())
assert.Equal(t, stat.BinLogCounter, stat2.BinLogCounter)
}

View File

@ -6,6 +6,7 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
@ -64,7 +65,7 @@ func (m *AppendMetrics) StartAppendGuard() *AppendMetricsGuard {
// IntoLogFields convert the metrics to log fields.
func (m *AppendMetrics) IntoLogFields() []zap.Field {
fields := []zap.Field{
zap.Object("message", m.msg),
log.FieldMessage(m.msg),
zap.Duration("append_duration", m.appendDuration),
zap.Duration("impl_append_duration", m.implAppendDuration),
}

View File

@ -33,7 +33,7 @@ func NewWriteMetrics(pchannel types.PChannelInfo, walName string) *WriteMetrics
slowLogThreshold = time.Second
}
if walName == wp.WALName && slowLogThreshold < 3*time.Second {
// slow log threshold is not set in woodpecker, so we set it to 0.
// woodpecker wal is always slow, so we need to set a higher threshold by default.
slowLogThreshold = 3 * time.Second
}
return &WriteMetrics{

View File

@ -0,0 +1,49 @@
package recovery
import (
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
const (
recoveryMagicStreamingInitialized int64 = 1 // the vchannel info is set into the catalog.
// the checkpoint is set into the catalog.
)
// newWALCheckpointFromProto creates a new WALCheckpoint from a protobuf message.
func newWALCheckpointFromProto(walName string, cp *streamingpb.WALCheckpoint) *WALCheckpoint {
return &WALCheckpoint{
MessageID: message.MustUnmarshalMessageID(walName, cp.MessageId.Id),
TimeTick: cp.TimeTick,
Magic: cp.RecoveryMagic,
}
}
// WALCheckpoint represents a consume checkpoint in the Write-Ahead Log (WAL).
type WALCheckpoint struct {
MessageID message.MessageID
TimeTick uint64
Magic int64
}
// IntoProto converts the WALCheckpoint to a protobuf message.
func (c *WALCheckpoint) IntoProto() *streamingpb.WALCheckpoint {
cp := &streamingpb.WALCheckpoint{
MessageId: &messagespb.MessageID{
Id: c.MessageID.Marshal(),
},
TimeTick: c.TimeTick,
RecoveryMagic: c.Magic,
}
return cp
}
// Clone creates a new WALCheckpoint with the same values as the original.
func (c *WALCheckpoint) Clone() *WALCheckpoint {
return &WALCheckpoint{
MessageID: c.MessageID,
TimeTick: c.TimeTick,
Magic: c.Magic,
}
}

View File

@ -0,0 +1,39 @@
package recovery
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
func TestNewWALCheckpointFromProto(t *testing.T) {
walName := "rocksmq"
messageID := rmq.NewRmqID(1)
timeTick := uint64(12345)
recoveryMagic := int64(1)
protoCheckpoint := &streamingpb.WALCheckpoint{
MessageId: &messagespb.MessageID{Id: messageID.Marshal()},
TimeTick: timeTick,
RecoveryMagic: recoveryMagic,
}
checkpoint := newWALCheckpointFromProto(walName, protoCheckpoint)
assert.True(t, messageID.EQ(checkpoint.MessageID))
assert.Equal(t, timeTick, checkpoint.TimeTick)
assert.Equal(t, recoveryMagic, checkpoint.Magic)
proto := checkpoint.IntoProto()
checkpoint2 := newWALCheckpointFromProto(walName, proto)
assert.True(t, messageID.EQ(checkpoint2.MessageID))
assert.Equal(t, timeTick, checkpoint2.TimeTick)
assert.Equal(t, recoveryMagic, checkpoint2.Magic)
checkpoint3 := checkpoint.Clone()
assert.True(t, messageID.EQ(checkpoint3.MessageID))
assert.Equal(t, timeTick, checkpoint3.TimeTick)
assert.Equal(t, recoveryMagic, checkpoint3.Magic)
}

View File

@ -0,0 +1,46 @@
package recovery
import (
"time"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
// newConfig creates a new config for the recovery module.
func newConfig() *config {
params := paramtable.Get()
persistInterval := params.StreamingCfg.WALRecoveryPersistInterval.GetAsDurationByParse()
maxDirtyMessages := params.StreamingCfg.WALRecoveryMaxDirtyMessage.GetAsInt()
gracefulTimeout := params.StreamingCfg.WALRecoveryGracefulCloseTimeout.GetAsDurationByParse()
cfg := &config{
persistInterval: persistInterval,
maxDirtyMessages: maxDirtyMessages,
gracefulTimeout: gracefulTimeout,
}
if err := cfg.validate(); err != nil {
panic(err)
}
return cfg
}
// config is the configuration for the recovery module.
type config struct {
persistInterval time.Duration // persistInterval is the interval to persist the dirty recovery snapshot.
maxDirtyMessages int // maxDirtyMessages is the maximum number of dirty messages to be persisted.
gracefulTimeout time.Duration // gracefulTimeout is the timeout for graceful close of recovery module.
}
func (cfg *config) validate() error {
if cfg.persistInterval <= 0 {
return errors.New("persist interval must be greater than 0")
}
if cfg.maxDirtyMessages <= 0 {
return errors.New("max dirty messages must be greater than 0")
}
if cfg.gracefulTimeout <= 0 {
return errors.New("graceful timeout must be greater than 0")
}
return nil
}

View File

@ -0,0 +1,51 @@
package recovery
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func TestNewConfig(t *testing.T) {
// Mock paramtable values
paramtable.Init()
cfg := newConfig()
assert.Equal(t, 10*time.Second, cfg.persistInterval)
assert.Equal(t, 100, cfg.maxDirtyMessages)
assert.Equal(t, 3*time.Second, cfg.gracefulTimeout)
}
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
persistInterval time.Duration
maxDirtyMessages int
gracefulTimeout time.Duration
expectError bool
}{
{"ValidConfig", 10 * time.Second, 100, 5 * time.Second, false},
{"InvalidPersistInterval", 0, 100, 5 * time.Second, true},
{"InvalidMaxDirtyMessages", 10 * time.Second, 0, 5 * time.Second, true},
{"InvalidGracefulTimeout", 10 * time.Second, 100, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config{
persistInterval: tt.persistInterval,
maxDirtyMessages: tt.maxDirtyMessages,
gracefulTimeout: tt.gracefulTimeout,
}
err := cfg.validate()
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -0,0 +1,133 @@
package recovery
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
)
// TODO: !!! all recovery persist operation should be a compare-and-swap operation to
// promise there's only one consumer of wal.
// But currently, we don't implement the CAS operation of meta interface.
// Should be fixed in future.
func (rs *RecoveryStorage) backgroundTask() {
ticker := time.NewTicker(rs.cfg.persistInterval)
defer func() {
rs.Logger().Info("recovery storage background task on exit...")
ticker.Stop()
rs.persistDritySnapshotWhenClosing()
rs.backgroundTaskNotifier.Finish(struct{}{})
rs.Logger().Info("recovery storage background task exit")
}()
for {
select {
case <-rs.backgroundTaskNotifier.Context().Done():
return // exit the background task
case <-rs.persistNotifier:
case <-ticker.C:
}
snapshot := rs.consumeDirtySnapshot()
if err := rs.persistDirtySnapshot(rs.backgroundTaskNotifier.Context(), snapshot, zap.DebugLevel); err != nil {
return
}
}
}
// persistDritySnapshotWhenClosing persists the dirty snapshot when closing the recovery storage.
func (rs *RecoveryStorage) persistDritySnapshotWhenClosing() {
ctx, cancel := context.WithTimeout(context.Background(), rs.cfg.gracefulTimeout)
defer cancel()
snapshot := rs.consumeDirtySnapshot()
_ = rs.persistDirtySnapshot(ctx, snapshot, zap.InfoLevel)
}
// persistDirtySnapshot persists the dirty snapshot to the catalog.
func (rs *RecoveryStorage) persistDirtySnapshot(ctx context.Context, snapshot *RecoverySnapshot, lvl zapcore.Level) (err error) {
logger := rs.Logger().With(
zap.String("checkpoint", snapshot.Checkpoint.MessageID.String()),
zap.Uint64("checkpointTimeTick", snapshot.Checkpoint.TimeTick),
zap.Int("vchannelCount", len(snapshot.VChannels)),
zap.Int("segmentCount", len(snapshot.SegmentAssignments)),
)
defer func() {
if err != nil {
logger.Warn("failed to persist dirty snapshot", zap.Error(err))
return
}
logger.Log(lvl, "persist dirty snapshot")
}()
futures := make([]*conc.Future[struct{}], 0, 2)
if len(snapshot.SegmentAssignments) > 0 {
future := conc.Go(func() (struct{}, error) {
err := rs.retryOperationWithBackoff(ctx,
logger.With(zap.String("op", "persistSegmentAssignments"), zap.Int64s("segmentIds", lo.Keys(snapshot.SegmentAssignments))),
func(ctx context.Context) error {
return resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, rs.channel.Name, snapshot.SegmentAssignments)
})
return struct{}{}, err
})
futures = append(futures, future)
}
if len(snapshot.VChannels) > 0 {
future := conc.Go(func() (struct{}, error) {
err := rs.retryOperationWithBackoff(ctx,
logger.With(zap.String("op", "persistVChannels"), zap.Strings("vchannels", lo.Keys(snapshot.VChannels))),
func(ctx context.Context) error {
return resource.Resource().StreamingNodeCatalog().SaveVChannels(ctx, rs.channel.Name, snapshot.VChannels)
})
return struct{}{}, err
})
futures = append(futures, future)
}
if err := conc.BlockOnAll(futures...); err != nil {
return err
}
// checkpoint updates should always be persisted after other updates success.
return rs.retryOperationWithBackoff(ctx, rs.Logger().With(zap.String("op", "persistCheckpoint")), func(ctx context.Context) error {
return resource.Resource().StreamingNodeCatalog().
SaveConsumeCheckpoint(ctx, rs.channel.Name, snapshot.Checkpoint.IntoProto())
})
}
// retryOperationWithBackoff retries the operation with exponential backoff.
func (rs *RecoveryStorage) retryOperationWithBackoff(ctx context.Context, logger *log.MLogger, op func(ctx context.Context) error) error {
backoff := rs.newBackoff()
for {
err := op(ctx)
if err == nil {
return nil
}
if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) {
return err
}
nextInterval := backoff.NextBackOff()
logger.Warn("failed to persist operation, wait for retry...", zap.Duration("nextRetryInterval", nextInterval), zap.Error(err))
select {
case <-time.After(nextInterval):
case <-ctx.Done():
return ctx.Err()
}
}
}
// newBackoff creates a new backoff instance with the default settings.
func (rs *RecoveryStorage) newBackoff() *backoff.ExponentialBackOff {
backoff := backoff.NewExponentialBackOff()
backoff.InitialInterval = 10 * time.Millisecond
backoff.MaxInterval = 1 * time.Second
backoff.MaxElapsedTime = 0
return backoff
}

View File

@ -0,0 +1,133 @@
package recovery
import (
"context"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
// recoverRecoveryInfoFromMeta retrieves the recovery info for the given channel.
func (r *RecoveryStorage) recoverRecoveryInfoFromMeta(ctx context.Context, walName string, channelInfo types.PChannelInfo, lastTimeTickMessage message.ImmutableMessage) error {
r.SetLogger(resource.Resource().Logger().With(
log.FieldComponent(componentRecoveryStorage),
zap.String("channel", channelInfo.String()),
zap.String("state", recoveryStorageStatePersistRecovering),
))
catalog := resource.Resource().StreamingNodeCatalog()
cpProto, err := catalog.GetConsumeCheckpoint(ctx, channelInfo.Name)
if err != nil {
return errors.Wrap(err, "failed to get checkpoint from catalog")
}
if cpProto == nil {
// There's no checkpoint for current pchannel, so we need to initialize the recover info.
if cpProto, err = r.initializeRecoverInfo(ctx, channelInfo, lastTimeTickMessage); err != nil {
return errors.Wrap(err, "failed to initialize checkpoint")
}
}
r.checkpoint = newWALCheckpointFromProto(walName, cpProto)
r.Logger().Info("recover checkpoint done",
zap.String("checkpoint", r.checkpoint.MessageID.String()),
zap.Uint64("timetick", r.checkpoint.TimeTick),
zap.Int64("magic", r.checkpoint.Magic),
)
fVChannel := conc.Go(func() (struct{}, error) {
var err error
vchannels, err := catalog.ListVChannel(ctx, channelInfo.Name)
if err != nil {
return struct{}{}, errors.Wrap(err, "failed to get vchannel from catalog")
}
r.vchannels = newVChannelRecoveryInfoFromVChannelMeta(vchannels)
r.Logger().Info("recovery vchannel info done", zap.Int("vchannels", len(r.vchannels)))
return struct{}{}, nil
})
fSegment := conc.Go(func() (struct{}, error) {
var err error
segmentAssign, err := catalog.ListSegmentAssignment(ctx, channelInfo.Name)
if err != nil {
return struct{}{}, errors.Wrap(err, "failed to get segment assignment from catalog")
}
r.segments = newSegmentRecoveryInfoFromSegmentAssignmentMeta(segmentAssign)
r.Logger().Info("recover segment info done", zap.Int("segments", len(r.segments)))
return struct{}{}, nil
})
if err = conc.BlockOnAll(fVChannel, fSegment); err != nil {
return err
}
return conc.BlockOnAll(fVChannel, fSegment)
}
// initializeRecoverInfo initializes the recover info for the given channel.
// before first streaming service is enabled, there's no recovery info for channel.
// we should initialize the recover info for the channel.
// !!! This function will only call once for each channel when the streaming service is enabled.
func (r *RecoveryStorage) initializeRecoverInfo(ctx context.Context, channelInfo types.PChannelInfo, untilMessage message.ImmutableMessage) (*streamingpb.WALCheckpoint, error) {
// The message that is not generated by the streaming service is not managed by the recovery storage at streamingnode.
// So we ignore it, just use the global milvus metainfo to initialize the recovery storage.
// !!! It's not a strong guarantee that keep the consistency of old arch and new arch.
r.Logger().Info("checkpoint not found in catalog, may upgrading from old arch, initializing it...", log.FieldMessage(untilMessage))
coord, err := resource.Resource().MixCoordClient().GetWithContext(ctx)
if err != nil {
return nil, errors.Wrap(err, "when wait for rootcoord client ready")
}
resp, err := coord.GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{
Pchannel: channelInfo.Name,
})
if err = merr.CheckRPCCall(resp, err); err != nil {
return nil, errors.Wrap(err, "failed to get pchannel info from rootcoord")
}
// save the vchannel recovery info into the catalog
vchannels := make(map[string]*streamingpb.VChannelMeta, len(resp.GetCollections()))
for _, collection := range resp.GetCollections() {
partitions := make([]*streamingpb.PartitionInfoOfVChannel, 0, len(collection.Partitions))
for _, partition := range collection.Partitions {
partitions = append(partitions, &streamingpb.PartitionInfoOfVChannel{PartitionId: partition.PartitionId})
}
vchannels[collection.Vchannel] = &streamingpb.VChannelMeta{
Vchannel: collection.Vchannel,
State: streamingpb.VChannelState_VCHANNEL_STATE_NORMAL,
CollectionInfo: &streamingpb.CollectionInfoOfVChannel{
CollectionId: collection.CollectionId,
Partitions: partitions,
},
}
}
// SaveVChannels saves the vchannels into the catalog.
if err := resource.Resource().StreamingNodeCatalog().SaveVChannels(ctx, channelInfo.Name, vchannels); err != nil {
return nil, errors.Wrap(err, "failed to save vchannels to catalog")
}
// Use the first timesync message as the initial checkpoint.
checkpoint := &streamingpb.WALCheckpoint{
MessageId: &messagespb.MessageID{
Id: untilMessage.LastConfirmedMessageID().Marshal(),
},
TimeTick: untilMessage.TimeTick(),
RecoveryMagic: recoveryMagicStreamingInitialized,
}
if err := resource.Resource().StreamingNodeCatalog().SaveConsumeCheckpoint(ctx, channelInfo.Name, checkpoint); err != nil {
return nil, errors.Wrap(err, "failed to save checkpoint to catalog")
}
r.Logger().Info("initialize checkpoint done",
zap.Int("vchannels", len(vchannels)),
zap.String("checkpoint", checkpoint.MessageId.String()),
zap.Uint64("timetick", checkpoint.TimeTick),
zap.Int64("magic", checkpoint.RecoveryMagic),
)
return checkpoint, nil
}

View File

@ -0,0 +1,119 @@
package recovery
import (
"context"
"os"
"testing"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
internaltypes "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func TestMain(m *testing.M) {
// Initialize the paramtable package
paramtable.Init()
// Run the tests
code := m.Run()
if code != 0 {
os.Exit(code)
}
}
func TestInitRecoveryInfoFromMeta(t *testing.T) {
snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t)
snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return([]*streamingpb.SegmentAssignmentMeta{}, nil)
snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).Return([]*streamingpb.VChannelMeta{}, nil)
snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(
&streamingpb.WALCheckpoint{
MessageId: &messagespb.MessageID{
Id: rmq.NewRmqID(1).Marshal(),
},
TimeTick: 1,
RecoveryMagic: recoveryMagicStreamingInitialized,
}, nil)
resource.InitForTest(t, resource.OptStreamingNodeCatalog(snCatalog))
walName := "rocksmq"
channel := types.PChannelInfo{Name: "test_channel"}
lastConfirmed := message.CreateTestTimeTickSyncMessage(t, 1, 1, rmq.NewRmqID(1))
rs := newRecoveryStorage(channel)
err := rs.recoverRecoveryInfoFromMeta(context.Background(), walName, channel, lastConfirmed.IntoImmutableMessage(rmq.NewRmqID(1)))
assert.NoError(t, err)
assert.NotNil(t, rs.checkpoint)
assert.Equal(t, recoveryMagicStreamingInitialized, rs.checkpoint.Magic)
assert.True(t, rs.checkpoint.MessageID.EQ(rmq.NewRmqID(1)))
}
func TestInitRecoveryInfoFromCoord(t *testing.T) {
var initialedVChannels map[string]*streamingpb.VChannelMeta
snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t)
snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.SegmentAssignmentMeta, error) {
return []*streamingpb.SegmentAssignmentMeta{}, nil
})
snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.VChannelMeta, error) {
return lo.Values(initialedVChannels), nil
})
snCatalog.EXPECT().SaveVChannels(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[string]*streamingpb.VChannelMeta) error {
initialedVChannels = m
return nil
})
snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(nil, nil)
snCatalog.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil)
fc := syncutil.NewFuture[internaltypes.MixCoordClient]()
c := mocks.NewMockMixCoordClient(t)
c.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything).Return(&rootcoordpb.GetPChannelInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Collections: []*rootcoordpb.CollectionInfoOnPChannel{
{
CollectionId: 1,
Partitions: []*rootcoordpb.PartitionInfoOnPChannel{
{PartitionId: 1},
{PartitionId: 2},
},
Vchannel: "v1",
},
{
CollectionId: 2,
Partitions: []*rootcoordpb.PartitionInfoOnPChannel{
{PartitionId: 3},
{PartitionId: 4},
},
Vchannel: "v2",
},
},
}, nil)
fc.Set(c)
resource.InitForTest(t, resource.OptStreamingNodeCatalog(snCatalog), resource.OptMixCoordClient(fc))
walName := "rocksmq"
channel := types.PChannelInfo{Name: "test_channel"}
lastConfirmed := message.CreateTestTimeTickSyncMessage(t, 1, 1, rmq.NewRmqID(1))
rs := newRecoveryStorage(channel)
err := rs.recoverRecoveryInfoFromMeta(context.Background(), walName, channel, lastConfirmed.IntoImmutableMessage(rmq.NewRmqID(1)))
assert.NoError(t, err)
assert.NotNil(t, rs.checkpoint)
assert.Len(t, rs.vchannels, 2)
assert.Len(t, initialedVChannels, 2)
}

View File

@ -0,0 +1,53 @@
package recovery
import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
// RecoverySnapshot is the snapshot of the recovery info.
type RecoverySnapshot struct {
VChannels map[string]*streamingpb.VChannelMeta
SegmentAssignments map[int64]*streamingpb.SegmentAssignmentMeta
Checkpoint *WALCheckpoint
TxnBuffer *utility.TxnBuffer
}
type BuildRecoveryStreamParam struct {
StartCheckpoint message.MessageID
EndTimeTick uint64
}
// RecoveryStreamBuilder is an interface that is used to build a recovery stream from the WAL.
type RecoveryStreamBuilder interface {
// WALName returns the name of the WAL.
WALName() string
// Channel returns the channel info of wal.
Channel() types.PChannelInfo
// Build builds a recovery stream from the given channel info.
// The recovery stream will return the messages from the start checkpoint to the end time tick.
Build(param BuildRecoveryStreamParam) RecoveryStream
}
// RecoveryStream is an interface that is used to recover the recovery storage from the WAL.
type RecoveryStream interface {
// Chan returns the channel of the recovery stream.
// The channel is closed when the recovery stream is done.
Chan() <-chan message.ImmutableMessage
// Error should be called after the stream `Chan()` is consumed.
// It returns the error if the stream is not done.
// If the stream is full consumed, it returns nil.
Error() error
// TxnBuffer returns the uncommitted txn buffer after recovery stream is done.
// Can be only called the stream is drained and Error() return nil.
TxnBuffer() *utility.TxnBuffer
// Close closes the recovery stream.
Close() error
}

View File

@ -0,0 +1,376 @@
package recovery
import (
"context"
"sync"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
const (
componentRecoveryStorage = "recovery-storage"
recoveryStorageStatePersistRecovering = "persist-recovering"
recoveryStorageStateStreamRecovering = "stream-recovering"
recoveryStorageStateWorking = "working"
)
// RecoverRecoveryStorage creates a new recovery storage.
func RecoverRecoveryStorage(
ctx context.Context,
recoveryStreamBuilder RecoveryStreamBuilder,
lastTimeTickMessage message.ImmutableMessage,
) (*RecoveryStorage, *RecoverySnapshot, error) {
rs := newRecoveryStorage(recoveryStreamBuilder.Channel())
if err := rs.recoverRecoveryInfoFromMeta(ctx, recoveryStreamBuilder.WALName(), recoveryStreamBuilder.Channel(), lastTimeTickMessage); err != nil {
rs.Logger().Warn("recovery storage failed", zap.Error(err))
return nil, nil, err
}
// recover the state from wal and start the background task to persist the state.
snapshot, err := rs.recoverFromStream(ctx, recoveryStreamBuilder, lastTimeTickMessage)
if err != nil {
rs.Logger().Warn("recovery storage failed", zap.Error(err))
return nil, nil, err
}
// recovery storage start work.
rs.SetLogger(resource.Resource().Logger().With(
log.FieldComponent(componentRecoveryStorage),
zap.String("channel", recoveryStreamBuilder.Channel().String()),
zap.String("state", recoveryStorageStateWorking)))
go rs.backgroundTask()
return rs, snapshot, nil
}
// newRecoveryStorage creates a new recovery storage.
func newRecoveryStorage(channel types.PChannelInfo) *RecoveryStorage {
cfg := newConfig()
return &RecoveryStorage{
backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
cfg: cfg,
mu: sync.Mutex{},
channel: channel,
dirtyCounter: 0,
persistNotifier: make(chan struct{}, 1),
}
}
// RecoveryStorage is a component that manages the recovery info for the streaming service.
// It will consume the message from the wal, consume the message in wal, and update the checkpoint for it.
type RecoveryStorage struct {
log.Binder
backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}]
cfg *config
mu sync.Mutex
channel types.PChannelInfo
segments map[int64]*segmentRecoveryInfo
vchannels map[string]*vchannelRecoveryInfo
checkpoint *WALCheckpoint
dirtyCounter int // records the message count since last persist snapshot.
// used to trigger the recovery persist operation.
persistNotifier chan struct{}
}
// ObserveMessage is called when a new message is observed.
func (r *RecoveryStorage) ObserveMessage(msg message.ImmutableMessage) {
r.mu.Lock()
defer r.mu.Unlock()
r.observeMessage(msg)
}
// Close closes the recovery storage and wait the background task stop.
func (r *RecoveryStorage) Close() {
r.backgroundTaskNotifier.Cancel()
r.backgroundTaskNotifier.BlockUntilFinish()
}
// notifyPersist notifies a persist operation.
func (r *RecoveryStorage) notifyPersist() {
select {
case r.persistNotifier <- struct{}{}:
default:
}
}
// consumeDirtySnapshot consumes the dirty state and returns a snapshot to persist.
// A snapshot is always a consistent state (fully consume a message or a txn message) of the recovery storage.
func (r *RecoveryStorage) consumeDirtySnapshot() *RecoverySnapshot {
r.mu.Lock()
defer r.mu.Unlock()
segments := make(map[int64]*streamingpb.SegmentAssignmentMeta)
vchannels := make(map[string]*streamingpb.VChannelMeta)
for _, segment := range r.segments {
dirtySnapshot, shouldBeRemoved := segment.ConsumeDirtyAndGetSnapshot()
if shouldBeRemoved {
delete(r.segments, segment.meta.SegmentId)
}
if dirtySnapshot != nil {
segments[segment.meta.SegmentId] = dirtySnapshot
}
}
for _, vchannel := range r.vchannels {
dirtySnapshot, shouldBeRemoved := vchannel.ConsumeDirtyAndGetSnapshot()
if shouldBeRemoved {
delete(r.vchannels, vchannel.meta.Vchannel)
}
if dirtySnapshot != nil {
vchannels[vchannel.meta.Vchannel] = dirtySnapshot
}
}
// clear the dirty counter.
r.dirtyCounter = 0
return &RecoverySnapshot{
VChannels: vchannels,
SegmentAssignments: segments,
Checkpoint: r.checkpoint.Clone(),
}
}
// observeMessage observes a message and update the recovery storage.
func (r *RecoveryStorage) observeMessage(msg message.ImmutableMessage) {
if msg.TimeTick() <= r.checkpoint.TimeTick {
if r.Logger().Level().Enabled(zap.DebugLevel) {
r.Logger().Debug("skip the message before the checkpoint",
log.FieldMessage(msg),
zap.Uint64("checkpoint", r.checkpoint.TimeTick),
zap.Uint64("incoming", msg.TimeTick()),
)
}
return
}
r.handleMessage(msg)
checkpointUpdates := !r.checkpoint.MessageID.EQ(msg.LastConfirmedMessageID())
r.checkpoint.TimeTick = msg.TimeTick()
r.checkpoint.MessageID = msg.LastConfirmedMessageID()
if checkpointUpdates {
// only count the dirty if last confirmed message id is updated.
// we always recover from that point, the writeaheadtimetick is just a redundant information.
r.dirtyCounter++
}
if r.dirtyCounter > r.cfg.maxDirtyMessages {
r.notifyPersist()
}
}
// The incoming message id is always sorted with timetick.
func (r *RecoveryStorage) handleMessage(msg message.ImmutableMessage) {
if msg.VChannel() != "" && msg.MessageType() != message.MessageTypeCreateCollection &&
msg.MessageType() != message.MessageTypeDropCollection && r.vchannels[msg.VChannel()] == nil {
r.detectInconsistency(msg, "vchannel not found")
}
switch msg.MessageType() {
case message.MessageTypeInsert:
immutableMsg := message.MustAsImmutableInsertMessageV1(msg)
r.handleInsert(immutableMsg)
case message.MessageTypeDelete:
immutableMsg := message.MustAsImmutableDeleteMessageV1(msg)
r.handleDelete(immutableMsg)
case message.MessageTypeCreateSegment:
immutableMsg := message.MustAsImmutableCreateSegmentMessageV2(msg)
r.handleCreateSegment(immutableMsg)
case message.MessageTypeFlush:
immutableMsg := message.MustAsImmutableFlushMessageV2(msg)
r.handleFlush(immutableMsg)
case message.MessageTypeManualFlush:
immutableMsg := message.MustAsImmutableManualFlushMessageV2(msg)
r.handleManualFlush(immutableMsg)
case message.MessageTypeCreateCollection:
immutableMsg := message.MustAsImmutableCreateCollectionMessageV1(msg)
r.handleCreateCollection(immutableMsg)
case message.MessageTypeDropCollection:
immutableMsg := message.MustAsImmutableDropCollectionMessageV1(msg)
r.handleDropCollection(immutableMsg)
case message.MessageTypeCreatePartition:
immutableMsg := message.MustAsImmutableCreatePartitionMessageV1(msg)
r.handleCreatePartition(immutableMsg)
case message.MessageTypeDropPartition:
immutableMsg := message.MustAsImmutableDropPartitionMessageV1(msg)
r.handleDropPartition(immutableMsg)
case message.MessageTypeTxn:
immutableMsg := message.AsImmutableTxnMessage(msg)
r.handleTxn(immutableMsg)
case message.MessageTypeImport:
immutableMsg := message.MustAsImmutableImportMessageV1(msg)
r.handleImport(immutableMsg)
case message.MessageTypeSchemaChange:
immutableMsg := message.MustAsImmutableCollectionSchemaChangeV2(msg)
r.handleSchemaChange(immutableMsg)
case message.MessageTypeTimeTick:
// nothing, the time tick message make no recovery operation.
default:
panic("unreachable: some message type can not be consumed, there's a critical bug.")
}
}
// handleInsert handles the insert message.
func (r *RecoveryStorage) handleInsert(msg message.ImmutableInsertMessageV1) {
for _, partition := range msg.Header().GetPartitions() {
if segment, ok := r.segments[partition.SegmentAssignment.SegmentId]; ok && segment.IsGrowing() {
segment.ObserveInsert(msg.TimeTick(), partition)
if r.Logger().Level().Enabled(zap.DebugLevel) {
r.Logger().Debug("insert entity", log.FieldMessage(msg), zap.Uint64("segmentRows", segment.Rows()), zap.Uint64("segmentBinary", segment.BinarySize()))
}
} else {
r.detectInconsistency(msg, "segment not found")
}
}
}
// handleDelete handles the delete message.
func (r *RecoveryStorage) handleDelete(msg message.ImmutableDeleteMessageV1) {
// nothing, current delete operation is managed by flowgraph, not recovery storage.
if r.Logger().Level().Enabled(zap.DebugLevel) {
r.Logger().Debug("delete entity", log.FieldMessage(msg))
}
}
// handleCreateSegment handles the create segment message.
func (r *RecoveryStorage) handleCreateSegment(msg message.ImmutableCreateSegmentMessageV2) {
segment := newSegmentRecoveryInfoFromCreateSegmentMessage(msg)
r.segments[segment.meta.SegmentId] = segment
r.Logger().Info("create segment", log.FieldMessage(msg))
}
// handleFlush handles the flush message.
func (r *RecoveryStorage) handleFlush(msg message.ImmutableFlushMessageV2) {
header := msg.Header()
if segment, ok := r.segments[header.SegmentId]; ok {
segment.ObserveFlush(msg.TimeTick())
r.Logger().Info("flush segment", log.FieldMessage(msg), zap.Uint64("rows", segment.Rows()), zap.Uint64("binarySize", segment.BinarySize()))
}
}
// handleManualFlush handles the manual flush message.
func (r *RecoveryStorage) handleManualFlush(msg message.ImmutableManualFlushMessageV2) {
segments := make(map[int64]struct{}, len(msg.Header().SegmentIds))
for _, segmentID := range msg.Header().SegmentIds {
segments[segmentID] = struct{}{}
}
r.flushSegments(msg, segments)
}
// flushSegments flushes the segments in the recovery storage.
func (r *RecoveryStorage) flushSegments(msg message.ImmutableMessage, sealSegmentIDs map[int64]struct{}) {
segmentIDs := make([]int64, 0)
rows := make([]uint64, 0)
binarySize := make([]uint64, 0)
for _, segment := range r.segments {
if _, ok := sealSegmentIDs[segment.meta.SegmentId]; ok {
segment.ObserveFlush(msg.TimeTick())
segmentIDs = append(segmentIDs, segment.meta.SegmentId)
rows = append(rows, segment.Rows())
binarySize = append(binarySize, segment.BinarySize())
}
}
if len(segmentIDs) != len(sealSegmentIDs) {
r.detectInconsistency(msg, "flush segments not exist", zap.Int64s("wanted", lo.Keys(sealSegmentIDs)), zap.Int64s("actually", segmentIDs))
}
r.Logger().Info("flush all segments of collection by manual flush", log.FieldMessage(msg), zap.Uint64s("rows", rows), zap.Uint64s("binarySize", binarySize))
}
// handleCreateCollection handles the create collection message.
func (r *RecoveryStorage) handleCreateCollection(msg message.ImmutableCreateCollectionMessageV1) {
if _, ok := r.vchannels[msg.VChannel()]; ok {
return
}
r.vchannels[msg.VChannel()] = newVChannelRecoveryInfoFromCreateCollectionMessage(msg)
r.Logger().Info("create collection", log.FieldMessage(msg))
}
// handleDropCollection handles the drop collection message.
func (r *RecoveryStorage) handleDropCollection(msg message.ImmutableDropCollectionMessageV1) {
if vchannelInfo, ok := r.vchannels[msg.VChannel()]; !ok || vchannelInfo.meta.State == streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
return
}
r.vchannels[msg.VChannel()].ObserveDropCollection(msg)
// flush all existing segments.
r.flushAllSegmentOfCollection(msg, msg.Header().CollectionId)
r.Logger().Info("drop collection", log.FieldMessage(msg))
}
// flushAllSegmentOfCollection flushes all segments of the collection.
func (r *RecoveryStorage) flushAllSegmentOfCollection(msg message.ImmutableMessage, collectionID int64) {
segmentIDs := make([]int64, 0)
rows := make([]uint64, 0)
for _, segment := range r.segments {
if segment.meta.CollectionId == collectionID {
segment.ObserveFlush(msg.TimeTick())
segmentIDs = append(segmentIDs, segment.meta.SegmentId)
rows = append(rows, segment.Rows())
}
}
r.Logger().Info("flush all segments of collection", log.FieldMessage(msg), zap.Int64s("segmentIDs", segmentIDs), zap.Uint64s("rows", rows))
}
// handleCreatePartition handles the create partition message.
func (r *RecoveryStorage) handleCreatePartition(msg message.ImmutableCreatePartitionMessageV1) {
if vchannelInfo, ok := r.vchannels[msg.VChannel()]; !ok || vchannelInfo.meta.State == streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
return
}
r.vchannels[msg.VChannel()].ObserveCreatePartition(msg)
r.Logger().Info("create partition", log.FieldMessage(msg))
}
// handleDropPartition handles the drop partition message.
func (r *RecoveryStorage) handleDropPartition(msg message.ImmutableDropPartitionMessageV1) {
r.vchannels[msg.VChannel()].ObserveDropPartition(msg)
// flush all existing segments.
r.flushAllSegmentOfPartition(msg, msg.Header().CollectionId, msg.Header().PartitionId)
r.Logger().Info("drop partition", log.FieldMessage(msg))
}
// flushAllSegmentOfPartition flushes all segments of the partition.
func (r *RecoveryStorage) flushAllSegmentOfPartition(msg message.ImmutableMessage, collectionID int64, partitionID int64) {
segmentIDs := make([]int64, 0)
rows := make([]uint64, 0)
for _, segment := range r.segments {
if segment.meta.PartitionId == partitionID {
segment.ObserveFlush(msg.TimeTick())
segmentIDs = append(segmentIDs, segment.meta.SegmentId)
rows = append(rows, segment.Rows())
}
}
r.Logger().Info("flush all segments of partition", log.FieldMessage(msg), zap.Int64s("segmentIDs", segmentIDs), zap.Uint64s("rows", rows))
}
// handleTxn handles the txn message.
func (r *RecoveryStorage) handleTxn(msg message.ImmutableTxnMessage) {
msg.RangeOver(func(im message.ImmutableMessage) error {
r.handleMessage(im)
return nil
})
}
// handleImport handles the import message.
func (r *RecoveryStorage) handleImport(_ message.ImmutableImportMessageV1) {
}
// handleSchemaChange handles the schema change message.
func (r *RecoveryStorage) handleSchemaChange(msg message.ImmutableSchemaChangeMessageV2) {
// when schema change happens, we need to flush all segments in the collection.
// TODO: add the flush segment list into schema change message.
// TODO: persist the schema change into recoveryinfo.
r.flushAllSegmentOfCollection(msg, msg.Header().CollectionId)
}
// detectInconsistency detects the inconsistency in the recovery storage.
func (r *RecoveryStorage) detectInconsistency(msg message.ImmutableMessage, reason string, extra ...zap.Field) {
fields := make([]zap.Field, 0, len(extra)+2)
fields = append(fields, log.FieldMessage(msg), zap.String("reason", reason))
fields = append(fields, extra...)
// The log is not fatal in some cases.
// because our meta is not atomic-updated, so these error may be logged if crashes when meta updated partially.
r.Logger().Warn("inconsistency detected", fields...)
}

View File

@ -0,0 +1,583 @@
package recovery
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func TestRecoveryStorage(t *testing.T) {
paramtable.Get().Save(paramtable.Get().StreamingCfg.WALRecoveryPersistInterval.Key, "1ms")
paramtable.Get().Save(paramtable.Get().StreamingCfg.WALRecoveryGracefulCloseTimeout.Key, "10ms")
vchannelMetas := make(map[string]*streamingpb.VChannelMeta)
segmentMetas := make(map[int64]*streamingpb.SegmentAssignmentMeta)
cp := &streamingpb.WALCheckpoint{
MessageId: &messagespb.MessageID{
Id: rmq.NewRmqID(1).Marshal(),
},
TimeTick: 1,
RecoveryMagic: 0,
}
snCatalog := mock_metastore.NewMockStreamingNodeCataLog(t)
snCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.SegmentAssignmentMeta, error) {
return lo.Values(segmentMetas), nil
})
snCatalog.EXPECT().ListVChannel(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, channel string) ([]*streamingpb.VChannelMeta, error) {
return lo.Values(vchannelMetas), nil
})
segmentSaveFailure := true
snCatalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[int64]*streamingpb.SegmentAssignmentMeta) error {
if segmentSaveFailure {
segmentSaveFailure = false
return errors.New("save failed")
}
for _, v := range m {
if v.State != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED {
segmentMetas[v.SegmentId] = v
} else {
delete(segmentMetas, v.SegmentId)
}
}
return nil
})
vchannelSaveFailure := true
snCatalog.EXPECT().SaveVChannels(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, s string, m map[string]*streamingpb.VChannelMeta) error {
if vchannelSaveFailure {
vchannelSaveFailure = false
return errors.New("save failed")
}
for _, v := range m {
if v.State != streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
vchannelMetas[v.Vchannel] = v
} else {
delete(vchannelMetas, v.Vchannel)
}
}
return nil
})
snCatalog.EXPECT().GetConsumeCheckpoint(mock.Anything, mock.Anything).Return(cp, nil)
checkpointSaveFailure := true
snCatalog.EXPECT().SaveConsumeCheckpoint(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pchannelName string, checkpoint *streamingpb.WALCheckpoint) error {
if checkpointSaveFailure {
checkpointSaveFailure = false
return errors.New("save failed")
}
cp = checkpoint
return nil
})
resource.InitForTest(t, resource.OptStreamingNodeCatalog(snCatalog))
b := &streamBuilder{
channel: types.PChannelInfo{Name: "test_channel"},
lastConfirmedMessageID: 1,
messageID: 1,
timetick: 1,
collectionIDs: make(map[int64]map[int64]map[int64]struct{}),
vchannels: make(map[int64]string),
idAlloc: 1,
}
msg := message.NewTimeTickMessageBuilderV1().
WithAllVChannel().
WithHeader(&message.TimeTickMessageHeader{}).
WithBody(&msgpb.TimeTickMsg{}).
MustBuildMutable().
WithTimeTick(1).
WithLastConfirmed(rmq.NewRmqID(1)).
IntoImmutableMessage(rmq.NewRmqID(1))
b.generateStreamMessage()
for i := 0; i < 3; i++ {
if i == 2 {
// make sure the checkpoint is saved.
paramtable.Get().Save(paramtable.Get().StreamingCfg.WALRecoveryGracefulCloseTimeout.Key, "1000s")
}
rs, snapshot, err := RecoverRecoveryStorage(context.Background(), b, msg)
assert.NoError(t, err)
assert.NotNil(t, rs)
assert.NotNil(t, snapshot)
msgs := b.generateStreamMessage()
for _, msg := range msgs {
rs.ObserveMessage(msg)
}
rs.Close()
var partitionNum int
var collectionNum int
var segmentNum int
for _, v := range rs.vchannels {
if v.meta.State != streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
collectionNum += 1
partitionNum += len(v.meta.CollectionInfo.Partitions)
}
}
for _, v := range rs.segments {
if v.meta.State != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED {
segmentNum += 1
}
}
assert.Equal(t, partitionNum, b.partitionNum())
assert.Equal(t, collectionNum, b.collectionNum())
assert.Equal(t, segmentNum, b.segmentNum())
}
assert.Equal(t, b.collectionNum(), len(vchannelMetas))
partitionNum := 0
for _, v := range vchannelMetas {
partitionNum += len(v.CollectionInfo.Partitions)
}
assert.Equal(t, b.partitionNum(), partitionNum)
assert.Equal(t, b.segmentNum(), len(segmentMetas))
}
type streamBuilder struct {
channel types.PChannelInfo
lastConfirmedMessageID int64
messageID int64
timetick uint64
collectionIDs map[int64]map[int64]map[int64]struct{}
vchannels map[int64]string
idAlloc int64
histories []message.ImmutableMessage
}
func (b *streamBuilder) collectionNum() int {
return len(b.collectionIDs)
}
func (b *streamBuilder) partitionNum() int {
partitionNum := 0
for _, partitions := range b.collectionIDs {
partitionNum += len(partitions)
}
return partitionNum
}
func (b *streamBuilder) segmentNum() int {
segmentNum := 0
for _, partitions := range b.collectionIDs {
for _, segments := range partitions {
segmentNum += len(segments)
}
}
return segmentNum
}
type testRecoveryStream struct {
ch chan message.ImmutableMessage
}
func (ts *testRecoveryStream) Chan() <-chan message.ImmutableMessage {
return ts.ch
}
func (ts *testRecoveryStream) Error() error {
return nil
}
func (ts *testRecoveryStream) TxnBuffer() *utility.TxnBuffer {
return nil
}
func (ts *testRecoveryStream) Close() error {
return nil
}
func (b *streamBuilder) WALName() string {
return "rocksmq"
}
func (b *streamBuilder) Channel() types.PChannelInfo {
return b.channel
}
func (b *streamBuilder) Build(param BuildRecoveryStreamParam) RecoveryStream {
rs := &testRecoveryStream{
ch: make(chan message.ImmutableMessage, len(b.histories)),
}
cp := param.StartCheckpoint
for _, msg := range b.histories {
if cp.LTE(msg.MessageID()) {
rs.ch <- msg
}
}
close(rs.ch)
return rs
}
func (b *streamBuilder) generateStreamMessage() []message.ImmutableMessage {
ops := []func() message.ImmutableMessage{
b.createCollection,
b.createPartition,
b.createSegment,
b.createSegment,
b.dropCollection,
b.dropPartition,
b.flushSegment,
b.flushSegment,
b.createInsert,
b.createInsert,
b.createInsert,
b.createDelete,
b.createDelete,
b.createDelete,
b.createTxn,
b.createTxn,
b.createManualFlush,
}
msgs := make([]message.ImmutableMessage, 0)
for i := 0; i < int(rand.Int63n(1000)+1000); i++ {
op := rand.Int31n(int32(len(ops)))
if msg := ops[op](); msg != nil {
msgs = append(msgs, msg)
}
}
b.histories = append(b.histories, msgs...)
return msgs
}
// createCollection creates a collection message with the given vchannel.
func (b *streamBuilder) createCollection() message.ImmutableMessage {
vchannel := fmt.Sprintf("vchannel_%d", b.allocID())
collectionID := b.allocID()
partitions := rand.Int31n(1023) + 1
partitionIDs := make(map[int64]map[int64]struct{}, partitions)
for i := int32(0); i < partitions; i++ {
partitionIDs[b.allocID()] = make(map[int64]struct{})
}
b.nextMessage()
b.collectionIDs[collectionID] = partitionIDs
b.vchannels[collectionID] = vchannel
return message.NewCreateCollectionMessageBuilderV1().
WithVChannel(vchannel).
WithHeader(&message.CreateCollectionMessageHeader{
CollectionId: collectionID,
PartitionIds: lo.Keys(partitionIDs),
}).
WithBody(&msgpb.CreateCollectionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
func (b *streamBuilder) createPartition() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
partitionID := b.allocID()
collection[partitionID] = make(map[int64]struct{})
b.nextMessage()
return message.NewCreatePartitionMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.CreatePartitionMessageHeader{
CollectionId: collectionID,
PartitionId: partitionID,
}).
WithBody(&msgpb.CreatePartitionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) createSegment() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID, partition := range collection {
if rand.Int31n(3) < 1 {
continue
}
segmentID := b.allocID()
partition[segmentID] = struct{}{}
b.nextMessage()
return message.NewCreateSegmentMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.CreateSegmentMessageHeader{
CollectionId: collectionID,
SegmentId: segmentID,
PartitionId: partitionID,
StorageVersion: 1,
MaxSegmentSize: 1024,
}).
WithBody(&message.CreateSegmentMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
return nil
}
func (b *streamBuilder) dropCollection() message.ImmutableMessage {
for collectionID := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
delete(b.collectionIDs, collectionID)
return message.NewDropCollectionMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.DropCollectionMessageHeader{
CollectionId: collectionID,
}).
WithBody(&msgpb.DropCollectionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) dropPartition() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID := range collection {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
delete(collection, partitionID)
return message.NewDropPartitionMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: collectionID,
PartitionId: partitionID,
}).
WithBody(&msgpb.DropPartitionRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
return nil
}
func (b *streamBuilder) flushSegment() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID := range collection {
if rand.Int31n(3) < 1 {
continue
}
for segmentID := range collection[partitionID] {
if rand.Int31n(4) < 1 {
continue
}
delete(collection[partitionID], segmentID)
b.nextMessage()
return message.NewFlushMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.FlushMessageHeader{
CollectionId: collectionID,
SegmentId: segmentID,
}).
WithBody(&message.FlushMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
}
return nil
}
func (b *streamBuilder) createTxn() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
txnSession := &message.TxnContext{
TxnID: message.TxnID(b.allocID()),
}
begin := message.NewBeginTxnMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.BeginTxnMessageHeader{}).
WithBody(&message.BeginTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithTxnContext(*txnSession).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
builder := message.NewImmutableTxnMessageBuilder(message.MustAsImmutableBeginTxnMessageV2(begin))
for partitionID := range collection {
for segmentID := range collection[partitionID] {
b.nextMessage()
builder.Add(message.NewInsertMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.InsertMessageHeader{
CollectionId: collectionID,
Partitions: []*messagespb.PartitionSegmentAssignment{
{
PartitionId: partitionID,
Rows: uint64(rand.Int31n(100)),
BinarySize: uint64(rand.Int31n(100)),
SegmentAssignment: &messagespb.SegmentAssignment{SegmentId: segmentID},
},
},
}).
WithBody(&msgpb.InsertRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithTxnContext(*txnSession).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID)))
}
}
b.nextMessage()
commit := message.NewCommitTxnMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.CommitTxnMessageHeader{}).
WithBody(&message.CommitTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithTxnContext(*txnSession).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
txnMsg, _ := builder.Build(message.MustAsImmutableCommitTxnMessageV2(commit))
return txnMsg
}
return nil
}
func (b *streamBuilder) createDelete() message.ImmutableMessage {
for collectionID := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
b.nextMessage()
return message.NewDeleteMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.DeleteMessageHeader{
CollectionId: collectionID,
}).
WithBody(&msgpb.DeleteRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) createManualFlush() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
segmentIDs := make([]int64, 0)
for partitionID := range collection {
if rand.Int31n(3) < 1 {
continue
}
for segmentID := range collection[partitionID] {
if rand.Int31n(4) < 2 {
continue
}
segmentIDs = append(segmentIDs, segmentID)
delete(collection[partitionID], segmentID)
}
}
if len(segmentIDs) == 0 {
continue
}
b.nextMessage()
return message.NewManualFlushMessageBuilderV2().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.ManualFlushMessageHeader{
CollectionId: collectionID,
SegmentIds: segmentIDs,
}).
WithBody(&message.ManualFlushMessageBody{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
return nil
}
func (b *streamBuilder) createInsert() message.ImmutableMessage {
for collectionID, collection := range b.collectionIDs {
if rand.Int31n(3) < 1 {
continue
}
for partitionID := range collection {
if rand.Int31n(3) < 1 {
continue
}
for segmentID := range collection[partitionID] {
if rand.Int31n(4) < 2 {
continue
}
b.nextMessage()
return message.NewInsertMessageBuilderV1().
WithVChannel(b.vchannels[collectionID]).
WithHeader(&message.InsertMessageHeader{
CollectionId: collectionID,
Partitions: []*messagespb.PartitionSegmentAssignment{
{
PartitionId: partitionID,
Rows: uint64(rand.Int31n(100)),
BinarySize: uint64(rand.Int31n(100)),
SegmentAssignment: &messagespb.SegmentAssignment{SegmentId: segmentID},
},
},
}).
WithBody(&msgpb.InsertRequest{}).
MustBuildMutable().
WithTimeTick(b.timetick).
WithLastConfirmed(rmq.NewRmqID(b.lastConfirmedMessageID)).
IntoImmutableMessage(rmq.NewRmqID(b.messageID))
}
}
}
return nil
}
func (b *streamBuilder) nextMessage() {
b.messageID++
if rand.Int31n(3) < 2 {
b.lastConfirmedMessageID = b.messageID + 1
}
b.timetick++
}
func (b *streamBuilder) allocID() int64 {
b.idAlloc++
return b.idAlloc
}

View File

@ -0,0 +1,91 @@
package recovery
import (
"context"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
// recoverFromStream recovers the recovery storage from the recovery stream.
func (r *RecoveryStorage) recoverFromStream(
ctx context.Context,
recoveryStreamBuilder RecoveryStreamBuilder,
lastTimeTickMessage message.ImmutableMessage,
) (snapshot *RecoverySnapshot, err error) {
r.SetLogger(resource.Resource().Logger().With(
log.FieldComponent(componentRecoveryStorage),
zap.String("channel", recoveryStreamBuilder.Channel().String()),
zap.String("startMessageID", r.checkpoint.MessageID.String()),
zap.Uint64("fromTimeTick", r.checkpoint.TimeTick),
zap.Uint64("toTimeTick", lastTimeTickMessage.TimeTick()),
zap.String("state", recoveryStorageStateStreamRecovering),
))
r.Logger().Info("recover from wal stream...")
rs := recoveryStreamBuilder.Build(BuildRecoveryStreamParam{
StartCheckpoint: r.checkpoint.MessageID,
EndTimeTick: lastTimeTickMessage.TimeTick(),
})
defer func() {
rs.Close()
if err != nil {
r.Logger().Warn("recovery from wal stream failed", zap.Error(err))
return
}
r.Logger().Info("recovery from wal stream done",
zap.Int("vchannels", len(snapshot.VChannels)),
zap.Int("segments", len(snapshot.SegmentAssignments)),
zap.String("checkpoint", snapshot.Checkpoint.MessageID.String()),
zap.Uint64("timetick", snapshot.Checkpoint.TimeTick),
)
}()
L:
for {
select {
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "failed to recover from wal")
case msg, ok := <-rs.Chan():
if !ok {
// The recovery stream is reach the end, we can stop the recovery.
break L
}
r.observeMessage(msg)
}
}
if rs.Error() != nil {
return nil, errors.Wrap(rs.Error(), "failed to read the recovery info from wal")
}
snapshot = r.getSnapshot()
snapshot.TxnBuffer = rs.TxnBuffer()
return snapshot, nil
}
// getSnapshot returns the snapshot of the recovery storage.
// Use this function to get the snapshot after recovery is finished,
// and use the snapshot to recover all write ahead components.
func (r *RecoveryStorage) getSnapshot() *RecoverySnapshot {
segments := make(map[int64]*streamingpb.SegmentAssignmentMeta, len(r.segments))
vchannels := make(map[string]*streamingpb.VChannelMeta, len(r.vchannels))
for segmentID, segment := range r.segments {
if segment.IsGrowing() {
segments[segmentID] = proto.Clone(segment.meta).(*streamingpb.SegmentAssignmentMeta)
}
}
for channelName, vchannel := range r.vchannels {
if vchannel.IsActive() {
vchannels[channelName] = proto.Clone(vchannel.meta).(*streamingpb.VChannelMeta)
}
}
return &RecoverySnapshot{
VChannels: vchannels,
SegmentAssignments: segments,
Checkpoint: r.checkpoint.Clone(),
}
}

View File

@ -0,0 +1,121 @@
package recovery
import (
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
)
// newSegmentRecoveryInfoFromSegmentAssignmentMeta creates a new segment recovery info from segment assignment meta.
func newSegmentRecoveryInfoFromSegmentAssignmentMeta(metas []*streamingpb.SegmentAssignmentMeta) map[int64]*segmentRecoveryInfo {
infos := make(map[int64]*segmentRecoveryInfo, len(metas))
for _, m := range metas {
infos[m.SegmentId] = &segmentRecoveryInfo{
meta: m,
// recover from persisted info, so it is not dirty.
dirty: false,
}
}
return infos
}
// newSegmentRecoveryInfoFromCreateSegmentMessage creates a new segment recovery info from a create segment message.
func newSegmentRecoveryInfoFromCreateSegmentMessage(msg message.ImmutableCreateSegmentMessageV2) *segmentRecoveryInfo {
header := msg.Header()
now := tsoutil.PhysicalTime(msg.TimeTick()).Unix()
return &segmentRecoveryInfo{
meta: &streamingpb.SegmentAssignmentMeta{
CollectionId: header.CollectionId,
PartitionId: header.PartitionId,
SegmentId: header.SegmentId,
Vchannel: msg.VChannel(),
State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING,
StorageVersion: header.StorageVersion,
CheckpointTimeTick: msg.TimeTick(),
Stat: &streamingpb.SegmentAssignmentStat{
MaxBinarySize: header.MaxSegmentSize,
InsertedRows: 0,
InsertedBinarySize: 0,
CreateTimestamp: now,
LastModifiedTimestamp: now,
BinlogCounter: 0,
CreateSegmentTimeTick: msg.TimeTick(),
},
},
// a new incoming create segment request is always dirty until it is flushed.
dirty: true,
}
}
// segmentRecoveryInfo is the recovery info for single segment.
type segmentRecoveryInfo struct {
meta *streamingpb.SegmentAssignmentMeta
dirty bool // whether the segment recovery info is dirty.
}
// IsGrowing returns true if the segment is in growing state.
func (info *segmentRecoveryInfo) IsGrowing() bool {
return info.meta.State == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING
}
// CreateSegmentTimeTick returns the time tick when the segment was created.
func (info *segmentRecoveryInfo) CreateSegmentTimeTick() uint64 {
return info.meta.Stat.CreateSegmentTimeTick
}
// Rows returns the number of rows in the segment.
func (info *segmentRecoveryInfo) Rows() uint64 {
return info.meta.Stat.InsertedRows
}
// BinarySize returns the binary size of the segment.
func (info *segmentRecoveryInfo) BinarySize() uint64 {
return info.meta.Stat.InsertedBinarySize
}
// ObserveInsert is called when an insert message is observed.
func (info *segmentRecoveryInfo) ObserveInsert(timetick uint64, assignment *messagespb.PartitionSegmentAssignment) {
if timetick < info.meta.CheckpointTimeTick {
// the txn message will share the same time tick.
// so we only filter the time tick is less than the checkpoint time tick.
// Consistent state is guaranteed by the recovery storage's mutex.
return
}
info.meta.Stat.InsertedBinarySize += assignment.BinarySize
info.meta.Stat.InsertedRows += assignment.Rows
info.meta.Stat.LastModifiedTimestamp = tsoutil.PhysicalTime(timetick).Unix()
info.meta.CheckpointTimeTick = timetick
info.dirty = true
}
// ObserveFlush is called when a segment should be flushed.
func (info *segmentRecoveryInfo) ObserveFlush(timetick uint64) {
if timetick < info.meta.CheckpointTimeTick {
// the txn message will share the same time tick.
// (although the flush operation is not a txn message)
// so we only filter the time tick is less than the checkpoint time tick.
// Consistent state is guaranteed by the recovery storage's mutex.
return
}
if info.meta.State == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED {
// idempotent
return
}
info.meta.State = streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED
info.meta.Stat.LastModifiedTimestamp = tsoutil.PhysicalTime(timetick).Unix()
info.meta.CheckpointTimeTick = timetick
info.dirty = true
}
// ConsumeDirtyAndGetSnapshot consumes the dirty segment recovery info and returns a snapshot to persist.
// Return nil if the segment recovery info is not dirty.
func (info *segmentRecoveryInfo) ConsumeDirtyAndGetSnapshot() (dirtySnapshot *streamingpb.SegmentAssignmentMeta, shouldBeRemoved bool) {
if !info.dirty {
return nil, false
}
info.dirty = false
return proto.Clone(info.meta).(*streamingpb.SegmentAssignmentMeta), info.meta.State == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED
}

View File

@ -0,0 +1,105 @@
package recovery
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
func TestNewSegmentRecoveryInfoFromSegmentAssignmentMeta(t *testing.T) {
// Test with empty input
metas := []*streamingpb.SegmentAssignmentMeta{}
infos := newSegmentRecoveryInfoFromSegmentAssignmentMeta(metas)
assert.Empty(t, infos)
// Test with valid input
metas = []*streamingpb.SegmentAssignmentMeta{
{SegmentId: 1, Vchannel: "vchannel-1"},
{SegmentId: 2, Vchannel: "vchannel-2"},
}
infos = newSegmentRecoveryInfoFromSegmentAssignmentMeta(metas)
assert.Len(t, infos, 2)
assert.Equal(t, int64(1), infos[1].meta.SegmentId)
assert.Equal(t, int64(2), infos[2].meta.SegmentId)
assert.False(t, infos[1].dirty)
assert.False(t, infos[2].dirty)
}
func TestSegmentRecoveryInfo(t *testing.T) {
msg := message.NewCreateSegmentMessageBuilderV2().
WithHeader(&message.CreateSegmentMessageHeader{
CollectionId: 100,
PartitionId: 1,
SegmentId: 2,
StorageVersion: storage.StorageV1,
MaxSegmentSize: 100,
}).
WithBody(&message.CreateSegmentMessageBody{}).
WithVChannel("vchannel-1").
MustBuildMutable()
id := rmq.NewRmqID(1)
ts := uint64(12345)
immutableMsg := msg.WithTimeTick(ts).WithLastConfirmed(id).IntoImmutableMessage(id)
info := newSegmentRecoveryInfoFromCreateSegmentMessage(message.MustAsImmutableCreateSegmentMessageV2(immutableMsg))
assert.Equal(t, int64(2), info.meta.SegmentId)
assert.Equal(t, int64(1), info.meta.PartitionId)
assert.Equal(t, storage.StorageV1, info.meta.StorageVersion)
assert.Equal(t, streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING, info.meta.State)
assert.Equal(t, uint64(100), info.meta.Stat.MaxBinarySize)
ts += 1
assign := &messagespb.PartitionSegmentAssignment{
PartitionId: 1,
Rows: 1,
BinarySize: 10,
SegmentAssignment: &messagespb.SegmentAssignment{
SegmentId: 2,
},
}
info.ObserveInsert(ts, assign)
assert.True(t, info.dirty)
snapshot, shouldBeRemoved := info.ConsumeDirtyAndGetSnapshot()
assert.NotNil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
assert.Equal(t, uint64(10), snapshot.Stat.InsertedBinarySize)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
// insert may came from same txn with same txn.
info.ObserveInsert(ts, assign)
assert.True(t, info.dirty)
ts += 1
info.ObserveFlush(ts)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.NotNil(t, snapshot)
assert.Equal(t, snapshot.State, streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED)
assert.True(t, shouldBeRemoved)
assert.False(t, info.dirty)
// idempotent
info.ObserveFlush(ts)
assert.NotNil(t, snapshot)
assert.Equal(t, snapshot.State, streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED)
assert.True(t, shouldBeRemoved)
assert.False(t, info.dirty)
// idempotent
info.ObserveFlush(ts + 1)
assert.NotNil(t, snapshot)
assert.Equal(t, snapshot.State, streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED)
assert.True(t, shouldBeRemoved)
assert.False(t, info.dirty)
}

View File

@ -0,0 +1,134 @@
package recovery
import (
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
// newVChannelRecoveryInfoFromCreateCollectionMessage creates a new vchannel recovery info from a create collection message.
func newVChannelRecoveryInfoFromVChannelMeta(meta []*streamingpb.VChannelMeta) map[string]*vchannelRecoveryInfo {
infos := make(map[string]*vchannelRecoveryInfo, len(meta))
for _, m := range meta {
infos[m.Vchannel] = &vchannelRecoveryInfo{
meta: m,
dirty: false, // recover from persisted info, so it is not dirty.
}
}
return infos
}
// newVChannelRecoveryInfoFromCreateCollectionMessage creates a new vchannel recovery info from a create collection message.
func newVChannelRecoveryInfoFromCreateCollectionMessage(msg message.ImmutableCreateCollectionMessageV1) *vchannelRecoveryInfo {
partitions := make([]*streamingpb.PartitionInfoOfVChannel, 0, len(msg.Header().PartitionIds))
for _, partitionId := range msg.Header().PartitionIds {
partitions = append(partitions, &streamingpb.PartitionInfoOfVChannel{
PartitionId: partitionId,
})
}
return &vchannelRecoveryInfo{
meta: &streamingpb.VChannelMeta{
Vchannel: msg.VChannel(),
State: streamingpb.VChannelState_VCHANNEL_STATE_NORMAL,
CollectionInfo: &streamingpb.CollectionInfoOfVChannel{
CollectionId: msg.Header().CollectionId,
Partitions: partitions,
},
CheckpointTimeTick: msg.TimeTick(),
},
// a new incoming create collection request is always dirty until it is persisted.
dirty: true,
}
}
// vchannelRecoveryInfo is the recovery info for a vchannel.
type vchannelRecoveryInfo struct {
meta *streamingpb.VChannelMeta
dirty bool // whether the vchannel recovery info is dirty.
}
// IsActive returns true if the vchannel is active.
func (info *vchannelRecoveryInfo) IsActive() bool {
return info.meta.State != streamingpb.VChannelState_VCHANNEL_STATE_DROPPED
}
// IsPartitionActive returns true if the partition is active.
func (info *vchannelRecoveryInfo) IsPartitionActive(partitionId int64) bool {
for _, partition := range info.meta.CollectionInfo.Partitions {
if partition.PartitionId == partitionId {
return true
}
}
return false
}
// ObserveDropCollection is called when a drop collection message is observed.
func (info *vchannelRecoveryInfo) ObserveDropCollection(msg message.ImmutableDropCollectionMessageV1) {
if msg.TimeTick() < info.meta.CheckpointTimeTick {
// the txn message will share the same time tick.
// (although the flush operation is not a txn message)
// so we only filter the time tick is less than the checkpoint time tick.
// Consistent state is guaranteed by the recovery storage's mutex.
return
}
if info.meta.State == streamingpb.VChannelState_VCHANNEL_STATE_DROPPED {
// make it idempotent, only the first drop collection message can be observed.
return
}
info.meta.State = streamingpb.VChannelState_VCHANNEL_STATE_DROPPED
info.meta.CheckpointTimeTick = msg.TimeTick()
info.dirty = true
}
// ObserveDropPartition is called when a drop partition message is observed.
func (info *vchannelRecoveryInfo) ObserveDropPartition(msg message.ImmutableDropPartitionMessageV1) {
if msg.TimeTick() < info.meta.CheckpointTimeTick {
// the txn message will share the same time tick.
// (although the flush operation is not a txn message)
// so we only filter the time tick is less than the checkpoint time tick.
// Consistent state is guaranteed by the recovery storage's mutex.
return
}
for i, partition := range info.meta.CollectionInfo.Partitions {
if partition.PartitionId == msg.Header().PartitionId {
// make it idempotent, only the first drop partition message can be observed.
info.meta.CollectionInfo.Partitions = append(info.meta.CollectionInfo.Partitions[:i], info.meta.CollectionInfo.Partitions[i+1:]...)
info.meta.CheckpointTimeTick = msg.TimeTick()
info.dirty = true
return
}
}
}
// ObserveCreatePartition is called when a create partition message is observed.
func (info *vchannelRecoveryInfo) ObserveCreatePartition(msg message.ImmutableCreatePartitionMessageV1) {
if msg.TimeTick() < info.meta.CheckpointTimeTick {
// the txn message will share the same time tick.
// (although the flush operation is not a txn message)
// so we only filter the time tick is less than the checkpoint time tick.
// Consistent state is guaranteed by the recovery storage.
return
}
for _, partition := range info.meta.CollectionInfo.Partitions {
if partition.PartitionId == msg.Header().PartitionId {
// make it idempotent, only the first create partition message can be observed.
return
}
}
info.meta.CollectionInfo.Partitions = append(info.meta.CollectionInfo.Partitions, &streamingpb.PartitionInfoOfVChannel{
PartitionId: msg.Header().PartitionId,
})
info.meta.CheckpointTimeTick = msg.TimeTick()
info.dirty = true
}
// ConsumeDirtyAndGetSnapshot returns the snapshot of the vchannel recovery info.
// It returns nil if the vchannel recovery info is not dirty.
func (info *vchannelRecoveryInfo) ConsumeDirtyAndGetSnapshot() (dirtySnapshot *streamingpb.VChannelMeta, ShouldBeRemoved bool) {
if !info.dirty {
return nil, info.meta.State == streamingpb.VChannelState_VCHANNEL_STATE_DROPPED
}
info.dirty = false
return proto.Clone(info.meta).(*streamingpb.VChannelMeta), info.meta.State == streamingpb.VChannelState_VCHANNEL_STATE_DROPPED
}

View File

@ -0,0 +1,199 @@
package recovery
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
func TestNewVChannelRecoveryInfoFromVChannelMeta(t *testing.T) {
meta := []*streamingpb.VChannelMeta{
{Vchannel: "vchannel-1"},
{Vchannel: "vchannel-2"},
}
info := newVChannelRecoveryInfoFromVChannelMeta(meta)
assert.Len(t, info, 2)
assert.NotNil(t, info["vchannel-1"])
assert.NotNil(t, info["vchannel-2"])
assert.False(t, info["vchannel-1"].dirty)
assert.False(t, info["vchannel-2"].dirty)
snapshot, shouldBeRemoved := info["vchannel-1"].ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
}
func TestNewVChannelRecoveryInfoFromCreateCollectionMessage(t *testing.T) {
// CreateCollection
msg := message.NewCreateCollectionMessageBuilderV1().
WithHeader(&message.CreateCollectionMessageHeader{
CollectionId: 100,
PartitionIds: []int64{101, 102},
}).
WithBody(&msgpb.CreateCollectionRequest{
CollectionName: "test-collection",
CollectionID: 100,
PartitionIDs: []int64{101, 102},
}).
WithVChannel("vchannel-1").
MustBuildMutable()
msgID := rmq.NewRmqID(1)
ts := uint64(12345)
immutableMsg := msg.WithTimeTick(ts).WithLastConfirmed(msgID).IntoImmutableMessage(msgID)
info := newVChannelRecoveryInfoFromCreateCollectionMessage(message.MustAsImmutableCreateCollectionMessageV1(immutableMsg))
assert.Equal(t, "vchannel-1", info.meta.Vchannel)
assert.Equal(t, streamingpb.VChannelState_VCHANNEL_STATE_NORMAL, info.meta.State)
assert.Equal(t, ts, info.meta.CheckpointTimeTick)
assert.Len(t, info.meta.CollectionInfo.Partitions, 2)
assert.True(t, info.dirty)
snapshot, shouldBeRemoved := info.ConsumeDirtyAndGetSnapshot()
assert.NotNil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
// CreatePartition
msg3 := message.NewCreatePartitionMessageBuilderV1().
WithHeader(&message.CreatePartitionMessageHeader{
CollectionId: 100,
PartitionId: 103,
}).
WithBody(&msgpb.CreatePartitionRequest{
CollectionName: "test-collection",
CollectionID: 100,
PartitionID: 103,
}).
WithVChannel("vchannel-1").
MustBuildMutable()
msgID3 := rmq.NewRmqID(3)
ts += 1
immutableMsg3 := msg3.WithTimeTick(ts).WithLastConfirmed(msgID3).IntoImmutableMessage(msgID3)
info.ObserveCreatePartition(message.MustAsImmutableCreatePartitionMessageV1(immutableMsg3))
// idempotent
info.ObserveCreatePartition(message.MustAsImmutableCreatePartitionMessageV1(immutableMsg3))
assert.Equal(t, "vchannel-1", info.meta.Vchannel)
assert.Equal(t, streamingpb.VChannelState_VCHANNEL_STATE_NORMAL, info.meta.State)
assert.Equal(t, ts, info.meta.CheckpointTimeTick)
assert.Len(t, info.meta.CollectionInfo.Partitions, 3)
assert.True(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.NotNil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
ts += 1
immutableMsg3 = msg3.WithTimeTick(ts).WithLastConfirmed(msgID3).IntoImmutableMessage(msgID3)
// idempotent
info.ObserveCreatePartition(message.MustAsImmutableCreatePartitionMessageV1(immutableMsg3))
assert.Len(t, info.meta.CollectionInfo.Partitions, 3)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
// DropPartition
msg4 := message.NewDropPartitionMessageBuilderV1().
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: 100,
PartitionId: 101,
}).
WithBody(&msgpb.DropPartitionRequest{
CollectionName: "test-collection",
CollectionID: 100,
PartitionID: 101,
}).
WithVChannel("vchannel-1").
MustBuildMutable()
msgID4 := rmq.NewRmqID(4)
ts += 1
immutableMsg4 := msg4.WithTimeTick(ts).WithLastConfirmed(msgID4).IntoImmutableMessage(msgID4)
info.ObserveDropPartition(message.MustAsImmutableDropPartitionMessageV1(immutableMsg4))
// idempotent
info.ObserveDropPartition(message.MustAsImmutableDropPartitionMessageV1(immutableMsg4))
assert.Equal(t, "vchannel-1", info.meta.Vchannel)
assert.Equal(t, streamingpb.VChannelState_VCHANNEL_STATE_NORMAL, info.meta.State)
assert.Equal(t, ts, info.meta.CheckpointTimeTick)
assert.Len(t, info.meta.CollectionInfo.Partitions, 2)
assert.NotContains(t, info.meta.CollectionInfo.Partitions, int64(101))
assert.True(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.NotNil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
ts += 1
immutableMsg4 = msg4.WithTimeTick(ts).WithLastConfirmed(msgID4).IntoImmutableMessage(msgID4)
// idempotent
info.ObserveDropPartition(message.MustAsImmutableDropPartitionMessageV1(immutableMsg4))
assert.Len(t, info.meta.CollectionInfo.Partitions, 2)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.False(t, shouldBeRemoved)
assert.False(t, info.dirty)
// DropCollection
msg2 := message.NewDropCollectionMessageBuilderV1().
WithHeader(&message.DropCollectionMessageHeader{
CollectionId: 100,
}).
WithBody(&msgpb.DropCollectionRequest{
CollectionName: "test-collection",
CollectionID: 100,
}).
WithVChannel("vchannel-1").
MustBuildMutable()
msgID2 := rmq.NewRmqID(2)
ts += 1
immutableMsg2 := msg2.WithTimeTick(ts).WithLastConfirmed(msgID2).IntoImmutableMessage(msgID2)
info.ObserveDropCollection(message.MustAsImmutableDropCollectionMessageV1(immutableMsg2))
assert.Equal(t, streamingpb.VChannelState_VCHANNEL_STATE_DROPPED, info.meta.State)
assert.Equal(t, ts, info.meta.CheckpointTimeTick)
assert.Len(t, info.meta.CollectionInfo.Partitions, 2)
assert.True(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.NotNil(t, snapshot)
assert.True(t, shouldBeRemoved)
assert.False(t, info.dirty)
snapshot, shouldBeRemoved = info.ConsumeDirtyAndGetSnapshot()
assert.Nil(t, snapshot)
assert.True(t, shouldBeRemoved)
}

View File

@ -82,17 +82,7 @@ message FlushMessageBody {}
message ManualFlushMessageBody {}
// CreateSegmentMessageBody is the body of create segment message.
message CreateSegmentMessageBody {
int64 collection_id = 1;
repeated CreateSegmentInfo segments = 2;
}
// CreateSegmentInfo is the info of create segment.
message CreateSegmentInfo {
int64 partition_id = 1;
int64 segment_id = 2;
int64 storage_version = 3;
}
message CreateSegmentMessageBody {}
// BeginTxnMessageBody is the body of begin transaction message.
// Just do nothing now.
@ -152,18 +142,23 @@ message DeleteMessageHeader {
// FlushMessageHeader just nothing.
message FlushMessageHeader {
int64 collection_id = 1;
repeated int64 segment_ids = 2;
int64 partition_id = 2;
int64 segment_id = 3;
}
// CreateSegmentMessageHeader just nothing.
message CreateSegmentMessageHeader {
int64 collection_id = 1;
repeated int64 segment_ids = 2;
int64 collection_id = 1;
int64 partition_id = 2;
int64 segment_id = 3;
int64 storage_version = 4; // the storage version of the segment.
uint64 max_segment_size = 5; // the max size bytes of the segment.
}
message ManualFlushMessageHeader {
int64 collection_id = 1;
uint64 flush_ts = 2;
repeated int64 segment_ids = 3; // the segment ids to be flushed
}
// CreateCollectionMessageHeader is the header of create collection message.
@ -216,6 +211,7 @@ message ImportMessageHeader {}
// SchemaChangeMessageHeader is the header of CollectionSchema update message.
message SchemaChangeMessageHeader{
int64 collection_id = 1;
repeated int64 flushed_segment_ids = 2;
}
// SchemaChangeMessageBody is the body of CollectionSchema update message.

File diff suppressed because it is too large Load Diff

View File

@ -487,6 +487,38 @@ message StreamingNodeManagerCollectStatusResponse {
StreamingNodeBalanceAttributes balance_attributes = 1;
}
///
/// VChannelMeta
///
// VChannelMeta is the meta information of a vchannel.
// We need to add vchannel meta in wal meta, so the wal can recover the information of it.
// The vchannel meta is also used to store the vchannel operation result, such as shard-splitting.
message VChannelMeta {
string vchannel = 1; // vchannel name.
VChannelState state = 2; // vchannel state.
CollectionInfoOfVChannel collection_info = 3; // if the channel is belong to a collection, the collection info will be setup.
uint64 checkpoint_time_tick = 4; // The timetick of checkpoint, the meta already see the message at this timetick.
}
// CollectionInfoOfVChannel is the collection info in vchannel.
message CollectionInfoOfVChannel {
int64 collection_id = 1; // collection id.
repeated PartitionInfoOfVChannel partitions = 2; // partitions.
}
// PartitionInfoOfVChannel is the partition info in vchannel.
message PartitionInfoOfVChannel {
int64 partition_id = 1; // partition id.
}
// VChannelState is the state of vchannel
enum VChannelState {
VCHANNEL_STATE_UNKNOWN = 0; // should never used.
VCHANNEL_STATE_NORMAL = 1; // vchannel is normal.
VCHANNEL_STATE_DROPPED = 2; // vchannel is dropped.
// VCHANNEL_STATE_SPLITTED = 3; // TODO: vchannel is splitted to other vchannels, used to support shard-splitting.
}
///
/// SegmentAssignment
///
@ -503,6 +535,7 @@ message SegmentAssignmentMeta {
SegmentAssignmentState state = 5;
SegmentAssignmentStat stat = 6;
int64 storage_version = 7;
uint64 checkpoint_time_tick = 8; // The timetick of checkpoint, the meta already see the message at this timetick.
}
// SegmentAssignmentState is the state of segment assignment.
@ -522,13 +555,20 @@ message SegmentAssignmentStat {
uint64 max_binary_size = 1;
uint64 inserted_rows = 2;
uint64 inserted_binary_size = 3;
int64 create_timestamp_nanoseconds = 4;
int64 last_modified_timestamp_nanoseconds = 5;
int64 create_timestamp = 4;
int64 last_modified_timestamp = 5;
uint64 binlog_counter = 6;
uint64 create_segment_time_tick = 7; // The timetick of create segment message in wal.
}
// The WALCheckpoint that is used to recovery the wal scanner.
message WALCheckpoint {
messages.MessageID messageID = 1;
messages.MessageID message_id = 1; // From here to recover all uncommited info.
// e.g., primary key index, segment assignment info, vchannel info...
// because current data path flush is slow, and managed by the coordinator, current current is not apply to it.
//
// because the data path flush is slow, so we add a new checkpoint here to promise fast recover the wal state from log.
uint64 time_tick = 2; // The timetick of checkpoint, keep consistecy with message_id.
// It's a hint for easier debugging.
int64 recovery_magic = 3; // The recovery version of the checkpoint, it's used to hint the future recovery info upgrading.
}

File diff suppressed because it is too large Load Diff

View File

@ -79,10 +79,13 @@ func marshalSpecializedHeader(t MessageType, h string, enc zapcore.ObjectEncoder
case *InsertMessageHeader:
enc.AddInt64("collectionID", header.GetCollectionId())
segmentIDs := make([]string, 0, len(header.GetPartitions()))
rows := make([]string, 0)
for _, partition := range header.GetPartitions() {
segmentIDs = append(segmentIDs, strconv.FormatInt(partition.GetSegmentAssignment().GetSegmentId(), 10))
rows = append(rows, strconv.FormatUint(partition.Rows, 10))
}
enc.AddString("segmentIDs", strings.Join(segmentIDs, "|"))
enc.AddString("rows", strings.Join(rows, "|"))
case *DeleteMessageHeader:
enc.AddInt64("collectionID", header.GetCollectionId())
case *CreateCollectionMessageHeader:
@ -97,21 +100,22 @@ func marshalSpecializedHeader(t MessageType, h string, enc zapcore.ObjectEncoder
enc.AddInt64("partitionID", header.GetPartitionId())
case *CreateSegmentMessageHeader:
enc.AddInt64("collectionID", header.GetCollectionId())
segmentIDs := make([]string, 0, len(header.GetSegmentIds()))
for _, segmentID := range header.GetSegmentIds() {
segmentIDs = append(segmentIDs, strconv.FormatInt(segmentID, 10))
}
enc.AddString("segmentIDs", strings.Join(segmentIDs, "|"))
enc.AddInt64("segmentID", header.GetSegmentId())
case *FlushMessageHeader:
enc.AddInt64("collectionID", header.GetCollectionId())
segmentIDs := make([]string, 0, len(header.GetSegmentIds()))
for _, segmentID := range header.GetSegmentIds() {
segmentIDs = append(segmentIDs, strconv.FormatInt(segmentID, 10))
}
enc.AddString("segmentIDs", strings.Join(segmentIDs, "|"))
enc.AddInt64("segmentID", header.GetSegmentId())
case *ManualFlushMessageHeader:
enc.AddInt64("collectionID", header.GetCollectionId())
encodeSegmentIDs(header.GetSegmentIds(), enc)
case *SchemaChangeMessageHeader:
case *ImportMessageHeader:
}
}
func encodeSegmentIDs(segmentIDs []int64, enc zapcore.ObjectEncoder) {
ids := make([]string, 0, len(segmentIDs))
for _, id := range segmentIDs {
ids = append(ids, strconv.FormatInt(id, 10))
}
enc.AddString("segmentIDs", strings.Join(ids, "|"))
}

View File

@ -12,7 +12,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
)
func CreateTestInsertMessage(t *testing.T, segmentID int64, totalRows int, timetick uint64, messageID MessageID) MutableMessage {
@ -171,17 +170,15 @@ func CreateTestCreateCollectionMessage(t *testing.T, collectionID int64, timetic
}
func CreateTestCreateSegmentMessage(t *testing.T, collectionID int64, timetick uint64, messageID MessageID) MutableMessage {
payload := &CreateSegmentMessageBody{
CollectionId: collectionID,
Segments: []*messagespb.CreateSegmentInfo{
{
PartitionId: 1,
SegmentId: 1,
},
},
}
payload := &CreateSegmentMessageBody{}
msg, err := NewCreateSegmentMessageBuilderV2().
WithHeader(&CreateSegmentMessageHeader{}).
WithHeader(&CreateSegmentMessageHeader{
CollectionId: collectionID,
PartitionId: 1,
SegmentId: 1,
StorageVersion: 1,
MaxSegmentSize: 1024,
}).
WithBody(payload).
WithVChannel("v1").
BuildMutable()

View File

@ -5392,6 +5392,11 @@ type streamingConfig struct {
// logging
LoggingAppendSlowThreshold ParamItem `refreshable:"true"`
// recovery configuration.
WALRecoveryPersistInterval ParamItem `refreshable:"true"`
WALRecoveryMaxDirtyMessage ParamItem `refreshable:"true"`
WALRecoveryGracefulCloseTimeout ParamItem `refreshable:"true"`
}
func (p *streamingConfig) init(base *BaseTable) {
@ -5529,6 +5534,39 @@ If the wal implementation is woodpecker, the minimum threshold is 3s`,
Export: true,
}
p.LoggingAppendSlowThreshold.Init(base.mgr)
p.WALRecoveryPersistInterval = ParamItem{
Key: "streaming.walRecovery.persistInterval",
Version: "2.6.0",
Doc: `The interval of persist recovery info, 10s by default.
Every the interval, the recovery info of wal will try to persist, and the checkpoint of wal can be advanced.
Currently it only affect the recovery of wal, but not affect the recovery of data flush into object storage`,
DefaultValue: "10s",
Export: true,
}
p.WALRecoveryPersistInterval.Init(base.mgr)
p.WALRecoveryMaxDirtyMessage = ParamItem{
Key: "streaming.walRecovery.maxDirtyMessage",
Version: "2.6.0",
Doc: `The max dirty message count of wal recovery, 100 by default.
If there are more than this count of dirty message in wal recovery info, it will be persisted immediately,
but not wait for the persist interval.`,
DefaultValue: "100",
Export: true,
}
p.WALRecoveryMaxDirtyMessage.Init(base.mgr)
p.WALRecoveryGracefulCloseTimeout = ParamItem{
Key: "streaming.walRecovery.gracefulCloseTimeout",
Version: "2.6.0",
Doc: `The graceful close timeout for wal recovery, 3s by default.
When the wal is on-closing, the recovery module will try to persist the recovery info for wal to make next recovery operation more fast.
If that persist operation exceeds this timeout, the wal recovery module will close right now.`,
DefaultValue: "3s",
Export: true,
}
p.WALRecoveryGracefulCloseTimeout.Init(base.mgr)
}
// runtimeConfig is just a private environment value table.

View File

@ -625,6 +625,10 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, 30*time.Second, params.StreamingCfg.WALWriteAheadBufferKeepalive.GetAsDurationByParse())
assert.Equal(t, int64(64*1024*1024), params.StreamingCfg.WALWriteAheadBufferCapacity.GetAsSize())
assert.Equal(t, 1*time.Second, params.StreamingCfg.LoggingAppendSlowThreshold.GetAsDurationByParse())
assert.Equal(t, 3*time.Second, params.StreamingCfg.WALRecoveryGracefulCloseTimeout.GetAsDurationByParse())
assert.Equal(t, 100, params.StreamingCfg.WALRecoveryMaxDirtyMessage.GetAsInt())
assert.Equal(t, 10*time.Second, params.StreamingCfg.WALRecoveryPersistInterval.GetAsDurationByParse())
params.Save(params.StreamingCfg.WALBalancerTriggerInterval.Key, "50s")
params.Save(params.StreamingCfg.WALBalancerBackoffInitialInterval.Key, "50s")
params.Save(params.StreamingCfg.WALBalancerBackoffMultiplier.Key, "3.5")
@ -639,6 +643,9 @@ func TestComponentParam(t *testing.T) {
params.Save(params.StreamingCfg.WALBalancerPolicyVChannelFairRebalanceTolerance.Key, "0.02")
params.Save(params.StreamingCfg.WALBalancerPolicyVChannelFairRebalanceMaxStep.Key, "4")
params.Save(params.StreamingCfg.LoggingAppendSlowThreshold.Key, "3s")
params.Save(params.StreamingCfg.WALRecoveryGracefulCloseTimeout.Key, "4s")
params.Save(params.StreamingCfg.WALRecoveryMaxDirtyMessage.Key, "200")
params.Save(params.StreamingCfg.WALRecoveryPersistInterval.Key, "20s")
assert.Equal(t, 50*time.Second, params.StreamingCfg.WALBalancerTriggerInterval.GetAsDurationByParse())
assert.Equal(t, 50*time.Second, params.StreamingCfg.WALBalancerBackoffInitialInterval.GetAsDurationByParse())
assert.Equal(t, 3.5, params.StreamingCfg.WALBalancerBackoffMultiplier.GetAsFloat())
@ -653,6 +660,9 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, 10*time.Second, params.StreamingCfg.WALWriteAheadBufferKeepalive.GetAsDurationByParse())
assert.Equal(t, int64(128*1024), params.StreamingCfg.WALWriteAheadBufferCapacity.GetAsSize())
assert.Equal(t, 3*time.Second, params.StreamingCfg.LoggingAppendSlowThreshold.GetAsDurationByParse())
assert.Equal(t, 4*time.Second, params.StreamingCfg.WALRecoveryGracefulCloseTimeout.GetAsDurationByParse())
assert.Equal(t, 200, params.StreamingCfg.WALRecoveryMaxDirtyMessage.GetAsInt())
assert.Equal(t, 20*time.Second, params.StreamingCfg.WALRecoveryPersistInterval.GetAsDurationByParse())
})
t.Run("channel config priority", func(t *testing.T) {