diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index 49d2b3e883..239e945c2a 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -17,6 +17,7 @@ package balance import ( + "fmt" "sort" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -60,6 +61,11 @@ type SegmentAssignPlan struct { Weight Weight } +func (segPlan SegmentAssignPlan) ToString() string { + return fmt.Sprintf("SegmentPlan:[collectionID: %d, replicaID: %d, segmentID: %d, from: %d, to: %d, weight: %d]\n", + segPlan.Segment.CollectionID, segPlan.ReplicaID, segPlan.Segment.ID, segPlan.From, segPlan.To, segPlan.Weight) +} + type ChannelAssignPlan struct { Channel *meta.DmChannel ReplicaID int64 @@ -68,8 +74,19 @@ type ChannelAssignPlan struct { Weight Weight } +func (chanPlan ChannelAssignPlan) ToString() string { + return fmt.Sprintf("ChannelPlan:[collectionID: %d, channel: %s, replicaID: %d, from: %d, to: %d, weight: %d]\n", + chanPlan.Channel.CollectionID, chanPlan.Channel.ChannelName, chanPlan.ReplicaID, chanPlan.From, chanPlan.To, chanPlan.Weight) +} + +var ( + RoundRobinBalancerName = "RoundRobinBalancer" + RowCountBasedBalancerName = "RowCountBasedBalancer" + ScoreBasedBalancerName = "ScoreBasedBalancer" +) + type Balance interface { - AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan + AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan Balance() ([]SegmentAssignPlan, []ChannelAssignPlan) } @@ -79,7 +96,7 @@ type RoundRobinBalancer struct { nodeManager *session.NodeManager } -func (b *RoundRobinBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { +func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { nodesInfo := b.getNodes(nodes) if len(nodesInfo) == 0 { return nil diff --git a/internal/querycoordv2/balance/balance_test.go b/internal/querycoordv2/balance/balance_test.go index 4ef24b2d82..dc7b86ffba 100644 --- a/internal/querycoordv2/balance/balance_test.go +++ b/internal/querycoordv2/balance/balance_test.go @@ -92,7 +92,7 @@ func (suite *BalanceTestSuite) TestAssignBalance() { suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i]) } } - plans := suite.roundRobinBalancer.AssignSegment(c.assignments, c.nodeIDs) + plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs) suite.ElementsMatch(c.expectPlans, plans) }) } diff --git a/internal/querycoordv2/balance/mock_balancer.go b/internal/querycoordv2/balance/mock_balancer.go index 552345f0fe..c0434ec24c 100644 --- a/internal/querycoordv2/balance/mock_balancer.go +++ b/internal/querycoordv2/balance/mock_balancer.go @@ -60,13 +60,13 @@ func (_c *MockBalancer_AssignChannel_Call) Return(_a0 []ChannelAssignPlan) *Mock return _c } -// AssignSegment provides a mock function with given fields: segments, nodes -func (_m *MockBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { - ret := _m.Called(segments, nodes) +// AssignSegment provides a mock function with given fields: collectionID, segments, nodes +func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { + ret := _m.Called(collectionID, segments, nodes) var r0 []SegmentAssignPlan - if rf, ok := ret.Get(0).(func([]*meta.Segment, []int64) []SegmentAssignPlan); ok { - r0 = rf(segments, nodes) + if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64) []SegmentAssignPlan); ok { + r0 = rf(collectionID, segments, nodes) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]SegmentAssignPlan) @@ -82,15 +82,16 @@ type MockBalancer_AssignSegment_Call struct { } // AssignSegment is a helper method to define mock.On call -// - segments []*meta.Segment -// - nodes []int64 -func (_e *MockBalancer_Expecter) AssignSegment(segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call { - return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", segments, nodes)} +// - collectionID int64 +// - segments []*meta.Segment +// - nodes []int64 +func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call { + return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes)} } -func (_c *MockBalancer_AssignSegment_Call) Run(run func(segments []*meta.Segment, nodes []int64)) *MockBalancer_AssignSegment_Call { +func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64)) *MockBalancer_AssignSegment_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*meta.Segment), args[1].([]int64)) + run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64)) }) return _c } diff --git a/internal/querycoordv2/balance/priority_queue_test.go b/internal/querycoordv2/balance/priority_queue_test.go new file mode 100644 index 0000000000..eb1fc53d3d --- /dev/null +++ b/internal/querycoordv2/balance/priority_queue_test.go @@ -0,0 +1,96 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balance + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMinPriorityQueue(t *testing.T) { + pq := newPriorityQueue() + + for i := 0; i < 5; i++ { + priority := i % 3 + nodeItem := newNodeItem(priority, int64(i)) + pq.push(&nodeItem) + } + + item := pq.pop() + assert.Equal(t, item.getPriority(), 0) + assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) + item = pq.pop() + assert.Equal(t, item.getPriority(), 0) + assert.Equal(t, item.(*nodeItem).nodeID, int64(3)) + item = pq.pop() + assert.Equal(t, item.getPriority(), 1) + assert.Equal(t, item.(*nodeItem).nodeID, int64(1)) + item = pq.pop() + assert.Equal(t, item.getPriority(), 1) + assert.Equal(t, item.(*nodeItem).nodeID, int64(4)) + item = pq.pop() + assert.Equal(t, item.getPriority(), 2) + println(item.getPriority()) + assert.Equal(t, item.(*nodeItem).nodeID, int64(2)) +} + +func TestPopPriorityQueue(t *testing.T) { + pq := newPriorityQueue() + + for i := 0; i < 1; i++ { + priority := 1 + nodeItem := newNodeItem(priority, int64(i)) + pq.push(&nodeItem) + } + + item := pq.pop() + assert.Equal(t, item.getPriority(), 1) + assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) + pq.push(item) + + // if it's round robin, but not working + item = pq.pop() + assert.Equal(t, item.getPriority(), 1) + assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) +} + +func TestMaxPriorityQueue(t *testing.T) { + pq := newPriorityQueue() + + for i := 0; i < 5; i++ { + priority := i % 3 + nodeItem := newNodeItem(-priority, int64(i)) + pq.push(&nodeItem) + } + + item := pq.pop() + assert.Equal(t, item.getPriority(), -2) + assert.Equal(t, item.(*nodeItem).nodeID, int64(2)) + item = pq.pop() + assert.Equal(t, item.getPriority(), -1) + assert.Equal(t, item.(*nodeItem).nodeID, int64(4)) + item = pq.pop() + assert.Equal(t, item.getPriority(), -1) + assert.Equal(t, item.(*nodeItem).nodeID, int64(1)) + item = pq.pop() + assert.Equal(t, item.getPriority(), 0) + assert.Equal(t, item.(*nodeItem).nodeID, int64(3)) + item = pq.pop() + assert.Equal(t, item.getPriority(), 0) + assert.Equal(t, item.(*nodeItem).nodeID, int64(0)) +} diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 6a665f3457..b0fe8cbc2c 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -37,7 +37,7 @@ type RowCountBasedBalancer struct { targetMgr *meta.TargetManager } -func (b *RowCountBasedBalancer) AssignSegment(segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { +func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { nodeItems := b.convertToNodeItems(nodes) if len(nodeItems) == 0 { return nil diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index 2c840f65a4..29a06d0a04 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -126,7 +126,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) } - plans := balancer.AssignSegment(c.assignments, c.nodes) + plans := balancer.AssignSegment(0, c.assignments, c.nodes) suite.ElementsMatch(c.expectPlans, plans) }) } diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go new file mode 100644 index 0000000000..3f8121dd68 --- /dev/null +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -0,0 +1,380 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balance + +import ( + "sort" + + "github.com/samber/lo" + "go.uber.org/zap" + "golang.org/x/exp/maps" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/internal/querycoordv2/params" + "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/internal/util/typeutil" +) + +type ScoreBasedBalancer struct { + *RowCountBasedBalancer + balancedCollectionsCurrentRound typeutil.UniqueSet +} + +func NewScoreBasedBalancer(scheduler task.Scheduler, + nodeManager *session.NodeManager, + dist *meta.DistributionManager, + meta *meta.Meta, + targetMgr *meta.TargetManager) *ScoreBasedBalancer { + return &ScoreBasedBalancer{ + RowCountBasedBalancer: NewRowCountBasedBalancer(scheduler, nodeManager, dist, meta, targetMgr), + balancedCollectionsCurrentRound: typeutil.NewUniqueSet(), + } +} + +// TODO assign channel need to think of global channels +func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { + nodeItems := b.convertToNodeItems(collectionID, nodes) + if len(nodeItems) == 0 { + return nil + } + queue := newPriorityQueue() + for _, item := range nodeItems { + queue.push(item) + } + + sort.Slice(segments, func(i, j int) bool { + return segments[i].GetNumOfRows() > segments[j].GetNumOfRows() + }) + + plans := make([]SegmentAssignPlan, 0, len(segments)) + for _, s := range segments { + // pick the node with the least row count and allocate to it. + ni := queue.pop().(*nodeItem) + plan := SegmentAssignPlan{ + From: -1, + To: ni.nodeID, + Weight: GetWeight(1), + Segment: s, + } + plans = append(plans, plan) + // change node's priority and push back, should count for both collection factor and local factor + p := ni.getPriority() + ni.setPriority(p + int(s.GetNumOfRows()) + + int(float64(s.GetNumOfRows())*params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat())) + queue.push(ni) + } + return plans +} + +func (b *ScoreBasedBalancer) convertToNodeItems(collectionID int64, nodeIDs []int64) []*nodeItem { + ret := make([]*nodeItem, 0, len(nodeIDs)) + for _, nodeInfo := range b.getNodes(nodeIDs) { + node := nodeInfo.ID() + priority := b.calculatePriority(collectionID, node) + nodeItem := newNodeItem(priority, node) + ret = append(ret, &nodeItem) + } + return ret +} + +func (b *ScoreBasedBalancer) calculatePriority(collectionID, nodeID int64) int { + globalSegments := b.dist.SegmentDistManager.GetByNode(nodeID) + rowCount := 0 + for _, s := range globalSegments { + rowCount += int(s.GetNumOfRows()) + } + + collectionSegments := b.dist.SegmentDistManager.GetByCollectionAndNode(collectionID, nodeID) + collectionRowCount := 0 + for _, s := range collectionSegments { + collectionRowCount += int(s.GetNumOfRows()) + } + return collectionRowCount + int(float64(rowCount)* + params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat()) +} + +func (b *ScoreBasedBalancer) Balance() ([]SegmentAssignPlan, []ChannelAssignPlan) { + ids := b.meta.CollectionManager.GetAll() + + // loading collection should skip balance + loadedCollections := lo.Filter(ids, func(cid int64, _ int) bool { + return b.meta.GetCollection(cid).Status == querypb.LoadStatus_Loaded + }) + + sort.Slice(loadedCollections, func(i, j int) bool { + return loadedCollections[i] < loadedCollections[j] + }) + + segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) + hasUnBalancedCollections := false + for _, cid := range loadedCollections { + if b.balancedCollectionsCurrentRound.Contain(cid) { + log.Debug("ScoreBasedBalancer has balanced collection, skip balancing in this round", + zap.Int64("collectionID", cid)) + continue + } + hasUnBalancedCollections = true + replicas := b.meta.ReplicaManager.GetByCollection(cid) + for _, replica := range replicas { + sPlans, cPlans := b.balanceReplica(replica) + PrintNewBalancePlans(cid, replica.GetID(), sPlans, cPlans) + segmentPlans = append(segmentPlans, sPlans...) + channelPlans = append(channelPlans, cPlans...) + } + b.balancedCollectionsCurrentRound.Insert(cid) + if len(segmentPlans) != 0 || len(channelPlans) != 0 { + log.Debug("ScoreBasedBalancer has generated balance plans for", zap.Int64("collectionID", cid)) + break + } + } + if !hasUnBalancedCollections { + b.balancedCollectionsCurrentRound.Clear() + log.Debug("ScoreBasedBalancer has balanced all " + + "collections in one round, clear collectionIDs for this round") + } + + return segmentPlans, channelPlans +} + +func (b *ScoreBasedBalancer) balanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) { + nodes := replica.GetNodes() + if len(nodes) == 0 { + return nil, nil + } + nodesSegments := make(map[int64][]*meta.Segment) + stoppingNodesSegments := make(map[int64][]*meta.Segment) + + outboundNodes := b.meta.ResourceManager.CheckOutboundNodes(replica) + + // calculate stopping nodes and available nodes. + for _, nid := range nodes { + segments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nid) + // Only balance segments in targets + segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool { + return b.targetMgr.GetHistoricalSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil + }) + + if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { + log.Info("not existed node", zap.Int64("nid", nid), zap.Any("segments", segments), zap.Error(err)) + continue + } else if isStopping { + stoppingNodesSegments[nid] = segments + } else if outboundNodes.Contain(nid) { + // if node is stop or transfer to other rg + log.RatedInfo(10, "meet outbound node, try to move out all segment/channel", + zap.Int64("collectionID", replica.GetCollectionID()), + zap.Int64("replicaID", replica.GetCollectionID()), + zap.Int64("node", nid), + ) + stoppingNodesSegments[nid] = segments + } else { + nodesSegments[nid] = segments + } + } + + if len(nodes) == len(stoppingNodesSegments) { + // no available nodes to balance + log.Warn("All nodes is under stopping mode or outbound, skip balance replica", + zap.Int64("collection", replica.CollectionID), + zap.Int64("replica id", replica.Replica.GetID()), + zap.String("replica group", replica.Replica.GetResourceGroup()), + zap.Int64s("nodes", replica.Replica.GetNodes()), + ) + return nil, nil + } + + if len(nodesSegments) <= 0 { + log.Warn("No nodes is available in resource group, skip balance replica", + zap.Int64("collection", replica.CollectionID), + zap.Int64("replica id", replica.Replica.GetID()), + zap.String("replica group", replica.Replica.GetResourceGroup()), + zap.Int64s("nodes", replica.Replica.GetNodes()), + ) + return nil, nil + } + //print current distribution before generating plans + PrintCurrentReplicaDist(replica, stoppingNodesSegments, nodesSegments, b.dist.ChannelDistManager) + if len(stoppingNodesSegments) != 0 { + log.Info("Handle stopping nodes", + zap.Int64("collection", replica.CollectionID), + zap.Int64("replica id", replica.Replica.GetID()), + zap.String("replica group", replica.Replica.GetResourceGroup()), + zap.Any("stopping nodes", maps.Keys(stoppingNodesSegments)), + zap.Any("available nodes", maps.Keys(nodesSegments)), + ) + // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score + return b.getStoppedSegmentPlan(replica, nodesSegments, stoppingNodesSegments), b.getStoppedChannelPlan(replica, lo.Keys(nodesSegments), lo.Keys(stoppingNodesSegments)) + } + + // normal balance, find segments from largest score nodes and transfer to smallest score nodes. + return b.getNormalSegmentPlan(replica, nodesSegments), b.getNormalChannelPlan(replica, lo.Keys(nodesSegments)) +} + +func (b *ScoreBasedBalancer) getStoppedSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment, stoppingNodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan { + segmentPlans := make([]SegmentAssignPlan, 0) + // generate candidates + nodeItems := b.convertToNodeItems(replica.GetCollectionID(), lo.Keys(nodesSegments)) + queue := newPriorityQueue() + for _, item := range nodeItems { + queue.push(item) + } + + // collect segment segments to assign + var segments []*meta.Segment + nodeIndex := make(map[int64]int64) + for nodeID, stoppingSegments := range stoppingNodesSegments { + for _, segment := range stoppingSegments { + segments = append(segments, segment) + nodeIndex[segment.GetID()] = nodeID + } + } + + sort.Slice(segments, func(i, j int) bool { + return segments[i].GetNumOfRows() > segments[j].GetNumOfRows() + }) + + for _, s := range segments { + // pick the node with the least row count and allocate to it. + ni := queue.pop().(*nodeItem) + plan := SegmentAssignPlan{ + ReplicaID: replica.GetID(), + From: nodeIndex[s.GetID()], + To: ni.nodeID, + Weight: GetWeight(1), + Segment: s, + } + segmentPlans = append(segmentPlans, plan) + // change node's priority and push back, should count for both collection factor and local factor + p := ni.getPriority() + ni.setPriority(p + int(s.GetNumOfRows()) + int(float64(s.GetNumOfRows())* + params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat())) + queue.push(ni) + } + + return segmentPlans +} + +func (b *ScoreBasedBalancer) getStoppedChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan { + channelPlans := make([]ChannelAssignPlan, 0) + for _, nodeID := range offlineNodes { + dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID) + plans := b.AssignChannel(dmChannels, onlineNodes) + for i := range plans { + plans[i].From = nodeID + plans[i].ReplicaID = replica.ID + plans[i].Weight = GetWeight(1) + } + channelPlans = append(channelPlans, plans...) + } + return channelPlans +} + +func (b *ScoreBasedBalancer) getNormalSegmentPlan(replica *meta.Replica, nodesSegments map[int64][]*meta.Segment) []SegmentAssignPlan { + if b.scheduler.GetSegmentTaskNum() != 0 { + // scheduler is handling segment task, skip + return nil + } + segmentPlans := make([]SegmentAssignPlan, 0) + + // generate candidates + nodeItems := b.convertToNodeItems(replica.GetCollectionID(), lo.Keys(nodesSegments)) + lastIdx := len(nodeItems) - 1 + havingMovedSegments := typeutil.NewUniqueSet() + + for { + sort.Slice(nodeItems, func(i, j int) bool { + return nodeItems[i].priority <= nodeItems[j].priority + }) + toNode := nodeItems[0] + fromNode := nodeItems[lastIdx] + + // sort the segments in asc order, try to mitigate to-from-unbalance + // TODO: segment infos inside dist manager may change in the process of making balance plan + fromSegments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.CollectionID, fromNode.nodeID) + sort.Slice(fromSegments, func(i, j int) bool { + return fromSegments[i].GetNumOfRows() < fromSegments[j].GetNumOfRows() + }) + var targetSegmentToMove *meta.Segment + for _, segment := range fromSegments { + targetSegmentToMove = segment + if havingMovedSegments.Contain(targetSegmentToMove.GetID()) { + targetSegmentToMove = nil + continue + } + break + } + if targetSegmentToMove == nil { + //the node with the highest score doesn't have any segments suitable for balancing, stop balancing this round + break + } + + fromPriority := fromNode.priority + toPriority := toNode.priority + unbalance := fromPriority - toPriority + nextFromPriority := fromPriority - int(targetSegmentToMove.GetNumOfRows()) - int(float64(targetSegmentToMove.GetNumOfRows())* + params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat()) + nextToPriority := toPriority + int(targetSegmentToMove.GetNumOfRows()) + int(float64(targetSegmentToMove.GetNumOfRows())* + params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat()) + + //still unbalanced after this balance plan is executed + if nextToPriority <= nextFromPriority { + plan := SegmentAssignPlan{ + ReplicaID: replica.GetID(), + From: fromNode.nodeID, + To: toNode.nodeID, + Segment: targetSegmentToMove, + Weight: GetWeight(0), + } + segmentPlans = append(segmentPlans, plan) + } else { + //if unbalance reverted after balance action, we will consider the benefit + //only trigger following balance when the generated reverted balance + //is far smaller than the original unbalance + nextUnbalance := nextToPriority - nextFromPriority + if int(float64(nextUnbalance)*params.Params.QueryCoordCfg.ScoreUnbalanceTolerationFactor.GetAsFloat()) < unbalance { + plan := SegmentAssignPlan{ + ReplicaID: replica.GetID(), + From: fromNode.nodeID, + To: toNode.nodeID, + Segment: targetSegmentToMove, + Weight: GetWeight(0), + } + segmentPlans = append(segmentPlans, plan) + } else { + //if the tiniest segment movement between the highest scored node and lowest scored node will + //not provide sufficient balance benefit, we will seize balancing in this round + break + } + } + havingMovedSegments.Insert(targetSegmentToMove.GetID()) + + //update node priority + toNode.setPriority(nextToPriority) + fromNode.setPriority(nextFromPriority) + // if toNode and fromNode can not find segment to balance, break, else try to balance the next round + // TODO swap segment between toNode and fromNode, see if the cluster becomes more balance + } + return segmentPlans +} + +func (b *ScoreBasedBalancer) getNormalChannelPlan(replica *meta.Replica, onlineNodes []int64) []ChannelAssignPlan { + // TODO + return make([]ChannelAssignPlan, 0) +} diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go new file mode 100644 index 0000000000..ae083b298f --- /dev/null +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -0,0 +1,604 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package balance + +import ( + "testing" + + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + . "github.com/milvus-io/milvus/internal/querycoordv2/params" + "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/etcd" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +type ScoreBasedBalancerTestSuite struct { + suite.Suite + balancer *ScoreBasedBalancer + kv *etcdkv.EtcdKV + broker *meta.MockBroker + mockScheduler *task.MockScheduler +} + +func (suite *ScoreBasedBalancerTestSuite) SetupSuite() { + Params.Init() +} + +func (suite *ScoreBasedBalancerTestSuite) SetupTest() { + var err error + config := GenerateEtcdConfig() + cli, err := etcd.GetEtcdClient( + config.UseEmbedEtcd.GetAsBool(), + config.EtcdUseSSL.GetAsBool(), + config.Endpoints.GetAsStrings(), + config.EtcdTLSCert.GetValue(), + config.EtcdTLSKey.GetValue(), + config.EtcdTLSCACert.GetValue(), + config.EtcdTLSMinVersion.GetValue()) + suite.Require().NoError(err) + suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.broker = meta.NewMockBroker(suite.T()) + + store := meta.NewMetaStore(suite.kv) + idAllocator := RandomIncrementIDAllocator() + nodeManager := session.NewNodeManager() + testMeta := meta.NewMeta(idAllocator, store, nodeManager) + testTarget := meta.NewTargetManager(suite.broker, testMeta) + + distManager := meta.NewDistributionManager() + suite.mockScheduler = task.NewMockScheduler(suite.T()) + suite.balancer = NewScoreBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget) +} + +func (suite *ScoreBasedBalancerTestSuite) TearDownTest() { + suite.kv.Close() +} + +func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { + cases := []struct { + name string + comment string + distributions map[int64][]*meta.Segment + assignments [][]*meta.Segment + nodes []int64 + collectionIDs []int64 + segmentCnts []int + states []session.State + expectPlans [][]SegmentAssignPlan + }{ + { + name: "test empty cluster assigning one collection", + comment: "this is most simple case in which global row count is zero for all nodes", + distributions: map[int64][]*meta.Segment{}, + assignments: [][]*meta.Segment{ + { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 5, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 10, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 15, CollectionID: 1}}, + }, + }, + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{0}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + segmentCnts: []int{0, 0, 0}, + expectPlans: [][]SegmentAssignPlan{ + { + //as assign segments is used while loading collection, + //all assignPlan should have weight equal to 1(HIGH PRIORITY) + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 15, + CollectionID: 1}}, From: -1, To: 1, Weight: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 10, + CollectionID: 1}}, From: -1, To: 3, Weight: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 5, + CollectionID: 1}}, From: -1, To: 2, Weight: 1}, + }, + }, + }, + { + name: "test non-empty cluster assigning one collection", + comment: "this case will verify the effect of global row for loading segments process, although node1" + + "has only 10 rows at the beginning, but it has so many rows on global view, resulting in a lower priority", + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 10, CollectionID: 1}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 300, CollectionID: 2}, Node: 1}, + //base: collection1-node1-priority is 10 + 0.1 * 310 = 41 + //assign3: collection1-node1-priority is 15 + 0.1 * 315 = 46.5 + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 20, CollectionID: 1}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 180, CollectionID: 2}, Node: 2}, + //base: collection1-node2-priority is 20 + 0.1 * 200 = 40 + //assign2: collection1-node2-priority is 30 + 0.1 * 210 = 51 + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 30, CollectionID: 1}, Node: 3}, + {SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 20, CollectionID: 2}, Node: 3}, + //base: collection1-node2-priority is 30 + 0.1 * 50 = 35 + //assign1: collection1-node2-priority is 45 + 0.1 * 65 = 51.5 + }, + }, + assignments: [][]*meta.Segment{ + { + {SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 5, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 8, NumOfRows: 10, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 9, NumOfRows: 15, CollectionID: 1}}, + }, + }, + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{1}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + segmentCnts: []int{0, 0, 0}, + expectPlans: [][]SegmentAssignPlan{ + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 9, NumOfRows: 15, CollectionID: 1}}, From: -1, To: 3, Weight: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 8, NumOfRows: 10, CollectionID: 1}}, From: -1, To: 2, Weight: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 5, CollectionID: 1}}, From: -1, To: 1, Weight: 1}, + }, + }, + }, + { + name: "test non-empty cluster assigning two collections at one round segment checking", + comment: "this case is used to demonstrate the existing assign mechanism having flaws when assigning " + + "multi collections at one round by using the only segment distribution", + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 10, CollectionID: 1}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 20, CollectionID: 1}, Node: 2}, + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 40, CollectionID: 1}, Node: 3}, + }, + }, + assignments: [][]*meta.Segment{ + { + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 60, CollectionID: 1}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 50, CollectionID: 1}}, + }, + { + {SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 15, CollectionID: 2}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 10, CollectionID: 2}}, + }, + }, + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{1, 2}, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + segmentCnts: []int{0, 0, 0}, + expectPlans: [][]SegmentAssignPlan{ + //note that these two segments plans are absolutely unbalanced globally, + //as if the assignment for collection1 could succeed, node1 and node2 will both have 70 rows + //much more than node3, but following assignment will still assign segment based on [10,20,40] + //rather than [70,70,40], this flaw will be mitigated by balance process and maybe fixed in the later versions + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 60, CollectionID: 1}}, From: -1, To: 1, Weight: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 50, CollectionID: 1}}, From: -1, To: 2, Weight: 1}, + }, + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 15, CollectionID: 2}}, From: -1, To: 1, Weight: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 7, NumOfRows: 10, CollectionID: 2}}, From: -1, To: 2, Weight: 1}, + }, + }, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } + for i := range c.collectionIDs { + plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes) + suite.ElementsMatch(c.expectPlans[i], plans) + } + }) + } +} + +func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { + cases := []struct { + name string + nodes []int64 + notExistedNodes []int64 + collectionIDs []int64 + replicaIDs []int64 + collectionsSegments [][]*datapb.SegmentBinlogs + states []session.State + shouldMock bool + distributions map[int64][]*meta.Segment + distributionChannels map[int64][]*meta.DmChannel + expectPlans []SegmentAssignPlan + expectChannelPlans []ChannelAssignPlan + }{ + { + name: "normal balance for one collection only", + nodes: []int64{1, 2}, + collectionIDs: []int64{1}, + replicaIDs: []int64{1}, + collectionsSegments: [][]*datapb.SegmentBinlogs{ + { + {SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}}, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{ + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 2}, From: 2, To: 1, ReplicaID: 1}, + }, + expectChannelPlans: []ChannelAssignPlan{}, + }, + { + name: "already balanced for one collection only", + nodes: []int64{1, 2}, + collectionIDs: []int64{1}, + replicaIDs: []int64{1}, + collectionsSegments: [][]*datapb.SegmentBinlogs{ + { + {SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + //1. set up target for multi collections + collections := make([]*meta.Collection, 0, len(c.collectionIDs)) + for i := range c.collectionIDs { + collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i])) + collections = append(collections, collection) + suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, c.collectionIDs[i], c.replicaIDs[i]).Return( + nil, c.collectionsSegments[i], nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionIDs[i]).Return([]int64{c.collectionIDs[i]}, nil).Maybe() + balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i]) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + collection.LoadType = querypb.LoadType_LoadCollection + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaIDs[i], c.collectionIDs[i], + append(c.nodes, c.notExistedNodes...))) + } + + //2. set up target for distribution for multi collections + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for node, v := range c.distributionChannels { + balancer.dist.ChannelDistManager.Update(node, v...) + } + + //3. set up nodes info and resourceManager for balancer + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") + nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, c.nodes[i]) + } + + //4. balance and verify result + segmentPlans, channelPlans := balancer.Balance() + suite.ElementsMatch(c.expectChannelPlans, channelPlans) + suite.ElementsMatch(c.expectPlans, segmentPlans) + }) + } +} + +func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { + balanceCase := struct { + name string + nodes []int64 + notExistedNodes []int64 + collectionIDs []int64 + replicaIDs []int64 + collectionsSegments [][]*datapb.SegmentBinlogs + states []session.State + shouldMock bool + distributions []map[int64][]*meta.Segment + expectPlans [][]SegmentAssignPlan + }{ + name: "balance considering both global rowCounts and collection rowCounts", + nodes: []int64{1, 2, 3}, + collectionIDs: []int64{1, 2}, + replicaIDs: []int64{1, 2}, + collectionsSegments: [][]*datapb.SegmentBinlogs{ + { + {SegmentID: 1}, {SegmentID: 3}, + }, + { + {SegmentID: 2}, {SegmentID: 4}, + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + distributions: []map[int64][]*meta.Segment{ + { + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 20}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 2, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, Node: 2}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 2, NumOfRows: 30}, Node: 2}, + }, + }, + { + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 20}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 2, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 4, CollectionID: 2, NumOfRows: 30}, Node: 2}, + }, + 3: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, Node: 3}, + }, + }, + }, + expectPlans: [][]SegmentAssignPlan{ + { + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, + Node: 2}, From: 2, To: 3, ReplicaID: 1, + }, + }, + {}, + }, + } + + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + //1. set up target for multi collections + collections := make([]*meta.Collection, 0, len(balanceCase.collectionIDs)) + for i := range balanceCase.collectionIDs { + collection := utils.CreateTestCollection(balanceCase.collectionIDs[i], int32(balanceCase.replicaIDs[i])) + collections = append(collections, collection) + suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, balanceCase.collectionIDs[i], balanceCase.replicaIDs[i]).Return( + nil, balanceCase.collectionsSegments[i], nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, balanceCase.collectionIDs[i]).Return([]int64{balanceCase.collectionIDs[i]}, nil).Maybe() + balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + collection.LoadType = querypb.LoadType_LoadCollection + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i], + append(balanceCase.nodes, balanceCase.notExistedNodes...))) + } + + //2. set up target for distribution for multi collections + for node, s := range balanceCase.distributions[0] { + balancer.dist.SegmentDistManager.Update(node, s...) + } + + //3. set up nodes info and resourceManager for balancer + for i := range balanceCase.nodes { + nodeInfo := session.NewNodeInfo(balanceCase.nodes[i], "127.0.0.1:0") + nodeInfo.SetState(balanceCase.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, balanceCase.nodes[i]) + } + + //4. first round balance + segmentPlans, _ := balancer.Balance() + suite.ElementsMatch(balanceCase.expectPlans[0], segmentPlans) + + //5. update segment distribution to simulate balance effect + for node, s := range balanceCase.distributions[1] { + balancer.dist.SegmentDistManager.Update(node, s...) + } + + //6. balance again + segmentPlans, _ = balancer.Balance() + suite.ElementsMatch(balanceCase.expectPlans[1], segmentPlans) + + //6. balance one more and finish this round + segmentPlans, _ = balancer.Balance() + suite.ElementsMatch(balanceCase.expectPlans[1], segmentPlans) +} + +func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { + cases := []struct { + name string + nodes []int64 + outBoundNodes []int64 + notExistedNodes []int64 + collectionIDs []int64 + replicaIDs []int64 + collectionsSegments [][]*datapb.SegmentBinlogs + states []session.State + shouldMock bool + distributions map[int64][]*meta.Segment + distributionChannels map[int64][]*meta.DmChannel + expectPlans []SegmentAssignPlan + expectChannelPlans []ChannelAssignPlan + }{ + { + name: "stopped balance for one collection", + nodes: []int64{1, 2, 3}, + outBoundNodes: []int64{}, + collectionIDs: []int64{1}, + replicaIDs: []int64{1}, + collectionsSegments: [][]*datapb.SegmentBinlogs{ + { + {SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, + }, + }, + states: []session.State{session.NodeStateStopping, session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{ + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, + Node: 1}, From: 1, To: 3, ReplicaID: 1, Weight: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, + Node: 1}, From: 1, To: 3, ReplicaID: 1, Weight: 1}, + }, + expectChannelPlans: []ChannelAssignPlan{}, + }, + { + name: "all nodes stopping", + nodes: []int64{1, 2, 3}, + outBoundNodes: []int64{}, + collectionIDs: []int64{1}, + replicaIDs: []int64{1}, + collectionsSegments: [][]*datapb.SegmentBinlogs{ + { + {SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, + }, + }, + states: []session.State{session.NodeStateStopping, session.NodeStateStopping, session.NodeStateStopping}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + }, + { + name: "all nodes outbound", + nodes: []int64{1, 2, 3}, + outBoundNodes: []int64{1, 2, 3}, + collectionIDs: []int64{1}, + replicaIDs: []int64{1}, + collectionsSegments: [][]*datapb.SegmentBinlogs{ + { + {SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, + }, + }, + states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, + distributions: map[int64][]*meta.Segment{ + 1: { + {SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, Node: 1}, + {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 1}, + }, + 2: { + {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 2}, + }, + }, + expectPlans: []SegmentAssignPlan{}, + expectChannelPlans: []ChannelAssignPlan{}, + }, + } + for i, c := range cases { + suite.Run(c.name, func() { + if i == 0 { + suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0) + } + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + + //1. set up target for multi collections + collections := make([]*meta.Collection, 0, len(c.collectionIDs)) + for i := range c.collectionIDs { + collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i])) + collections = append(collections, collection) + suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, c.collectionIDs[i], c.replicaIDs[i]).Return( + nil, c.collectionsSegments[i], nil) + suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionIDs[i]).Return([]int64{c.collectionIDs[i]}, nil).Maybe() + balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i]) + collection.LoadPercentage = 100 + collection.Status = querypb.LoadStatus_Loaded + collection.LoadType = querypb.LoadType_LoadCollection + balancer.meta.CollectionManager.PutCollection(collection) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaIDs[i], c.collectionIDs[i], + append(c.nodes, c.notExistedNodes...))) + } + + //2. set up target for distribution for multi collections + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for node, v := range c.distributionChannels { + balancer.dist.ChannelDistManager.Update(node, v...) + } + + //3. set up nodes info and resourceManager for balancer + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") + nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, c.nodes[i]) + } + + for i := range c.outBoundNodes { + suite.balancer.meta.ResourceManager.UnassignNode(meta.DefaultResourceGroupName, c.outBoundNodes[i]) + } + + //4. balance and verify result + segmentPlans, channelPlans := balancer.Balance() + suite.ElementsMatch(c.expectChannelPlans, channelPlans) + suite.ElementsMatch(c.expectPlans, segmentPlans) + }) + } +} + +func TestScoreBasedBalancerSuite(t *testing.T) { + suite.Run(t, new(ScoreBasedBalancerTestSuite)) +} diff --git a/internal/querycoordv2/balance/utils.go b/internal/querycoordv2/balance/utils.go index e7fe0d9c24..cf94d6bdab 100644 --- a/internal/querycoordv2/balance/utils.go +++ b/internal/querycoordv2/balance/utils.go @@ -18,13 +18,19 @@ package balance import ( "context" + "fmt" "time" "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/task" "go.uber.org/zap" ) +const ( + InfoPrefix = "Balance-Info:" +) + func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout time.Duration, plans []SegmentAssignPlan) []task.Task { ret := make([]task.Task, 0) for _, p := range plans { @@ -105,3 +111,75 @@ func CreateChannelTasksFromPlans(ctx context.Context, checkerID int64, timeout t } return ret } + +func PrintNewBalancePlans(collectionID int64, replicaID int64, segmentPlans []SegmentAssignPlan, + channelPlans []ChannelAssignPlan) { + balanceInfo := fmt.Sprintf("%s{collectionID:%d, replicaID:%d, ", InfoPrefix, collectionID, replicaID) + for _, segmentPlan := range segmentPlans { + balanceInfo += segmentPlan.ToString() + } + for _, channelPlan := range channelPlans { + balanceInfo += channelPlan.ToString() + } + balanceInfo += "}" + log.Info(balanceInfo) +} + +func PrintCurrentReplicaDist(replica *meta.Replica, + stoppingNodesSegments map[int64][]*meta.Segment, nodeSegments map[int64][]*meta.Segment, + channelManager *meta.ChannelDistManager) { + distInfo := fmt.Sprintf("%s {collectionID:%d, replicaID:%d, ", InfoPrefix, replica.CollectionID, replica.GetID()) + //1. print stopping nodes segment distribution + distInfo += "[stoppingNodesSegmentDist:" + for stoppingNodeID, stoppedSegments := range stoppingNodesSegments { + distInfo += fmt.Sprintf("[nodeID:%d, ", stoppingNodeID) + distInfo += "stopped-segments:[" + for _, stoppedSegment := range stoppedSegments { + distInfo += fmt.Sprintf("%d,", stoppedSegment.GetID()) + } + distInfo += "]]" + } + distInfo += "]\n" + //2. print normal nodes segment distribution + distInfo += "[normalNodesSegmentDist:" + for normalNodeID, normalNodeSegments := range nodeSegments { + distInfo += fmt.Sprintf("[nodeID:%d, ", normalNodeID) + distInfo += "loaded-segments:[" + nodeRowSum := int64(0) + for _, normalSegment := range normalNodeSegments { + distInfo += fmt.Sprintf("[segmentID: %d, rowCount: %d] ", + normalSegment.GetID(), normalSegment.GetNumOfRows()) + nodeRowSum += normalSegment.GetNumOfRows() + } + distInfo += fmt.Sprintf("] nodeRowSum:%d]", nodeRowSum) + } + distInfo += "]\n" + + //3. print stopping nodes channel distribution + distInfo += "[stoppingNodesChannelDist:" + for stoppingNodeID := range stoppingNodesSegments { + stoppingNodeChannels := channelManager.GetByNode(stoppingNodeID) + distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", stoppingNodeID, len(stoppingNodeChannels)) + distInfo += "channels:[" + for _, stoppingChan := range stoppingNodeChannels { + distInfo += fmt.Sprintf("%s,", stoppingChan.GetChannelName()) + } + distInfo += "]]" + } + distInfo += "]\n" + + //4. print normal nodes channel distribution + distInfo += "[normalNodesChannelDist:" + for normalNodeID := range nodeSegments { + normalNodeChannels := channelManager.GetByNode(normalNodeID) + distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", normalNodeID, len(normalNodeChannels)) + distInfo += "channels:[" + for _, normalNodeChan := range normalNodeChannels { + distInfo += fmt.Sprintf("%s,", normalNodeChan.GetChannelName()) + } + distInfo += "]]" + } + distInfo += "]\n" + + log.Info(distInfo) +} diff --git a/internal/querycoordv2/checkers/controller.go b/internal/querycoordv2/checkers/controller.go index 935067b40b..404ef8ccef 100644 --- a/internal/querycoordv2/checkers/controller.go +++ b/internal/querycoordv2/checkers/controller.go @@ -54,13 +54,14 @@ func NewCheckerController( dist *meta.DistributionManager, targetMgr *meta.TargetManager, balancer balance.Balance, + nodeMgr *session.NodeManager, scheduler task.Scheduler) *CheckerController { // CheckerController runs checkers with the order, // the former checker has higher priority checkers := []Checker{ NewChannelChecker(meta, dist, targetMgr, balancer), - NewSegmentChecker(meta, dist, targetMgr, balancer), + NewSegmentChecker(meta, dist, targetMgr, balancer, nodeMgr), NewBalanceChecker(balancer), } for i, checker := range checkers { diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index 821c428dbf..3b5423c272 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -26,9 +26,11 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/balance" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" + "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/samber/lo" "go.uber.org/zap" ) @@ -39,6 +41,7 @@ type SegmentChecker struct { dist *meta.DistributionManager targetMgr *meta.TargetManager balancer balance.Balance + nodeMgr *session.NodeManager } func NewSegmentChecker( @@ -46,12 +49,14 @@ func NewSegmentChecker( dist *meta.DistributionManager, targetMgr *meta.TargetManager, balancer balance.Balance, + nodeMgr *session.NodeManager, ) *SegmentChecker { return &SegmentChecker{ meta: meta, dist: dist, targetMgr: targetMgr, balancer: balancer, + nodeMgr: nodeMgr, } } @@ -274,9 +279,13 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] } outboundNodes := c.meta.ResourceManager.CheckOutboundNodes(replica) availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { - return !outboundNodes.Contain(node) + stop, err := c.nodeMgr.IsStoppingNode(node) + if err != nil { + return false + } + return !outboundNodes.Contain(node) && !stop }) - plans := c.balancer.AssignSegment(packedSegments, availableNodes) + plans := c.balancer.AssignSegment(replica.CollectionID, packedSegments, availableNodes) for i := range plans { plans[i].ReplicaID = replica.GetID() } diff --git a/internal/querycoordv2/checkers/segment_checker_test.go b/internal/querycoordv2/checkers/segment_checker_test.go index 5294368ae7..1bfea5bdb1 100644 --- a/internal/querycoordv2/checkers/segment_checker_test.go +++ b/internal/querycoordv2/checkers/segment_checker_test.go @@ -73,7 +73,7 @@ func (suite *SegmentCheckerTestSuite) SetupTest() { targetManager := meta.NewTargetManager(suite.broker, suite.meta) balancer := suite.createMockBalancer() - suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer) + suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer, suite.nodeMgr) suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe() } @@ -84,7 +84,7 @@ func (suite *SegmentCheckerTestSuite) TearDownTest() { func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance { balancer := balance.NewMockBalancer(suite.T()) - balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything).Maybe().Return(func(segments []*meta.Segment, nodes []int64) []balance.SegmentAssignPlan { + balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64) []balance.SegmentAssignPlan { plans := make([]balance.SegmentAssignPlan, 0, len(segments)) for i, s := range segments { plan := balance.SegmentAssignPlan{ diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index 98a3e3d40c..34501d6709 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -98,7 +98,11 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe if dstNodeSet.Len() == 0 { outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica) availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { - return !outboundNodes.Contain(node) + stop, err := s.nodeMgr.IsStoppingNode(node) + if err != nil { + return false + } + return !outboundNodes.Contain(node) && !stop }) dstNodeSet.Insert(availableNodes...) } @@ -132,7 +136,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe zap.Int64("srcNodeID", srcNode), zap.Int64s("destNodeIDs", dstNodeSet.Collect()), ) - plans := s.balancer.AssignSegment(toBalance.Collect(), dstNodeSet.Collect()) + plans := s.balancer.AssignSegment(req.GetCollectionID(), toBalance.Collect(), dstNodeSet.Collect()) tasks := make([]task.Task, 0, len(plans)) for _, plan := range plans { log.Info("manually balance segment...", diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index bb2d8db377..1d43f3c043 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -109,7 +109,8 @@ type Server struct { replicaObserver *observers.ReplicaObserver resourceObserver *observers.ResourceObserver - balancer balance.Balance + balancer balance.Balance + balancerMap map[string]balance.Balance // Active-standby enableActiveStandBy bool @@ -249,15 +250,21 @@ func (s *Server) initQueryCoord() error { s.taskScheduler, ) - // Init balancer - log.Info("init balancer") - s.balancer = balance.NewRowCountBasedBalancer( - s.taskScheduler, - s.nodeMgr, - s.dist, - s.meta, - s.targetMgr, - ) + // Init balancer map and balancer + log.Info("init all available balancer") + s.balancerMap = make(map[string]balance.Balance) + s.balancerMap[balance.RoundRobinBalancerName] = balance.NewRoundRobinBalancer(s.taskScheduler, s.nodeMgr) + s.balancerMap[balance.RowCountBasedBalancerName] = balance.NewRowCountBasedBalancer(s.taskScheduler, + s.nodeMgr, s.dist, s.meta, s.targetMgr) + s.balancerMap[balance.ScoreBasedBalancerName] = balance.NewScoreBasedBalancer(s.taskScheduler, + s.nodeMgr, s.dist, s.meta, s.targetMgr) + if balancer, ok := s.balancerMap[params.Params.QueryCoordCfg.Balancer.GetValue()]; ok { + s.balancer = balancer + log.Info("use config balancer", zap.String("balancer", params.Params.QueryCoordCfg.Balancer.GetValue())) + } else { + s.balancer = s.balancerMap[balance.RowCountBasedBalancerName] + log.Info("use rowCountBased auto balancer") + } // Init checker controller log.Info("init checker controller") @@ -266,6 +273,7 @@ func (s *Server) initQueryCoord() error { s.dist, s.targetMgr, s.balancer, + s.nodeMgr, s.taskScheduler, ) diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 69f3c88bc4..cc705fbda1 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -475,6 +475,7 @@ func (suite *ServerSuite) hackServer() { suite.server.dist, suite.server.targetMgr, suite.server.balancer, + suite.server.nodeMgr, suite.server.taskScheduler, ) suite.server.targetObserver = observers.NewTargetObserver( diff --git a/internal/querycoordv2/task/mock_scheduler.go b/internal/querycoordv2/task/mock_scheduler.go index ae70e7b551..43a47b9a6c 100644 --- a/internal/querycoordv2/task/mock_scheduler.go +++ b/internal/querycoordv2/task/mock_scheduler.go @@ -13,6 +13,14 @@ type MockScheduler struct { mock.Mock } +func (_m *MockScheduler) GetChannelTaskNum() int { + return 0 +} + +func (_m *MockScheduler) GetSegmentTaskNum() int { + return 0 +} + type MockScheduler_Expecter struct { mock *mock.Mock } @@ -41,7 +49,7 @@ type MockScheduler_Add_Call struct { } // Add is a helper method to define mock.On call -// - task Task +// - task Task func (_e *MockScheduler_Expecter) Add(task interface{}) *MockScheduler_Add_Call { return &MockScheduler_Add_Call{Call: _e.mock.On("Add", task)} } @@ -69,7 +77,7 @@ type MockScheduler_AddExecutor_Call struct { } // AddExecutor is a helper method to define mock.On call -// - nodeID int64 +// - nodeID int64 func (_e *MockScheduler_Expecter) AddExecutor(nodeID interface{}) *MockScheduler_AddExecutor_Call { return &MockScheduler_AddExecutor_Call{Call: _e.mock.On("AddExecutor", nodeID)} } @@ -97,7 +105,7 @@ type MockScheduler_Dispatch_Call struct { } // Dispatch is a helper method to define mock.On call -// - node int64 +// - node int64 func (_e *MockScheduler_Expecter) Dispatch(node interface{}) *MockScheduler_Dispatch_Call { return &MockScheduler_Dispatch_Call{Call: _e.mock.On("Dispatch", node)} } @@ -134,7 +142,7 @@ type MockScheduler_GetNodeChannelDelta_Call struct { } // GetNodeChannelDelta is a helper method to define mock.On call -// - nodeID int64 +// - nodeID int64 func (_e *MockScheduler_Expecter) GetNodeChannelDelta(nodeID interface{}) *MockScheduler_GetNodeChannelDelta_Call { return &MockScheduler_GetNodeChannelDelta_Call{Call: _e.mock.On("GetNodeChannelDelta", nodeID)} } @@ -171,7 +179,7 @@ type MockScheduler_GetNodeSegmentDelta_Call struct { } // GetNodeSegmentDelta is a helper method to define mock.On call -// - nodeID int64 +// - nodeID int64 func (_e *MockScheduler_Expecter) GetNodeSegmentDelta(nodeID interface{}) *MockScheduler_GetNodeSegmentDelta_Call { return &MockScheduler_GetNodeSegmentDelta_Call{Call: _e.mock.On("GetNodeSegmentDelta", nodeID)} } @@ -199,7 +207,7 @@ type MockScheduler_RemoveByNode_Call struct { } // RemoveByNode is a helper method to define mock.On call -// - node int64 +// - node int64 func (_e *MockScheduler_Expecter) RemoveByNode(node interface{}) *MockScheduler_RemoveByNode_Call { return &MockScheduler_RemoveByNode_Call{Call: _e.mock.On("RemoveByNode", node)} } @@ -227,7 +235,7 @@ type MockScheduler_RemoveExecutor_Call struct { } // RemoveExecutor is a helper method to define mock.On call -// - nodeID int64 +// - nodeID int64 func (_e *MockScheduler_Expecter) RemoveExecutor(nodeID interface{}) *MockScheduler_RemoveExecutor_Call { return &MockScheduler_RemoveExecutor_Call{Call: _e.mock.On("RemoveExecutor", nodeID)} } @@ -255,7 +263,7 @@ type MockScheduler_Start_Call struct { } // Start is a helper method to define mock.On call -// - ctx context.Context +// - ctx context.Context func (_e *MockScheduler_Expecter) Start(ctx interface{}) *MockScheduler_Start_Call { return &MockScheduler_Start_Call{Call: _e.mock.On("Start", ctx)} } diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 8a0e62a4bf..0947548892 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -120,6 +120,8 @@ type Scheduler interface { RemoveByNode(node int64) GetNodeSegmentDelta(nodeID int64) int GetNodeChannelDelta(nodeID int64) int + GetChannelTaskNum() int + GetSegmentTaskNum() int } type taskScheduler struct { @@ -292,7 +294,6 @@ func (scheduler *taskScheduler) preAdd(task Task) error { return merr.WrapErrServiceInternal("task with the same channel exists") } - if GetTaskType(task) == TaskTypeGrow { nodesWithChannel := scheduler.distMgr.LeaderViewManager.GetChannelDist(task.Channel()) replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel) @@ -300,11 +301,9 @@ func (scheduler *taskScheduler) preAdd(task Task) error { return merr.WrapErrServiceInternal("channel subscribed, it can be only balanced") } } - default: panic(fmt.Sprintf("preAdd: forget to process task type: %+v", task)) } - return nil } @@ -386,6 +385,20 @@ func (scheduler *taskScheduler) GetNodeChannelDelta(nodeID int64) int { return calculateNodeDelta(nodeID, scheduler.channelTasks) } +func (scheduler *taskScheduler) GetChannelTaskNum() int { + scheduler.rwmutex.RLock() + defer scheduler.rwmutex.RUnlock() + + return len(scheduler.channelTasks) +} + +func (scheduler *taskScheduler) GetSegmentTaskNum() int { + scheduler.rwmutex.RLock() + defer scheduler.rwmutex.RUnlock() + + return len(scheduler.segmentTasks) +} + func calculateNodeDelta[K comparable, T ~map[K]Task](nodeID int64, tasks T) int { delta := 0 for _, task := range tasks { diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index a559896b67..18632f5224 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -435,7 +435,6 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC // LoadSegments load historical data into query node, historical data can be vector data or index func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { nodeID := node.session.ServerID - log.Info("wayblink", zap.Int64("nodeID", nodeID)) // check node healthy if !node.lifetime.Add(commonpbutil.IsHealthy) { err := fmt.Errorf("query node %d is not ready", nodeID) diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index 86187933e0..645a5d3129 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -1084,6 +1084,9 @@ type queryCoordConfig struct { //---- Balance --- AutoBalance ParamItem `refreshable:"true"` + Balancer ParamItem `refreshable:"true"` + GlobalRowCountFactor ParamItem `refreshable:"true"` + ScoreUnbalanceTolerationFactor ParamItem `refreshable:"true"` OverloadedMemoryThresholdPercentage ParamItem `refreshable:"true"` BalanceIntervalSeconds ParamItem `refreshable:"true"` MemoryUsageMaxDifferencePercentage ParamItem `refreshable:"true"` @@ -1149,13 +1152,43 @@ func (p *queryCoordConfig) init(base *BaseTable) { p.AutoBalance = ParamItem{ Key: "queryCoord.autoBalance", Version: "2.0.0", - DefaultValue: "tru", + DefaultValue: "true", PanicIfEmpty: true, Doc: "Enable auto balance", Export: true, } p.AutoBalance.Init(base.mgr) + p.Balancer = ParamItem{ + Key: "queryCoord.balancer", + Version: "2.0.0", + DefaultValue: "RowCountBasedBalancer", + PanicIfEmpty: true, + Doc: "auto balancer used for segments on queryNodes", + Export: true, + } + p.Balancer.Init(base.mgr) + + p.GlobalRowCountFactor = ParamItem{ + Key: "queryCoord.globalRowCountFactor", + Version: "2.0.0", + DefaultValue: "0.1", + PanicIfEmpty: true, + Doc: "the weight used when balancing segments among queryNodes", + Export: true, + } + p.GlobalRowCountFactor.Init(base.mgr) + + p.ScoreUnbalanceTolerationFactor = ParamItem{ + Key: "queryCoord.scoreUnbalanceTolerationFactor", + Version: "2.0.0", + DefaultValue: "1.3", + PanicIfEmpty: true, + Doc: "the largest value for unbalanced extent between from and to nodes when doing balance", + Export: true, + } + p.ScoreUnbalanceTolerationFactor.Init(base.mgr) + p.OverloadedMemoryThresholdPercentage = ParamItem{ Key: "queryCoord.overloadedMemoryThresholdPercentage", Version: "2.0.0", diff --git a/internal/util/typeutil/set.go b/internal/util/typeutil/set.go index d408f976c0..2baff5f505 100644 --- a/internal/util/typeutil/set.go +++ b/internal/util/typeutil/set.go @@ -91,6 +91,10 @@ func (set Set[T]) Remove(elements ...T) { } } +func (set Set[T]) Clear() { + set.Remove(set.Collect()...) +} + // Get all elements in the set func (set Set[T]) Collect() []T { elements := make([]T, 0, len(set)) diff --git a/internal/util/typeutil/set_test.go b/internal/util/typeutil/set_test.go index eb33342c19..fe402956f0 100644 --- a/internal/util/typeutil/set_test.go +++ b/internal/util/typeutil/set_test.go @@ -36,3 +36,18 @@ func TestUniqueSet(t *testing.T) { assert.True(t, set.Contain(9)) assert.False(t, set.Contain(5, 7, 9)) } + +func TestUniqueSetClear(t *testing.T) { + set := make(UniqueSet) + set.Insert(5, 7, 9) + assert.True(t, set.Contain(5)) + assert.True(t, set.Contain(7)) + assert.True(t, set.Contain(9)) + assert.Equal(t, 3, set.Len()) + + set.Clear() + assert.False(t, set.Contain(5)) + assert.False(t, set.Contain(7)) + assert.False(t, set.Contain(9)) + assert.Equal(t, 0, set.Len()) +}