diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 88b1d55c17..5c926df966 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -658,6 +658,10 @@ message PartitionLoadInfo { int32 recover_times = 7; } +message ChannelNodeInfo { + repeated int64 rw_nodes =6; +} + message Replica { int64 ID = 1; int64 collectionID = 2; @@ -665,6 +669,7 @@ message Replica { string resource_group = 4; repeated int64 ro_nodes = 5; // the in-using node but should not be assigned to these replica. // can not load new channel or segment on it anymore. + map channel_node_infos = 6; } enum SyncType { diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index c2f3ff8296..26e1acc108 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -52,13 +52,6 @@ func (chanPlan *ChannelAssignPlan) ToString() string { chanPlan.Channel.CollectionID, chanPlan.Channel.ChannelName, chanPlan.Replica.GetID(), chanPlan.From, chanPlan.To) } -var ( - RoundRobinBalancerName = "RoundRobinBalancer" - RowCountBasedBalancerName = "RowCountBasedBalancer" - ScoreBasedBalancerName = "ScoreBasedBalancer" - MultiTargetBalancerName = "MultipleTargetBalancer" -) - type Balance interface { AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan diff --git a/internal/querycoordv2/balance/channel_level_score_balancer.go b/internal/querycoordv2/balance/channel_level_score_balancer.go new file mode 100644 index 0000000000..5e5e69d7c4 --- /dev/null +++ b/internal/querycoordv2/balance/channel_level_score_balancer.go @@ -0,0 +1,282 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balance + +import ( + "math" + "sort" + + "github.com/samber/lo" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// score based segment use (collection_row_count + global_row_count * factor) as node' score +// and try to make each node has almost same score through balance segment. +type ChannelLevelScoreBalancer struct { + *ScoreBasedBalancer +} + +func NewChannelLevelScoreBalancer(scheduler task.Scheduler, + nodeManager *session.NodeManager, + dist *meta.DistributionManager, + meta *meta.Meta, + targetMgr *meta.TargetManager, +) *ChannelLevelScoreBalancer { + return &ChannelLevelScoreBalancer{ + ScoreBasedBalancer: NewScoreBasedBalancer(scheduler, nodeManager, dist, meta, targetMgr), + } +} + +func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) { + log := log.With( + zap.Int64("collection", replica.GetCollectionID()), + zap.Int64("replica id", replica.GetID()), + zap.String("replica group", replica.GetResourceGroup()), + ) + + exclusiveMode := true + channels := b.targetMgr.GetDmChannelsByCollection(replica.GetCollectionID(), meta.CurrentTarget) + for channelName := range channels { + if len(replica.GetChannelRWNodes(channelName)) == 0 { + exclusiveMode = false + break + } + } + + // if some channel doesn't own nodes, exit exclusive mode + if !exclusiveMode { + return b.ScoreBasedBalancer.BalanceReplica(replica) + } + + channelPlans := make([]ChannelAssignPlan, 0) + segmentPlans := make([]SegmentAssignPlan, 0) + for channelName := range channels { + if replica.NodesCount() == 0 { + return nil, nil + } + + onlineNodes := make([]int64, 0) + offlineNodes := make([]int64, 0) + // read only nodes is offline in current replica. + if replica.RONodesCount() > 0 { + // if node is stop or transfer to other rg + log.RatedInfo(10, "meet read only node, try to move out all segment/channel", zap.Int64s("node", replica.GetRONodes())) + offlineNodes = append(offlineNodes, replica.GetRONodes()...) + } + + // mark channel's outbound access node as offline + channelRWNode := typeutil.NewUniqueSet(replica.GetChannelRWNodes(channelName)...) + channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithChannelName2Channel(channelName), meta.WithReplica2Channel(replica)) + for _, channel := range channelDist { + if !channelRWNode.Contain(channel.Node) { + offlineNodes = append(offlineNodes, channel.Node) + } + } + segmentDist := b.dist.SegmentDistManager.GetByFilter(meta.WithChannel(channelName), meta.WithReplica(replica)) + for _, segment := range segmentDist { + if !channelRWNode.Contain(segment.Node) { + offlineNodes = append(offlineNodes, segment.Node) + } + } + + for nid := range channelRWNode { + if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { + log.Info("not existed node", zap.Int64("nid", nid), zap.Error(err)) + continue + } else if isStopping { + offlineNodes = append(offlineNodes, nid) + } else { + onlineNodes = append(onlineNodes, nid) + } + } + + if len(onlineNodes) == 0 { + // no available nodes to balance + return nil, nil + } + + if len(offlineNodes) != 0 { + if !paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { + log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", offlineNodes)) + return nil, nil + } + + log.Info("Handle stopping nodes", + zap.Any("stopping nodes", offlineNodes), + zap.Any("available nodes", onlineNodes), + ) + // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, channelName, onlineNodes, offlineNodes)...) + if len(channelPlans) == 0 { + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, channelName, onlineNodes, offlineNodes)...) + } + } else { + if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { + channelPlans = append(channelPlans, b.genChannelPlan(replica, channelName, onlineNodes)...) + } + + if len(channelPlans) == 0 { + segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, channelName, onlineNodes)...) + } + } + } + + return segmentPlans, channelPlans +} + +func (b *ChannelLevelScoreBalancer) genStoppingChannelPlan(replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan { + channelPlans := make([]ChannelAssignPlan, 0) + for _, nodeID := range offlineNodes { + dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID), meta.WithChannelName2Channel(channelName)) + plans := b.AssignChannel(dmChannels, onlineNodes, false) + for i := range plans { + plans[i].From = nodeID + plans[i].Replica = replica + } + channelPlans = append(channelPlans, plans...) + } + return channelPlans +} + +func (b *ChannelLevelScoreBalancer) genStoppingSegmentPlan(replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { + segmentPlans := make([]SegmentAssignPlan, 0) + for _, nodeID := range offlineNodes { + dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID), meta.WithChannel(channelName)) + segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { + return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil && + b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil && + segment.GetLevel() != datapb.SegmentLevel_L0 + }) + plans := b.AssignSegment(replica.GetCollectionID(), segments, onlineNodes, false) + for i := range plans { + plans[i].From = nodeID + plans[i].Replica = replica + } + segmentPlans = append(segmentPlans, plans...) + } + return segmentPlans +} + +func (b *ChannelLevelScoreBalancer) genSegmentPlan(replica *meta.Replica, channelName string, onlineNodes []int64) []SegmentAssignPlan { + segmentDist := make(map[int64][]*meta.Segment) + nodeScore := make(map[int64]int, 0) + totalScore := 0 + + // list all segment which could be balanced, and calculate node's score + for _, node := range onlineNodes { + dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node), meta.WithChannel(channelName)) + segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { + return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil && + b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil && + segment.GetLevel() != datapb.SegmentLevel_L0 + }) + segmentDist[node] = segments + + rowCount := b.calculateScore(replica.GetCollectionID(), node) + totalScore += rowCount + nodeScore[node] = rowCount + } + + if totalScore == 0 { + return nil + } + + // find the segment from the node which has more score than the average + segmentsToMove := make([]*meta.Segment, 0) + average := totalScore / len(onlineNodes) + for node, segments := range segmentDist { + leftScore := nodeScore[node] + if leftScore <= average { + continue + } + + sort.Slice(segments, func(i, j int) bool { + return segments[i].GetNumOfRows() < segments[j].GetNumOfRows() + }) + for _, s := range segments { + segmentsToMove = append(segmentsToMove, s) + leftScore -= b.calculateSegmentScore(s) + if leftScore <= average { + break + } + } + } + + // if the segment are redundant, skip it's balance for now + segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool { + return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1 + }) + + if len(segmentsToMove) == 0 { + return nil + } + + segmentPlans := b.AssignSegment(replica.GetCollectionID(), segmentsToMove, onlineNodes, false) + for i := range segmentPlans { + segmentPlans[i].From = segmentPlans[i].Segment.Node + segmentPlans[i].Replica = replica + } + + return segmentPlans +} + +func (b *ChannelLevelScoreBalancer) genChannelPlan(replica *meta.Replica, channelName string, onlineNodes []int64) []ChannelAssignPlan { + channelPlans := make([]ChannelAssignPlan, 0) + if len(onlineNodes) > 1 { + // start to balance channels on all available nodes + channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica), meta.WithChannelName2Channel(channelName)) + if len(channelDist) == 0 { + return nil + } + average := int(math.Ceil(float64(len(channelDist)) / float64(len(onlineNodes)))) + + // find nodes with less channel count than average + nodeWithLessChannel := make([]int64, 0) + channelsToMove := make([]*meta.DmChannel, 0) + for _, node := range onlineNodes { + channels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(node)) + + if len(channels) <= average { + nodeWithLessChannel = append(nodeWithLessChannel, node) + continue + } + + channelsToMove = append(channelsToMove, channels[average:]...) + } + + if len(nodeWithLessChannel) == 0 || len(channelsToMove) == 0 { + return nil + } + + channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false) + for i := range channelPlans { + channelPlans[i].From = channelPlans[i].Channel.Node + channelPlans[i].Replica = replica + } + + return channelPlans + } + return channelPlans +} diff --git a/internal/querycoordv2/balance/channel_level_score_balancer_test.go b/internal/querycoordv2/balance/channel_level_score_balancer_test.go new file mode 100644 index 0000000000..87c0841c71 --- /dev/null +++ b/internal/querycoordv2/balance/channel_level_score_balancer_test.go @@ -0,0 +1,1312 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package balance + +import ( + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/kv" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + . "github.com/milvus-io/milvus/internal/querycoordv2/params" + "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type ChannelLevelScoreBalancerTestSuite struct { + suite.Suite + balancer *ChannelLevelScoreBalancer + kv kv.MetaKv + broker *meta.MockBroker + mockScheduler *task.MockScheduler +} + +func (suite *ChannelLevelScoreBalancerTestSuite) SetupSuite() { + paramtable.Init() +} + +func (suite *ChannelLevelScoreBalancerTestSuite) SetupTest() { + var err error + config := GenerateEtcdConfig() + cli, err := etcd.GetEtcdClient( + config.UseEmbedEtcd.GetAsBool(), + config.EtcdUseSSL.GetAsBool(), + config.Endpoints.GetAsStrings(), + config.EtcdTLSCert.GetValue(), + config.EtcdTLSKey.GetValue(), + config.EtcdTLSCACert.GetValue(), + config.EtcdTLSMinVersion.GetValue()) + suite.Require().NoError(err) + suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.broker = meta.NewMockBroker(suite.T()) + + store := querycoord.NewCatalog(suite.kv) + idAllocator := RandomIncrementIDAllocator() + nodeManager := session.NewNodeManager() + testMeta := meta.NewMeta(idAllocator, store, nodeManager) + testTarget := meta.NewTargetManager(suite.broker, testMeta) + + distManager := meta.NewDistributionManager() + suite.mockScheduler = task.NewMockScheduler(suite.T()) + suite.balancer = NewChannelLevelScoreBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget) +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TearDownTest() { + suite.kv.Close() +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegment() { + cases := []struct { + name string + comment string + distributions map[int64][]*meta.Segment + assignments [][]*meta.Segment + nodes []int64 + collectionIDs []int64 + segmentCnts []int + states []session.State + expectPlans [][]SegmentAssignPlan + }{ + { + name: "test empty cluster assigning one collection", + comment: "this is most simple case in which global row count is zero for all nodes", + distributions: map[int64][]*meta.Segment{}, + assignments: [][]*meta.Segment{ + { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 5, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 10, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 15, CollectionID: 1}}, + }, + }, + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{0}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + segmentCnts: []int{0, 0, 0}, + expectPlans: [][]SegmentAssignPlan{ + { + // as assign segments is used while loading collection, + // all assignPlan should have weight equal to 1(HIGH PRIORITY) + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 3, NumOfRows: 15, + CollectionID: 1, + }}, From: -1, To: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 2, NumOfRows: 10, + CollectionID: 1, + }}, From: -1, To: 3}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 1, NumOfRows: 5, + CollectionID: 1, + }}, From: -1, To: 2}, + }, + }, + }, + { + name: "test non-empty cluster assigning one collection", + comment: "this case will verify the effect of global row for loading segments process, although node1" + + "has only 10 rows at the beginning, but it has so many rows on global view, resulting in a lower priority", + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 10, CollectionID: 1}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 300, CollectionID: 2}, Node: 1}, + // base: collection1-node1-priority is 10 + 0.1 * 310 = 41 + // assign3: collection1-node1-priority is 15 + 0.1 * 315 = 46.5 + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 20, CollectionID: 1}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 180, CollectionID: 2}, Node: 2}, + // base: collection1-node2-priority is 20 + 0.1 * 200 = 40 + // assign2: collection1-node2-priority is 30 + 0.1 * 210 = 51 + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 30, CollectionID: 1}, Node: 3}, + {SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 20, CollectionID: 2}, Node: 3}, + // base: collection1-node2-priority is 30 + 0.1 * 50 = 35 + // assign1: collection1-node2-priority is 45 + 0.1 * 65 = 51.5 + }, + }, + assignments: [][]*meta.Segment{ + { + {SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 5, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 8, NumOfRows: 10, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 9, NumOfRows: 15, CollectionID: 1}}, + }, + }, + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{1}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + segmentCnts: []int{0, 0, 0}, + expectPlans: [][]SegmentAssignPlan{ + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 9, NumOfRows: 15, CollectionID: 1}}, From: -1, To: 3}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 8, NumOfRows: 10, CollectionID: 1}}, From: -1, To: 2}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 5, CollectionID: 1}}, From: -1, To: 1}, + }, + }, + }, + { + name: "test non-empty cluster assigning two collections at one round segment checking", + comment: "this case is used to demonstrate the existing assign mechanism having flaws when assigning " + + "multi collections at one round by using the only segment distribution", + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 10, CollectionID: 1}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 20, CollectionID: 1}, Node: 2}, + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 40, CollectionID: 1}, Node: 3}, + }, + }, + assignments: [][]*meta.Segment{ + { + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 60, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 50, CollectionID: 1}}, + }, + { + {SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 15, CollectionID: 2}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 10, CollectionID: 2}}, + }, + }, + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{1, 2}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + segmentCnts: []int{0, 0, 0}, + expectPlans: [][]SegmentAssignPlan{ + // note that these two segments plans are absolutely unbalanced globally, + // as if the assignment for collection1 could succeed, node1 and node2 will both have 70 rows + // much more than node3, but following assignment will still assign segment based on [10,20,40] + // rather than [70,70,40], this flaw will be mitigated by balance process and maybe fixed in the later versions + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 60, CollectionID: 1}}, From: -1, To: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 50, CollectionID: 1}}, From: -1, To: 2}, + }, + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 15, CollectionID: 2}}, From: -1, To: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 10, CollectionID: 2}}, From: -1, To: 2}, + }, + }, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: c.nodes[i], + Address: "127.0.0.1:0", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } + for i := range c.collectionIDs { + plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false) + assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans[i], plans) + } + }) + } +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestSuspendNode() { + cases := []struct { + name string + distributions map[int64][]*meta.Segment + assignments []*meta.Segment + nodes []int64 + segmentCnts []int + states []session.State + expectPlans []SegmentAssignPlan + }{ + { + name: "test suspend node", + distributions: map[int64][]*meta.Segment{ + 2: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20}, Node: 2}}, + 3: {{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 30}, Node: 3}}, + }, + assignments: []*meta.Segment{ + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}}, + }, + nodes: []int64{1, 2, 3, 4}, + states: []session.State{session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend}, + segmentCnts: []int{0, 1, 1, 0}, + expectPlans: []SegmentAssignPlan{}, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + // I do not find a better way to do the setup and teardown work for subtests yet. + // If you do, please replace with it. + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: c.nodes[i], + Address: "localhost", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } + plans := balancer.AssignSegment(0, c.assignments, c.nodes, false) + // all node has been suspend, so no node to assign segment + suite.ElementsMatch(c.expectPlans, plans) + }) + } +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegmentWithGrowing() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + distributions := map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20, CollectionID: 1}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 20, CollectionID: 1}, Node: 2}, + }, + } + for node, s := range distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + + for _, node := range lo.Keys(distributions) { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: node, + Address: "127.0.0.1:0", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithSegmentCnt(20)) + nodeInfo.SetState(session.NodeStateNormal) + suite.balancer.nodeManager.Add(nodeInfo) + } + + toAssign := []*meta.Segment{ + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 10, CollectionID: 1}, Node: 3}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10, CollectionID: 1}, Node: 3}, + } + + // mock 50 growing row count in node 1, which is delegator, expect all segment assign to node 2 + leaderView := &meta.LeaderView{ + ID: 1, + CollectionID: 1, + NumOfGrowingRows: 50, + } + suite.balancer.dist.LeaderViewManager.Update(1, leaderView) + plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false) + for _, p := range plans { + suite.Equal(int64(2), p.To) + } +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() { + cases := []struct { + name string + nodes []int64 + collectionID int64 + replicaID int64 + segments []*datapb.SegmentInfo + channels []*datapb.VchannelInfo + states []session.State + shouldMock bool + distributions map[int64][]*meta.Segment + distributionChannels map[int64][]*meta.DmChannel + expectPlans []SegmentAssignPlan + expectChannelPlans []ChannelAssignPlan + }{ + { + name: "normal balance for one collection only", + nodes: []int64{1, 2}, + collectionID: 1, + replicaID: 1, + segments: []*datapb.SegmentInfo{ + {ID: 1, PartitionID: 1}, {ID: 2, PartitionID: 1}, {ID: 3, PartitionID: 1}, + }, + channels: []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}}, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{ + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, From: 2, To: 1, Replica: newReplicaDefaultRG(1)}, + }, + expectChannelPlans: []ChannelAssignPlan{}, + }, + { + name: "already balanced for one collection only", + nodes: []int64{1, 2}, + collectionID: 1, + replicaID: 1, + segments: []*datapb.SegmentInfo{ + {ID: 1, PartitionID: 1}, {ID: 2, PartitionID: 1}, {ID: 3, PartitionID: 1}, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + // 1. set up target for multi collections + collection := utils.CreateTestCollection(c.collectionID, int32(c.replicaID)) + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, c.collectionID).Return( + c.channels, c.segments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + + // 2. set up target for distribution for multi collections + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for node, v := range c.distributionChannels { + balancer.dist.ChannelDistManager.Update(node, v...) + } + + // 3. set up nodes info and resourceManager for balancer + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: c.nodes[i], + Address: "127.0.0.1:0", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + } + + // 4. balance and verify result + segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, c.collectionID) + assertChannelAssignPlanElementMatch(&suite.Suite, c.expectChannelPlans, channelPlans) + assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, segmentPlans) + }) + } +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() { + balanceCase := struct { + name string + nodes []int64 + notExistedNodes []int64 + collectionIDs []int64 + replicaIDs []int64 + segments [][]*datapb.SegmentInfo + channels []*datapb.VchannelInfo + states []session.State + shouldMock bool + distributions []map[int64][]*meta.Segment + expectPlans [][]SegmentAssignPlan + }{ + name: "balance considering both global rowCounts and collection rowCounts", + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{1, 2}, + replicaIDs: []int64{1, 2}, + segments: [][]*datapb.SegmentInfo{ + { + {ID: 1, PartitionID: 1}, + {ID: 3, PartitionID: 1}, + }, + { + {ID: 2, PartitionID: 2}, + {ID: 4, PartitionID: 2}, + }, + }, + channels: []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + distributions: []map[int64][]*meta.Segment{ + { + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 20}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 2, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 2, NumOfRows: 30}, Node: 2}, + }, + }, + { + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 20}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 2, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 2, NumOfRows: 30}, Node: 2}, + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, Node: 3}, + }, + }, + }, + expectPlans: [][]SegmentAssignPlan{ + { + { + Segment: &meta.Segment{ + SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, + Node: 2, + }, From: 2, To: 3, Replica: newReplicaDefaultRG(1), + }, + }, + {}, + }, + } + + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + // 1. set up target for multi collections + for i := range balanceCase.collectionIDs { + collection := utils.CreateTestCollection(balanceCase.collectionIDs[i], int32(balanceCase.replicaIDs[i])) + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, balanceCase.collectionIDs[i]).Return( + balanceCase.channels, balanceCase.segments[i], nil) + + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + collection.LoadType = querypb.LoadType_LoadCollection + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i], + append(balanceCase.nodes, balanceCase.notExistedNodes...))) + balancer.targetMgr.UpdateCollectionNextTarget(balanceCase.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionNextTarget(balanceCase.collectionIDs[i]) + } + + // 2. set up target for distribution for multi collections + for node, s := range balanceCase.distributions[0] { + balancer.dist.SegmentDistManager.Update(node, s...) + } + + // 3. set up nodes info and resourceManager for balancer + for i := range balanceCase.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: balanceCase.nodes[i], + Address: "127.0.0.1:0", + Hostname: "localhost", + }) + nodeInfo.SetState(balanceCase.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(balanceCase.nodes[i]) + } + + // 4. first round balance + segmentPlans, _ := suite.getCollectionBalancePlans(balancer, balanceCase.collectionIDs[0]) + assertSegmentAssignPlanElementMatch(&suite.Suite, balanceCase.expectPlans[0], segmentPlans) + + // 5. update segment distribution to simulate balance effect + for node, s := range balanceCase.distributions[1] { + balancer.dist.SegmentDistManager.Update(node, s...) + } + + // 6. balance again + segmentPlans, _ = suite.getCollectionBalancePlans(balancer, balanceCase.collectionIDs[1]) + assertSegmentAssignPlanElementMatch(&suite.Suite, balanceCase.expectPlans[1], segmentPlans) +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() { + cases := []struct { + name string + nodes []int64 + outBoundNodes []int64 + collectionID int64 + replicaID int64 + segments []*datapb.SegmentInfo + channels []*datapb.VchannelInfo + states []session.State + shouldMock bool + distributions map[int64][]*meta.Segment + distributionChannels map[int64][]*meta.DmChannel + expectPlans []SegmentAssignPlan + expectChannelPlans []ChannelAssignPlan + }{ + { + name: "stopped balance for one collection", + nodes: []int64{1, 2, 3}, + outBoundNodes: []int64{}, + collectionID: 1, + replicaID: 1, + segments: []*datapb.SegmentInfo{ + {ID: 1, PartitionID: 1}, {ID: 2, PartitionID: 1}, {ID: 3, PartitionID: 1}, + }, + channels: []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + }, + states: []session.State{session.NodeStateStopping, session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{ + {Segment: &meta.Segment{ + SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, + Node: 1, + }, From: 1, To: 3, Replica: newReplicaDefaultRG(1)}, + {Segment: &meta.Segment{ + SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, + Node: 1, + }, From: 1, To: 3, Replica: newReplicaDefaultRG(1)}, + }, + expectChannelPlans: []ChannelAssignPlan{}, + }, + { + name: "all nodes stopping", + nodes: []int64{1, 2, 3}, + outBoundNodes: []int64{}, + collectionID: 1, + replicaID: 1, + segments: []*datapb.SegmentInfo{ + {ID: 1}, {ID: 2}, {ID: 3}, + }, + channels: []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + }, + states: []session.State{session.NodeStateStopping, session.NodeStateStopping, session.NodeStateStopping}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + }, + { + name: "all nodes outbound", + nodes: []int64{1, 2, 3}, + outBoundNodes: []int64{1, 2, 3}, + collectionID: 1, + replicaID: 1, + segments: []*datapb.SegmentInfo{ + {ID: 1}, {ID: 2}, {ID: 3}, + }, + channels: []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + }, + } + for i, c := range cases { + suite.Run(c.name, func() { + if i == 0 { + suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0).Maybe() + } + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + // 1. set up target for multi collections + collection := utils.CreateTestCollection(c.collectionID, int32(c.replicaID)) + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, c.collectionID).Return( + c.channels, c.segments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + + // 2. set up target for distribution for multi collections + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for node, v := range c.distributionChannels { + balancer.dist.ChannelDistManager.Update(node, v...) + } + + // 3. set up nodes info and resourceManager for balancer + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: c.nodes[i], + Address: "127.0.0.1:0", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + } + + for i := range c.outBoundNodes { + suite.balancer.meta.ResourceManager.HandleNodeDown(c.outBoundNodes[i]) + } + utils.RecoverAllCollection(balancer.meta) + + // 4. balance and verify result + segmentPlans, channelPlans := suite.getCollectionBalancePlans(suite.balancer, c.collectionID) + assertChannelAssignPlanElementMatch(&suite.Suite, c.expectChannelPlans, channelPlans) + assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, segmentPlans) + }) + } +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() { + cases := []struct { + name string + collectionID int64 + replicaWithNodes map[int64][]int64 + segments []*datapb.SegmentInfo + channels []*datapb.VchannelInfo + states []session.State + shouldMock bool + segmentDist map[int64][]*meta.Segment + channelDist map[int64][]*meta.DmChannel + expectPlans []SegmentAssignPlan + expectChannelPlans []ChannelAssignPlan + }{ + { + name: "normal balance for one collection only", + collectionID: 1, + replicaWithNodes: map[int64][]int64{1: {1, 2}, 2: {3, 4}}, + segments: []*datapb.SegmentInfo{ + {ID: 1, CollectionID: 1, PartitionID: 1}, + {ID: 2, CollectionID: 1, PartitionID: 1}, + {ID: 3, CollectionID: 1, PartitionID: 1}, + {ID: 4, CollectionID: 1, PartitionID: 1}, + }, + channels: []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + { + CollectionID: 1, ChannelName: "channel2", FlushedSegmentIds: []int64{2}, + }, + { + CollectionID: 1, ChannelName: "channel3", FlushedSegmentIds: []int64{3}, + }, + { + CollectionID: 1, ChannelName: "channel4", FlushedSegmentIds: []int64{4}, + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, + segmentDist: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 30}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 30}, Node: 1}, + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 3}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 1, NumOfRows: 30}, Node: 3}, + }, + }, + channelDist: map[int64][]*meta.DmChannel{ + 1: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel1"}, Node: 1}, + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel2"}, Node: 1}, + }, + 3: { + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel3"}, Node: 3}, + {VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel4"}, Node: 3}, + }, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + // 1. set up target for multi collections + collection := utils.CreateTestCollection(c.collectionID, int32(len(c.replicaWithNodes))) + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, c.collectionID).Return( + c.channels, c.segments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) + for replicaID, nodes := range c.replicaWithNodes { + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, c.collectionID, nodes)) + } + balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + + // 2. set up target for distribution for multi collections + for node, s := range c.segmentDist { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for node, v := range c.channelDist { + balancer.dist.ChannelDistManager.Update(node, v...) + } + + // 3. set up nodes info and resourceManager for balancer + for _, nodes := range c.replicaWithNodes { + for i := range nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodes[i], + Address: "127.0.0.1:0", + Version: common.Version, + }) + nodeInfo.UpdateStats(session.WithChannelCnt(len(c.channelDist[nodes[i]]))) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i]) + } + } + + // expected to balance channel first + segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, c.collectionID) + suite.Len(segmentPlans, 0) + suite.Len(channelPlans, 2) + + // mock new distribution after channel balance + balancer.dist.ChannelDistManager.Update(1, &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel1"}, Node: 1}) + balancer.dist.ChannelDistManager.Update(2, &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel2"}, Node: 2}) + balancer.dist.ChannelDistManager.Update(3, &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel3"}, Node: 3}) + balancer.dist.ChannelDistManager.Update(4, &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "channel4"}, Node: 4}) + + // expected to balance segment + segmentPlans, channelPlans = suite.getCollectionBalancePlans(balancer, c.collectionID) + suite.Len(segmentPlans, 2) + suite.Len(channelPlans, 0) + }) + } +} + +func (suite *ChannelLevelScoreBalancerTestSuite) getCollectionBalancePlans(balancer *ChannelLevelScoreBalancer, + collectionID int64, +) ([]SegmentAssignPlan, []ChannelAssignPlan) { + replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID) + segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) + for _, replica := range replicas { + sPlans, cPlans := balancer.BalanceReplica(replica) + segmentPlans = append(segmentPlans, sPlans...) + channelPlans = append(channelPlans, cPlans...) + } + return segmentPlans, channelPlans +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_ChannelOutBound() { + Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) + defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) + Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") + defer Params.Reset(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key) + + balancer := suite.balancer + + collectionID := int64(1) + partitionID := int64(1) + + // 1. set up target for multi collections + segments := []*datapb.SegmentInfo{ + {ID: 1, PartitionID: partitionID}, {ID: 2, PartitionID: partitionID}, + } + + channels := []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + { + CollectionID: 1, ChannelName: "channel2", + }, + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return( + channels, segments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe() + + collection := utils.CreateTestCollection(collectionID, int32(1)) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + + // 3. set up nodes info and resourceManager for balancer + nodeCount := 4 + for i := 0; i < nodeCount; i++ { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: int64(i), + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(session.NodeStateNormal) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + } + utils.RecoverAllCollection(balancer.meta) + + replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + ch1Nodes := replica.GetChannelRWNodes("channel1") + ch2Nodes := replica.GetChannelRWNodes("channel2") + suite.Len(ch1Nodes, 2) + suite.Len(ch2Nodes, 2) + + balancer.dist.ChannelDistManager.Update(ch1Nodes[0], []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "channel2", + }, + Node: ch1Nodes[0], + }, + }...) + + sPlans, cPlans := balancer.BalanceReplica(replica) + suite.Len(sPlans, 0) + suite.Len(cPlans, 1) +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_SegmentOutbound() { + Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) + defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) + Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") + defer Params.Reset(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key) + + balancer := suite.balancer + + collectionID := int64(1) + partitionID := int64(1) + + // 1. set up target for multi collections + segments := []*datapb.SegmentInfo{ + {ID: 1, PartitionID: partitionID}, {ID: 2, PartitionID: partitionID}, {ID: 3, PartitionID: partitionID}, + } + + channels := []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + { + CollectionID: 1, ChannelName: "channel2", + }, + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return( + channels, segments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe() + + collection := utils.CreateTestCollection(collectionID, int32(1)) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + + // 3. set up nodes info and resourceManager for balancer + nodeCount := 4 + for i := 0; i < nodeCount; i++ { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: int64(i), + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(session.NodeStateNormal) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + } + utils.RecoverAllCollection(balancer.meta) + + replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + ch1Nodes := replica.GetChannelRWNodes("channel1") + ch2Nodes := replica.GetChannelRWNodes("channel2") + suite.Len(ch1Nodes, 2) + suite.Len(ch2Nodes, 2) + + balancer.dist.ChannelDistManager.Update(ch1Nodes[0], []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "channel1", + }, + Node: ch1Nodes[0], + }, + }...) + + balancer.dist.ChannelDistManager.Update(ch2Nodes[0], []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "channel2", + }, + Node: ch2Nodes[0], + }, + }...) + + balancer.dist.SegmentDistManager.Update(ch1Nodes[0], []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segments[0].ID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10, + InsertChannel: "channel2", + }, + Node: ch1Nodes[0], + }, + }...) + + sPlans, cPlans := balancer.BalanceReplica(replica) + suite.Len(sPlans, 1) + suite.Len(cPlans, 0) +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_NodeStopping() { + Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) + defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) + Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") + defer Params.Reset(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key) + + balancer := suite.balancer + + collectionID := int64(1) + partitionID := int64(1) + + // 1. set up target for multi collections + segments := []*datapb.SegmentInfo{ + {ID: 1, PartitionID: partitionID}, {ID: 2, PartitionID: partitionID}, {ID: 3, PartitionID: partitionID}, + } + + channels := []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + { + CollectionID: 1, ChannelName: "channel2", + }, + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return( + channels, segments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe() + + collection := utils.CreateTestCollection(collectionID, int32(1)) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + + // 3. set up nodes info and resourceManager for balancer + nodeCount := 4 + for i := 0; i < nodeCount; i++ { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: int64(i), + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(session.NodeStateNormal) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + } + utils.RecoverAllCollection(balancer.meta) + + replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + ch1Nodes := replica.GetChannelRWNodes("channel1") + ch2Nodes := replica.GetChannelRWNodes("channel2") + suite.Len(ch1Nodes, 2) + suite.Len(ch2Nodes, 2) + + balancer.dist.ChannelDistManager.Update(ch1Nodes[0], []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "channel1", + }, + Node: ch1Nodes[0], + }, + }...) + + balancer.dist.ChannelDistManager.Update(ch2Nodes[0], []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "channel2", + }, + Node: ch2Nodes[0], + }, + }...) + + balancer.dist.SegmentDistManager.Update(ch1Nodes[0], []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segments[0].ID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10, + InsertChannel: "channel1", + }, + Node: ch1Nodes[0], + }, + }...) + + balancer.dist.SegmentDistManager.Update(ch2Nodes[0], []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segments[1].ID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10, + InsertChannel: "channel2", + }, + Node: ch2Nodes[0], + }, + }...) + + suite.balancer.nodeManager.Stopping(ch1Nodes[0]) + suite.balancer.nodeManager.Stopping(ch2Nodes[0]) + sPlans, cPlans := balancer.BalanceReplica(replica) + suite.Len(sPlans, 0) + suite.Len(cPlans, 2) + + balancer.dist.ChannelDistManager.Update(ch1Nodes[0]) + balancer.dist.ChannelDistManager.Update(ch2Nodes[0]) + + sPlans, cPlans = balancer.BalanceReplica(replica) + suite.Len(sPlans, 2) + suite.Len(cPlans, 0) +} + +func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_SegmentUnbalance() { + Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) + defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) + Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") + defer Params.Reset(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key) + + balancer := suite.balancer + + collectionID := int64(1) + partitionID := int64(1) + + // 1. set up target for multi collections + segments := []*datapb.SegmentInfo{ + {ID: 1, PartitionID: partitionID}, {ID: 2, PartitionID: partitionID}, {ID: 3, PartitionID: partitionID}, {ID: 4, PartitionID: partitionID}, + } + + channels := []*datapb.VchannelInfo{ + { + CollectionID: 1, ChannelName: "channel1", + }, + { + CollectionID: 1, ChannelName: "channel2", + }, + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return( + channels, segments, nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe() + + collection := utils.CreateTestCollection(collectionID, int32(1)) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(collectionID) + + // 3. set up nodes info and resourceManager for balancer + nodeCount := 4 + for i := 0; i < nodeCount; i++ { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: int64(i), + Address: "127.0.0.1:0", + Hostname: "localhost", + Version: common.Version, + }) + // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(session.NodeStateNormal) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + } + utils.RecoverAllCollection(balancer.meta) + + replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + ch1Nodes := replica.GetChannelRWNodes("channel1") + ch2Nodes := replica.GetChannelRWNodes("channel2") + suite.Len(ch1Nodes, 2) + suite.Len(ch2Nodes, 2) + + balancer.dist.ChannelDistManager.Update(ch1Nodes[0], []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "channel1", + }, + Node: ch1Nodes[0], + }, + }...) + + balancer.dist.ChannelDistManager.Update(ch2Nodes[0], []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "channel2", + }, + Node: ch2Nodes[0], + }, + }...) + + balancer.dist.SegmentDistManager.Update(ch1Nodes[0], []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segments[0].ID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10, + InsertChannel: "channel1", + }, + Node: ch1Nodes[0], + }, + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segments[1].ID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10, + InsertChannel: "channel1", + }, + Node: ch1Nodes[0], + }, + }...) + + balancer.dist.SegmentDistManager.Update(ch2Nodes[0], []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segments[2].ID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10, + InsertChannel: "channel2", + }, + Node: ch2Nodes[0], + }, + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segments[3].ID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10, + InsertChannel: "channel2", + }, + Node: ch2Nodes[0], + }, + }...) + + sPlans, cPlans := balancer.BalanceReplica(replica) + suite.Len(sPlans, 2) + suite.Len(cPlans, 0) +} + +func TestChannelLevelScoreBalancerSuite(t *testing.T) { + suite.Run(t, new(ChannelLevelScoreBalancerTestSuite)) +} diff --git a/internal/querycoordv2/checkers/channel_checker.go b/internal/querycoordv2/checkers/channel_checker.go index e8314b45e9..d220697416 100644 --- a/internal/querycoordv2/checkers/channel_checker.go +++ b/internal/querycoordv2/checkers/channel_checker.go @@ -217,7 +217,16 @@ func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int } func (c *ChannelChecker) createChannelLoadTask(ctx context.Context, channels []*meta.DmChannel, replica *meta.Replica) []task.Task { - plans := c.balancer.AssignChannel(channels, replica.GetNodes(), false) + plans := make([]balance.ChannelAssignPlan, 0) + for _, ch := range channels { + rwNodes := replica.GetChannelRWNodes(ch.GetChannelName()) + if len(rwNodes) == 0 { + rwNodes = replica.GetNodes() + } + plan := c.balancer.AssignChannel([]*meta.DmChannel{ch}, rwNodes, false) + plans = append(plans, plan...) + } + for i := range plans { plans[i].Replica = replica } diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index f7eb202aa1..17e9e7346f 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -383,19 +383,6 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] return nil } - // filter out stopping nodes. - availableNodes := lo.Filter(replica.GetNodes(), func(node int64, _ int) bool { - stop, err := c.nodeMgr.IsStoppingNode(node) - if err != nil { - return false - } - return !stop - }) - - if len(availableNodes) == 0 { - return nil - } - isLevel0 := segments[0].GetLevel() == datapb.SegmentLevel_L0 shardSegments := lo.GroupBy(segments, func(s *datapb.SegmentInfo) string { return s.GetInsertChannel() @@ -409,6 +396,24 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] continue } + rwNodes := replica.GetChannelRWNodes(shard) + if len(rwNodes) == 0 { + rwNodes = replica.GetNodes() + } + + // filter out stopping nodes. + availableNodes := lo.Filter(rwNodes, func(node int64, _ int) bool { + stop, err := c.nodeMgr.IsStoppingNode(node) + if err != nil { + return false + } + return !stop + }) + + if len(availableNodes) == 0 { + return nil + } + // L0 segment can only be assign to shard leader's node if isLevel0 { availableNodes = []int64{leader.ID} diff --git a/internal/querycoordv2/job/job_load.go b/internal/querycoordv2/job/job_load.go index 5dc89c3ad1..d28813f08e 100644 --- a/internal/querycoordv2/job/job_load.go +++ b/internal/querycoordv2/job/job_load.go @@ -152,9 +152,14 @@ func (job *LoadCollectionJob) Execute() error { // 2. create replica if not exist replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID()) if len(replicas) == 0 { + collectionInfo, err := job.broker.DescribeCollection(job.ctx, req.GetCollectionID()) + if err != nil { + return err + } + // API of LoadCollection is wired, we should use map[resourceGroupNames]replicaNumber as input, to keep consistency with `TransferReplica` API. // Then we can implement dynamic replica changed in different resource group independently. - replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber()) + replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) if err != nil { msg := "failed to spawn replica for collection" log.Warn(msg, zap.Error(err)) @@ -337,7 +342,11 @@ func (job *LoadPartitionJob) Execute() error { // 2. create replica if not exist replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID()) if len(replicas) == 0 { - replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber()) + collectionInfo, err := job.broker.DescribeCollection(job.ctx, req.GetCollectionID()) + if err != nil { + return err + } + replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) if err != nil { msg := "failed to spawn replica for collection" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/meta/constant.go b/internal/querycoordv2/meta/constant.go new file mode 100644 index 0000000000..b67d659926 --- /dev/null +++ b/internal/querycoordv2/meta/constant.go @@ -0,0 +1,25 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package meta + +const ( + RoundRobinBalancerName = "RoundRobinBalancer" + RowCountBasedBalancerName = "RowCountBasedBalancer" + ScoreBasedBalancerName = "ScoreBasedBalancer" + MultiTargetBalancerName = "MultipleTargetBalancer" + ChannelLevelScoreBalancerName = "ChannelLevelScoreBalancer" +) diff --git a/internal/querycoordv2/meta/replica.go b/internal/querycoordv2/meta/replica.go index 46622967ff..d58565b6a6 100644 --- a/internal/querycoordv2/meta/replica.go +++ b/internal/querycoordv2/meta/replica.go @@ -4,6 +4,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -38,7 +39,7 @@ func NewReplica(replica *querypb.Replica, nodes ...typeutil.UniqueSet) *Replica } // newReplica creates a new replica from pb. -func newReplica(replica *querypb.Replica) *Replica { +func newReplica(replica *querypb.Replica, channels ...string) *Replica { return &Replica{ replicaPB: proto.Clone(replica).(*querypb.Replica), rwNodes: typeutil.NewUniqueSet(replica.Nodes...), @@ -122,20 +123,38 @@ func (replica *Replica) AddRWNode(nodes ...int64) { replica.replicaPB.Nodes = replica.rwNodes.Collect() } +func (replica *Replica) GetChannelRWNodes(channelName string) []int64 { + channelNodeInfos := replica.replicaPB.GetChannelNodeInfos() + if channelNodeInfos[channelName] == nil || len(channelNodeInfos[channelName].GetRwNodes()) == 0 { + return nil + } + return replica.replicaPB.ChannelNodeInfos[channelName].GetRwNodes() +} + // copyForWrite returns a mutable replica for write operations. func (replica *Replica) copyForWrite() *mutableReplica { + exclusiveRWNodeToChannel := make(map[int64]string) + for name, channelNodeInfo := range replica.replicaPB.GetChannelNodeInfos() { + for _, nodeID := range channelNodeInfo.GetRwNodes() { + exclusiveRWNodeToChannel[nodeID] = name + } + } + return &mutableReplica{ - &Replica{ + Replica: &Replica{ replicaPB: proto.Clone(replica.replicaPB).(*querypb.Replica), rwNodes: typeutil.NewUniqueSet(replica.replicaPB.Nodes...), roNodes: typeutil.NewUniqueSet(replica.replicaPB.RoNodes...), }, + exclusiveRWNodeToChannel: exclusiveRWNodeToChannel, } } // mutableReplica is a mutable type (COW) for manipulating replica meta info for replica manager. type mutableReplica struct { *Replica + + exclusiveRWNodeToChannel map[int64]string } // SetResourceGroup sets the resource group name of the replica. @@ -146,6 +165,9 @@ func (replica *mutableReplica) SetResourceGroup(resourceGroup string) { // AddRWNode adds the node to rw nodes of the replica. func (replica *mutableReplica) AddRWNode(nodes ...int64) { replica.Replica.AddRWNode(nodes...) + + // try to update node's assignment between channels + replica.tryBalanceNodeForChannel() } // AddRONode moves the node from rw nodes to ro nodes of the replica. @@ -155,6 +177,12 @@ func (replica *mutableReplica) AddRONode(nodes ...int64) { replica.replicaPB.Nodes = replica.rwNodes.Collect() replica.roNodes.Insert(nodes...) replica.replicaPB.RoNodes = replica.roNodes.Collect() + + // remove node from channel's exclusive list + replica.removeChannelExclusiveNodes(nodes...) + + // try to update node's assignment between channels + replica.tryBalanceNodeForChannel() } // RemoveNode removes the node from rw nodes and ro nodes of the replica. @@ -164,6 +192,84 @@ func (replica *mutableReplica) RemoveNode(nodes ...int64) { replica.replicaPB.RoNodes = replica.roNodes.Collect() replica.rwNodes.Remove(nodes...) replica.replicaPB.Nodes = replica.rwNodes.Collect() + + // remove node from channel's exclusive list + replica.removeChannelExclusiveNodes(nodes...) + + // try to update node's assignment between channels + replica.tryBalanceNodeForChannel() +} + +func (replica *mutableReplica) removeChannelExclusiveNodes(nodes ...int64) { + channelNodeMap := make(map[string][]int64) + for _, nodeID := range nodes { + channelName, ok := replica.exclusiveRWNodeToChannel[nodeID] + if ok { + if channelNodeMap[channelName] == nil { + channelNodeMap[channelName] = make([]int64, 0) + } + channelNodeMap[channelName] = append(channelNodeMap[channelName], nodeID) + } + delete(replica.exclusiveRWNodeToChannel, nodeID) + } + + for channelName, nodeIDs := range channelNodeMap { + channelNodeInfo, ok := replica.replicaPB.ChannelNodeInfos[channelName] + if ok { + channelUsedNodes := typeutil.NewUniqueSet() + channelUsedNodes.Insert(channelNodeInfo.GetRwNodes()...) + channelUsedNodes.Remove(nodeIDs...) + replica.replicaPB.ChannelNodeInfos[channelName].RwNodes = channelUsedNodes.Collect() + } + } +} + +func (replica *mutableReplica) tryBalanceNodeForChannel() { + channelNodeInfos := replica.replicaPB.GetChannelNodeInfos() + if len(channelNodeInfos) == 0 { + return + } + + channelExclusiveFactor := paramtable.Get().QueryCoordCfg.ChannelExclusiveNodeFactor.GetAsInt() + // to do: if query node scale in happens, and the condition does not meet, should we exit channel's exclusive mode? + if len(replica.rwNodes) < len(channelNodeInfos)*channelExclusiveFactor { + for name := range replica.replicaPB.GetChannelNodeInfos() { + replica.replicaPB.ChannelNodeInfos[name] = &querypb.ChannelNodeInfo{} + } + return + } + + if channelNodeInfos != nil { + average := replica.RWNodesCount() / len(channelNodeInfos) + + // release node in channel + for channelName, channelNodeInfo := range channelNodeInfos { + currentNodes := channelNodeInfo.GetRwNodes() + if len(currentNodes) > average { + replica.replicaPB.ChannelNodeInfos[channelName].RwNodes = currentNodes[:average] + for _, nodeID := range currentNodes[average:] { + delete(replica.exclusiveRWNodeToChannel, nodeID) + } + } + } + + // acquire node in channel + for channelName, channelNodeInfo := range channelNodeInfos { + currentNodes := channelNodeInfo.GetRwNodes() + if len(currentNodes) < average { + for _, nodeID := range replica.rwNodes.Collect() { + if _, ok := replica.exclusiveRWNodeToChannel[nodeID]; !ok { + currentNodes = append(currentNodes, nodeID) + replica.exclusiveRWNodeToChannel[nodeID] = channelName + if len(currentNodes) == average { + break + } + } + } + replica.replicaPB.ChannelNodeInfos[channelName].RwNodes = currentNodes + } + } + } } // IntoReplica returns the immutable replica, After calling this method, the mutable replica should not be used again. diff --git a/internal/querycoordv2/meta/replica_manager.go b/internal/querycoordv2/meta/replica_manager.go index 39fdca067f..27640f9d41 100644 --- a/internal/querycoordv2/meta/replica_manager.go +++ b/internal/querycoordv2/meta/replica_manager.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -94,13 +95,16 @@ func (m *ReplicaManager) Get(id typeutil.UniqueID) *Replica { } // Spawn spawns N replicas at resource group for given collection in ReplicaManager. -func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int) ([]*Replica, error) { +func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int, channels []string) ([]*Replica, error) { m.rwmutex.Lock() defer m.rwmutex.Unlock() if m.collIDToReplicaIDs[collection] != nil { return nil, fmt.Errorf("replicas of collection %d is already spawned", collection) } + balancePolicy := paramtable.Get().QueryCoordCfg.Balancer.GetValue() + enableChannelExclusiveMode := balancePolicy == ChannelLevelScoreBalancerName + replicas := make([]*Replica, 0) for rgName, replicaNum := range replicaNumInRG { for ; replicaNum > 0; replicaNum-- { @@ -108,10 +112,18 @@ func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int) if err != nil { return nil, err } + + channelExclusiveNodeInfo := make(map[string]*querypb.ChannelNodeInfo) + if enableChannelExclusiveMode { + for _, channel := range channels { + channelExclusiveNodeInfo[channel] = &querypb.ChannelNodeInfo{} + } + } replicas = append(replicas, newReplica(&querypb.Replica{ - ID: id, - CollectionID: collection, - ResourceGroup: rgName, + ID: id, + CollectionID: collection, + ResourceGroup: rgName, + ChannelNodeInfos: channelExclusiveNodeInfo, })) } } @@ -267,7 +279,7 @@ func (m *ReplicaManager) GetByNode(nodeID typeutil.UniqueID) []*Replica { replicas := make([]*Replica, 0) for _, replica := range m.replicas { - if replica.rwNodes.Contain(nodeID) { + if replica.Contains(nodeID) { replicas = append(replicas, replica) } } diff --git a/internal/querycoordv2/meta/replica_manager_test.go b/internal/querycoordv2/meta/replica_manager_test.go index c7cf580f81..1520c89241 100644 --- a/internal/querycoordv2/meta/replica_manager_test.go +++ b/internal/querycoordv2/meta/replica_manager_test.go @@ -111,11 +111,26 @@ func (suite *ReplicaManagerSuite) TestSpawn() { mgr := suite.mgr mgr.idAllocator = ErrorIDAllocator() - _, err := mgr.Spawn(1, map[string]int{DefaultResourceGroupName: 1}) + _, err := mgr.Spawn(1, map[string]int{DefaultResourceGroupName: 1}, nil) suite.Error(err) replicas := mgr.GetByCollection(1) suite.Len(replicas, 0) + + mgr.idAllocator = suite.idAllocator + replicas, err = mgr.Spawn(1, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + suite.NoError(err) + for _, replica := range replicas { + suite.Len(replica.replicaPB.GetChannelNodeInfos(), 0) + } + + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.Balancer.Key, ChannelLevelScoreBalancerName) + defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.Balancer.Key) + replicas, err = mgr.Spawn(2, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + suite.NoError(err) + for _, replica := range replicas { + suite.Len(replica.replicaPB.GetChannelNodeInfos(), 2) + } } func (suite *ReplicaManagerSuite) TestGet() { @@ -262,7 +277,7 @@ func (suite *ReplicaManagerSuite) spawnAll() { mgr := suite.mgr for id, cfg := range suite.collections { - replicas, err := mgr.Spawn(id, cfg.spawnConfig) + replicas, err := mgr.Spawn(id, cfg.spawnConfig, nil) suite.NoError(err) totalSpawn := 0 rgsOfCollection := make(map[string]typeutil.UniqueSet) @@ -277,12 +292,12 @@ func (suite *ReplicaManagerSuite) spawnAll() { func (suite *ReplicaManagerSuite) TestResourceGroup() { mgr := NewReplicaManager(suite.idAllocator, suite.catalog) - replicas1, err := mgr.Spawn(int64(1000), map[string]int{DefaultResourceGroupName: 1}) + replicas1, err := mgr.Spawn(int64(1000), map[string]int{DefaultResourceGroupName: 1}, nil) suite.NoError(err) suite.NotNil(replicas1) suite.Len(replicas1, 1) - replica2, err := mgr.Spawn(int64(2000), map[string]int{DefaultResourceGroupName: 1}) + replica2, err := mgr.Spawn(int64(2000), map[string]int{DefaultResourceGroupName: 1}, nil) suite.NoError(err) suite.NotNil(replica2) suite.Len(replica2, 1) @@ -365,7 +380,7 @@ func (suite *ReplicaManagerV2Suite) TestSpawn() { mgr := suite.mgr for id, cfg := range suite.collections { - replicas, err := mgr.Spawn(id, cfg.spawnConfig) + replicas, err := mgr.Spawn(id, cfg.spawnConfig, nil) suite.NoError(err) rgsOfCollection := make(map[string]typeutil.UniqueSet) for rg := range cfg.spawnConfig { diff --git a/internal/querycoordv2/meta/replica_test.go b/internal/querycoordv2/meta/replica_test.go index 07ab0a94eb..1c10c34f71 100644 --- a/internal/querycoordv2/meta/replica_test.go +++ b/internal/querycoordv2/meta/replica_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type ReplicaSuite struct { @@ -15,6 +16,7 @@ type ReplicaSuite struct { } func (suite *ReplicaSuite) SetupSuite() { + paramtable.Init() suite.replicaPB = &querypb.Replica{ ID: 1, CollectionID: 2, @@ -177,6 +179,60 @@ func (suite *ReplicaSuite) testRead(r *Replica) { suite.False(r.ContainRWNode(4)) } +func (suite *ReplicaSuite) TestChannelExclusiveMode() { + r := newReplica(&querypb.Replica{ + ID: 1, + CollectionID: 2, + ResourceGroup: DefaultResourceGroupName, + ChannelNodeInfos: map[string]*querypb.ChannelNodeInfo{ + "channel1": {}, + "channel2": {}, + "channel3": {}, + "channel4": {}, + }, + }) + + mutableReplica := r.copyForWrite() + // add 10 rw nodes, exclusive mode is false. + for i := 0; i < 10; i++ { + mutableReplica.AddRWNode(int64(i)) + } + r = mutableReplica.IntoReplica() + for _, channelNodeInfo := range r.replicaPB.GetChannelNodeInfos() { + suite.Equal(0, len(channelNodeInfo.GetRwNodes())) + } + + mutableReplica = r.copyForWrite() + // add 10 rw nodes, exclusive mode is true. + for i := 10; i < 20; i++ { + mutableReplica.AddRWNode(int64(i)) + } + r = mutableReplica.IntoReplica() + for _, channelNodeInfo := range r.replicaPB.GetChannelNodeInfos() { + suite.Equal(5, len(channelNodeInfo.GetRwNodes())) + } + + // 4 node become read only, exclusive mode still be true + mutableReplica = r.copyForWrite() + for i := 0; i < 4; i++ { + mutableReplica.AddRONode(int64(i)) + } + r = mutableReplica.IntoReplica() + for _, channelNodeInfo := range r.replicaPB.GetChannelNodeInfos() { + suite.Equal(4, len(channelNodeInfo.GetRwNodes())) + } + + // 4 node has been removed, exclusive mode back to false + mutableReplica = r.copyForWrite() + for i := 4; i < 8; i++ { + mutableReplica.RemoveNode(int64(i)) + } + r = mutableReplica.IntoReplica() + for _, channelNodeInfo := range r.replicaPB.GetChannelNodeInfos() { + suite.Equal(0, len(channelNodeInfo.GetRwNodes())) + } +} + func TestReplica(t *testing.T) { suite.Run(t, new(ReplicaSuite)) } diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index b336ad6072..b23a573426 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -391,7 +391,7 @@ func (suite *CollectionObserverSuite) loadAll() { func (suite *CollectionObserverSuite) load(collection int64) { // Mock meta data - replicas, err := suite.meta.ReplicaManager.Spawn(collection, map[string]int{meta.DefaultResourceGroupName: int(suite.replicaNumber[collection])}) + replicas, err := suite.meta.ReplicaManager.Spawn(collection, map[string]int{meta.DefaultResourceGroupName: int(suite.replicaNumber[collection])}, nil) suite.NoError(err) for _, replica := range replicas { replica.AddRWNode(suite.nodes...) diff --git a/internal/querycoordv2/observers/replica_observer_test.go b/internal/querycoordv2/observers/replica_observer_test.go index c3ae2bbff7..9ddfb7a019 100644 --- a/internal/querycoordv2/observers/replica_observer_test.go +++ b/internal/querycoordv2/observers/replica_observer_test.go @@ -120,7 +120,7 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { replicas, err := suite.meta.Spawn(suite.collectionID, map[string]int{ "rg1": 1, "rg2": 1, - }) + }, nil) suite.NoError(err) suite.Equal(2, len(replicas)) diff --git a/internal/querycoordv2/observers/target_observer_test.go b/internal/querycoordv2/observers/target_observer_test.go index fa260c98b0..e3553c5c16 100644 --- a/internal/querycoordv2/observers/target_observer_test.go +++ b/internal/querycoordv2/observers/target_observer_test.go @@ -92,7 +92,7 @@ func (suite *TargetObserverSuite) SetupTest() { suite.NoError(err) err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID)) suite.NoError(err) - replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}) + replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil) suite.NoError(err) replicas[0].AddRWNode(2) err = suite.meta.ReplicaManager.Put(replicas...) @@ -276,7 +276,7 @@ func (suite *TargetObserverCheckSuite) SetupTest() { suite.NoError(err) err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID)) suite.NoError(err) - replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}) + replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil) suite.NoError(err) replicas[0].AddRWNode(2) err = suite.meta.ReplicaManager.Put(replicas...) diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 3fec78f109..7fc80bd7c7 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -115,8 +115,7 @@ type Server struct { resourceObserver *observers.ResourceObserver leaderCacheObserver *observers.LeaderCacheObserver - balancer balance.Balance - balancerMap map[string]balance.Balance + balancer balance.Balance // Active-standby enableActiveStandBy bool @@ -289,21 +288,21 @@ func (s *Server) initQueryCoord() error { ) // Init balancer map and balancer - log.Info("init all available balancer") - s.balancerMap = make(map[string]balance.Balance) - s.balancerMap[balance.RoundRobinBalancerName] = balance.NewRoundRobinBalancer(s.taskScheduler, s.nodeMgr) - s.balancerMap[balance.RowCountBasedBalancerName] = balance.NewRowCountBasedBalancer(s.taskScheduler, - s.nodeMgr, s.dist, s.meta, s.targetMgr) - s.balancerMap[balance.ScoreBasedBalancerName] = balance.NewScoreBasedBalancer(s.taskScheduler, - s.nodeMgr, s.dist, s.meta, s.targetMgr) - s.balancerMap[balance.MultiTargetBalancerName] = balance.NewMultiTargetBalancer(s.taskScheduler, s.nodeMgr, s.dist, s.meta, s.targetMgr) - - if balancer, ok := s.balancerMap[params.Params.QueryCoordCfg.Balancer.GetValue()]; ok { - s.balancer = balancer - log.Info("use config balancer", zap.String("balancer", params.Params.QueryCoordCfg.Balancer.GetValue())) - } else { - s.balancer = s.balancerMap[balance.RowCountBasedBalancerName] - log.Info("use rowCountBased auto balancer") + log.Info("init balancer") + switch params.Params.QueryCoordCfg.Balancer.GetValue() { + case meta.RoundRobinBalancerName: + s.balancer = balance.NewRoundRobinBalancer(s.taskScheduler, s.nodeMgr) + case meta.RowCountBasedBalancerName: + s.balancer = balance.NewRowCountBasedBalancer(s.taskScheduler, s.nodeMgr, s.dist, s.meta, s.targetMgr) + case meta.ScoreBasedBalancerName: + s.balancer = balance.NewScoreBasedBalancer(s.taskScheduler, s.nodeMgr, s.dist, s.meta, s.targetMgr) + case meta.MultiTargetBalancerName: + s.balancer = balance.NewMultiTargetBalancer(s.taskScheduler, s.nodeMgr, s.dist, s.meta, s.targetMgr) + case meta.ChannelLevelScoreBalancerName: + s.balancer = balance.NewChannelLevelScoreBalancer(s.taskScheduler, s.nodeMgr, s.dist, s.meta, s.targetMgr) + default: + log.Info(fmt.Sprintf("default to use %s", meta.ScoreBasedBalancerName)) + s.balancer = balance.NewScoreBasedBalancer(s.taskScheduler, s.nodeMgr, s.dist, s.meta, s.targetMgr) } // Init checker controller diff --git a/internal/querycoordv2/utils/meta.go b/internal/querycoordv2/utils/meta.go index 8be51b038b..4dff731286 100644 --- a/internal/querycoordv2/utils/meta.go +++ b/internal/querycoordv2/utils/meta.go @@ -169,14 +169,14 @@ func checkResourceGroup(m *meta.Meta, resourceGroups []string, replicaNumber int } // SpawnReplicasWithRG spawns replicas in rgs one by one for given collection. -func SpawnReplicasWithRG(m *meta.Meta, collection int64, resourceGroups []string, replicaNumber int32) ([]*meta.Replica, error) { +func SpawnReplicasWithRG(m *meta.Meta, collection int64, resourceGroups []string, replicaNumber int32, channels []string) ([]*meta.Replica, error) { replicaNumInRG, err := checkResourceGroup(m, resourceGroups, replicaNumber) if err != nil { return nil, err } // Spawn it in replica manager. - replicas, err := m.ReplicaManager.Spawn(collection, replicaNumInRG) + replicas, err := m.ReplicaManager.Spawn(collection, replicaNumInRG, channels) if err != nil { return nil, err } diff --git a/internal/querycoordv2/utils/meta_test.go b/internal/querycoordv2/utils/meta_test.go index 02e6a0a091..70a385cc61 100644 --- a/internal/querycoordv2/utils/meta_test.go +++ b/internal/querycoordv2/utils/meta_test.go @@ -118,7 +118,7 @@ func TestSpawnReplicasWithRG(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := SpawnReplicasWithRG(tt.args.m, tt.args.collection, tt.args.resourceGroups, tt.args.replicaNumber) + got, err := SpawnReplicasWithRG(tt.args.m, tt.args.collection, tt.args.resourceGroups, tt.args.replicaNumber, nil) if (err != nil) != tt.wantErr { t.Errorf("SpawnReplicasWithRG() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index ca9186b135..3a743694cc 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1491,6 +1491,7 @@ type queryCoordConfig struct { CheckNodeSessionInterval ParamItem `refreshable:"false"` GracefulStopTimeout ParamItem `refreshable:"true"` EnableStoppingBalance ParamItem `refreshable:"true"` + ChannelExclusiveNodeFactor ParamItem `refreshable:"true"` } func (p *queryCoordConfig) init(base *BaseTable) { @@ -1967,6 +1968,15 @@ func (p *queryCoordConfig) init(base *BaseTable) { Export: true, } p.EnableStoppingBalance.Init(base.mgr) + + p.ChannelExclusiveNodeFactor = ParamItem{ + Key: "queryCoord.channelExclusiveNodeFactor", + Version: "2.4.2", + DefaultValue: "4", + Doc: "the least node number for enable channel's exclusive mode", + Export: true, + } + p.ChannelExclusiveNodeFactor.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 788ed6a279..3bfe21d06d 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -305,6 +305,8 @@ func TestComponentParam(t *testing.T) { params.Save("queryCoord.gracefulStopTimeout", "100") assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second)) assert.Equal(t, true, Params.EnableStoppingBalance.GetAsBool()) + + assert.Equal(t, 4, Params.ChannelExclusiveNodeFactor.GetAsInt()) }) t.Run("test queryNodeConfig", func(t *testing.T) { diff --git a/tests/integration/balance/channel_exclusive_balance_test.go b/tests/integration/balance/channel_exclusive_balance_test.go new file mode 100644 index 0000000000..08799745d4 --- /dev/null +++ b/tests/integration/balance/channel_exclusive_balance_test.go @@ -0,0 +1,263 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balance + +import ( + "context" + "fmt" + "strconv" + "strings" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/samber/lo" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/tests/integration" +) + +type ChannelExclusiveBalanceSuit struct { + integration.MiniClusterSuite +} + +func (s *ChannelExclusiveBalanceSuit) SetupSuite() { + paramtable.Init() + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "1000") + paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1") + + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") + + // disable compaction + paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableCompaction.Key, "false") + + s.Require().NoError(s.SetupEmbedEtcd()) +} + +func (s *ChannelExclusiveBalanceSuit) TearDownSuite() { + defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.EnableCompaction.Key) + s.MiniClusterSuite.TearDownSuite() +} + +func (s *ChannelExclusiveBalanceSuit) initCollection(collectionName string, replica int, channelNum int, segmentNum int, segmentRowNum int, segmentDeleteNum int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + dim = 128 + dbName = "" + ) + + schema := integration.ConstructSchema(collectionName, dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := s.Cluster.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: int32(channelNum), + }) + s.NoError(err) + s.True(merr.Ok(createCollectionStatus)) + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := s.Cluster.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.Status)) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + for i := 0; i < segmentNum; i++ { + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, segmentRowNum, dim) + hashKeys := integration.GenerateHashKeys(segmentRowNum) + insertResult, err := s.Cluster.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(segmentRowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.Status)) + + if segmentDeleteNum > 0 { + if segmentDeleteNum > segmentRowNum { + segmentDeleteNum = segmentRowNum + } + + pks := insertResult.GetIDs().GetIntId().GetData() + log.Info("========================delete expr==================", + zap.Int("length of pk", len(pks)), + ) + + expr := fmt.Sprintf("%s in [%s]", integration.Int64Field, strings.Join(lo.Map(pks, func(pk int64, _ int) string { return strconv.FormatInt(pk, 10) }), ",")) + + deleteResp, err := s.Cluster.Proxy.Delete(ctx, &milvuspb.DeleteRequest{ + CollectionName: collectionName, + Expr: expr, + }) + s.Require().NoError(err) + s.Require().True(merr.Ok(deleteResp.GetStatus())) + s.Require().EqualValues(len(pks), deleteResp.GetDeleteCnt()) + } + + // flush + flushResp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) + } + + // create index + createIndexStatus, err := s.Cluster.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2), + }) + s.NoError(err) + s.True(merr.Ok(createIndexStatus)) + s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) + + for i := 1; i < replica; i++ { + s.Cluster.AddQueryNode() + } + + // load + loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + ReplicaNumber: int32(replica), + }) + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode()) + s.True(merr.Ok(loadStatus)) + s.WaitForLoad(ctx, collectionName) + log.Info("initCollection Done") +} + +func (s *ChannelExclusiveBalanceSuit) TestBalanceOnSingleReplica() { + name := "test_balance_" + funcutil.GenRandomStr() + channelCount := 5 + channelNodeCount := 3 + + s.initCollection(name, 1, channelCount, 5, 2000, 0) + + ctx := context.Background() + qnList := make([]*grpcquerynode.Server, 0) + // add a querynode, expected balance happens + for i := 1; i < channelCount*channelNodeCount; i++ { + qn := s.Cluster.AddQueryNode() + qnList = append(qnList, qn) + } + + // expected each channel own 3 exclusive node + s.Eventually(func() bool { + channelNodeCounter := make(map[string]int) + for _, node := range s.Cluster.GetAllQueryNodes() { + resp1, err := node.GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{}) + s.NoError(err) + s.True(merr.Ok(resp1.GetStatus())) + + log.Info("resp", zap.Any("segments", resp1.Segments)) + if channel, ok := s.isSameChannel(resp1.GetSegments()); ok { + channelNodeCounter[channel] += 1 + } + } + + log.Info("dist", zap.Any("nodes", channelNodeCounter)) + nodeCountMatch := true + for _, cnt := range channelNodeCounter { + if cnt != channelNodeCount { + nodeCountMatch = false + break + } + } + + return nodeCountMatch + }, 60*time.Second, 3*time.Second) + + // add two new query node and stop two old querynode + s.Cluster.AddQueryNode() + s.Cluster.AddQueryNode() + qnList[0].Stop() + qnList[1].Stop() + + // expected each channel own 3 exclusive node + s.Eventually(func() bool { + channelNodeCounter := make(map[string]int) + for _, node := range s.Cluster.GetAllQueryNodes() { + resp1, err := node.GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{}) + if err != nil && merr.Ok(resp1.GetStatus()) { + log.Info("resp", zap.Any("segments", resp1.Segments)) + if channel, ok := s.isSameChannel(resp1.GetSegments()); ok { + channelNodeCounter[channel] += 1 + } + } + } + + log.Info("dist", zap.Any("nodes", channelNodeCounter)) + nodeCountMatch := true + for _, cnt := range channelNodeCounter { + if cnt != channelNodeCount { + nodeCountMatch = false + break + } + } + + return nodeCountMatch + }, 60*time.Second, 3*time.Second) +} + +func (s *ChannelExclusiveBalanceSuit) isSameChannel(segments []*querypb.SegmentVersionInfo) (string, bool) { + if len(segments) == 0 { + return "", false + } + + channelName := segments[0].Channel + + _, find := lo.Find(segments, func(segment *querypb.SegmentVersionInfo) bool { + return segment.Channel != channelName + }) + + return channelName, !find +} + +func TestChannelExclusiveBalance(t *testing.T) { + suite.Run(t, new(ChannelExclusiveBalanceSuit)) +}