diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index 8a5de7c014..a751006461 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -99,7 +99,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe toBalance := typeutil.NewSet[*meta.Segment]() // Only balance segments in targets - segments := s.dist.SegmentDistManager.GetByNode(srcNode) + segments := s.dist.SegmentDistManager.GetByCollectionAndNode(req.GetCollectionID(), srcNode) segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool { return s.targetMgr.GetHistoricalSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil }) diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 19a32333bf..3943073485 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -584,6 +585,65 @@ func (suite *ServiceSuite) TestLoadBalance() { suite.Contains(resp.Reason, ErrNotHealthy.Error()) } +func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() { + suite.loadAll() + ctx := context.Background() + server := suite.server + + srcNode := int64(1001) + dstNode := int64(1002) + metaSegments := make([]*meta.Segment, 0) + segmentOnCollection := make(map[int64][]int64) + + // update two collection's dist + for _, collection := range suite.collections { + replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas[0].AddNode(srcNode) + replicas[0].AddNode(dstNode) + defer replicas[0].RemoveNode(srcNode) + defer replicas[0].RemoveNode(dstNode) + suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + + for partition, segments := range suite.segments[collection] { + for _, segment := range segments { + metaSegments = append(metaSegments, + utils.CreateTestSegment(collection, partition, segment, srcNode, 1, "test-channel")) + + if segmentOnCollection[collection] == nil { + segmentOnCollection[collection] = make([]int64, 0) + } + segmentOnCollection[collection] = append(segmentOnCollection[collection], segment) + } + } + } + suite.dist.SegmentDistManager.Update(srcNode, metaSegments...) + + // expect each collection can only trigger its own segment's balance + for _, collection := range suite.collections { + req := &querypb.LoadBalanceRequest{ + CollectionID: collection, + SourceNodeIDs: []int64{srcNode}, + DstNodeIDs: []int64{dstNode}, + } + suite.taskScheduler.ExpectedCalls = make([]*mock.Call, 0) + suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(t task.Task) { + actions := t.Actions() + suite.Len(actions, 2) + growAction := actions[0].(*task.SegmentAction) + reduceAction := actions[1].(*task.SegmentAction) + suite.True(lo.Contains(segmentOnCollection[collection], growAction.SegmentID())) + suite.True(lo.Contains(segmentOnCollection[collection], reduceAction.SegmentID())) + suite.Equal(dstNode, growAction.Node()) + suite.Equal(srcNode, reduceAction.Node()) + t.Cancel() + }).Return(nil) + resp, err := server.LoadBalance(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) + suite.taskScheduler.AssertExpectations(suite.T()) + } +} + func (suite *ServiceSuite) TestLoadBalanceFailed() { suite.loadAll() ctx := context.Background()