diff --git a/internal/.mockery.yaml b/internal/.mockery.yaml index 2179959e87..8804105c23 100644 --- a/internal/.mockery.yaml +++ b/internal/.mockery.yaml @@ -12,12 +12,16 @@ packages: github.com/milvus-io/milvus/internal/streamingcoord/server/balancer: interfaces: Balancer: - github.com/milvus-io/milvus/internal/streamingnode/client/manager: + github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster: interfaces: - ManagerClient: + AppendOperator: github.com/milvus-io/milvus/internal/streamingcoord/client: interfaces: Client: + BroadcastService: + github.com/milvus-io/milvus/internal/streamingnode/client/manager: + interfaces: + ManagerClient: github.com/milvus-io/milvus/internal/streamingnode/client/handler: interfaces: HandlerClient: @@ -46,10 +50,10 @@ packages: InterceptorWithReady: InterceptorBuilder: github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector: - interfaces: + interfaces: SealOperator: github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector: - interfaces: + interfaces: TimeTickSyncOperator: google.golang.org/grpc: interfaces: diff --git a/internal/distributed/streaming/append.go b/internal/distributed/streaming/append.go index 2fd0820e54..b4193d8b94 100644 --- a/internal/distributed/streaming/append.go +++ b/internal/distributed/streaming/append.go @@ -17,6 +17,12 @@ func (w *walAccesserImpl) appendToWAL(ctx context.Context, msg message.MutableMe return p.Produce(ctx, msg) } +func (w *walAccesserImpl) broadcastToWAL(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + // The broadcast operation will be sent to the coordinator. + // The coordinator will dispatch the message to all the vchannels with an eventual consistency guarantee. + return w.streamingCoordClient.Broadcast().Broadcast(ctx, msg) +} + // createOrGetProducer creates or get a producer. // vchannel in same pchannel can share the same producer. func (w *walAccesserImpl) getProducer(pchannel string) *producer.ResumableProducer { @@ -40,14 +46,19 @@ func assertValidMessage(msgs ...message.MutableMessage) { if msg.MessageType().IsSystem() { panic("system message is not allowed to append from client") } - } - for _, msg := range msgs { if msg.VChannel() == "" { - panic("vchannel is empty") + panic("we don't support sent all vchannel message at client now") } } } +// assertValidBroadcastMessage asserts the message is not system message. +func assertValidBroadcastMessage(msg message.BroadcastMutableMessage) { + if msg.MessageType().IsSystem() { + panic("system message is not allowed to broadcast append from client") + } +} + // We only support delete and insert message for txn now. func assertIsDmlMessage(msgs ...message.MutableMessage) { for _, msg := range msgs { diff --git a/internal/distributed/streaming/streaming.go b/internal/distributed/streaming/streaming.go index 810b15065d..efd77d5f2a 100644 --- a/internal/distributed/streaming/streaming.go +++ b/internal/distributed/streaming/streaming.go @@ -78,6 +78,7 @@ type Scanner interface { // WALAccesser is the interfaces to interact with the milvus write ahead log. type WALAccesser interface { + // WALName returns the name of the wal. WALName() string // Txn returns a transaction for writing records to the log. @@ -87,6 +88,10 @@ type WALAccesser interface { // RawAppend writes a records to the log. RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) + // BroadcastAppend sends a broadcast message to all target vchannels. + // BroadcastAppend guarantees the atomicity written of the messages and eventual consistency. + BroadcastAppend(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) + // Read returns a scanner for reading records from the wal. Read(ctx context.Context, opts ReadOption) Scanner diff --git a/internal/distributed/streaming/streaming_test.go b/internal/distributed/streaming/streaming_test.go index c24f652616..e44e18e7c2 100644 --- a/internal/distributed/streaming/streaming_test.go +++ b/internal/distributed/streaming/streaming_test.go @@ -14,7 +14,10 @@ import ( "github.com/milvus-io/milvus/pkg/util/paramtable" ) -const vChannel = "by-dev-rootcoord-dml_4" +var vChannels = []string{ + "by-dev-rootcoord-dml_4", + "by-dev-rootcoord-dml_5", +} func TestMain(m *testing.M) { paramtable.Init() @@ -33,10 +36,11 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.CreateCollectionRequest{ CollectionID: 1, }). - WithVChannel(vChannel). - BuildMutable() - resp, err := streaming.WAL().RawAppend(context.Background(), msg) - fmt.Printf("%+v\t%+v\n", resp, err) + WithBroadcast(vChannels). + BuildBroadcast() + + resp, err := streaming.WAL().BroadcastAppend(context.Background(), msg) + t.Logf("CreateCollection: %+v\t%+v\n", resp, err) for i := 0; i < 500; i++ { time.Sleep(time.Millisecond * 1) @@ -47,17 +51,17 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.InsertRequest{ CollectionID: 1, }). - WithVChannel(vChannel). + WithVChannel(vChannels[0]). BuildMutable() resp, err := streaming.WAL().RawAppend(context.Background(), msg) - fmt.Printf("%+v\t%+v\n", resp, err) + t.Logf("Insert: %+v\t%+v\n", resp, err) } for i := 0; i < 500; i++ { time.Sleep(time.Millisecond * 1) txn, err := streaming.WAL().Txn(context.Background(), streaming.TxnOption{ - VChannel: vChannel, - Keepalive: 100 * time.Millisecond, + VChannel: vChannels[0], + Keepalive: 500 * time.Millisecond, }) if err != nil { t.Errorf("txn failed: %v", err) @@ -71,7 +75,7 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.InsertRequest{ CollectionID: 1, }). - WithVChannel(vChannel). + WithVChannel(vChannels[0]). BuildMutable() err := txn.Append(context.Background(), msg) fmt.Printf("%+v\n", err) @@ -80,7 +84,7 @@ func TestStreamingProduce(t *testing.T) { if err != nil { t.Errorf("txn failed: %v", err) } - fmt.Printf("%+v\n", result) + t.Logf("txn commit: %+v\n", result) } msg, _ = message.NewDropCollectionMessageBuilderV1(). @@ -90,10 +94,10 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.DropCollectionRequest{ CollectionID: 1, }). - WithVChannel(vChannel). - BuildMutable() - resp, err = streaming.WAL().RawAppend(context.Background(), msg) - fmt.Printf("%+v\t%+v\n", resp, err) + WithBroadcast(vChannels). + BuildBroadcast() + resp, err = streaming.WAL().BroadcastAppend(context.Background(), msg) + t.Logf("DropCollection: %+v\t%+v\n", resp, err) } func TestStreamingConsume(t *testing.T) { @@ -102,7 +106,7 @@ func TestStreamingConsume(t *testing.T) { defer streaming.Release() ch := make(message.ChanMessageHandler, 10) s := streaming.WAL().Read(context.Background(), streaming.ReadOption{ - VChannel: vChannel, + VChannel: vChannels[0], DeliverPolicy: options.DeliverPolicyAll(), MessageHandler: ch, }) @@ -115,7 +119,7 @@ func TestStreamingConsume(t *testing.T) { time.Sleep(10 * time.Millisecond) select { case msg := <-ch: - fmt.Printf("msgID=%+v, msgType=%+v, tt=%d, lca=%+v, body=%s, idx=%d\n", + t.Logf("msgID=%+v, msgType=%+v, tt=%d, lca=%+v, body=%s, idx=%d\n", msg.MessageID(), msg.MessageType(), msg.TimeTick(), diff --git a/internal/distributed/streaming/wal.go b/internal/distributed/streaming/wal.go index b4d7fb5f90..f721f2d63b 100644 --- a/internal/distributed/streaming/wal.go +++ b/internal/distributed/streaming/wal.go @@ -29,11 +29,11 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl { // Create a new streamingnode handler client. handlerClient := handler.NewHandlerClient(streamingCoordClient.Assignment()) return &walAccesserImpl{ - lifetime: typeutil.NewLifetime(), - streamingCoordAssignmentClient: streamingCoordClient, - handlerClient: handlerClient, - producerMutex: sync.Mutex{}, - producers: make(map[string]*producer.ResumableProducer), + lifetime: typeutil.NewLifetime(), + streamingCoordClient: streamingCoordClient, + handlerClient: handlerClient, + producerMutex: sync.Mutex{}, + producers: make(map[string]*producer.ResumableProducer), // TODO: optimize the pool size, use the streaming api but not goroutines. appendExecutionPool: conc.NewPool[struct{}](10), @@ -46,8 +46,8 @@ type walAccesserImpl struct { lifetime *typeutil.Lifetime // All services - streamingCoordAssignmentClient client.Client - handlerClient handler.HandlerClient + streamingCoordClient client.Client + handlerClient handler.HandlerClient producerMutex sync.Mutex producers map[string]*producer.ResumableProducer @@ -71,6 +71,16 @@ func (w *walAccesserImpl) RawAppend(ctx context.Context, msg message.MutableMess return w.appendToWAL(ctx, msg) } +func (w *walAccesserImpl) BroadcastAppend(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + assertValidBroadcastMessage(msg) + if !w.lifetime.Add(typeutil.LifetimeStateWorking) { + return nil, ErrWALAccesserClosed + } + defer w.lifetime.Done() + + return w.broadcastToWAL(ctx, msg) +} + // Read returns a scanner for reading records from the wal. func (w *walAccesserImpl) Read(_ context.Context, opts ReadOption) Scanner { if !w.lifetime.Add(typeutil.LifetimeStateWorking) { @@ -149,7 +159,7 @@ func (w *walAccesserImpl) Close() { w.producerMutex.Unlock() w.handlerClient.Close() - w.streamingCoordAssignmentClient.Close() + w.streamingCoordClient.Close() } // newErrScanner creates a scanner that returns an error. diff --git a/internal/distributed/streaming/wal_test.go b/internal/distributed/streaming/wal_test.go index db527c044e..a850b9cce3 100644 --- a/internal/distributed/streaming/wal_test.go +++ b/internal/distributed/streaming/wal_test.go @@ -30,19 +30,33 @@ const ( func TestWAL(t *testing.T) { coordClient := mock_client.NewMockClient(t) coordClient.EXPECT().Close().Return() + broadcastServce := mock_client.NewMockBroadcastService(t) + broadcastServce.EXPECT().Broadcast(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, bmm message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + result := make(map[string]*types.AppendResult) + for idx, msg := range bmm.SplitIntoMutableMessage() { + result[msg.VChannel()] = &types.AppendResult{ + MessageID: walimplstest.NewTestMessageID(int64(idx)), + TimeTick: uint64(time.Now().UnixMilli()), + } + } + return &types.BroadcastAppendResult{ + AppendResults: result, + }, nil + }) + coordClient.EXPECT().Broadcast().Return(broadcastServce) handler := mock_handler.NewMockHandlerClient(t) handler.EXPECT().Close().Return() w := &walAccesserImpl{ - lifetime: typeutil.NewLifetime(), - streamingCoordAssignmentClient: coordClient, - handlerClient: handler, - producerMutex: sync.Mutex{}, - producers: make(map[string]*producer.ResumableProducer), - appendExecutionPool: conc.NewPool[struct{}](10), - dispatchExecutionPool: conc.NewPool[struct{}](10), + lifetime: typeutil.NewLifetime(), + streamingCoordClient: coordClient, + handlerClient: handler, + producerMutex: sync.Mutex{}, + producers: make(map[string]*producer.ResumableProducer), + appendExecutionPool: conc.NewPool[struct{}](10), + dispatchExecutionPool: conc.NewPool[struct{}](10), } - defer w.Close() ctx := context.Background() @@ -114,6 +128,18 @@ func TestWAL(t *testing.T) { newInsertMessage(vChannel3), ) assert.NoError(t, resp.UnwrapFirstError()) + + r, err := w.BroadcastAppend(ctx, newBroadcastMessage([]string{vChannel1, vChannel2, vChannel3})) + assert.NoError(t, err) + assert.Len(t, r.AppendResults, 3) + + w.Close() + + resp = w.AppendMessages(ctx, newInsertMessage(vChannel1)) + assert.Error(t, resp.UnwrapFirstError()) + r, err = w.BroadcastAppend(ctx, newBroadcastMessage([]string{vChannel1, vChannel2, vChannel3})) + assert.Error(t, err) + assert.Nil(t, r) } func newInsertMessage(vChannel string) message.MutableMessage { @@ -127,3 +153,15 @@ func newInsertMessage(vChannel string) message.MutableMessage { } return msg } + +func newBroadcastMessage(vchannels []string) message.BroadcastMutableMessage { + msg, err := message.NewDropCollectionMessageBuilderV1(). + WithBroadcast(vchannels). + WithHeader(&message.DropCollectionMessageHeader{}). + WithBody(&msgpb.DropCollectionRequest{}). + BuildBroadcast() + if err != nil { + panic(err) + } + return msg +} diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 090296d11b..c7a2042dd7 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -210,6 +210,15 @@ type StreamingCoordCataLog interface { // SavePChannel save a pchannel info to metastore. SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error + + // ListBroadcastTask list all broadcast tasks. + // Used to recovery the broadcast tasks. + ListBroadcastTask(ctx context.Context) ([]*streamingpb.BroadcastTask, error) + + // SaveBroadcastTask save the broadcast task to metastore. + // Make the task recoverable after restart. + // When broadcast task is done, it will be removed from metastore. + SaveBroadcastTask(ctx context.Context, task *streamingpb.BroadcastTask) error } // StreamingNodeCataLog is the interface for streamingnode catalog diff --git a/internal/metastore/kv/streamingcoord/constant.go b/internal/metastore/kv/streamingcoord/constant.go index 5ae1f85b7d..1f92dc9977 100644 --- a/internal/metastore/kv/streamingcoord/constant.go +++ b/internal/metastore/kv/streamingcoord/constant.go @@ -1,6 +1,7 @@ package streamingcoord const ( - MetaPrefix = "streamingcoord-meta" - PChannelMeta = MetaPrefix + "/pchannel" + MetaPrefix = "streamingcoord-meta/" + PChannelMetaPrefix = MetaPrefix + "pchannel/" + BroadcastTaskPrefix = MetaPrefix + "broadcast-task/" ) diff --git a/internal/metastore/kv/streamingcoord/kv_catalog.go b/internal/metastore/kv/streamingcoord/kv_catalog.go index d3d804052e..c0a16a5251 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog.go @@ -2,6 +2,7 @@ package streamingcoord import ( "context" + "strconv" "github.com/cockroachdb/errors" "google.golang.org/protobuf/proto" @@ -14,6 +15,14 @@ import ( ) // NewCataLog creates a new catalog instance +// streamingcoord-meta +// ├── broadcast +// │   ├── task-1 +// │   └── task-2 +// └── pchannel +// +// ├── pchannel-1 +// └── pchannel-2 func NewCataLog(metaKV kv.MetaKv) metastore.StreamingCoordCataLog { return &catalog{ metaKV: metaKV, @@ -27,7 +36,7 @@ type catalog struct { // ListPChannels returns all pchannels func (c *catalog) ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { - keys, values, err := c.metaKV.LoadWithPrefix(ctx, PChannelMeta) + keys, values, err := c.metaKV.LoadWithPrefix(ctx, PChannelMetaPrefix) if err != nil { return nil, err } @@ -60,7 +69,41 @@ func (c *catalog) SavePChannels(ctx context.Context, infos []*streamingpb.PChann }) } +func (c *catalog) ListBroadcastTask(ctx context.Context) ([]*streamingpb.BroadcastTask, error) { + keys, values, err := c.metaKV.LoadWithPrefix(ctx, BroadcastTaskPrefix) + if err != nil { + return nil, err + } + infos := make([]*streamingpb.BroadcastTask, 0, len(values)) + for k, value := range values { + info := &streamingpb.BroadcastTask{} + err = proto.Unmarshal([]byte(value), info) + if err != nil { + return nil, errors.Wrapf(err, "unmarshal broadcast task %s failed", keys[k]) + } + infos = append(infos, info) + } + return infos, nil +} + +func (c *catalog) SaveBroadcastTask(ctx context.Context, task *streamingpb.BroadcastTask) error { + key := buildBroadcastTaskPath(task.TaskId) + if task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE { + return c.metaKV.Remove(ctx, key) + } + v, err := proto.Marshal(task) + if err != nil { + return errors.Wrapf(err, "marshal broadcast task failed") + } + return c.metaKV.Save(ctx, key, string(v)) +} + // buildPChannelInfoPath builds the path for pchannel info. func buildPChannelInfoPath(name string) string { - return PChannelMeta + "/" + name + return PChannelMetaPrefix + name +} + +// buildBroadcastTaskPath builds the path for broadcast task. +func buildBroadcastTaskPath(id int64) string { + return BroadcastTaskPrefix + strconv.FormatInt(id, 10) } diff --git a/internal/metastore/kv/streamingcoord/kv_catalog_test.go b/internal/metastore/kv/streamingcoord/kv_catalog_test.go index 227ad0469b..215aee3d15 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog_test.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog_test.go @@ -2,6 +2,7 @@ package streamingcoord import ( "context" + "strings" "testing" "github.com/cockroachdb/errors" @@ -20,8 +21,10 @@ func TestCatalog(t *testing.T) { keys := make([]string, 0, len(kvStorage)) vals := make([]string, 0, len(kvStorage)) for k, v := range kvStorage { - keys = append(keys, k) - vals = append(vals, v) + if strings.HasPrefix(k, s) { + keys = append(keys, k) + vals = append(vals, v) + } } return keys, vals, nil }) @@ -31,12 +34,21 @@ func TestCatalog(t *testing.T) { } return nil }) + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, key, value string) error { + kvStorage[key] = value + return nil + }) + kv.EXPECT().Remove(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, key string) error { + delete(kvStorage, key) + return nil + }) catalog := NewCataLog(kv) metas, err := catalog.ListPChannel(context.Background()) assert.NoError(t, err) assert.Empty(t, metas) + // PChannel test err = catalog.SavePChannels(context.Background(), []*streamingpb.PChannelMeta{ { Channel: &streamingpb.PChannelInfo{Name: "test", Term: 1}, @@ -53,6 +65,37 @@ func TestCatalog(t *testing.T) { assert.NoError(t, err) assert.Len(t, metas, 2) + // BroadcastTask test + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{ + TaskId: 1, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + }) + assert.NoError(t, err) + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{ + TaskId: 2, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + }) + assert.NoError(t, err) + + tasks, err := catalog.ListBroadcastTask(context.Background()) + assert.NoError(t, err) + assert.Len(t, tasks, 2) + for _, task := range tasks { + assert.Equal(t, streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, task.State) + } + + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{ + TaskId: 1, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE, + }) + assert.NoError(t, err) + tasks, err = catalog.ListBroadcastTask(context.Background()) + assert.NoError(t, err) + assert.Len(t, tasks, 1) + for _, task := range tasks { + assert.Equal(t, streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, task.State) + } + // error path. kv.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Unset() kv.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Return(nil, nil, errors.New("load error")) @@ -60,7 +103,19 @@ func TestCatalog(t *testing.T) { assert.Error(t, err) assert.Nil(t, metas) + tasks, err = catalog.ListBroadcastTask(context.Background()) + assert.Error(t, err) + assert.Nil(t, tasks) + kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Unset() kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Return(errors.New("save error")) + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Unset() + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("save error")) + err = catalog.SavePChannels(context.Background(), []*streamingpb.PChannelMeta{{ + Channel: &streamingpb.PChannelInfo{Name: "test", Term: 1}, + Node: &streamingpb.StreamingNodeInfo{ServerId: 1}, + }}) + assert.Error(t, err) + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{}) assert.Error(t, err) } diff --git a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go index e077e04030..eb9f7ce2d8 100644 --- a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go +++ b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go @@ -149,6 +149,65 @@ func (_c *MockWALAccesser_AppendMessagesWithOption_Call) RunAndReturn(run func(c return _c } +// BroadcastAppend provides a mock function with given fields: ctx, msg +func (_m *MockWALAccesser) BroadcastAppend(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + ret := _m.Called(ctx, msg) + + if len(ret) == 0 { + panic("no return value specified for BroadcastAppend") + } + + var r0 *types.BroadcastAppendResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok { + r0 = rf(ctx, msg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.BroadcastAppendResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockWALAccesser_BroadcastAppend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BroadcastAppend' +type MockWALAccesser_BroadcastAppend_Call struct { + *mock.Call +} + +// BroadcastAppend is a helper method to define mock.On call +// - ctx context.Context +// - msg message.BroadcastMutableMessage +func (_e *MockWALAccesser_Expecter) BroadcastAppend(ctx interface{}, msg interface{}) *MockWALAccesser_BroadcastAppend_Call { + return &MockWALAccesser_BroadcastAppend_Call{Call: _e.mock.On("BroadcastAppend", ctx, msg)} +} + +func (_c *MockWALAccesser_BroadcastAppend_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockWALAccesser_BroadcastAppend_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage)) + }) + return _c +} + +func (_c *MockWALAccesser_BroadcastAppend_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockWALAccesser_BroadcastAppend_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockWALAccesser_BroadcastAppend_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockWALAccesser_BroadcastAppend_Call { + _c.Call.Return(run) + return _c +} + // RawAppend provides a mock function with given fields: ctx, msgs, opts func (_m *MockWALAccesser) RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption) (*types.AppendResult, error) { _va := make([]interface{}, len(opts)) diff --git a/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go index b0bc3b7775..651554d48b 100644 --- a/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go +++ b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go @@ -23,6 +23,64 @@ func (_m *MockStreamingCoordCataLog) EXPECT() *MockStreamingCoordCataLog_Expecte return &MockStreamingCoordCataLog_Expecter{mock: &_m.Mock} } +// ListBroadcastTask provides a mock function with given fields: ctx +func (_m *MockStreamingCoordCataLog) ListBroadcastTask(ctx context.Context) ([]*streamingpb.BroadcastTask, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListBroadcastTask") + } + + var r0 []*streamingpb.BroadcastTask + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*streamingpb.BroadcastTask, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*streamingpb.BroadcastTask); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*streamingpb.BroadcastTask) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamingCoordCataLog_ListBroadcastTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListBroadcastTask' +type MockStreamingCoordCataLog_ListBroadcastTask_Call struct { + *mock.Call +} + +// ListBroadcastTask is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockStreamingCoordCataLog_Expecter) ListBroadcastTask(ctx interface{}) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + return &MockStreamingCoordCataLog_ListBroadcastTask_Call{Call: _e.mock.On("ListBroadcastTask", ctx)} +} + +func (_c *MockStreamingCoordCataLog_ListBroadcastTask_Call) Run(run func(ctx context.Context)) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockStreamingCoordCataLog_ListBroadcastTask_Call) Return(_a0 []*streamingpb.BroadcastTask, _a1 error) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamingCoordCataLog_ListBroadcastTask_Call) RunAndReturn(run func(context.Context) ([]*streamingpb.BroadcastTask, error)) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + _c.Call.Return(run) + return _c +} + // ListPChannel provides a mock function with given fields: ctx func (_m *MockStreamingCoordCataLog) ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { ret := _m.Called(ctx) @@ -81,6 +139,53 @@ func (_c *MockStreamingCoordCataLog_ListPChannel_Call) RunAndReturn(run func(con return _c } +// SaveBroadcastTask provides a mock function with given fields: ctx, task +func (_m *MockStreamingCoordCataLog) SaveBroadcastTask(ctx context.Context, task *streamingpb.BroadcastTask) error { + ret := _m.Called(ctx, task) + + if len(ret) == 0 { + panic("no return value specified for SaveBroadcastTask") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.BroadcastTask) error); ok { + r0 = rf(ctx, task) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordCataLog_SaveBroadcastTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveBroadcastTask' +type MockStreamingCoordCataLog_SaveBroadcastTask_Call struct { + *mock.Call +} + +// SaveBroadcastTask is a helper method to define mock.On call +// - ctx context.Context +// - task *streamingpb.BroadcastTask +func (_e *MockStreamingCoordCataLog_Expecter) SaveBroadcastTask(ctx interface{}, task interface{}) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + return &MockStreamingCoordCataLog_SaveBroadcastTask_Call{Call: _e.mock.On("SaveBroadcastTask", ctx, task)} +} + +func (_c *MockStreamingCoordCataLog_SaveBroadcastTask_Call) Run(run func(ctx context.Context, task *streamingpb.BroadcastTask)) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*streamingpb.BroadcastTask)) + }) + return _c +} + +func (_c *MockStreamingCoordCataLog_SaveBroadcastTask_Call) Return(_a0 error) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordCataLog_SaveBroadcastTask_Call) RunAndReturn(run func(context.Context, *streamingpb.BroadcastTask) error) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + _c.Call.Return(run) + return _c +} + // SavePChannels provides a mock function with given fields: ctx, info func (_m *MockStreamingCoordCataLog) SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error { ret := _m.Called(ctx, info) diff --git a/internal/mocks/streamingcoord/mock_client/mock_BroadcastService.go b/internal/mocks/streamingcoord/mock_client/mock_BroadcastService.go new file mode 100644 index 0000000000..3c84e0cce1 --- /dev/null +++ b/internal/mocks/streamingcoord/mock_client/mock_BroadcastService.go @@ -0,0 +1,98 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_client + +import ( + context "context" + + message "github.com/milvus-io/milvus/pkg/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// MockBroadcastService is an autogenerated mock type for the BroadcastService type +type MockBroadcastService struct { + mock.Mock +} + +type MockBroadcastService_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBroadcastService) EXPECT() *MockBroadcastService_Expecter { + return &MockBroadcastService_Expecter{mock: &_m.Mock} +} + +// Broadcast provides a mock function with given fields: ctx, msg +func (_m *MockBroadcastService) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + ret := _m.Called(ctx, msg) + + if len(ret) == 0 { + panic("no return value specified for Broadcast") + } + + var r0 *types.BroadcastAppendResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok { + r0 = rf(ctx, msg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.BroadcastAppendResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroadcastService_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast' +type MockBroadcastService_Broadcast_Call struct { + *mock.Call +} + +// Broadcast is a helper method to define mock.On call +// - ctx context.Context +// - msg message.BroadcastMutableMessage +func (_e *MockBroadcastService_Expecter) Broadcast(ctx interface{}, msg interface{}) *MockBroadcastService_Broadcast_Call { + return &MockBroadcastService_Broadcast_Call{Call: _e.mock.On("Broadcast", ctx, msg)} +} + +func (_c *MockBroadcastService_Broadcast_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockBroadcastService_Broadcast_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage)) + }) + return _c +} + +func (_c *MockBroadcastService_Broadcast_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockBroadcastService_Broadcast_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroadcastService_Broadcast_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockBroadcastService_Broadcast_Call { + _c.Call.Return(run) + return _c +} + +// NewMockBroadcastService creates a new instance of MockBroadcastService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBroadcastService(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBroadcastService { + mock := &MockBroadcastService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingcoord/mock_client/mock_Client.go b/internal/mocks/streamingcoord/mock_client/mock_Client.go index 02923644d2..574e08d015 100644 --- a/internal/mocks/streamingcoord/mock_client/mock_Client.go +++ b/internal/mocks/streamingcoord/mock_client/mock_Client.go @@ -67,6 +67,53 @@ func (_c *MockClient_Assignment_Call) RunAndReturn(run func() client.AssignmentS return _c } +// Broadcast provides a mock function with given fields: +func (_m *MockClient) Broadcast() client.BroadcastService { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Broadcast") + } + + var r0 client.BroadcastService + if rf, ok := ret.Get(0).(func() client.BroadcastService); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.BroadcastService) + } + } + + return r0 +} + +// MockClient_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast' +type MockClient_Broadcast_Call struct { + *mock.Call +} + +// Broadcast is a helper method to define mock.On call +func (_e *MockClient_Expecter) Broadcast() *MockClient_Broadcast_Call { + return &MockClient_Broadcast_Call{Call: _e.mock.On("Broadcast")} +} + +func (_c *MockClient_Broadcast_Call) Run(run func()) *MockClient_Broadcast_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Broadcast_Call) Return(_a0 client.BroadcastService) *MockClient_Broadcast_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_Broadcast_Call) RunAndReturn(run func() client.BroadcastService) *MockClient_Broadcast_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: func (_m *MockClient) Close() { _m.Called() diff --git a/internal/mocks/streamingcoord/server/mock_broadcaster/mock_AppendOperator.go b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_AppendOperator.go new file mode 100644 index 0000000000..8f049c5616 --- /dev/null +++ b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_AppendOperator.go @@ -0,0 +1,100 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_broadcaster + +import ( + context "context" + + message "github.com/milvus-io/milvus/pkg/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + streaming "github.com/milvus-io/milvus/internal/distributed/streaming" +) + +// MockAppendOperator is an autogenerated mock type for the AppendOperator type +type MockAppendOperator struct { + mock.Mock +} + +type MockAppendOperator_Expecter struct { + mock *mock.Mock +} + +func (_m *MockAppendOperator) EXPECT() *MockAppendOperator_Expecter { + return &MockAppendOperator_Expecter{mock: &_m.Mock} +} + +// AppendMessages provides a mock function with given fields: ctx, msgs +func (_m *MockAppendOperator) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { + _va := make([]interface{}, len(msgs)) + for _i := range msgs { + _va[_i] = msgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AppendMessages") + } + + var r0 streaming.AppendResponses + if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) streaming.AppendResponses); ok { + r0 = rf(ctx, msgs...) + } else { + r0 = ret.Get(0).(streaming.AppendResponses) + } + + return r0 +} + +// MockAppendOperator_AppendMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessages' +type MockAppendOperator_AppendMessages_Call struct { + *mock.Call +} + +// AppendMessages is a helper method to define mock.On call +// - ctx context.Context +// - msgs ...message.MutableMessage +func (_e *MockAppendOperator_Expecter) AppendMessages(ctx interface{}, msgs ...interface{}) *MockAppendOperator_AppendMessages_Call { + return &MockAppendOperator_AppendMessages_Call{Call: _e.mock.On("AppendMessages", + append([]interface{}{ctx}, msgs...)...)} +} + +func (_c *MockAppendOperator_AppendMessages_Call) Run(run func(ctx context.Context, msgs ...message.MutableMessage)) *MockAppendOperator_AppendMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]message.MutableMessage, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(message.MutableMessage) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockAppendOperator_AppendMessages_Call) Return(_a0 streaming.AppendResponses) *MockAppendOperator_AppendMessages_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockAppendOperator_AppendMessages_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) streaming.AppendResponses) *MockAppendOperator_AppendMessages_Call { + _c.Call.Return(run) + return _c +} + +// NewMockAppendOperator creates a new instance of MockAppendOperator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockAppendOperator(t interface { + mock.TestingT + Cleanup(func()) +}) *MockAppendOperator { + mock := &MockAppendOperator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/streamingcoord/client/broadcast/broadcast_impl.go b/internal/streamingcoord/client/broadcast/broadcast_impl.go new file mode 100644 index 0000000000..b6296748d1 --- /dev/null +++ b/internal/streamingcoord/client/broadcast/broadcast_impl.go @@ -0,0 +1,56 @@ +package broadcast + +import ( + "context" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// NewBroadcastService creates a new broadcast service. +func NewBroadcastService(walName string, service lazygrpc.Service[streamingpb.StreamingCoordBroadcastServiceClient]) *BroadcastServiceImpl { + return &BroadcastServiceImpl{ + walName: walName, + service: service, + } +} + +// BroadcastServiceImpl is the implementation of BroadcastService. +type BroadcastServiceImpl struct { + walName string + service lazygrpc.Service[streamingpb.StreamingCoordBroadcastServiceClient] +} + +// Broadcast sends a broadcast message to the streaming coord to perform a broadcast. +func (c *BroadcastServiceImpl) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + client, err := c.service.GetService(ctx) + if err != nil { + return nil, err + } + resp, err := client.Broadcast(ctx, &streamingpb.BroadcastRequest{ + Message: &messagespb.Message{ + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), + }, + }) + if err != nil { + return nil, err + } + results := make(map[string]*types.AppendResult, len(resp.Results)) + for channel, result := range resp.Results { + msgID, err := message.UnmarshalMessageID(c.walName, result.Id.Id) + if err != nil { + return nil, err + } + results[channel] = &types.AppendResult{ + MessageID: msgID, + TimeTick: result.GetTimetick(), + TxnCtx: message.NewTxnContextFromProto(result.GetTxnContext()), + Extra: result.GetExtra(), + } + } + return &types.BroadcastAppendResult{AppendResults: results}, nil +} diff --git a/internal/streamingcoord/client/client.go b/internal/streamingcoord/client/client.go index 83a55fd107..07f0937360 100644 --- a/internal/streamingcoord/client/client.go +++ b/internal/streamingcoord/client/client.go @@ -11,12 +11,15 @@ import ( "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/streamingcoord/client/assignment" + "github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker" streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" + "github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/interceptor" @@ -32,8 +35,16 @@ type AssignmentService interface { types.AssignmentDiscoverWatcher } +// BroadcastService is the interface of broadcast service. +type BroadcastService interface { + // Broadcast sends a broadcast message to the streaming service. + Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) +} + // Client is the interface of log service client. type Client interface { + Broadcast() BroadcastService + // Assignment access assignment service. Assignment() AssignmentService @@ -58,10 +69,12 @@ func NewClient(etcdCli *clientv3.Client) Client { ) }) assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient) + broadcastService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordBroadcastServiceClient) return &clientImpl{ conn: conn, rb: rb, assignmentService: assignment.NewAssignmentService(assignmentService), + broadcastService: broadcast.NewBroadcastService(util.MustSelectWALName(), broadcastService), } } diff --git a/internal/streamingcoord/client/client_impl.go b/internal/streamingcoord/client/client_impl.go index ffb0b0355a..88c94794e1 100644 --- a/internal/streamingcoord/client/client_impl.go +++ b/internal/streamingcoord/client/client_impl.go @@ -2,6 +2,7 @@ package client import ( "github.com/milvus-io/milvus/internal/streamingcoord/client/assignment" + "github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" ) @@ -11,6 +12,11 @@ type clientImpl struct { conn lazygrpc.Conn rb resolver.Builder assignmentService *assignment.AssignmentServiceImpl + broadcastService *broadcast.BroadcastServiceImpl +} + +func (c *clientImpl) Broadcast() BroadcastService { + return c.broadcastService } // Assignment access assignment service. diff --git a/internal/streamingcoord/server/broadcaster/append_operator.go b/internal/streamingcoord/server/broadcaster/append_operator.go new file mode 100644 index 0000000000..ec849ea2be --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/append_operator.go @@ -0,0 +1,14 @@ +package broadcaster + +import ( + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/util/streamingutil" +) + +// NewAppendOperator creates an append operator to handle the incoming messages for broadcaster. +func NewAppendOperator() AppendOperator { + if streamingutil.IsStreamingServiceEnabled() { + return streaming.WAL() + } + return nil +} diff --git a/internal/streamingcoord/server/broadcaster/broadcaster.go b/internal/streamingcoord/server/broadcaster/broadcaster.go new file mode 100644 index 0000000000..79e77bb882 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcaster.go @@ -0,0 +1,24 @@ +package broadcaster + +import ( + "context" + + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +type Broadcaster interface { + // Broadcast broadcasts the message to all channels. + Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) + + // Close closes the broadcaster. + Close() +} + +// AppendOperator is used to append messages, there's only two implement of this interface: +// 1. streaming.WAL() +// 2. old msgstream interface +type AppendOperator interface { + AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses +} diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_impl.go b/internal/streamingcoord/server/broadcaster/broadcaster_impl.go new file mode 100644 index 0000000000..2da0e0679f --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcaster_impl.go @@ -0,0 +1,207 @@ +package broadcaster + +import ( + "context" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/contextutil" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func RecoverBroadcaster( + ctx context.Context, + appendOperator AppendOperator, +) (Broadcaster, error) { + logger := resource.Resource().Logger().With(log.FieldComponent("broadcaster")) + tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx) + if err != nil { + return nil, err + } + pendings := make([]*broadcastTask, 0, len(tasks)) + for _, task := range tasks { + if task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING { + // recover pending task + t := newTask(task, logger) + pendings = append(pendings, t) + } + } + b := &broadcasterImpl{ + logger: logger, + lifetime: typeutil.NewLifetime(), + backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), + pendings: pendings, + backoffs: typeutil.NewHeap[*broadcastTask](&broadcastTaskArray{}), + backoffChan: make(chan *broadcastTask), + pendingChan: make(chan *broadcastTask), + workerChan: make(chan *broadcastTask), + appendOperator: appendOperator, + } + go b.execute() + return b, nil +} + +// broadcasterImpl is the implementation of Broadcaster +type broadcasterImpl struct { + logger *log.MLogger + lifetime *typeutil.Lifetime + backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] + pendings []*broadcastTask + backoffs typeutil.Heap[*broadcastTask] + pendingChan chan *broadcastTask + backoffChan chan *broadcastTask + workerChan chan *broadcastTask + appendOperator AppendOperator +} + +// Broadcast broadcasts the message to all channels. +func (b *broadcasterImpl) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (result *types.BroadcastAppendResult, err error) { + if !b.lifetime.Add(typeutil.LifetimeStateWorking) { + return nil, status.NewOnShutdownError("broadcaster is closing") + } + defer func() { + if err != nil { + b.logger.Warn("broadcast message failed", zap.Error(err)) + return + } + }() + + // Once the task is persisted, it must be successful. + task, err := b.persistBroadcastTask(ctx, msg) + if err != nil { + return nil, err + } + t := newTask(task, b.logger) + select { + case <-b.backgroundTaskNotifier.Context().Done(): + // We can only check the background context but not the request context here. + // Because we want the new incoming task must be delivered to the background task queue + // otherwise the broadcaster is closing + return nil, status.NewOnShutdownError("broadcaster is closing") + case b.pendingChan <- t: + } + + // Wait both request context and the background task context. + ctx, _ = contextutil.MergeContext(ctx, b.backgroundTaskNotifier.Context()) + return t.BlockUntilTaskDone(ctx) +} + +// persistBroadcastTask persists the broadcast task into catalog. +func (b *broadcasterImpl) persistBroadcastTask(ctx context.Context, msg message.BroadcastMutableMessage) (*streamingpb.BroadcastTask, error) { + defer b.lifetime.Done() + + id, err := resource.Resource().IDAllocator().Allocate(ctx) + if err != nil { + return nil, status.NewInner("allocate new id failed, %s", err.Error()) + } + task := &streamingpb.BroadcastTask{ + TaskId: int64(id), + Message: &messagespb.Message{Payload: msg.Payload(), Properties: msg.Properties().ToRawMap()}, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + } + // Save the task into catalog to help recovery. + if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, task); err != nil { + return nil, status.NewInner("save broadcast task failed, %s", err.Error()) + } + return task, nil +} + +func (b *broadcasterImpl) Close() { + b.lifetime.SetState(typeutil.LifetimeStateStopped) + b.lifetime.Wait() + + b.backgroundTaskNotifier.Cancel() + b.backgroundTaskNotifier.BlockUntilFinish() +} + +// execute the broadcaster +func (b *broadcasterImpl) execute() { + b.logger.Info("broadcaster start to execute") + defer func() { + b.backgroundTaskNotifier.Finish(struct{}{}) + b.logger.Info("broadcaster execute exit") + }() + + // Start n workers to handle the broadcast task. + wg := sync.WaitGroup{} + for i := 0; i < 4; i++ { + i := i + // Start n workers to handle the broadcast task. + wg.Add(1) + go func() { + defer wg.Done() + b.worker(i) + }() + } + defer wg.Wait() + + b.dispatch() +} + +func (b *broadcasterImpl) dispatch() { + for { + var workerChan chan *broadcastTask + var nextTask *broadcastTask + var nextBackOff <-chan time.Time + // Wait for new task. + if len(b.pendings) > 0 { + workerChan = b.workerChan + nextTask = b.pendings[0] + } + if b.backoffs.Len() > 0 { + var nextInterval time.Duration + nextBackOff, nextInterval = b.backoffs.Peek().NextTimer() + b.logger.Info("backoff task", zap.Duration("nextInterval", nextInterval)) + } + + select { + case <-b.backgroundTaskNotifier.Context().Done(): + return + case task := <-b.pendingChan: + b.pendings = append(b.pendings, task) + case task := <-b.backoffChan: + // task is backoff, push it into backoff queue to make a delay retry. + b.backoffs.Push(task) + case <-nextBackOff: + // backoff is done, move all the backoff done task into pending to retry. + for b.backoffs.Len() > 0 && b.backoffs.Peek().NextInterval() < time.Millisecond { + b.pendings = append(b.pendings, b.backoffs.Pop()) + } + case workerChan <- nextTask: + // The task is sent to worker, remove it from pending list. + b.pendings = b.pendings[1:] + } + } +} + +func (b *broadcasterImpl) worker(no int) { + defer func() { + b.logger.Info("broadcaster worker exit", zap.Int("no", no)) + }() + + for { + select { + case <-b.backgroundTaskNotifier.Context().Done(): + return + case task := <-b.workerChan: + if err := task.Poll(b.backgroundTaskNotifier.Context(), b.appendOperator); err != nil { + // If the task is not done, repush it into pendings and retry infinitely. + select { + case <-b.backgroundTaskNotifier.Context().Done(): + return + case b.backoffChan <- task: + } + } + } + } +} diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_test.go b/internal/streamingcoord/server/broadcaster/broadcaster_test.go new file mode 100644 index 0000000000..624535f1c8 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcaster_test.go @@ -0,0 +1,142 @@ +package broadcaster + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/atomic" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster" + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + internaltypes "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +func TestBroadcaster(t *testing.T) { + meta := mock_metastore.NewMockStreamingCoordCataLog(t) + meta.EXPECT().ListBroadcastTask(mock.Anything). + RunAndReturn(func(ctx context.Context) ([]*streamingpb.BroadcastTask, error) { + return []*streamingpb.BroadcastTask{ + createNewBroadcastTask(1, []string{"v1"}), + createNewBroadcastTask(2, []string{"v1", "v2"}), + createNewBroadcastTask(3, []string{"v1", "v2", "v3"}), + }, nil + }).Times(1) + done := atomic.NewInt64(0) + meta.EXPECT().SaveBroadcastTask(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, bt *streamingpb.BroadcastTask) error { + // may failure + if rand.Int31n(10) < 5 { + return errors.New("save task failed") + } + if bt.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE { + done.Inc() + } + return nil + }) + rc := idalloc.NewMockRootCoordClient(t) + f := syncutil.NewFuture[internaltypes.RootCoordClient]() + f.Set(rc) + resource.InitForTest(resource.OptStreamingCatalog(meta), resource.OptRootCoordClient(f)) + + operator, appended := createOpeartor(t) + bc, err := RecoverBroadcaster(context.Background(), operator) + assert.NoError(t, err) + assert.NotNil(t, bc) + assert.Eventually(t, func() bool { + return appended.Load() == 6 && done.Load() == 3 + }, 10*time.Second, 10*time.Millisecond) + + var result *types.BroadcastAppendResult + for { + var err error + result, err = bc.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"})) + if err == nil { + break + } + } + assert.Equal(t, int(appended.Load()), 9) + assert.Equal(t, len(result.AppendResults), 3) + + assert.Eventually(t, func() bool { + return done.Load() == 4 + }, 10*time.Second, 10*time.Millisecond) + + // TODO: error path. + bc.Close() + + result, err = bc.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"})) + assert.Error(t, err) + assert.Nil(t, result) +} + +func createOpeartor(t *testing.T) (AppendOperator, *atomic.Int64) { + id := atomic.NewInt64(1) + appended := atomic.NewInt64(0) + operator := mock_broadcaster.NewMockAppendOperator(t) + f := func(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { + resps := streaming.AppendResponses{ + Responses: make([]streaming.AppendResponse, len(msgs)), + } + for idx := range msgs { + newID := walimplstest.NewTestMessageID(id.Inc()) + if rand.Int31n(10) < 5 { + resps.Responses[idx] = streaming.AppendResponse{ + Error: errors.New("append failed"), + } + continue + } + resps.Responses[idx] = streaming.AppendResponse{ + AppendResult: &types.AppendResult{ + MessageID: newID, + TimeTick: uint64(time.Now().UnixMilli()), + }, + Error: nil, + } + appended.Inc() + } + return resps + } + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything).RunAndReturn(f) + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) + return operator, appended +} + +func createNewBroadcastMsg(vchannels []string) message.BroadcastMutableMessage { + msg, err := message.NewDropCollectionMessageBuilderV1(). + WithHeader(&messagespb.DropCollectionMessageHeader{}). + WithBody(&msgpb.DropCollectionRequest{}). + WithBroadcast(vchannels). + BuildBroadcast() + if err != nil { + panic(err) + } + return msg +} + +func createNewBroadcastTask(taskID int64, vchannels []string) *streamingpb.BroadcastTask { + msg := createNewBroadcastMsg(vchannels) + return &streamingpb.BroadcastTask{ + TaskId: taskID, + Message: &messagespb.Message{ + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), + }, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + } +} diff --git a/internal/streamingcoord/server/broadcaster/task.go b/internal/streamingcoord/server/broadcaster/task.go new file mode 100644 index 0000000000..52a2b0e77d --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/task.go @@ -0,0 +1,126 @@ +package broadcaster + +import ( + "context" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var errBroadcastTaskIsNotDone = errors.New("broadcast task is not done") + +// newTask creates a new task +func newTask(task *streamingpb.BroadcastTask, logger *log.MLogger) *broadcastTask { + bt := message.NewBroadcastMutableMessage(task.Message.Payload, task.Message.Properties) + msgs := bt.SplitIntoMutableMessage() + return &broadcastTask{ + logger: logger.With(zap.Int64("taskID", task.TaskId), zap.Int("broadcastTotal", len(msgs))), + task: task, + pendingMessages: msgs, + appendResult: make(map[string]*types.AppendResult, len(msgs)), + future: syncutil.NewFuture[*types.BroadcastAppendResult](), + BackoffWithInstant: typeutil.NewBackoffWithInstant(typeutil.BackoffTimerConfig{ + Default: 10 * time.Second, + Backoff: typeutil.BackoffConfig{ + InitialInterval: 10 * time.Millisecond, + Multiplier: 2.0, + MaxInterval: 10 * time.Second, + }, + }), + } +} + +// broadcastTask is the task for broadcasting messages. +type broadcastTask struct { + logger *log.MLogger + task *streamingpb.BroadcastTask + pendingMessages []message.MutableMessage + appendResult map[string]*types.AppendResult + future *syncutil.Future[*types.BroadcastAppendResult] + *typeutil.BackoffWithInstant +} + +// Poll polls the task, return nil if the task is done, otherwise not done. +// Poll can be repeated called until the task is done. +func (b *broadcastTask) Poll(ctx context.Context, operator AppendOperator) error { + if len(b.pendingMessages) > 0 { + b.logger.Debug("broadcast task is polling to make sent...", zap.Int("pendingMessages", len(b.pendingMessages))) + resps := operator.AppendMessages(ctx, b.pendingMessages...) + newPendings := make([]message.MutableMessage, 0) + for idx, resp := range resps.Responses { + if resp.Error != nil { + newPendings = append(newPendings, b.pendingMessages[idx]) + continue + } + b.appendResult[b.pendingMessages[idx].VChannel()] = resp.AppendResult + } + b.pendingMessages = newPendings + if len(newPendings) == 0 { + b.future.Set(&types.BroadcastAppendResult{AppendResults: b.appendResult}) + } + b.logger.Info("broadcast task make a new broadcast done", zap.Int("pendingMessages", len(b.pendingMessages))) + } + if len(b.pendingMessages) == 0 { + // There's no more pending message, mark the task as done. + b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE + if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.task); err != nil { + b.logger.Warn("save broadcast task failed", zap.Error(err)) + b.UpdateInstantWithNextBackOff() + return err + } + return nil + } + b.UpdateInstantWithNextBackOff() + return errBroadcastTaskIsNotDone +} + +// BlockUntilTaskDone blocks until the task is done. +func (b *broadcastTask) BlockUntilTaskDone(ctx context.Context) (*types.BroadcastAppendResult, error) { + return b.future.GetWithContext(ctx) +} + +type broadcastTaskArray []*broadcastTask + +// Len returns the length of the heap. +func (h broadcastTaskArray) Len() int { + return len(h) +} + +// Less returns true if the element at index i is less than the element at index j. +func (h broadcastTaskArray) Less(i, j int) bool { + return h[i].NextInstant().Before(h[j].NextInstant()) +} + +// Swap swaps the elements at indexes i and j. +func (h broadcastTaskArray) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +// Push pushes the last one at len. +func (h *broadcastTaskArray) Push(x interface{}) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(*broadcastTask)) +} + +// Pop pop the last one at len. +func (h *broadcastTaskArray) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// Peek returns the element at the top of the heap. +// Panics if the heap is empty. +func (h *broadcastTaskArray) Peek() interface{} { + return (*h)[0] +} diff --git a/internal/streamingcoord/server/builder.go b/internal/streamingcoord/server/builder.go index 4d2215b6df..dcbb5eeb4c 100644 --- a/internal/streamingcoord/server/builder.go +++ b/internal/streamingcoord/server/builder.go @@ -5,6 +5,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/kv/streamingcoord" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/streamingcoord/server/service" "github.com/milvus-io/milvus/internal/types" @@ -52,10 +53,13 @@ func (s *ServerBuilder) Build() *Server { resource.OptRootCoordClient(s.rootCoordClient), ) balancer := syncutil.NewFuture[balancer.Balancer]() + broadcaster := syncutil.NewFuture[broadcaster.Broadcaster]() return &Server{ logger: resource.Resource().Logger().With(log.FieldComponent("server")), session: s.session, assignmentService: service.NewAssignmentService(balancer), + broadcastService: service.NewBroadcastService(broadcaster), balancer: balancer, + broadcaster: broadcaster, } } diff --git a/internal/streamingcoord/server/resource/resource.go b/internal/streamingcoord/server/resource/resource.go index 89b8dee573..96a92e3727 100644 --- a/internal/streamingcoord/server/resource/resource.go +++ b/internal/streamingcoord/server/resource/resource.go @@ -8,6 +8,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/streamingnode/client/manager" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -29,6 +30,7 @@ func OptETCD(etcd *clientv3.Client) optResourceInit { func OptRootCoordClient(rootCoordClient *syncutil.Future[types.RootCoordClient]) optResourceInit { return func(r *resourceImpl) { r.rootCoordClient = rootCoordClient + r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } } @@ -48,6 +50,7 @@ func Init(opts ...optResourceInit) { for _, opt := range opts { opt(newR) } + assertNotNil(newR.IDAllocator()) assertNotNil(newR.RootCoordClient()) assertNotNil(newR.ETCD()) assertNotNil(newR.StreamingCatalog()) @@ -64,6 +67,7 @@ func Resource() *resourceImpl { // resourceImpl is a basic resource dependency for streamingnode server. // All utility on it is concurrent-safe and singleton. type resourceImpl struct { + idAllocator idalloc.Allocator rootCoordClient *syncutil.Future[types.RootCoordClient] etcdClient *clientv3.Client streamingCatalog metastore.StreamingCoordCataLog @@ -76,6 +80,11 @@ func (r *resourceImpl) RootCoordClient() *syncutil.Future[types.RootCoordClient] return r.rootCoordClient } +// IDAllocator returns the IDAllocator client. +func (r *resourceImpl) IDAllocator() idalloc.Allocator { + return r.idAllocator +} + // StreamingCatalog returns the StreamingCatalog client. func (r *resourceImpl) StreamingCatalog() metastore.StreamingCoordCataLog { return r.streamingCatalog diff --git a/internal/streamingcoord/server/server.go b/internal/streamingcoord/server/server.go index 2b9e50f3c2..f465d1b4b6 100644 --- a/internal/streamingcoord/server/server.go +++ b/internal/streamingcoord/server/server.go @@ -8,6 +8,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" _ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" "github.com/milvus-io/milvus/internal/streamingcoord/server/service" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/streamingutil" @@ -27,9 +28,11 @@ type Server struct { // service level variables. assignmentService service.AssignmentService + broadcastService service.BroadcastService // basic component variables can be used at service level. - balancer *syncutil.Future[balancer.Balancer] + balancer *syncutil.Future[balancer.Balancer] + broadcaster *syncutil.Future[broadcaster.Broadcaster] } // Init initializes the streamingcoord server. @@ -46,8 +49,9 @@ func (s *Server) Start(ctx context.Context) (err error) { // initBasicComponent initialize all underlying dependency for streamingcoord. func (s *Server) initBasicComponent(ctx context.Context) (err error) { + futures := make([]*conc.Future[struct{}], 0) if streamingutil.IsStreamingServiceEnabled() { - fBalancer := conc.Go(func() (struct{}, error) { + futures = append(futures, conc.Go(func() (struct{}, error) { s.logger.Info("start recovery balancer...") // Read new incoming topics from configuration, and register it into balancer. newIncomingTopics := util.GetAllTopicsFromConfiguration() @@ -59,10 +63,22 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) { s.balancer.Set(balancer) s.logger.Info("recover balancer done") return struct{}{}, nil - }) - return conc.AwaitAll(fBalancer) + })) } - return nil + // The broadcaster of msgstream is implemented on current streamingcoord to reduce the development complexity. + // So we need to recover it. + futures = append(futures, conc.Go(func() (struct{}, error) { + s.logger.Info("start recovery broadcaster...") + broadcaster, err := broadcaster.RecoverBroadcaster(ctx, broadcaster.NewAppendOperator()) + if err != nil { + s.logger.Warn("recover broadcaster failed", zap.Error(err)) + return struct{}{}, err + } + s.broadcaster.Set(broadcaster) + s.logger.Info("recover broadcaster done") + return struct{}{}, nil + })) + return conc.AwaitAll(futures...) } // RegisterGRPCService register all grpc service to grpc server. @@ -70,6 +86,7 @@ func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) { if streamingutil.IsStreamingServiceEnabled() { streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService) } + streamingpb.RegisterStreamingCoordBroadcastServiceServer(grpcServer, s.broadcastService) } // Close closes the streamingcoord server. @@ -80,5 +97,11 @@ func (s *Server) Stop() { } else { s.logger.Info("balancer not ready, skip close") } + if s.broadcaster.Ready() { + s.logger.Info("start close broadcaster...") + s.broadcaster.Get().Close() + } else { + s.logger.Info("broadcaster not ready, skip close") + } s.logger.Info("streamingcoord server stopped") } diff --git a/internal/streamingcoord/server/service/broadcast.go b/internal/streamingcoord/server/service/broadcast.go new file mode 100644 index 0000000000..6d192615e3 --- /dev/null +++ b/internal/streamingcoord/server/service/broadcast.go @@ -0,0 +1,44 @@ +package service + +import ( + "context" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +// BroadcastService is the interface of the broadcast service. +type BroadcastService interface { + streamingpb.StreamingCoordBroadcastServiceServer +} + +// NewBroadcastService creates a new broadcast service. +func NewBroadcastService(bc *syncutil.Future[broadcaster.Broadcaster]) BroadcastService { + return &broadcastServceImpl{ + broadcaster: bc, + } +} + +// broadcastServiceeeeImpl is the implementation of the broadcast service. +type broadcastServceImpl struct { + broadcaster *syncutil.Future[broadcaster.Broadcaster] +} + +// Broadcast broadcasts the message to all channels. +func (s *broadcastServceImpl) Broadcast(ctx context.Context, req *streamingpb.BroadcastRequest) (*streamingpb.BroadcastResponse, error) { + broadcaster, err := s.broadcaster.GetWithContext(ctx) + if err != nil { + return nil, err + } + results, err := broadcaster.Broadcast(ctx, message.NewBroadcastMutableMessage(req.Message.Payload, req.Message.Properties)) + if err != nil { + return nil, err + } + protoResult := make(map[string]*streamingpb.ProduceMessageResponseResult, len(results.AppendResults)) + for vchannel, result := range results.AppendResults { + protoResult[vchannel] = result.IntoProto() + } + return &streamingpb.BroadcastResponse{Results: protoResult}, nil +} diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go index c97c9b491b..a5c417b64b 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go @@ -33,8 +33,8 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index 0626d9de28..cb762dccdb 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -8,10 +8,10 @@ import ( "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/streamingnode/server/resource/test_utility.go b/internal/streamingnode/server/resource/test_utility.go index a287d85693..a05ec41c69 100644 --- a/internal/streamingnode/server/resource/test_utility.go +++ b/internal/streamingnode/server/resource/test_utility.go @@ -6,10 +6,10 @@ package resource import ( "testing" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/syncutil" ) diff --git a/internal/streamingnode/server/service/handler/producer/produce_server.go b/internal/streamingnode/server/service/handler/producer/produce_server.go index 0bd9b35721..4075efe808 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server.go @@ -13,7 +13,6 @@ import ( "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -221,20 +220,9 @@ func (p *ProduceServer) sendProduceResult(reqID int64, appendResult *wal.AppendR } if err != nil { p.logger.Warn("append message to wal failed", zap.Int64("requestID", reqID), zap.Error(err)) - resp.Response = &streamingpb.ProduceMessageResponse_Error{ - Error: status.AsStreamingError(err).AsPBError(), - } + resp.Response = &streamingpb.ProduceMessageResponse_Error{Error: status.AsStreamingError(err).AsPBError()} } else { - resp.Response = &streamingpb.ProduceMessageResponse_Result{ - Result: &streamingpb.ProduceMessageResponseResult{ - Id: &messagespb.MessageID{ - Id: appendResult.MessageID.Marshal(), - }, - Timetick: appendResult.TimeTick, - TxnContext: appendResult.TxnCtx.IntoProto(), - Extra: appendResult.Extra, - }, - } + resp.Response = &streamingpb.ProduceMessageResponse_Result{Result: appendResult.IntoProto()} } // If server context is canceled, it means the stream has been closed. diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index b217af0d52..8a222b04b3 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -21,10 +21,10 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" internaltypes "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go index 7093f1139e..4497551c2b 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go @@ -15,12 +15,12 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" internaltypes "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go index 2ed586859f..a705663ad8 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go @@ -21,7 +21,7 @@ func NewTimeTickMsg(ts uint64, lastConfirmedMessageID message.MessageID, sourceI commonpbutil.WithSourceID(sourceID), ), }). - WithBroadcast(). + WithAllVChannel(). BuildMutable() if err != nil { return nil, err diff --git a/internal/streamingnode/server/resource/idalloc/allocator.go b/internal/util/idalloc/allocator.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/allocator.go rename to internal/util/idalloc/allocator.go diff --git a/internal/streamingnode/server/resource/idalloc/allocator_test.go b/internal/util/idalloc/allocator_test.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/allocator_test.go rename to internal/util/idalloc/allocator_test.go diff --git a/internal/streamingnode/server/resource/idalloc/basic_allocator.go b/internal/util/idalloc/basic_allocator.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/basic_allocator.go rename to internal/util/idalloc/basic_allocator.go diff --git a/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go b/internal/util/idalloc/basic_allocator_test.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/basic_allocator_test.go rename to internal/util/idalloc/basic_allocator_test.go diff --git a/internal/streamingnode/server/resource/idalloc/mallocator.go b/internal/util/idalloc/mallocator.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/mallocator.go rename to internal/util/idalloc/mallocator.go diff --git a/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go b/internal/util/idalloc/test_mock_root_coord_client.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go rename to internal/util/idalloc/test_mock_root_coord_client.go diff --git a/pkg/streaming/proto/messages.proto b/pkg/streaming/proto/messages.proto index 091e590427..843556ce45 100644 --- a/pkg/streaming/proto/messages.proto +++ b/pkg/streaming/proto/messages.proto @@ -248,3 +248,8 @@ message RMQMessageLayout { bytes payload = 1; // message body map properties = 2; // message properties } + +// VChannels is a layout to represent the virtual channels for broadcast. +message VChannels { + repeated string vchannels = 1; +} \ No newline at end of file diff --git a/pkg/streaming/proto/streaming.proto b/pkg/streaming/proto/streaming.proto index e4a6943ae2..0a7debc9da 100644 --- a/pkg/streaming/proto/streaming.proto +++ b/pkg/streaming/proto/streaming.proto @@ -60,18 +60,48 @@ message VersionPair { int64 local = 2; } +// BroadcastTaskState is the state of the broadcast task. +enum BroadcastTaskState { + BROADCAST_TASK_STATE_UNKNOWN = 0; // should never used. + BROADCAST_TASK_STATE_PENDING = 1; // task is pending. + BROADCAST_TASK_STATE_DONE = 2; // task is done, the message is broadcasted, and the persisted task can be cleared. +} + +// BroadcastTask is the task to broadcast the message. +message BroadcastTask { + int64 task_id = 1; // task id. + messages.Message message = 2; // message to be broadcast. + BroadcastTaskState state = 3; // state of the task. +} + // // Milvus Service // -service StreamingCoordStateService { +service StreamingNodeStateService { rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {} } -service StreamingNodeStateService { - rpc GetComponentStates(milvus.GetComponentStatesRequest) - returns (milvus.ComponentStates) {} +// +// StreamingCoordBroadcastService +// + +// StreamingCoordBroadcastService is the broadcast service for streaming coord. +service StreamingCoordBroadcastService { + // Broadcast receives broadcast messages from other component and make sure that the message is broadcast to all wal. + // It performs an atomic broadcast to all wal, achieve eventual consistency. + rpc Broadcast(BroadcastRequest) returns (BroadcastResponse) {} +} + +// BroadcastRequest is the request of the Broadcast RPC. +message BroadcastRequest { + messages.Message message = 1; // message to be broadcast. +} + +// BroadcastResponse is the response of the Broadcast RPC. +message BroadcastResponse { + map results = 1; } // diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index 32bdad9db6..0f941c6851 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -7,16 +7,32 @@ import ( "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // NewMutableMessage creates a new mutable message. // !!! Only used at server side for streamingnode internal service, don't use it at client side. func NewMutableMessage(payload []byte, properties map[string]string) MutableMessage { - return &messageImpl{ + m := &messageImpl{ payload: payload, properties: properties, } + // make a assertion by vchannel function. + m.assertNotBroadcast() + return m +} + +// NewBroadcastMutableMessage creates a new broadcast mutable message. +// !!! Only used at server side for streamingcoord internal service, don't use it at client side. +func NewBroadcastMutableMessage(payload []byte, properties map[string]string) BroadcastMutableMessage { + m := &messageImpl{ + payload: payload, + properties: properties, + } + m.assertBroadcast() + return m } // NewImmutableMessage creates a new immutable message. @@ -82,10 +98,10 @@ func newMutableMessageBuilder[H proto.Message, B proto.Message](v Version) *muta // mutableMesasgeBuilder is the builder for message. type mutableMesasgeBuilder[H proto.Message, B proto.Message] struct { - header H - body B - properties propertiesImpl - broadcast bool + header H + body B + properties propertiesImpl + allVChannel bool } // WithMessageHeader creates a new builder with determined message type. @@ -102,16 +118,41 @@ func (b *mutableMesasgeBuilder[H, B]) WithBody(body B) *mutableMesasgeBuilder[H, // WithVChannel creates a new builder with virtual channel. func (b *mutableMesasgeBuilder[H, B]) WithVChannel(vchannel string) *mutableMesasgeBuilder[H, B] { - if b.broadcast { - panic("a broadcast message cannot hold vchannel") + if b.allVChannel { + panic("a all vchannel message cannot set up vchannel property") } b.WithProperty(messageVChannel, vchannel) return b } // WithBroadcast creates a new builder with broadcast property. -func (b *mutableMesasgeBuilder[H, B]) WithBroadcast() *mutableMesasgeBuilder[H, B] { - b.broadcast = true +func (b *mutableMesasgeBuilder[H, B]) WithBroadcast(vchannels []string) *mutableMesasgeBuilder[H, B] { + if len(vchannels) < 1 { + panic("broadcast message must have at least one vchannel") + } + if b.allVChannel { + panic("a all vchannel message cannot set up vchannel property") + } + if b.properties.Exist(messageVChannel) { + panic("a broadcast message cannot set up vchannel property") + } + deduplicated := typeutil.NewSet(vchannels...) + vcs, err := EncodeProto(&messagespb.VChannels{ + Vchannels: deduplicated.Collect(), + }) + if err != nil { + panic("failed to encode vchannels") + } + b.properties.Set(messageVChannels, vcs) + return b +} + +// WithAllVChannel creates a new builder with all vchannel property. +func (b *mutableMesasgeBuilder[H, B]) WithAllVChannel() *mutableMesasgeBuilder[H, B] { + if b.properties.Exist(messageVChannel) || b.properties.Exist(messageVChannels) { + panic("a vchannel or broadcast message cannot set up all vchannel property") + } + b.allVChannel = true return b } @@ -135,6 +176,34 @@ func (b *mutableMesasgeBuilder[H, B]) WithProperties(kvs map[string]string) *mut // Panic if not set payload and message type. // should only used at client side. func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) { + if !b.allVChannel && !b.properties.Exist(messageVChannel) { + panic("a non broadcast message builder not ready for vchannel field") + } + + msg, err := b.build() + if err != nil { + return nil, err + } + return msg, nil +} + +// BuildBroadcast builds a broad mutable message. +// Panic if not set payload and message type. +// should only used at client side. +func (b *mutableMesasgeBuilder[H, B]) BuildBroadcast() (BroadcastMutableMessage, error) { + if !b.properties.Exist(messageVChannels) { + panic("a broadcast message builder not ready for vchannel field") + } + + msg, err := b.build() + if err != nil { + return nil, err + } + return msg, nil +} + +// build builds a message. +func (b *mutableMesasgeBuilder[H, B]) build() (*messageImpl, error) { // payload and header must be a pointer if reflect.ValueOf(b.header).IsNil() { panic("message builder not ready for header field") @@ -142,9 +211,6 @@ func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) { if reflect.ValueOf(b.body).IsNil() { panic("message builder not ready for body field") } - if !b.broadcast && !b.properties.Exist(messageVChannel) { - panic("a non broadcast message builder not ready for vchannel field") - } // setup header. sp, err := EncodeProto(b.header) diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index 733ed568d8..49a7361c82 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -29,11 +29,6 @@ type BasicMessage interface { // Should be used with read-only promise. Properties() RProperties - // VChannel returns the virtual channel of current message. - // Available only when the message's version greater than 0. - // Return "" if message is broadcasted. - VChannel() string - // TimeTick returns the time tick of current message. // Available only when the message's version greater than 0. // Otherwise, it will panic. @@ -52,6 +47,11 @@ type BasicMessage interface { type MutableMessage interface { BasicMessage + // VChannel returns the virtual channel of current message. + // Available only when the message's version greater than 0. + // Return "" if message is can be seen by all vchannels on the pchannel. + VChannel() string + // WithBarrierTimeTick sets the barrier time tick of current message. // these time tick is used to promised the message will be sent after that time tick. // and the message which timetick is less than it will never concurrent append with it. @@ -82,6 +82,19 @@ type MutableMessage interface { IntoImmutableMessage(msgID MessageID) ImmutableMessage } +// BroadcastMutableMessage is the broadcast message interface. +// Indicated the message is broadcasted on various vchannels. +type BroadcastMutableMessage interface { + BasicMessage + + // BroadcastVChannels returns the target vchannels of the message broadcast. + // Those vchannels can be on multi pchannels. + BroadcastVChannels() []string + + // SplitIntoMutableMessage splits the broadcast message into multiple mutable messages. + SplitIntoMutableMessage() []MutableMessage +} + // ImmutableMessage is the read-only message interface. // Once a message is persistent by wal or temporary generated by wal, it will be immutable. type ImmutableMessage interface { @@ -90,6 +103,11 @@ type ImmutableMessage interface { // WALName returns the name of message related wal. WALName() string + // VChannel returns the virtual channel of current message. + // Available only when the message's version greater than 0. + // Return "" if message is can be seen by all vchannels on the pchannel. + VChannel() string + // MessageID returns the message id of current message. MessageID() MessageID diff --git a/pkg/streaming/util/message/message_impl.go b/pkg/streaming/util/message/message_impl.go index 41e9ac0379..7e4a4c0be2 100644 --- a/pkg/streaming/util/message/message_impl.go +++ b/pkg/streaming/util/message/message_impl.go @@ -141,8 +141,11 @@ func (m *messageImpl) BarrierTimeTick() uint64 { } // VChannel returns the vchannel of current message. -// If the message is broadcasted, the vchannel will be empty. +// If the message is a all channel message, it will return "". +// If the message is a broadcast message, it will panic. func (m *messageImpl) VChannel() string { + m.assertNotBroadcast() + value, ok := m.properties.Get(messageVChannel) if !ok { return "" @@ -150,6 +153,60 @@ func (m *messageImpl) VChannel() string { return value } +// BroadcastVChannels returns the vchannels of current message that want to broadcast. +// If the message is not a broadcast message, it will panic. +func (m *messageImpl) BroadcastVChannels() []string { + m.assertBroadcast() + + value, _ := m.properties.Get(messageVChannels) + vcs := &messagespb.VChannels{} + if err := DecodeProto(value, vcs); err != nil { + panic("can not decode vchannels") + } + return vcs.Vchannels +} + +// SplitIntoMutableMessage splits the current broadcast message into multiple messages. +func (m *messageImpl) SplitIntoMutableMessage() []MutableMessage { + vchannels := m.BroadcastVChannels() + + vchannelExist := make(map[string]struct{}, len(vchannels)) + msgs := make([]MutableMessage, 0, len(vchannels)) + for _, vchannel := range vchannels { + newPayload := make([]byte, len(m.payload)) + copy(newPayload, m.payload) + + newProperties := make(propertiesImpl, len(m.properties)) + for key, val := range m.properties { + if key != messageVChannels { + newProperties.Set(key, val) + } + } + newProperties.Set(messageVChannel, vchannel) + if _, ok := vchannelExist[vchannel]; ok { + panic("there's a bug in the message codes, duplicate vchannel in broadcast message") + } + msgs = append(msgs, &messageImpl{ + payload: newPayload, + properties: newProperties, + }) + vchannelExist[vchannel] = struct{}{} + } + return msgs +} + +func (m *messageImpl) assertNotBroadcast() { + if m.properties.Exist(messageVChannels) { + panic("current message is a broadcast message") + } +} + +func (m *messageImpl) assertBroadcast() { + if !m.properties.Exist(messageVChannels) { + panic("current message is not a broadcast message") + } +} + type immutableMessageImpl struct { messageImpl id MessageID diff --git a/pkg/streaming/util/message/properties.go b/pkg/streaming/util/message/properties.go index 575c7d2146..3f0d120e32 100644 --- a/pkg/streaming/util/message/properties.go +++ b/pkg/streaming/util/message/properties.go @@ -10,6 +10,7 @@ const ( messageLastConfirmed = "_lc" // message last confirmed message id. messageLastConfirmedIDSameWithMessageID = "_lcs" // message last confirmed message id is the same with message id. messageVChannel = "_vc" // message virtual channel. + messageVChannels = "_vcs" // message virtual channels for broadcast message. messageHeader = "_h" // specialized message header. messageTxnContext = "_tx" // transaction context. ) diff --git a/pkg/streaming/util/types/streaming_node.go b/pkg/streaming/util/types/streaming_node.go index 4c6a13e699..0cca5798e1 100644 --- a/pkg/streaming/util/types/streaming_node.go +++ b/pkg/streaming/util/types/streaming_node.go @@ -7,6 +7,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -88,6 +89,16 @@ func (n *StreamingNodeStatus) ErrorOfNode() error { return n.Err } +// BroadcastAppendResult is the result of broadcast append operation. +type BroadcastAppendResult struct { + AppendResults map[string]*AppendResult // make the channel name to the append result. +} + +// GetAppendResult returns the append result of the given channel. +func (r *BroadcastAppendResult) GetAppendResult(channelName string) *AppendResult { + return r.AppendResults[channelName] +} + // AppendResult is the result of append operation. type AppendResult struct { // MessageID is generated by underlying walimpls. @@ -112,3 +123,15 @@ func (r *AppendResult) GetExtra(m proto.Message) error { AllowPartial: true, }) } + +// IntoProto converts the append result to proto. +func (r *AppendResult) IntoProto() *streamingpb.ProduceMessageResponseResult { + return &streamingpb.ProduceMessageResponseResult{ + Id: &messagespb.MessageID{ + Id: r.MessageID.Marshal(), + }, + Timetick: r.TimeTick, + TxnContext: r.TxnCtx.IntoProto(), + Extra: r.Extra, + } +} diff --git a/pkg/util/contextutil/context_util.go b/pkg/util/contextutil/context_util.go index 8cf699b430..2bded437d1 100644 --- a/pkg/util/contextutil/context_util.go +++ b/pkg/util/contextutil/context_util.go @@ -121,3 +121,15 @@ func WithDeadlineCause(parent context.Context, deadline time.Time, err error) (c cancel(context.Canceled) } } + +// MergeContext create a cancellation context that cancels when any of the given contexts are canceled. +func MergeContext(ctx1 context.Context, ctx2 context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancelCause(ctx1) + stop := context.AfterFunc(ctx2, func() { + cancel(context.Cause(ctx2)) + }) + return ctx, func() { + stop() + cancel(context.Canceled) + } +} diff --git a/pkg/util/typeutil/backoff_timer.go b/pkg/util/typeutil/backoff_timer.go index dd26b136fe..997ccb2839 100644 --- a/pkg/util/typeutil/backoff_timer.go +++ b/pkg/util/typeutil/backoff_timer.go @@ -94,3 +94,49 @@ func (t *BackoffTimer) NextInterval() time.Duration { } return t.configFetcher.DefaultInterval() } + +// NewBackoffWithInstant creates a new backoff with instant +func NewBackoffWithInstant(fetcher BackoffTimerConfigFetcher) *BackoffWithInstant { + cfg := fetcher.BackoffConfig() + defaultInterval := fetcher.DefaultInterval() + backoff := backoff.NewExponentialBackOff() + backoff.InitialInterval = cfg.InitialInterval + backoff.Multiplier = cfg.Multiplier + backoff.MaxInterval = cfg.MaxInterval + backoff.MaxElapsedTime = defaultInterval + backoff.Stop = defaultInterval + backoff.Reset() + return &BackoffWithInstant{ + backoff: backoff, + nextInstant: time.Now(), + } +} + +// BackoffWithInstant is a backoff with instant. +// A instant can be recorded with `UpdateInstantWithNextBackOff` +// NextInstant can be used to make priority decision. +type BackoffWithInstant struct { + backoff *backoff.ExponentialBackOff + nextInstant time.Time +} + +// NextInstant returns the next instant +func (t *BackoffWithInstant) NextInstant() time.Time { + return t.nextInstant +} + +// NextInterval returns the next interval +func (t *BackoffWithInstant) NextInterval() time.Duration { + return time.Until(t.nextInstant) +} + +// NextTimer returns the next timer and the duration of the timer +func (t *BackoffWithInstant) NextTimer() (<-chan time.Time, time.Duration) { + next := time.Until(t.nextInstant) + return time.After(next), next +} + +// UpdateInstantWithNextBackOff updates the next instant with next backoff +func (t *BackoffWithInstant) UpdateInstantWithNextBackOff() { + t.nextInstant = time.Now().Add(t.backoff.NextBackOff()) +}