fix: get streamingnodes from service discovery without channel assign (#44033)

issue: #43767

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-08-26 14:29:51 +08:00 committed by GitHub
parent 2ad41872da
commit 575345ae7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 250 additions and 65 deletions

View File

@ -26,7 +26,6 @@ func newStreamingNodeManager() *StreamingNodeManager {
balancer: syncutil.NewFuture[balancer.Balancer](),
cond: syncutil.NewContextCond(&sync.Mutex{}),
latestAssignments: make(map[string]types.PChannelInfoAssigned),
streamingNodes: typeutil.NewUniqueSet(),
nodeChangedNotifier: syncutil.NewVersionedNotifier(),
}
go snm.execute()
@ -69,8 +68,7 @@ type StreamingNodeManager struct {
// The coord is merged after 2.6, so we don't need to make distribution safe.
cond *syncutil.ContextCond
latestAssignments map[string]types.PChannelInfoAssigned // The latest assignments info got from streaming coord balance module.
streamingNodes typeutil.UniqueSet
nodeChangedNotifier *syncutil.VersionedNotifier // used to notify that node in streaming node manager has been changed.
nodeChangedNotifier *syncutil.VersionedNotifier // used to notify that node in streaming node manager has been changed.
}
// GetLatestWALLocated returns the server id of the node that the wal of the vChannel is located.
@ -131,9 +129,19 @@ func (s *StreamingNodeManager) GetWALLocated(vChannel string) int64 {
// GetStreamingQueryNodeIDs returns the server ids of the streaming query nodes.
func (s *StreamingNodeManager) GetStreamingQueryNodeIDs() typeutil.UniqueSet {
s.cond.L.Lock()
defer s.cond.L.Unlock()
return s.streamingNodes.Clone()
balancer, err := s.balancer.GetWithContext(context.Background())
if err != nil {
panic(err)
}
streamingNodes, err := balancer.GetAllStreamingNodes(context.Background())
if err != nil {
panic(err)
}
streamingNodeIDs := typeutil.NewUniqueSet()
for _, streamingNode := range streamingNodes {
streamingNodeIDs.Insert(streamingNode.ServerID)
}
return streamingNodeIDs
}
// ListenNodeChanged returns a listener for node changed event.
@ -160,13 +168,11 @@ func (s *StreamingNodeManager) execute() (err error) {
) error {
s.cond.LockAndBroadcast()
s.latestAssignments = make(map[string]types.PChannelInfoAssigned)
s.streamingNodes = typeutil.NewUniqueSet()
for _, relation := range relations {
s.latestAssignments[relation.Channel.Name] = relation
s.streamingNodes.Insert(relation.Node.ServerID)
}
s.nodeChangedNotifier.NotifyAll()
log.Info("streaming node manager updated", zap.Any("assignments", s.latestAssignments), zap.Any("streamingNodes", s.streamingNodes))
log.Info("streaming node manager updated", zap.Any("assignments", s.latestAssignments))
s.cond.L.Unlock()
return nil
}); err != nil {

View File

@ -22,6 +22,7 @@ func TestStreamingNodeManager(t *testing.T) {
b := mock_balancer.NewMockBalancer(t)
ch := make(chan pChannelInfoAssigned, 1)
b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil)
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).Run(
func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) {
for {
@ -58,6 +59,11 @@ func TestStreamingNodeManager(t *testing.T) {
node := m.GetWALLocated("a_test")
assert.Equal(t, node, int64(1))
b.EXPECT().GetAllStreamingNodes(mock.Anything).Unset()
b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{
1: {ServerID: 1, Address: "localhost:1"},
}, nil)
streamingNodes = m.GetStreamingQueryNodeIDs()
assert.Equal(t, len(streamingNodes), 1)

View File

@ -3,6 +3,28 @@
package snmanager
import (
"context"
"testing"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func ResetStreamingNodeManager() {
StaticStreamingNodeManager = newStreamingNodeManager()
}
func ResetDoNothingStreamingNodeManager(t *testing.T) {
ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error {
<-ctx.Done()
return ctx.Err()
}).Maybe()
b.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{}, nil).Maybe()
StaticStreamingNodeManager.SetBalancerReady(b)
}

View File

@ -60,6 +60,64 @@ func (_c *MockBalancer_Close_Call) RunAndReturn(run func()) *MockBalancer_Close_
return _c
}
// GetAllStreamingNodes provides a mock function with given fields: ctx
func (_m *MockBalancer) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetAllStreamingNodes")
}
var r0 map[int64]*types.StreamingNodeInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*types.StreamingNodeInfo, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) map[int64]*types.StreamingNodeInfo); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]*types.StreamingNodeInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBalancer_GetAllStreamingNodes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAllStreamingNodes'
type MockBalancer_GetAllStreamingNodes_Call struct {
*mock.Call
}
// GetAllStreamingNodes is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockBalancer_Expecter) GetAllStreamingNodes(ctx interface{}) *MockBalancer_GetAllStreamingNodes_Call {
return &MockBalancer_GetAllStreamingNodes_Call{Call: _e.mock.On("GetAllStreamingNodes", ctx)}
}
func (_c *MockBalancer_GetAllStreamingNodes_Call) Run(run func(ctx context.Context)) *MockBalancer_GetAllStreamingNodes_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockBalancer_GetAllStreamingNodes_Call) Return(_a0 map[int64]*types.StreamingNodeInfo, _a1 error) *MockBalancer_GetAllStreamingNodes_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBalancer_GetAllStreamingNodes_Call) RunAndReturn(run func(context.Context) (map[int64]*types.StreamingNodeInfo, error)) *MockBalancer_GetAllStreamingNodes_Call {
_c.Call.Return(run)
return _c
}
// GetLatestWALLocated provides a mock function with given fields: ctx, pchannel
func (_m *MockBalancer) GetLatestWALLocated(ctx context.Context, pchannel string) (int64, bool) {
ret := _m.Called(ctx, pchannel)

View File

@ -160,6 +160,64 @@ func (_c *MockManagerClient_CollectAllStatus_Call) RunAndReturn(run func(context
return _c
}
// GetAllStreamingNodes provides a mock function with given fields: ctx
func (_m *MockManagerClient) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetAllStreamingNodes")
}
var r0 map[int64]*types.StreamingNodeInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*types.StreamingNodeInfo, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) map[int64]*types.StreamingNodeInfo); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]*types.StreamingNodeInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockManagerClient_GetAllStreamingNodes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAllStreamingNodes'
type MockManagerClient_GetAllStreamingNodes_Call struct {
*mock.Call
}
// GetAllStreamingNodes is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockManagerClient_Expecter) GetAllStreamingNodes(ctx interface{}) *MockManagerClient_GetAllStreamingNodes_Call {
return &MockManagerClient_GetAllStreamingNodes_Call{Call: _e.mock.On("GetAllStreamingNodes", ctx)}
}
func (_c *MockManagerClient_GetAllStreamingNodes_Call) Run(run func(ctx context.Context)) *MockManagerClient_GetAllStreamingNodes_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockManagerClient_GetAllStreamingNodes_Call) Return(_a0 map[int64]*types.StreamingNodeInfo, _a1 error) *MockManagerClient_GetAllStreamingNodes_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockManagerClient_GetAllStreamingNodes_Call) RunAndReturn(run func(context.Context) (map[int64]*types.StreamingNodeInfo, error)) *MockManagerClient_GetAllStreamingNodes_Call {
_c.Call.Return(run)
return _c
}
// Remove provides a mock function with given fields: ctx, pchannel
func (_m *MockManagerClient) Remove(ctx context.Context, pchannel types.PChannelInfoAssigned) error {
ret := _m.Called(ctx, pchannel)

View File

@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
"github.com/milvus-io/milvus/internal/querycoordv2/balance"
@ -54,6 +55,8 @@ func (suite *ChannelCheckerTestSuite) SetupSuite() {
}
func (suite *ChannelCheckerTestSuite) SetupTest() {
snmanager.ResetDoNothingStreamingNodeManager(suite.T())
var err error
config := GenerateEtcdConfig()
cli, err := etcd.GetEtcdClient(

View File

@ -86,6 +86,7 @@ func (suite *ChannelDistManagerSuite) SetupSuite() {
}
func (suite *ChannelDistManagerSuite) SetupTest() {
snmanager.ResetDoNothingStreamingNodeManager(suite.T())
suite.dist = NewChannelDistManager()
// Distribution:
// node 0 contains channel dmc0
@ -347,27 +348,20 @@ func (suite *ChannelDistManagerSuite) TestGetShardLeader() {
suite.Nil(leader)
// Test streaming node
snmanager.ResetStreamingNodeManager()
balancer := mock_balancer.NewMockBalancer(suite.T())
balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
versions := []typeutil.VersionInt64Pair{
{Global: 1, Local: 2},
}
pchans := [][]types.PChannelInfoAssigned{
{
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel3", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 4, Address: "localhost:1"},
},
},
}
for i := 0; i < len(versions); i++ {
cb(versions[i], pchans[i])
}
<-ctx.Done()
return context.Cause(ctx)
return ctx.Err()
})
defer snmanager.ResetStreamingNodeManager()
balancer.EXPECT().GetAllStreamingNodes(mock.Anything).Return(map[int64]*types.StreamingNodeInfo{
4: {
ServerID: 4,
Address: "localhost:1",
},
}, nil)
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
defer snmanager.ResetStreamingNodeManager()
suite.Eventually(func() bool {
nodeIDs := snmanager.StaticStreamingNodeManager.GetStreamingQueryNodeIDs()
return nodeIDs.Contain(4)

View File

@ -66,6 +66,8 @@ func (suite *ReplicaObserverSuite) SetupSuite() {
}
func (suite *ReplicaObserverSuite) SetupTest() {
snmanager.ResetDoNothingStreamingNodeManager(suite.T())
var err error
config := GenerateEtcdConfig()
cli, err := etcd.GetEtcdClient(
@ -212,51 +214,36 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() {
}
func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() {
suite.observer.Stop()
snmanager.ResetStreamingNodeManager()
balancer := mock_balancer.NewMockBalancer(suite.T())
change := make(chan struct{})
balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
versions := []typeutil.VersionInt64Pair{
{Global: 1, Local: 2},
{Global: 1, Local: 3},
}
pchans := [][]types.PChannelInfoAssigned{
{
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel2", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel3", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 3, Address: "localhost:1"},
},
},
{
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel2", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel3", Term: 2},
Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"},
},
},
}
for i := 0; i < len(versions); i++ {
cb(versions[i], pchans[i])
<-change
}
<-ctx.Done()
return context.Cause(ctx)
return ctx.Err()
})
balancer.EXPECT().GetAllStreamingNodes(mock.Anything).RunAndReturn(func(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) {
pchans := []map[int64]*types.StreamingNodeInfo{
{
1: {ServerID: 1, Address: "localhost:1"},
2: {ServerID: 2, Address: "localhost:2"},
3: {ServerID: 3, Address: "localhost:3"},
},
{
1: {ServerID: 1, Address: "localhost:1"},
2: {ServerID: 2, Address: "localhost:2"},
},
}
select {
case <-change:
return pchans[1], nil
default:
return pchans[0], nil
}
})
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
suite.observer = NewReplicaObserver(suite.meta, suite.distMgr)
suite.observer.Start()
ctx := context.Background()
err := suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collectionID, 2))
@ -305,9 +292,12 @@ func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() {
suite.Equal(nodes.Len(), 2)
}
func (suite *ReplicaObserverSuite) TearDownTest() {
suite.observer.Stop()
}
func (suite *ReplicaObserverSuite) TearDownSuite() {
suite.kv.Close()
suite.observer.Stop()
streamingutil.UnsetStreamingServiceEnabled()
}

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/json"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore"
@ -142,6 +143,8 @@ func (suite *TaskSuite) TearDownSuite() {
}
func (suite *TaskSuite) SetupTest() {
snmanager.ResetDoNothingStreamingNodeManager(suite.T())
config := GenerateEtcdConfig()
suite.ctx = context.Background()
cli, err := etcd.GetEtcdClient(

View File

@ -21,6 +21,9 @@ var (
// Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency.
// Balancer should be thread safe.
type Balancer interface {
// GetAllStreamingNodes fetches all streaming node info.
GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error)
// UpdateBalancePolicy update the balance policy.
UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error)

View File

@ -80,6 +80,11 @@ func (b *balancerImpl) RegisterStreamingEnabledNotifier(notifier *syncutil.Async
b.channelMetaManager.RegisterStreamingEnabledNotifier(notifier)
}
// GetAllStreamingNodes fetches all streaming node info.
func (b *balancerImpl) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) {
return resource.Resource().StreamingNodeManagerClient().GetAllStreamingNodes(ctx)
}
// GetLatestWALLocated returns the server id of the node that the wal of the vChannel is located.
func (b *balancerImpl) GetLatestWALLocated(ctx context.Context, pchannel string) (int64, bool) {
return b.channelMetaManager.GetLatestWALLocated(ctx, pchannel)

View File

@ -29,7 +29,12 @@ type ManagerClient interface {
// WatchNodeChanged returns a channel that receive the signal that a streaming node change.
WatchNodeChanged(ctx context.Context) (<-chan struct{}, error)
// GetAllStreamingNodes fetches all streaming node info.
// The result is fetch from service discovery, so there's no rpc call.
GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error)
// CollectAllStatus collects status of all streamingnode, such as load balance attributes.
// The result is fetch from service discovery and make a broadcast rpc call to all streamingnode.
CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error)
// Assign a wal instance for the channel on streaming node of given server id.

View File

@ -53,6 +53,29 @@ func (c *managerClientImpl) WatchNodeChanged(ctx context.Context) (<-chan struct
return resultCh, nil
}
// GetAllStreamingNodes fetches all streaming node info.
func (c *managerClientImpl) GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error) {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
return nil, status.NewOnShutdownError("manager client is closing")
}
defer c.lifetime.Done()
// Get all discovered streamingnode.
state, err := c.rb.Resolver().GetLatestState(ctx)
if err != nil {
return nil, err
}
result := make(map[int64]*types.StreamingNodeInfo, len(state.State.Addresses))
for serverID, session := range state.Sessions() {
result[serverID] = &types.StreamingNodeInfo{
ServerID: serverID,
Address: session.Address,
}
}
return result, nil
}
// CollectAllStatus collects status in all underlying streamingnode.
func (c *managerClientImpl) CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {

View File

@ -74,6 +74,8 @@ func TestManager(t *testing.T) {
states := []map[uint64]bool{
{1: false, 2: false, 3: true},
{1: true, 2: false},
{1: true, 2: false},
{1: true, 2: false},
}
r.EXPECT().GetLatestState(mock.Anything).Unset()
r.EXPECT().GetLatestState(mock.Anything).RunAndReturn(func(ctx context.Context) (discoverer.VersionedState, error) {
@ -88,6 +90,10 @@ func TestManager(t *testing.T) {
assert.ErrorIs(t, nodes[3].Err, types.ErrNotAlive)
assert.ErrorIs(t, nodes[1].Err, types.ErrStopping)
nodeInfos, err := m.GetAllStreamingNodes(context.Background())
assert.NoError(t, err)
assert.Len(t, nodeInfos, 2)
// Test Assign
serverID := int64(2)
managerServiceClient.EXPECT().Assign(mock.Anything, mock.Anything).RunAndReturn(
@ -123,6 +129,9 @@ func TestManager(t *testing.T) {
rb.EXPECT().Close().Return()
m.Close()
nodeInfos, err = m.GetAllStreamingNodes(context.Background())
assert.Nil(t, nodeInfos)
assert.Error(t, err)
nodes, err = m.CollectAllStatus(context.Background())
assert.Nil(t, nodes)
assert.Error(t, err)