diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index c1f1f6f178..b0112a176c 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type RowCountBasedBalancer struct { @@ -175,26 +176,31 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) ([]Segment return nil, nil } - if len(offlineNodes) > 0 { - log.Info("Balance for stopping nodes", - zap.Any("stoppingNodes", offlineNodes), - zap.Any("onlineNodes", onlineNodes), + segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) + if len(offlineNodes) != 0 { + log.Info("Handle stopping nodes", + zap.Int64("collection", replica.CollectionID), + zap.Int64("replica id", replica.Replica.GetID()), + zap.String("replica group", replica.Replica.GetResourceGroup()), + zap.Any("stopping nodes", offlineNodes), + zap.Any("available nodes", onlineNodes), ) - - channelPlans := b.genStoppingChannelPlan(replica, onlineNodes, offlineNodes) + // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, onlineNodes, offlineNodes)...) if len(channelPlans) == 0 { - return b.genStoppingSegmentPlan(replica, onlineNodes, offlineNodes), nil + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, onlineNodes, offlineNodes)...) + } + } else { + if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { + channelPlans = append(channelPlans, b.genChannelPlan(replica, onlineNodes)...) + } + + if len(channelPlans) == 0 { + segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, onlineNodes)...) } - return nil, channelPlans } - // segment balance will count the growing row num in delegator, so it's better to balance channel first, - // to avoid balance segment again after balance channel - channelPlans := b.genChannelPlan(replica, onlineNodes) - if len(channelPlans) == 0 { - return b.genSegmentPlan(replica, onlineNodes), nil - } - return nil, channelPlans + return segmentPlans, channelPlans } func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index 92aa83acb2..81dfb7b8f2 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -17,6 +17,7 @@ package balance import ( + "fmt" "testing" "github.com/samber/lo" @@ -872,6 +873,139 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { } } +func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() { + cases := []struct { + name string + nodes []int64 + notExistedNodes []int64 + segmentCnts []int + states []session.State + shouldMock bool + distributions map[int64][]*meta.Segment + distributionChannels map[int64][]*meta.DmChannel + expectPlans []SegmentAssignPlan + expectChannelPlans []ChannelAssignPlan + multiple bool + enableBalanceChannel bool + }{ + { + name: "balance channel", + nodes: []int64{2, 3}, + segmentCnts: []int{2, 2}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, + shouldMock: true, + distributions: map[int64][]*meta.Segment{}, + distributionChannels: map[int64][]*meta.DmChannel{ + 2: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v2"}, Node: 2}, + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 2}, + }, + 3: {}, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{ + {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 2}, From: 2, To: 3, ReplicaID: 1}, + }, + enableBalanceChannel: true, + }, + + { + name: "disable balance channel", + nodes: []int64{2, 3}, + segmentCnts: []int{2, 2}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, + shouldMock: true, + distributions: map[int64][]*meta.Segment{}, + distributionChannels: map[int64][]*meta.DmChannel{ + 2: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v2"}, Node: 2}, + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 2}, + }, + 3: {}, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + enableBalanceChannel: false, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + segments := []*datapb.SegmentInfo{ + { + ID: 1, + PartitionID: 1, + }, + { + ID: 2, + PartitionID: 1, + }, + { + ID: 3, + PartitionID: 1, + }, + { + ID: 4, + PartitionID: 1, + }, + { + ID: 5, + PartitionID: 1, + }, + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil) + collection := utils.CreateTestCollection(1, 1) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + collection.LoadType = querypb.LoadType_LoadCollection + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) + suite.broker.ExpectedCalls = nil + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil) + balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) + balancer.targetMgr.UpdateCollectionCurrentTarget(1) + balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) + suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0).Maybe() + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for node, v := range c.distributionChannels { + balancer.dist.ChannelDistManager.Update(node, v...) + } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, c.nodes[i]) + } + + Params.Save(Params.QueryCoordCfg.AutoBalanceChannel.Key, fmt.Sprint(c.enableBalanceChannel)) + segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, 1) + if !c.multiple { + suite.ElementsMatch(c.expectChannelPlans, channelPlans) + suite.ElementsMatch(c.expectPlans, segmentPlans) + } else { + suite.Subset(c.expectPlans, segmentPlans) + suite.Subset(c.expectChannelPlans, channelPlans) + } + + // clear distribution + for node := range c.distributions { + balancer.dist.SegmentDistManager.Update(node) + } + for node := range c.distributionChannels { + balancer.dist.ChannelDistManager.Update(node) + } + }) + } +} + func TestRowCountBasedBalancerSuite(t *testing.T) { suite.Run(t, new(RowCountBasedBalancerTestSuite)) } diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 9bfd497768..8340638a14 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -193,13 +194,15 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, lo.Keys(nodesSegments), lo.Keys(stoppingNodesSegments))...) if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.getStoppedSegmentPlan(replica, nodesSegments, stoppingNodesSegments)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, nodesSegments, stoppingNodesSegments)...) } } else { - // normal balance, find segments from largest score nodes and transfer to smallest score nodes. - channelPlans = append(channelPlans, b.genChannelPlan(replica, lo.Keys(nodesSegments))...) + if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { + channelPlans = append(channelPlans, b.genChannelPlan(replica, lo.Keys(nodesSegments))...) + } + if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.getNormalSegmentPlan(replica, nodesSegments)...) + segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, nodesSegments)...) } } @@ -210,7 +213,7 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss return segmentPlans, channelPlans } -func (b *ScoreBasedBalancer) getStoppedSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment, stoppingNodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan { +func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment, stoppingNodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan { segmentPlans := make([]SegmentAssignPlan, 0) // generate candidates nodeItems := b.convertToNodeItems(replica.GetCollectionID(), lo.Keys(nodesSegments)) @@ -253,7 +256,7 @@ func (b *ScoreBasedBalancer) getStoppedSegmentPlan(replica *meta.Replica, nodesS return segmentPlans } -func (b *ScoreBasedBalancer) getNormalSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan { +func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan { segmentPlans := make([]SegmentAssignPlan, 0) // generate candidates diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index dc02775ce1..da4b7f5999 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1248,6 +1248,7 @@ type queryCoordConfig struct { // ---- Balance --- AutoBalance ParamItem `refreshable:"true"` + AutoBalanceChannel ParamItem `refreshable:"true"` Balancer ParamItem `refreshable:"true"` GlobalRowCountFactor ParamItem `refreshable:"true"` ScoreUnbalanceTolerationFactor ParamItem `refreshable:"true"` @@ -1340,6 +1341,16 @@ func (p *queryCoordConfig) init(base *BaseTable) { } p.AutoBalance.Init(base.mgr) + p.AutoBalanceChannel = ParamItem{ + Key: "queryCoord.autoBalanceChannel", + Version: "2.3.4", + DefaultValue: "true", + PanicIfEmpty: true, + Doc: "Enable auto balance channel", + Export: true, + } + p.AutoBalanceChannel.Init(base.mgr) + p.Balancer = ParamItem{ Key: "queryCoord.balancer", Version: "2.0.0", diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index c8f5b6e591..c6ff578230 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -280,6 +280,7 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 10000, Params.IndexCheckInterval.GetAsInt()) assert.Equal(t, 3, Params.CollectionRecoverTimesLimit.GetAsInt()) assert.Equal(t, false, Params.AutoBalance.GetAsBool()) + assert.Equal(t, true, Params.AutoBalanceChannel.GetAsBool()) assert.Equal(t, 10, Params.CheckAutoBalanceConfigInterval.GetAsInt()) })