From 575345ae7b02546dbda01d23f7846986aa89ad29 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Tue, 26 Aug 2025 14:29:51 +0800 Subject: [PATCH] fix: get streamingnodes from service discovery without channel assign (#44033) issue: #43767 Signed-off-by: chyezh --- .../snmanager/streaming_node_manager.go | 24 ++++--- .../snmanager/streaming_node_manager_test.go | 6 ++ .../coordinator/snmanager/test_utility.go | 22 ++++++ .../server/mock_balancer/mock_Balancer.go | 58 +++++++++++++++ .../client/mock_manager/mock_ManagerClient.go | 58 +++++++++++++++ .../checkers/channel_checker_test.go | 3 + .../meta/channel_dist_manager_test.go | 26 +++---- .../observers/replica_observer_test.go | 70 ++++++++----------- internal/querycoordv2/task/task_test.go | 3 + .../server/balancer/balancer.go | 3 + .../server/balancer/balancer_impl.go | 5 ++ .../client/manager/manager_client.go | 5 ++ .../client/manager/manager_client_impl.go | 23 ++++++ .../client/manager/manager_test.go | 9 +++ 14 files changed, 250 insertions(+), 65 deletions(-) diff --git a/internal/coordinator/snmanager/streaming_node_manager.go b/internal/coordinator/snmanager/streaming_node_manager.go index 1b752a7a4b..8882c79bc4 100644 --- a/internal/coordinator/snmanager/streaming_node_manager.go +++ b/internal/coordinator/snmanager/streaming_node_manager.go @@ -26,7 +26,6 @@ func newStreamingNodeManager() *StreamingNodeManager { balancer: syncutil.NewFuture[balancer.Balancer](), cond: syncutil.NewContextCond(&sync.Mutex{}), latestAssignments: make(map[string]types.PChannelInfoAssigned), - streamingNodes: typeutil.NewUniqueSet(), nodeChangedNotifier: syncutil.NewVersionedNotifier(), } go snm.execute() @@ -69,8 +68,7 @@ type StreamingNodeManager struct { // The coord is merged after 2.6, so we don't need to make distribution safe. cond *syncutil.ContextCond latestAssignments map[string]types.PChannelInfoAssigned // The latest assignments info got from streaming coord balance module. - streamingNodes typeutil.UniqueSet - nodeChangedNotifier *syncutil.VersionedNotifier // used to notify that node in streaming node manager has been changed. + nodeChangedNotifier *syncutil.VersionedNotifier // used to notify that node in streaming node manager has been changed. } // GetLatestWALLocated returns the server id of the node that the wal of the vChannel is located. @@ -131,9 +129,19 @@ func (s *StreamingNodeManager) GetWALLocated(vChannel string) int64 { // GetStreamingQueryNodeIDs returns the server ids of the streaming query nodes. func (s *StreamingNodeManager) GetStreamingQueryNodeIDs() typeutil.UniqueSet { - s.cond.L.Lock() - defer s.cond.L.Unlock() - return s.streamingNodes.Clone() + balancer, err := s.balancer.GetWithContext(context.Background()) + if err != nil { + panic(err) + } + streamingNodes, err := balancer.GetAllStreamingNodes(context.Background()) + if err != nil { + panic(err) + } + streamingNodeIDs := typeutil.NewUniqueSet() + for _, streamingNode := range streamingNodes { + streamingNodeIDs.Insert(streamingNode.ServerID) + } + return streamingNodeIDs } // ListenNodeChanged returns a listener for node changed event. @@ -160,13 +168,11 @@ func (s *StreamingNodeManager) execute() (err error) { ) error { s.cond.LockAndBroadcast() s.latestAssignments = make(map[string]types.PChannelInfoAssigned) - s.streamingNodes = typeutil.NewUniqueSet() for _, relation := range relations { s.latestAssignments[relation.Channel.Name] = relation - s.streamingNodes.Insert(relation.Node.ServerID) } s.nodeChangedNotifier.NotifyAll() - log.Info("streaming node manager updated", zap.Any("assignments", s.latestAssignments), zap.Any("streamingNodes", s.streamingNodes)) + log.Info("streaming node manager updated", zap.Any("assignments", s.latestAssignments)) s.cond.L.Unlock() return nil }); err != nil { diff --git a/internal/coordinator/snmanager/streaming_node_manager_test.go b/internal/coordinator/snmanager/streaming_node_manager_test.go index 84936a586e..24a80755e9 100644 --- a/internal/coordinator/snmanager/streaming_node_manager_test.go +++ b/internal/coordinator/snmanager/streaming_node_manager_test.go @@ -22,6 +22,7 @@ func TestStreamingNodeManager(t *testing.T) { b := mock_balancer.NewMockBalancer(t) ch := make(chan pChannelInfoAssigned, 1) + b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil) b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).Run( func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) { for { @@ -58,6 +59,11 @@ func TestStreamingNodeManager(t *testing.T) { node := m.GetWALLocated("a_test") assert.Equal(t, node, int64(1)) + + b.EXPECT().GetAllStreamingNodes(mock.Anything).Unset() + b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{ + 1: {ServerID: 1, Address: "localhost:1"}, + }, nil) streamingNodes = m.GetStreamingQueryNodeIDs() assert.Equal(t, len(streamingNodes), 1) diff --git a/internal/coordinator/snmanager/test_utility.go b/internal/coordinator/snmanager/test_utility.go index f60e9e524b..ee409d518b 100644 --- a/internal/coordinator/snmanager/test_utility.go +++ b/internal/coordinator/snmanager/test_utility.go @@ -3,6 +3,28 @@ package snmanager +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + func ResetStreamingNodeManager() { StaticStreamingNodeManager = newStreamingNodeManager() } + +func ResetDoNothingStreamingNodeManager(t *testing.T) { + ResetStreamingNodeManager() + b := mock_balancer.NewMockBalancer(t) + b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error { + <-ctx.Done() + return ctx.Err() + }).Maybe() + b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil).Maybe() + StaticStreamingNodeManager.SetBalancerReady(b) +} diff --git a/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go index aa110a809c..c042f9d9a1 100644 --- a/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go +++ b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go @@ -60,6 +60,64 @@ func (_c *MockBalancer_Close_Call) RunAndReturn(run func()) *MockBalancer_Close_ return _c } +// GetAllStreamingNodes provides a mock function with given fields: ctx +func (_m *MockBalancer) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetAllStreamingNodes") + } + + var r0 map[int64]*types.StreamingNodeInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*types.StreamingNodeInfo, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) map[int64]*types.StreamingNodeInfo); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*types.StreamingNodeInfo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBalancer_GetAllStreamingNodes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAllStreamingNodes' +type MockBalancer_GetAllStreamingNodes_Call struct { + *mock.Call +} + +// GetAllStreamingNodes is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockBalancer_Expecter) GetAllStreamingNodes(ctx interface{}) *MockBalancer_GetAllStreamingNodes_Call { + return &MockBalancer_GetAllStreamingNodes_Call{Call: _e.mock.On("GetAllStreamingNodes", ctx)} +} + +func (_c *MockBalancer_GetAllStreamingNodes_Call) Run(run func(ctx context.Context)) *MockBalancer_GetAllStreamingNodes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockBalancer_GetAllStreamingNodes_Call) Return(_a0 map[int64]*types.StreamingNodeInfo, _a1 error) *MockBalancer_GetAllStreamingNodes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBalancer_GetAllStreamingNodes_Call) RunAndReturn(run func(context.Context) (map[int64]*types.StreamingNodeInfo, error)) *MockBalancer_GetAllStreamingNodes_Call { + _c.Call.Return(run) + return _c +} + // GetLatestWALLocated provides a mock function with given fields: ctx, pchannel func (_m *MockBalancer) GetLatestWALLocated(ctx context.Context, pchannel string) (int64, bool) { ret := _m.Called(ctx, pchannel) diff --git a/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go b/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go index 4747109d9b..b7f409f35d 100644 --- a/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go +++ b/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go @@ -160,6 +160,64 @@ func (_c *MockManagerClient_CollectAllStatus_Call) RunAndReturn(run func(context return _c } +// GetAllStreamingNodes provides a mock function with given fields: ctx +func (_m *MockManagerClient) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetAllStreamingNodes") + } + + var r0 map[int64]*types.StreamingNodeInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*types.StreamingNodeInfo, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) map[int64]*types.StreamingNodeInfo); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*types.StreamingNodeInfo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockManagerClient_GetAllStreamingNodes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAllStreamingNodes' +type MockManagerClient_GetAllStreamingNodes_Call struct { + *mock.Call +} + +// GetAllStreamingNodes is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockManagerClient_Expecter) GetAllStreamingNodes(ctx interface{}) *MockManagerClient_GetAllStreamingNodes_Call { + return &MockManagerClient_GetAllStreamingNodes_Call{Call: _e.mock.On("GetAllStreamingNodes", ctx)} +} + +func (_c *MockManagerClient_GetAllStreamingNodes_Call) Run(run func(ctx context.Context)) *MockManagerClient_GetAllStreamingNodes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockManagerClient_GetAllStreamingNodes_Call) Return(_a0 map[int64]*types.StreamingNodeInfo, _a1 error) *MockManagerClient_GetAllStreamingNodes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockManagerClient_GetAllStreamingNodes_Call) RunAndReturn(run func(context.Context) (map[int64]*types.StreamingNodeInfo, error)) *MockManagerClient_GetAllStreamingNodes_Call { + _c.Call.Return(run) + return _c +} + // Remove provides a mock function with given fields: ctx, pchannel func (_m *MockManagerClient) Remove(ctx context.Context, pchannel types.PChannelInfoAssigned) error { ret := _m.Called(ctx, pchannel) diff --git a/internal/querycoordv2/checkers/channel_checker_test.go b/internal/querycoordv2/checkers/channel_checker_test.go index f2e03f2f41..a16750aa0c 100644 --- a/internal/querycoordv2/checkers/channel_checker_test.go +++ b/internal/querycoordv2/checkers/channel_checker_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus/internal/coordinator/snmanager" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/querycoordv2/balance" @@ -54,6 +55,8 @@ func (suite *ChannelCheckerTestSuite) SetupSuite() { } func (suite *ChannelCheckerTestSuite) SetupTest() { + snmanager.ResetDoNothingStreamingNodeManager(suite.T()) + var err error config := GenerateEtcdConfig() cli, err := etcd.GetEtcdClient( diff --git a/internal/querycoordv2/meta/channel_dist_manager_test.go b/internal/querycoordv2/meta/channel_dist_manager_test.go index a62aa4221d..702da97029 100644 --- a/internal/querycoordv2/meta/channel_dist_manager_test.go +++ b/internal/querycoordv2/meta/channel_dist_manager_test.go @@ -86,6 +86,7 @@ func (suite *ChannelDistManagerSuite) SetupSuite() { } func (suite *ChannelDistManagerSuite) SetupTest() { + snmanager.ResetDoNothingStreamingNodeManager(suite.T()) suite.dist = NewChannelDistManager() // Distribution: // node 0 contains channel dmc0 @@ -347,27 +348,20 @@ func (suite *ChannelDistManagerSuite) TestGetShardLeader() { suite.Nil(leader) // Test streaming node + snmanager.ResetStreamingNodeManager() balancer := mock_balancer.NewMockBalancer(suite.T()) balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { - versions := []typeutil.VersionInt64Pair{ - {Global: 1, Local: 2}, - } - pchans := [][]types.PChannelInfoAssigned{ - { - types.PChannelInfoAssigned{ - Channel: types.PChannelInfo{Name: "pchannel3", Term: 1}, - Node: types.StreamingNodeInfo{ServerID: 4, Address: "localhost:1"}, - }, - }, - } - for i := 0; i < len(versions); i++ { - cb(versions[i], pchans[i]) - } <-ctx.Done() - return context.Cause(ctx) + return ctx.Err() }) - defer snmanager.ResetStreamingNodeManager() + balancer.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{ + 4: { + ServerID: 4, + Address: "localhost:1", + }, + }, nil) snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer) + defer snmanager.ResetStreamingNodeManager() suite.Eventually(func() bool { nodeIDs := snmanager.StaticStreamingNodeManager.GetStreamingQueryNodeIDs() return nodeIDs.Contain(4) diff --git a/internal/querycoordv2/observers/replica_observer_test.go b/internal/querycoordv2/observers/replica_observer_test.go index e8132ca7cf..8e6f6d9a32 100644 --- a/internal/querycoordv2/observers/replica_observer_test.go +++ b/internal/querycoordv2/observers/replica_observer_test.go @@ -66,6 +66,8 @@ func (suite *ReplicaObserverSuite) SetupSuite() { } func (suite *ReplicaObserverSuite) SetupTest() { + snmanager.ResetDoNothingStreamingNodeManager(suite.T()) + var err error config := GenerateEtcdConfig() cli, err := etcd.GetEtcdClient( @@ -212,51 +214,36 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { } func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() { + suite.observer.Stop() + snmanager.ResetStreamingNodeManager() balancer := mock_balancer.NewMockBalancer(suite.T()) change := make(chan struct{}) balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { - versions := []typeutil.VersionInt64Pair{ - {Global: 1, Local: 2}, - {Global: 1, Local: 3}, - } - pchans := [][]types.PChannelInfoAssigned{ - { - types.PChannelInfoAssigned{ - Channel: types.PChannelInfo{Name: "pchannel", Term: 1}, - Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, - }, - types.PChannelInfoAssigned{ - Channel: types.PChannelInfo{Name: "pchannel2", Term: 1}, - Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"}, - }, - types.PChannelInfoAssigned{ - Channel: types.PChannelInfo{Name: "pchannel3", Term: 1}, - Node: types.StreamingNodeInfo{ServerID: 3, Address: "localhost:1"}, - }, - }, - { - types.PChannelInfoAssigned{ - Channel: types.PChannelInfo{Name: "pchannel", Term: 1}, - Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, - }, - types.PChannelInfoAssigned{ - Channel: types.PChannelInfo{Name: "pchannel2", Term: 1}, - Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"}, - }, - types.PChannelInfoAssigned{ - Channel: types.PChannelInfo{Name: "pchannel3", Term: 2}, - Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"}, - }, - }, - } - for i := 0; i < len(versions); i++ { - cb(versions[i], pchans[i]) - <-change - } <-ctx.Done() - return context.Cause(ctx) + return ctx.Err() + }) + balancer.EXPECT().GetAllStreamingNodes(mock.Anything).RunAndReturn(func(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) { + pchans := []map[int64]*types.StreamingNodeInfo{ + { + 1: {ServerID: 1, Address: "localhost:1"}, + 2: {ServerID: 2, Address: "localhost:2"}, + 3: {ServerID: 3, Address: "localhost:3"}, + }, + { + 1: {ServerID: 1, Address: "localhost:1"}, + 2: {ServerID: 2, Address: "localhost:2"}, + }, + } + select { + case <-change: + return pchans[1], nil + default: + return pchans[0], nil + } }) snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer) + suite.observer = NewReplicaObserver(suite.meta, suite.distMgr) + suite.observer.Start() ctx := context.Background() err := suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collectionID, 2)) @@ -305,9 +292,12 @@ func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() { suite.Equal(nodes.Len(), 2) } +func (suite *ReplicaObserverSuite) TearDownTest() { + suite.observer.Stop() +} + func (suite *ReplicaObserverSuite) TearDownSuite() { suite.kv.Close() - suite.observer.Stop() streamingutil.UnsetStreamingServiceEnabled() } diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index b3d83068cd..3467bf4257 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/json" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore" @@ -142,6 +143,8 @@ func (suite *TaskSuite) TearDownSuite() { } func (suite *TaskSuite) SetupTest() { + snmanager.ResetDoNothingStreamingNodeManager(suite.T()) + config := GenerateEtcdConfig() suite.ctx = context.Background() cli, err := etcd.GetEtcdClient( diff --git a/internal/streamingcoord/server/balancer/balancer.go b/internal/streamingcoord/server/balancer/balancer.go index b2c41e1bd1..a50eacfedd 100644 --- a/internal/streamingcoord/server/balancer/balancer.go +++ b/internal/streamingcoord/server/balancer/balancer.go @@ -21,6 +21,9 @@ var ( // Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency. // Balancer should be thread safe. type Balancer interface { + // GetAllStreamingNodes fetches all streaming node info. + GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) + // UpdateBalancePolicy update the balance policy. UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error) diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index 367eb971de..9f07043950 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -80,6 +80,11 @@ func (b *balancerImpl) RegisterStreamingEnabledNotifier(notifier *syncutil.Async b.channelMetaManager.RegisterStreamingEnabledNotifier(notifier) } +// GetAllStreamingNodes fetches all streaming node info. +func (b *balancerImpl) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) { + return resource.Resource().StreamingNodeManagerClient().GetAllStreamingNodes(ctx) +} + // GetLatestWALLocated returns the server id of the node that the wal of the vChannel is located. func (b *balancerImpl) GetLatestWALLocated(ctx context.Context, pchannel string) (int64, bool) { return b.channelMetaManager.GetLatestWALLocated(ctx, pchannel) diff --git a/internal/streamingnode/client/manager/manager_client.go b/internal/streamingnode/client/manager/manager_client.go index bdea8f20e8..43a82a34d8 100644 --- a/internal/streamingnode/client/manager/manager_client.go +++ b/internal/streamingnode/client/manager/manager_client.go @@ -29,7 +29,12 @@ type ManagerClient interface { // WatchNodeChanged returns a channel that receive the signal that a streaming node change. WatchNodeChanged(ctx context.Context) (<-chan struct{}, error) + // GetAllStreamingNodes fetches all streaming node info. + // The result is fetch from service discovery, so there's no rpc call. + GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) + // CollectAllStatus collects status of all streamingnode, such as load balance attributes. + // The result is fetch from service discovery and make a broadcast rpc call to all streamingnode. CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) // Assign a wal instance for the channel on streaming node of given server id. diff --git a/internal/streamingnode/client/manager/manager_client_impl.go b/internal/streamingnode/client/manager/manager_client_impl.go index 6393d2bf06..9b2ea5852e 100644 --- a/internal/streamingnode/client/manager/manager_client_impl.go +++ b/internal/streamingnode/client/manager/manager_client_impl.go @@ -53,6 +53,29 @@ func (c *managerClientImpl) WatchNodeChanged(ctx context.Context) (<-chan struct return resultCh, nil } +// GetAllStreamingNodes fetches all streaming node info. +func (c *managerClientImpl) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) { + if !c.lifetime.Add(typeutil.LifetimeStateWorking) { + return nil, status.NewOnShutdownError("manager client is closing") + } + defer c.lifetime.Done() + + // Get all discovered streamingnode. + state, err := c.rb.Resolver().GetLatestState(ctx) + if err != nil { + return nil, err + } + + result := make(map[int64]*types.StreamingNodeInfo, len(state.State.Addresses)) + for serverID, session := range state.Sessions() { + result[serverID] = &types.StreamingNodeInfo{ + ServerID: serverID, + Address: session.Address, + } + } + return result, nil +} + // CollectAllStatus collects status in all underlying streamingnode. func (c *managerClientImpl) CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) { if !c.lifetime.Add(typeutil.LifetimeStateWorking) { diff --git a/internal/streamingnode/client/manager/manager_test.go b/internal/streamingnode/client/manager/manager_test.go index 6613ac7e0b..2bb8a7e5e0 100644 --- a/internal/streamingnode/client/manager/manager_test.go +++ b/internal/streamingnode/client/manager/manager_test.go @@ -74,6 +74,8 @@ func TestManager(t *testing.T) { states := []map[uint64]bool{ {1: false, 2: false, 3: true}, {1: true, 2: false}, + {1: true, 2: false}, + {1: true, 2: false}, } r.EXPECT().GetLatestState(mock.Anything).Unset() r.EXPECT().GetLatestState(mock.Anything).RunAndReturn(func(ctx context.Context) (discoverer.VersionedState, error) { @@ -88,6 +90,10 @@ func TestManager(t *testing.T) { assert.ErrorIs(t, nodes[3].Err, types.ErrNotAlive) assert.ErrorIs(t, nodes[1].Err, types.ErrStopping) + nodeInfos, err := m.GetAllStreamingNodes(context.Background()) + assert.NoError(t, err) + assert.Len(t, nodeInfos, 2) + // Test Assign serverID := int64(2) managerServiceClient.EXPECT().Assign(mock.Anything, mock.Anything).RunAndReturn( @@ -123,6 +129,9 @@ func TestManager(t *testing.T) { rb.EXPECT().Close().Return() m.Close() + nodeInfos, err = m.GetAllStreamingNodes(context.Background()) + assert.Nil(t, nodeInfos) + assert.Error(t, err) nodes, err = m.CollectAllStatus(context.Background()) assert.Nil(t, nodes) assert.Error(t, err)