From c02892e9fb9f349cb133811be9f9a68140b73797 Mon Sep 17 00:00:00 2001 From: wei liu Date: Mon, 31 Mar 2025 16:00:19 +0800 Subject: [PATCH] enhance: Balance the collection with the largest row count first (#40297) issue: #37651 this PR enable to balance the collection with largest row count first, to avoid temporary migration of small table data to new nodes during their onboarding, only to be moved out again after the large table balance, which would cause unnecessary load. --------- Signed-off-by: Wei Liu --- .../querycoordv2/checkers/balance_checker.go | 60 +++- .../checkers/balance_checker_test.go | 305 +++++++++++++++++- .../querycoordv2/meta/mock_target_manager.go | 48 +++ internal/querycoordv2/meta/target.go | 13 + internal/querycoordv2/meta/target_manager.go | 9 + .../querycoordv2/meta/target_manager_test.go | 1 + pkg/util/paramtable/component_param.go | 11 + 7 files changed, 433 insertions(+), 14 deletions(-) diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index 54b43ccc30..52d17e2653 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -19,6 +19,7 @@ package checkers import ( "context" "sort" + "strings" "time" "github.com/samber/lo" @@ -85,24 +86,26 @@ func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) b func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int64 { ids := b.meta.GetAll(ctx) + + // Sort collections using the configured sort order + ids = b.sortCollections(ctx, ids) + if paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - // balance collections influenced by stopping nodes - stoppingReplicas := make([]int64, 0) for _, cid := range ids { // if target and meta isn't ready, skip balance this collection if !b.readyToCheck(ctx, cid) { continue } replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid) + stoppingReplicas := make([]int64, 0) for _, replica := range replicas { if replica.RONodesCount() > 0 { stoppingReplicas = append(stoppingReplicas, replica.GetID()) } } - } - // do stopping balance only in this round - if len(stoppingReplicas) > 0 { - return stoppingReplicas + if len(stoppingReplicas) > 0 { + return stoppingReplicas + } } } @@ -123,9 +126,6 @@ func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64 collection := b.meta.GetCollection(ctx, cid) return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded }) - sort.Slice(loadedCollections, func(i, j int) bool { - return loadedCollections[i] < loadedCollections[j] - }) // Before performing balancing, check the CurrentTarget/LeaderView/Distribution for all collections. // If any collection has unready info, skip the balance operation to avoid inconsistencies. @@ -138,6 +138,9 @@ func (b *BalanceChecker) getReplicaForNormalBalance(ctx context.Context) []int64 return nil } + // Sort collections using the configured sort order + loadedCollections = b.sortCollections(ctx, loadedCollections) + // iterator one normal collection in one round normalReplicasToBalance := make([]int64, 0) hasUnbalancedCollection := false @@ -187,6 +190,11 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task { if len(stoppingReplicas) > 0 { // check for stopping balance first segmentPlans, channelPlans = b.balanceReplicas(ctx, stoppingReplicas) + // iterate all collection to find a collection to balance + for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 { + replicasToBalance := b.getReplicaForStoppingBalance(ctx) + segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance) + } } else { // then check for auto balance if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) { @@ -212,3 +220,37 @@ func (b *BalanceChecker) Check(ctx context.Context) []task.Task { ret = append(ret, tasks...) return ret } + +func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 { + sortOrder := strings.ToLower(Params.QueryCoordCfg.BalanceTriggerOrder.GetValue()) + if sortOrder == "" { + sortOrder = "byrowcount" // Default to ByRowCount + } + + // Define sorting functions + sortByRowCount := func(i, j int) bool { + rowCount1 := b.targetMgr.GetCollectionRowCount(ctx, collections[i], meta.CurrentTargetFirst) + rowCount2 := b.targetMgr.GetCollectionRowCount(ctx, collections[j], meta.CurrentTargetFirst) + return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j]) + } + + sortByCollectionID := func(i, j int) bool { + return collections[i] < collections[j] + } + + // Select the appropriate sorting function + var sortFunc func(i, j int) bool + switch sortOrder { + case "byrowcount": + sortFunc = sortByRowCount + case "bycollectionid": + sortFunc = sortByCollectionID + default: + log.Warn("Invalid balance sort order configuration, using default ByRowCount", zap.String("sortOrder", sortOrder)) + sortFunc = sortByRowCount + } + + // Sort the collections + sort.Slice(collections, sortFunc) + return collections +} diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 8b62b6c781..97e76e1152 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -49,7 +49,7 @@ type BalanceCheckerTestSuite struct { broker *meta.MockBroker nodeMgr *session.NodeManager scheduler *task.MockScheduler - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface } func (suite *BalanceCheckerTestSuite) SetupSuite() { @@ -290,7 +290,7 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) // test stopping balance - idsToBalance := []int64{int64(replicaID1), int64(replicaID2)} + idsToBalance := []int64{int64(replicaID1)} replicasToBalance := suite.checker.getReplicaForStoppingBalance(ctx) suite.ElementsMatch(idsToBalance, replicasToBalance) @@ -305,7 +305,7 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { segPlans = append(segPlans, mockPlan) suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans) tasks := suite.checker.Check(context.TODO()) - suite.Len(tasks, 2) + suite.Len(tasks, 1) } func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { @@ -349,14 +349,16 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { // test normal balance when one collection has unready target mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true) mockTarget.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(false) + mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() replicasToBalance := suite.checker.getReplicaForNormalBalance(ctx) suite.Len(replicasToBalance, 0) // test stopping balance with target not ready mockTarget.ExpectedCalls = nil mockTarget.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(false) - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true) - mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false) + mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid1), mock.Anything).Return(true).Maybe() + mockTarget.EXPECT().IsCurrentTargetExist(mock.Anything, int64(cid2), mock.Anything).Return(false).Maybe() + mockTarget.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(100).Maybe() mr1 := replica1.CopyForWrite() mr1.AddRONode(1) suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) @@ -440,6 +442,299 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceInterval() { suite.Equal(funcCallCounter.Load(), int64(1)) } +func (suite *BalanceCheckerTestSuite) TestBalanceOrder() { + ctx := context.Background() + nodeID1, nodeID2 := int64(1), int64(2) + + // set collections meta + cid1, replicaID1, partitionID1 := 1, 1, 1 + collection1 := utils.CreateTestCollection(int64(cid1), int32(replicaID1)) + collection1.Status = querypb.LoadStatus_Loaded + replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2}) + partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + + cid2, replicaID2, partitionID2 := 2, 2, 2 + collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) + collection2.Status = querypb.LoadStatus_Loaded + replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2}) + partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + + // mock collection row count + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(100)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(200)).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + + // mock stopping node + mr1 := replica1.CopyForWrite() + mr1.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) + mr2 := replica2.CopyForWrite() + mr2.AddRONode(nodeID2) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + + // test normal balance order + replicas := suite.checker.getReplicaForNormalBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID2)}) + + // test stopping balance order + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID2)}) + + // mock collection row count + mockTargetManager.ExpectedCalls = nil + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid1), mock.Anything).Return(int64(200)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, int64(cid2), mock.Anything).Return(int64(100)).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + + // test normal balance order + replicas = suite.checker.getReplicaForNormalBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID1)}) + + // test stopping balance order + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Equal(replicas, []int64{int64(replicaID1)}) +} + +func (suite *BalanceCheckerTestSuite) TestSortCollections() { + ctx := context.Background() + + // Set up test collections + cid1, cid2, cid3 := int64(1), int64(2), int64(3) + + // Mock the target manager for row count returns + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + + // Collection 1: Low ID, High row count + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() + + // Collection 2: Middle ID, Low row count + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() + + // Collection 3: High ID, Middle row count + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() + + collections := []int64{cid1, cid2, cid3} + + // Test ByRowCount sorting (default) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") + sortedCollections := suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Collections should be sorted by row count (highest first)") + + // Test ByCollectionID sorting + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Collections should be sorted by collection ID (ascending)") + + // Test with empty sort order (should default to ByRowCount) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is empty") + + // Test with invalid sort order (should default to ByRowCount) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "InvalidOrder") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid3, cid2}, sortedCollections, "Should default to ByRowCount when sort order is invalid") + + // Test with mixed case (should be case-insensitive) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "bYcOlLeCtIoNiD") + sortedCollections = suite.checker.sortCollections(ctx, collections) + suite.Equal([]int64{cid1, cid2, cid3}, sortedCollections, "Should handle case-insensitive sort order names") + + // Test with empty collection list + emptyCollections := []int64{} + sortedCollections = suite.checker.sortCollections(ctx, emptyCollections) + suite.Equal([]int64{}, sortedCollections, "Should handle empty collection list") +} + +func (suite *BalanceCheckerTestSuite) TestSortCollectionsIntegration() { + ctx := context.Background() + + // Set up test collections and nodes + nodeID1 := int64(1) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID1, + Address: "localhost", + Hostname: "localhost", + })) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) + + // Create two collections to ensure sorting is triggered + cid1, replicaID1 := int64(1), int64(101) + collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) + collection1.Status = querypb.LoadStatus_Loaded + replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + + // Add a second collection with different characteristics + cid2, replicaID2 := int64(2), int64(102) + collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) + collection2.Status = querypb.LoadStatus_Loaded + replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + + // Mock target manager + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + + // Setup different row counts to test sorting + // Collection 1 has more rows than Collection 2 + var getRowCountCallCount int + mockTargetManager.On("GetCollectionRowCount", mock.Anything, mock.Anything, mock.Anything). + Return(func(ctx context.Context, collectionID int64, scope int32) int64 { + getRowCountCallCount++ + if collectionID == cid1 { + return 200 // More rows in collection 1 + } + return 100 // Fewer rows in collection 2 + }) + + mockTargetManager.On("IsCurrentTargetReady", mock.Anything, mock.Anything).Return(true) + mockTargetManager.On("IsNextTargetExist", mock.Anything, mock.Anything).Return(true) + + // Configure for testing + paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") + + // Clear first to avoid previous test state + suite.checker.normalBalanceCollectionsCurrentRound.Clear() + + // Call normal balance + _ = suite.checker.getReplicaForNormalBalance(ctx) + + // Verify GetCollectionRowCount was called at least twice (once for each collection) + // This confirms that the collections were sorted + suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during normal balance") + + // Reset counter and test stopping balance + getRowCountCallCount = 0 + + // Set up for stopping balance test + mr1 := replica1.CopyForWrite() + mr1.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) + + mr2 := replica2.CopyForWrite() + mr2.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + + paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") + + // Call stopping balance + _ = suite.checker.getReplicaForStoppingBalance(ctx) + + // Verify GetCollectionRowCount was called at least twice during stopping balance + suite.True(getRowCountCallCount >= 2, "GetCollectionRowCount should be called at least twice during stopping balance") +} + +func (suite *BalanceCheckerTestSuite) TestBalanceTriggerOrder() { + ctx := context.Background() + + // Set up nodes + nodeID1, nodeID2 := int64(1), int64(2) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID1, + Address: "localhost", + Hostname: "localhost", + })) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: nodeID2, + Address: "localhost", + Hostname: "localhost", + })) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) + + // Create collections with different row counts + cid1, replicaID1 := int64(1), int64(101) + collection1 := utils.CreateTestCollection(cid1, int32(replicaID1)) + collection1.Status = querypb.LoadStatus_Loaded + replica1 := utils.CreateTestReplica(replicaID1, cid1, []int64{nodeID1, nodeID2}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + + cid2, replicaID2 := int64(2), int64(102) + collection2 := utils.CreateTestCollection(cid2, int32(replicaID2)) + collection2.Status = querypb.LoadStatus_Loaded + replica2 := utils.CreateTestReplica(replicaID2, cid2, []int64{nodeID1, nodeID2}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + + cid3, replicaID3 := int64(3), int64(103) + collection3 := utils.CreateTestCollection(cid3, int32(replicaID3)) + collection3.Status = querypb.LoadStatus_Loaded + replica3 := utils.CreateTestReplica(replicaID3, cid3, []int64{nodeID1, nodeID2}) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection3) + suite.checker.meta.ReplicaManager.Put(ctx, replica3) + + // Mock the target manager + mockTargetManager := meta.NewMockTargetManager(suite.T()) + suite.checker.targetMgr = mockTargetManager + + // Set row counts: Collection 1 (highest), Collection 3 (middle), Collection 2 (lowest) + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid1, mock.Anything).Return(int64(300)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid2, mock.Anything).Return(int64(100)).Maybe() + mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, cid3, mock.Anything).Return(int64(200)).Maybe() + + // Mark the current target as ready + mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe() + mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe() + + // Enable auto balance + paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") + + // Test with ByRowCount order (default) + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByRowCount") + suite.checker.normalBalanceCollectionsCurrentRound.Clear() + + // Normal balance should pick the collection with highest row count first + replicas := suite.checker.getReplicaForNormalBalance(ctx) + suite.Contains(replicas, replicaID1, "Should balance collection with highest row count first") + + // Add stopping nodes to test stopping balance + mr1 := replica1.CopyForWrite() + mr1.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) + + mr2 := replica2.CopyForWrite() + mr2.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) + + mr3 := replica3.CopyForWrite() + mr3.AddRONode(nodeID1) + suite.checker.meta.ReplicaManager.Put(ctx, mr3.IntoReplica()) + + // Enable stopping balance + paramtable.Get().Save(Params.QueryCoordCfg.EnableStoppingBalance.Key, "true") + + // Stopping balance should also pick the collection with highest row count first + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with highest row count") + + // Test with ByCollectionID order + paramtable.Get().Save(Params.QueryCoordCfg.BalanceTriggerOrder.Key, "ByCollectionID") + suite.checker.normalBalanceCollectionsCurrentRound.Clear() + + // Normal balance should pick the collection with lowest ID first + replicas = suite.checker.getReplicaForNormalBalance(ctx) + suite.Contains(replicas, replicaID1, "Should balance collection with lowest ID first") + + // Stopping balance should also pick the collection with lowest ID first + replicas = suite.checker.getReplicaForStoppingBalance(ctx) + suite.Contains(replicas, replicaID1, "Stopping balance should prioritize collection with lowest ID") +} + func TestBalanceCheckerSuite(t *testing.T) { suite.Run(t, new(BalanceCheckerTestSuite)) } diff --git a/internal/querycoordv2/meta/mock_target_manager.go b/internal/querycoordv2/meta/mock_target_manager.go index d82a61f157..1ffeb2dfd5 100644 --- a/internal/querycoordv2/meta/mock_target_manager.go +++ b/internal/querycoordv2/meta/mock_target_manager.go @@ -74,6 +74,54 @@ func (_c *MockTargetManager_CanSegmentBeMoved_Call) RunAndReturn(run func(contex return _c } +// GetCollectionRowCount provides a mock function with given fields: ctx, collectionID, scope +func (_m *MockTargetManager) GetCollectionRowCount(ctx context.Context, collectionID int64, scope int32) int64 { + ret := _m.Called(ctx, collectionID, scope) + + if len(ret) == 0 { + panic("no return value specified for GetCollectionRowCount") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) int64); ok { + r0 = rf(ctx, collectionID, scope) + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockTargetManager_GetCollectionRowCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionRowCount' +type MockTargetManager_GetCollectionRowCount_Call struct { + *mock.Call +} + +// GetCollectionRowCount is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetCollectionRowCount(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionRowCount_Call { + return &MockTargetManager_GetCollectionRowCount_Call{Call: _e.mock.On("GetCollectionRowCount", ctx, collectionID, scope)} +} + +func (_c *MockTargetManager_GetCollectionRowCount_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetCollectionRowCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetCollectionRowCount_Call) Return(_a0 int64) *MockTargetManager_GetCollectionRowCount_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetCollectionRowCount_Call) RunAndReturn(run func(context.Context, int64, int32) int64) *MockTargetManager_GetCollectionRowCount_Call { + _c.Call.Return(run) + return _c +} + // GetCollectionTargetVersion provides a mock function with given fields: ctx, collectionID, scope func (_m *MockTargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope int32) int64 { ret := _m.Called(ctx, collectionID, scope) diff --git a/internal/querycoordv2/meta/target.go b/internal/querycoordv2/meta/target.go index ac055df009..38b745f1e8 100644 --- a/internal/querycoordv2/meta/target.go +++ b/internal/querycoordv2/meta/target.go @@ -42,11 +42,15 @@ type CollectionTarget struct { // record target status, if target has been save before milvus v2.4.19, then the target will lack of segment info. lackSegmentInfo bool + + // cache collection total row count + totalRowCount int64 } func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[string]*DmChannel, partitionIDs []int64) *CollectionTarget { channel2Segments := make(map[string][]*datapb.SegmentInfo, len(dmChannels)) partition2Segments := make(map[int64][]*datapb.SegmentInfo, len(partitionIDs)) + totalRowCount := int64(0) for _, segment := range segments { channel := segment.GetInsertChannel() if _, ok := channel2Segments[channel]; !ok { @@ -58,6 +62,7 @@ func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[ partition2Segments[partitionID] = make([]*datapb.SegmentInfo, 0) } partition2Segments[partitionID] = append(partition2Segments[partitionID], segment) + totalRowCount += segment.GetNumOfRows() } return &CollectionTarget{ segments: segments, @@ -66,6 +71,7 @@ func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[ dmChannels: dmChannels, partitions: typeutil.NewSet(partitionIDs...), version: time.Now().UnixNano(), + totalRowCount: totalRowCount, } } @@ -77,6 +83,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget var partitions []int64 lackSegmentInfo := false + totalRowCount := int64(0) for _, t := range target.GetChannelTargets() { if _, ok := channel2Segments[t.GetChannelName()]; !ok { channel2Segments[t.GetChannelName()] = make([]*datapb.SegmentInfo, 0) @@ -100,6 +107,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget segments[segment.GetID()] = info channel2Segments[t.GetChannelName()] = append(channel2Segments[t.GetChannelName()], info) partition2Segments[partition.GetPartitionID()] = append(partition2Segments[partition.GetPartitionID()], info) + totalRowCount += segment.GetNumOfRows() } partitions = append(partitions, partition.GetPartitionID()) } @@ -128,6 +136,7 @@ func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget partitions: typeutil.NewSet(partitions...), version: target.GetVersion(), lackSegmentInfo: lackSegmentInfo, + totalRowCount: totalRowCount, } } @@ -222,6 +231,10 @@ func (p *CollectionTarget) Ready() bool { return !p.lackSegmentInfo } +func (p *CollectionTarget) GetRowCount() int64 { + return p.totalRowCount +} + type target struct { keyLock *lock.KeyLock[int64] // guards updateCollectionTarget // just maintain target at collection level diff --git a/internal/querycoordv2/meta/target_manager.go b/internal/querycoordv2/meta/target_manager.go index bf7893bcb7..d69a863a6f 100644 --- a/internal/querycoordv2/meta/target_manager.go +++ b/internal/querycoordv2/meta/target_manager.go @@ -72,6 +72,7 @@ type TargetManagerInterface interface { GetTargetJSON(ctx context.Context, scope TargetScope, collectionID int64) string GetPartitions(ctx context.Context, collectionID int64, scope TargetScope) ([]int64, error) IsCurrentTargetReady(ctx context.Context, collectionID int64) bool + GetCollectionRowCount(ctx context.Context, collectionID int64, scope TargetScope) int64 } type TargetManager struct { @@ -609,3 +610,11 @@ func (mgr *TargetManager) IsCurrentTargetReady(ctx context.Context, collectionID return target.Ready() } + +func (mgr *TargetManager) GetCollectionRowCount(ctx context.Context, collectionID int64, scope TargetScope) int64 { + target := mgr.getCollectionTarget(scope, collectionID) + if len(target) == 0 { + return 0 + } + return target[0].GetRowCount() +} diff --git a/internal/querycoordv2/meta/target_manager_test.go b/internal/querycoordv2/meta/target_manager_test.go index 29875f28dd..cf9dd4f4cf 100644 --- a/internal/querycoordv2/meta/target_manager_test.go +++ b/internal/querycoordv2/meta/target_manager_test.go @@ -620,6 +620,7 @@ func (suite *TargetManagerSuite) TestRecover() { suite.Equal(int64(100), segment.GetNumOfRows()) } suite.True(target.Ready()) + suite.Equal(int64(200), target.GetRowCount()) // after recover, target info should be cleaned up targets, err := suite.catalog.GetCollectionTargets(ctx) diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index cfebe80441..78076218cd 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1969,6 +1969,7 @@ type queryCoordConfig struct { AutoBalance ParamItem `refreshable:"true"` AutoBalanceChannel ParamItem `refreshable:"true"` Balancer ParamItem `refreshable:"true"` + BalanceTriggerOrder ParamItem `refreshable:"true"` GlobalRowCountFactor ParamItem `refreshable:"true"` ScoreUnbalanceTolerationFactor ParamItem `refreshable:"true"` ReverseUnbalanceTolerationFactor ParamItem `refreshable:"true"` @@ -2106,6 +2107,16 @@ If this parameter is set false, Milvus simply searches the growing segments with } p.Balancer.Init(base.mgr) + p.BalanceTriggerOrder = ParamItem{ + Key: "queryCoord.balanceTriggerOrder", + Version: "2.5.8", + DefaultValue: "ByRowCount", + PanicIfEmpty: false, + Doc: "sorting order for collection balancing, options: ByRowCount, ByCollectionID", + Export: false, + } + p.BalanceTriggerOrder.Init(base.mgr) + p.GlobalRowCountFactor = ParamItem{ Key: "queryCoord.globalRowCountFactor", Version: "2.0.0",