diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index 54b43ccc30..52d17e2653 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -19,6 +19,7 @@ package checkers import ( "context" "sort" + "strings" "time" "github.com/samber/lo" @@ -85,24 +86,26 @@ func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) b func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int64 { ids := b.meta.GetAll(ctx) + + // Sort collections using the configured sort order + ids = b.sortCollections(ctx, ids) + if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - // balance collections influenced by stopping nodes - stoppingReplicas := make([]int64, 0) for _, cid := range ids { // if target and meta isn't ready, skip balance this collection if !b.readyToCheck(ctx, cid) { continue } replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid) + stoppingReplicas := make([]int64, 0) for _, replica := range replicas { if replica.RONodesCount() > 0 { stoppingReplicas = append(stoppingReplicas, replica.GetID()) } } - } - // do stopping balance only in this round - if len(stoppingReplicas) > 0 { - return stoppingReplicas + if len(stoppingReplicas) > 0 { + return stoppingReplicas + } } } @@ -123,9 +126,6 @@ func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64 collection := b.meta.GetCollection(ctx, cid) return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded }) - sort.Slice(loadedCollections, func(i, j int) bool { - return loadedCollections[i] < loadedCollections[j] - }) // Before performing balancing, check the CurrentTarget/LeaderView/Distribution for all collections. // If any collection has unready info, skip the balance operation to avoid inconsistencies. @@ -138,6 +138,9 @@ func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64 return nil } + // Sort collections using the configured sort order + loadedCollections = b.sortCollections(ctx, loadedCollections) + // iterator one normal collection in one round normalReplicasToBalance := make([]int64, 0) hasUnbalancedCollection := false @@ -187,6 +190,11 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task { if len(stoppingReplicas) > 0 { // check for stopping balance first segmentPlans, channelPlans = b.balanceReplicas(ctx, stoppingReplicas) + // iterate all collection to find a collection to balance + for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 { + replicasToBalance := b.getReplicaForStoppingBalance(ctx) + segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance) + } } else { // then check for auto balance if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) { @@ -212,3 +220,37 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task { ret = append(ret, tasks...) return ret } + +func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 { + sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) + if sortOrder == "" { + sortOrder = "byrowcount" // Default to ByRowCount + } + + // Define sorting functions + sortByRowCount := func(i, j int) bool { + rowCount1 := b.targetMgr.GetCollectionRowCount(ctx, collections[i], meta.CurrentTargetFirst) + rowCount2 := b.targetMgr.GetCollectionRowCount(ctx, collections[j], meta.CurrentTargetFirst) + return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j]) + } + + sortByCollectionID := func(i, j int) bool { + return collections[i] < collections[j] + } + + // Select the appropriate sorting function + var sortFunc func(i, j int) bool + switch sortOrder { + case "byrowcount": + sortFunc = sortByRowCount + case "bycollectionid": + sortFunc = sortByCollectionID + default: + log.Warn("Invalid balance sort order configuration, using default ByRowCount", zap.String("sortOrder", sortOrder)) + sortFunc = sortByRowCount + } + + // Sort the collections + sort.Slice(collections, sortFunc) + return collections +} diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 8b62b6c781..97e76e1152 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -49,7 +49,7 @@ type BalanceCheckerTestSuite struct { broker *meta.MockBroker nodeMgr *session.NodeManager scheduler *task.MockScheduler - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface } func (suite *BalanceCheckerTestSuite) SetupSuite() { @@ -290,7 +290,7 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) // test stopping balance - idsToBalance := []int64{int64(replicaID1), int64(replicaID2)} + idsToBalance := []int64{int64(replicaID1)} replicasToBalance := suite.checker.getReplicaForStoppingBalance(ctx) suite.ElementsMatch(idsToBalance, replicasToBalance) @@ -305,7 +305,7 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { segPlans = append(segPlans, mockPlan) suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans) tasks := suite.checker.Check(context.TODO()) - suite.Len(tasks, 2) + suite.Len(tasks, 1) } func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { @@ -349,14 +349,16 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { // test normal balance when one collection has unready target mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true) mockTarget.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false) + mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) suite.Len(replicasToBalance, 0) // test stopping balance with target not ready mockTarget.ExpectedCalls = nil mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(false) - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true) - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false) + mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true).Maybe() + mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false).Maybe() + mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() mr1 := replica1.CopyForWrite() mr1.AddRONode(1) suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) @@ -440,6 +442,299 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceInterval() { suite.Equal(funcCallCounter.Load(), int64(1)) } +func (suite *BalanceCheckerTestSuite) TestBalanceOrder() { + ctx := context.Background() + nodeID1, nodeID2 := int64(1), int64(2) + + // set collections meta + cid1, replicaID1, partitionID1 := 1, 1, 1 + collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) + collection1.Status = querypb.LoadStatus_Loaded + replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2}) + partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + + cid2, replicaID2, partitionID2 := 2, 2, 2 + collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) + collection2.Status = querypb.LoadStatus_Loaded + replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2}) + partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + + // mock collection row count + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(100)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(200)).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + + // mock stopping node + mr1 := replica1.CopyForWrite() + mr1.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) + mr2 := replica2.CopyForWrite() + mr2.AddRONode(nodeID2) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + + // test normal balance order + replicas := suite.checker.getReplicaForNormalBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID2)}) + + // test stopping balance order + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID2)}) + + // mock collection row count + mockTargetManager.ExpectedCalls = nil + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(200)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(100)).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + + // test normal balance order + replicas = suite.checker.getReplicaForNormalBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID1)}) + + // test stopping balance order + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID1)}) +} + +func (suite *BalanceCheckerTestSuite) TestSortCollections() { + ctx := context.Background() + + // Set up test collections + cid1, cid2, cid3 := int64(1), int64(2), int64(3) + + // Mock the target manager for row count returns + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + + // Collection 1: Low ID, High row count + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() + + // Collection 2: Middle ID, Low row count + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() + + // Collection 3: High ID, Middle row count + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() + + collections := []int64{cid1, cid2, cid3} + + // Test ByRowCount sorting (default) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") + sortedCollections := suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Collections should be sorted by row count (highest first)") + + // Test ByCollectionID sorting + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Collections should be sorted by collection ID (ascending)") + + // Test with empty sort order (should default to ByRowCount) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is empty") + + // Test with invalid sort order (should default to ByRowCount) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "InvalidOrder") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is invalid") + + // Test with mixed case (should be case-insensitive) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "bYcOlLeCtIoNiD") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Should handle case-insensitive sort order names") + + // Test with empty collection list + emptyCollections := []int64{} + sortedCollections = suite.checker.sortCollections(ctx, emptyCollections) + suite.Equal([]int64{}, sortedCollections, "Should handle empty collection list") +} + +func (suite *BalanceCheckerTestSuite) TestSortCollectionsIntegration() { + ctx := context.Background() + + // Set up test collections and nodes + nodeID1 := int64(1) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID1, + Address: "localhost", + Hostname: "localhost", + })) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) + + // Create two collections to ensure sorting is triggered + cid1, replicaID1 := int64(1), int64(101) + collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) + collection1.Status = querypb.LoadStatus_Loaded + replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + + // Add a second collection with different characteristics + cid2, replicaID2 := int64(2), int64(102) + collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) + collection2.Status = querypb.LoadStatus_Loaded + replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + + // Mock target manager + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + + // Setup different row counts to test sorting + // Collection 1 has more rows than Collection 2 + var getRowCountCallCount int + mockTargetManager.On("GetCollectionRowCount", mock.Anything, mock.Anything, mock.Anything). + Return(func(ctx context.Context, collectionID int64, scope int32) int64 { + getRowCountCallCount++ + if collectionID == cid1 { + return 200 // More rows in collection 1 + } + return 100 // Fewer rows in collection 2 + }) + + mockTargetManager.On("IsCurrentTargetReady", mock.Anything, mock.Anything).Return(true) + mockTargetManager.On("IsNextTargetExist", mock.Anything, mock.Anything).Return(true) + + // Configure for testing + paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") + + // Clear first to avoid previous test state + suite.checker.normalBalanceCollectionsCurrentRound.Clear() + + // Call normal balance + _ = suite.checker.getReplicaForNormalBalance(ctx) + + // Verify GetCollectionRowCount was called at least twice (once for each collection) + // This confirms that the collections were sorted + suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during normal balance") + + // Reset counter and test stopping balance + getRowCountCallCount = 0 + + // Set up for stopping balance test + mr1 := replica1.CopyForWrite() + mr1.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) + + mr2 := replica2.CopyForWrite() + mr2.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + + paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") + + // Call stopping balance + _ = suite.checker.getReplicaForStoppingBalance(ctx) + + // Verify GetCollectionRowCount was called at least twice during stopping balance + suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during stopping balance") +} + +func (suite *BalanceCheckerTestSuite) TestBalanceTriggerOrder() { + ctx := context.Background() + + // Set up nodes + nodeID1, nodeID2 := int64(1), int64(2) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID1, + Address: "localhost", + Hostname: "localhost", + })) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID2, + Address: "localhost", + Hostname: "localhost", + })) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) + + // Create collections with different row counts + cid1, replicaID1 := int64(1), int64(101) + collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) + collection1.Status = querypb.LoadStatus_Loaded + replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1, nodeID2}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + + cid2, replicaID2 := int64(2), int64(102) + collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) + collection2.Status = querypb.LoadStatus_Loaded + replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1, nodeID2}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + + cid3, replicaID3 := int64(3), int64(103) + collection3 := utils.CreateTestCollection(cid3, int32(replicaID3)) + collection3.Status = querypb.LoadStatus_Loaded + replica3 := utils.CreateTestReplica(replicaID3, cid3, []int64{nodeID1, nodeID2}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection3) + suite.checker.meta.ReplicaManager.Put(ctx, replica3) + + // Mock the target manager + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + + // Set row counts: Collection 1 (highest), Collection 3 (middle), Collection 2 (lowest) + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() + + // Mark the current target as ready + mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() + + // Enable auto balance + paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") + + // Test with ByRowCount order (default) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") + suite.checker.normalBalanceCollectionsCurrentRound.Clear() + + // Normal balance should pick the collection with highest row count first + replicas := suite.checker.getReplicaForNormalBalance(ctx) + suite.Contains(replicas, replicaID1, "Should balance collection with highest row count first") + + // Add stopping nodes to test stopping balance + mr1 := replica1.CopyForWrite() + mr1.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) + + mr2 := replica2.CopyForWrite() + mr2.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + + mr3 := replica3.CopyForWrite() + mr3.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr3.IntoReplica()) + + // Enable stopping balance + paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") + + // Stopping balance should also pick the collection with highest row count first + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with highest row count") + + // Test with ByCollectionID order + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") + suite.checker.normalBalanceCollectionsCurrentRound.Clear() + + // Normal balance should pick the collection with lowest ID first + replicas = suite.checker.getReplicaForNormalBalance(ctx) + suite.Contains(replicas, replicaID1, "Should balance collection with lowest ID first") + + // Stopping balance should also pick the collection with lowest ID first + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with lowest ID") +} + func TestBalanceCheckerSuite(t *testing.T) { suite.Run(t, new(BalanceCheckerTestSuite)) } diff --git a/internal/querycoordv2/meta/mock_target_manager.go b/internal/querycoordv2/meta/mock_target_manager.go index d82a61f157..1ffeb2dfd5 100644 --- a/internal/querycoordv2/meta/mock_target_manager.go +++ b/internal/querycoordv2/meta/mock_target_manager.go @@ -74,6 +74,54 @@ func (_c *MockTargetManager_CanSegmentBeMoved_Call) RunAndReturn(run func(contex return _c } +// GetCollectionRowCount provides a mock function with given fields: ctx, collectionID, scope +func (_m *MockTargetManager) GetCollectionRowCount(ctx context.Context, collectionID int64, scope int32) int64 { + ret := _m.Called(ctx, collectionID, scope) + + if len(ret) == 0 { + panic("no return value specified for GetCollectionRowCount") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) int64); ok { + r0 = rf(ctx, collectionID, scope) + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockTargetManager_GetCollectionRowCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionRowCount' +type MockTargetManager_GetCollectionRowCount_Call struct { + *mock.Call +} + +// GetCollectionRowCount is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetCollectionRowCount(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionRowCount_Call { + return &MockTargetManager_GetCollectionRowCount_Call{Call: _e.mock.On("GetCollectionRowCount", ctx, collectionID, scope)} +} + +func (_c *MockTargetManager_GetCollectionRowCount_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetCollectionRowCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetCollectionRowCount_Call) Return(_a0 int64) *MockTargetManager_GetCollectionRowCount_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetCollectionRowCount_Call) RunAndReturn(run func(context.Context, int64, int32) int64) *MockTargetManager_GetCollectionRowCount_Call { + _c.Call.Return(run) + return _c +} + // GetCollectionTargetVersion provides a mock function with given fields: ctx, collectionID, scope func (_m *MockTargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope int32) int64 { ret := _m.Called(ctx, collectionID, scope) diff --git a/internal/querycoordv2/meta/target.go b/internal/querycoordv2/meta/target.go index ac055df009..38b745f1e8 100644 --- a/internal/querycoordv2/meta/target.go +++ b/internal/querycoordv2/meta/target.go @@ -42,11 +42,15 @@ type CollectionTarget struct { // record target status, if target has been save before milvus v2.4.19, then the target will lack of segment info. lackSegmentInfo bool + + // cache collection total row count + totalRowCount int64 } func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[string]*DmChannel, partitionIDs []int64) *CollectionTarget { channel2Segments := make(map[string][]*datapb.SegmentInfo, len(dmChannels)) partition2Segments := make(map[int64][]*datapb.SegmentInfo, len(partitionIDs)) + totalRowCount := int64(0) for _, segment := range segments { channel := segment.GetInsertChannel() if _, ok := channel2Segments[channel]; !ok { @@ -58,6 +62,7 @@ func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[ partition2Segments[partitionID] = make([]*datapb.SegmentInfo, 0) } partition2Segments[partitionID] = append(partition2Segments[partitionID], segment) + totalRowCount += segment.GetNumOfRows() } return &CollectionTarget{ segments: segments, @@ -66,6 +71,7 @@ func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[ dmChannels: dmChannels, partitions: typeutil.NewSet(partitionIDs...), version: time.Now().UnixNano(), + totalRowCount: totalRowCount, } } @@ -77,6 +83,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget var partitions []int64 lackSegmentInfo := false + totalRowCount := int64(0) for _, t := range target.GetChannelTargets() { if _, ok := channel2Segments[t.GetChannelName()]; !ok { channel2Segments[t.GetChannelName()] = make([]*datapb.SegmentInfo, 0) @@ -100,6 +107,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget segments[segment.GetID()] = info channel2Segments[t.GetChannelName()] = append(channel2Segments[t.GetChannelName()], info) partition2Segments[partition.GetPartitionID()] = append(partition2Segments[partition.GetPartitionID()], info) + totalRowCount += segment.GetNumOfRows() } partitions = append(partitions, partition.GetPartitionID()) } @@ -128,6 +136,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget partitions: typeutil.NewSet(partitions...), version: target.GetVersion(), lackSegmentInfo: lackSegmentInfo, + totalRowCount: totalRowCount, } } @@ -222,6 +231,10 @@ func (p *CollectionTarget) Ready() bool { return !p.lackSegmentInfo } +func (p *CollectionTarget) GetRowCount() int64 { + return p.totalRowCount +} + type target struct { keyLock *lock.KeyLock[int64] // guards updateCollectionTarget // just maintain target at collection level diff --git a/internal/querycoordv2/meta/target_manager.go b/internal/querycoordv2/meta/target_manager.go index bf7893bcb7..d69a863a6f 100644 --- a/internal/querycoordv2/meta/target_manager.go +++ b/internal/querycoordv2/meta/target_manager.go @@ -72,6 +72,7 @@ type TargetManagerInterface interface { GetTargetJSON(ctx context.Context, scope TargetScope, collectionID int64) string GetPartitions(ctx context.Context, collectionID int64, scope TargetScope) ([]int64, error) IsCurrentTargetReady(ctx context.Context, collectionID int64) bool + GetCollectionRowCount(ctx context.Context, collectionID int64, scope TargetScope) int64 } type TargetManager struct { @@ -609,3 +610,11 @@ func (mgr *TargetManager) IsCurrentTargetReady(ctx context.Context, collectionID return target.Ready() } + +func (mgr *TargetManager) GetCollectionRowCount(ctx context.Context, collectionID int64, scope TargetScope) int64 { + target := mgr.getCollectionTarget(scope, collectionID) + if len(target) == 0 { + return 0 + } + return target[0].GetRowCount() +} diff --git a/internal/querycoordv2/meta/target_manager_test.go b/internal/querycoordv2/meta/target_manager_test.go index 29875f28dd..cf9dd4f4cf 100644 --- a/internal/querycoordv2/meta/target_manager_test.go +++ b/internal/querycoordv2/meta/target_manager_test.go @@ -620,6 +620,7 @@ func (suite *TargetManagerSuite) TestRecover() { suite.Equal(int64(100), segment.GetNumOfRows()) } suite.True(target.Ready()) + suite.Equal(int64(200), target.GetRowCount()) // after recover, target info should be cleaned up targets, err := suite.catalog.GetCollectionTargets(ctx) diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index cfebe80441..78076218cd 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1969,6 +1969,7 @@ type queryCoordConfig struct { AutoBalance ParamItem `refreshable:"true"` AutoBalanceChannel ParamItem `refreshable:"true"` Balancer ParamItem `refreshable:"true"` + BalanceTriggerOrder ParamItem `refreshable:"true"` GlobalRowCountFactor ParamItem `refreshable:"true"` ScoreUnbalanceTolerationFactor ParamItem `refreshable:"true"` ReverseUnbalanceTolerationFactor ParamItem `refreshable:"true"` @@ -2106,6 +2107,16 @@ If this parameter is set false, Milvus simply searches the growing segments with } p.Balancer.Init(base.mgr) + p.BalanceTriggerOrder = ParamItem{ + Key: "queryCoord.balanceTriggerOrder", + Version: "2.5.8", + DefaultValue: "ByRowCount", + PanicIfEmpty: false, + Doc: "sorting order for collection balancing, options: ByRowCount, ByCollectionID", + Export: false, + } + p.BalanceTriggerOrder.Init(base.mgr) + p.GlobalRowCountFactor = ParamItem{ Key: "queryCoord.globalRowCountFactor", Version: "2.0.0",