diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index e3fd1e5121..ff8103e3b4 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -46,10 +46,13 @@ func (segPlan *SegmentAssignPlan) String() string { } type ChannelAssignPlan struct { - Channel *meta.DmChannel - Replica *meta.Replica - From int64 - To int64 + Channel *meta.DmChannel + Replica *meta.Replica + From int64 + To int64 + FromScore int64 + ToScore int64 + ChannelScore int64 } func (chanPlan *ChannelAssignPlan) String() string { diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index d69a676c46..d5be949695 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -393,7 +393,7 @@ func newNodeItem(currentScore int, nodeID int64) nodeItem { func (b *nodeItem) getPriority() int { // if node lacks more score between assignedScore and currentScore, then higher priority - return int(b.currentScore - b.assignedScore) + return int(math.Ceil(b.currentScore - b.assignedScore)) } func (b *nodeItem) setPriority(priority int) { diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index f7842a4ce4..8e19c8921d 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -191,19 +191,19 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64 } from := int64(-1) - // fromScore := int64(0) + fromScore := int64(0) if sourceNode != nil { from = sourceNode.nodeID - // fromScore = int64(sourceNode.getPriority()) + fromScore = int64(sourceNode.getPriority()) } plan := ChannelAssignPlan{ - From: from, - To: targetNode.nodeID, - Channel: ch, - // FromScore: fromScore, - // ToScore: int64(targetNode.getPriority()), - // SegmentScore: int64(scoreChanges), + From: from, + To: targetNode.nodeID, + Channel: ch, + FromScore: fromScore, + ToScore: int64(targetNode.getPriority()), + ChannelScore: int64(scoreChanges), } br.AddRecord(StrRecordf("add segment plan %s", plan)) plans = append(plans, plan) @@ -487,6 +487,20 @@ func (b *ScoreBasedBalancer) BalanceReplica(ctx context.Context, replica *meta.R return segmentPlans, channelPlans } +func (b *ScoreBasedBalancer) genStoppingChannelPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64, roNodes []int64) []ChannelAssignPlan { + channelPlans := make([]ChannelAssignPlan, 0) + for _, nodeID := range roNodes { + dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID)) + plans := b.AssignChannel(ctx, replica.GetCollectionID(), dmChannels, rwNodes, false) + for i := range plans { + plans[i].From = nodeID + plans[i].Replica = replica + } + channelPlans = append(channelPlans, plans...) + } + return channelPlans +} + func (b *ScoreBasedBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { segmentPlans := make([]SegmentAssignPlan, 0) for _, nodeID := range offlineNodes { diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index 3bf70e7894..f24d8394b2 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -23,6 +23,7 @@ import ( "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "go.uber.org/atomic" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" @@ -1470,3 +1471,127 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnChannelExclusive() _, channelPlans = suite.getCollectionBalancePlans(balancer, 3) suite.Len(channelPlans, 2) } + +func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnStoppingNode() { + ctx := context.Background() + balancer := suite.balancer + + // mock 10 collections with each collection has 1 channel + collectionNum := 10 + channelNum := 1 + for i := 1; i <= collectionNum; i++ { + collectionID := int64(i) + collection := utils.CreateTestCollection(collectionID, int32(1)) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, collectionID)) + balancer.meta.ReplicaManager.Spawn(ctx, collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil) + + channels := make([]*datapb.VchannelInfo, channelNum) + for i := 0; i < channelNum; i++ { + channels[i] = &datapb.VchannelInfo{CollectionID: collectionID, ChannelName: fmt.Sprintf("channel-%d-%d", collectionID, i)} + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return( + channels, nil, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe() + balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) + } + + // mock querynode-1 to node manager + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + nodeInfo.SetState(session.NodeStateNormal) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 1) + utils.RecoverAllCollection(balancer.meta) + + // mock channel distribution + channelDist := make([]*meta.DmChannel, 0) + for i := 1; i <= collectionNum; i++ { + collectionID := int64(i) + for i := 0; i < channelNum; i++ { + channelDist = append(channelDist, &meta.DmChannel{ + VchannelInfo: &datapb.VchannelInfo{CollectionID: collectionID, ChannelName: fmt.Sprintf("channel-%d-%d", collectionID, i)}, Node: 1, + }) + } + } + balancer.dist.ChannelDistManager.Update(1, channelDist...) + + // assert balance channel won't happens on 1 querynode + ret := make([]ChannelAssignPlan, 0) + for i := 1; i <= collectionNum; i++ { + collectionID := int64(i) + _, channelPlans := suite.getCollectionBalancePlans(balancer, collectionID) + ret = append(ret, channelPlans...) + } + suite.Len(ret, 0) + + // mock querynode-2 and querynode-3 to node manager + nodeInfo2 := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 2, + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + suite.balancer.nodeManager.Add(nodeInfo2) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 2) + // mock querynode-2 and querynode-3 to node manager + nodeInfo3 := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 3, + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + suite.balancer.nodeManager.Add(nodeInfo3) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 3) + utils.RecoverAllCollection(balancer.meta) + // mock querynode-1 to stopping, trigger stopping balance, expect to generate 10 balance channel task, and 5 for node-2, 5 for node-3 + nodeInfo.SetState(session.NodeStateStopping) + suite.balancer.meta.ResourceManager.HandleNodeDown(ctx, 1) + utils.RecoverAllCollection(balancer.meta) + + node2Counter := atomic.NewInt32(0) + node3Counter := atomic.NewInt32(0) + + suite.mockScheduler.ExpectedCalls = nil + suite.mockScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskDelta(mock.Anything, mock.Anything).RunAndReturn(func(nodeID, collection int64) int { + if collection == -1 { + if nodeID == 2 { + return int(node2Counter.Load()) + } + + if nodeID == 3 { + return int(node3Counter.Load()) + } + } + return 0 + }) + suite.mockScheduler.EXPECT().GetSegmentTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + suite.mockScheduler.EXPECT().GetChannelTaskNum(mock.Anything, mock.Anything).Return(0).Maybe() + + for i := 1; i <= collectionNum; i++ { + collectionID := int64(i) + _, channelPlans := suite.getCollectionBalancePlans(balancer, collectionID) + suite.Len(channelPlans, 1) + if channelPlans[0].To == 2 { + node2Counter.Inc() + } + + if channelPlans[0].To == 3 { + node3Counter.Inc() + } + + if i%2 == 0 { + suite.Equal(node2Counter.Load(), node3Counter.Load()) + } + } + suite.Equal(node2Counter.Load(), int32(5)) + suite.Equal(node3Counter.Load(), int32(5)) +}