diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 3fa20e2c26..cfc64df282 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -144,7 +144,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor // if all available delegator has been excluded even after refresh shard leader cache // we should clear excludeNodes and try to select node again instead of failing the request at selectNode - if len(shardLeaders) > 0 && len(shardLeaders) == excludeNodes.Len() { + if len(shardLeaders) > 0 && len(shardLeaders) <= excludeNodes.Len() { allReplicaExcluded := true for _, node := range shardLeaders { if !excludeNodes.Contain(node.nodeID) { diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index ffba4fa1de..d6df8a4d1a 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -239,6 +239,11 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { } log = log.With(zap.Int64("shardLeader", view.Node)) + // NOTE: for balance segment task, expected load and release execution on the same shard leader + if GetTaskType(task) == TaskTypeMove { + task.SetShardLeaderID(view.Node) + } + startTs := time.Now() log.Info("load segments...") status, err := ex.cluster.LoadSegments(task.Context(), view.Node, req) @@ -270,6 +275,12 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) { ) ctx := task.Context() + var err error + defer func() { + if err != nil { + task.Fail(err) + } + }() dstNode := action.Node() @@ -300,7 +311,14 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) { view := ex.dist.ChannelDistManager.GetShardLeader(task.Shard(), replica) if view == nil { msg := "no shard leader for the segment to execute releasing" - err := merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found") + err = merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found") + log.Warn(msg, zap.Error(err)) + return + } + // NOTE: for balance segment task, expected load and release execution on the same shard leader + if GetTaskType(task) == TaskTypeMove && task.ShardLeaderID() != view.Node { + msg := "shard leader changed, skip release" + err = merr.WrapErrServiceInternal(fmt.Sprintf("shard leader changed from %d to %d", task.ShardLeaderID(), view.Node)) log.Warn(msg, zap.Error(err)) return } diff --git a/internal/querycoordv2/task/task.go b/internal/querycoordv2/task/task.go index 32f339b0bb..d03d74a9be 100644 --- a/internal/querycoordv2/task/task.go +++ b/internal/querycoordv2/task/task.go @@ -327,6 +327,8 @@ type SegmentTask struct { segmentID typeutil.UniqueID loadPriority commonpb.LoadPriority + // for balance segment task, expected load and release execution on the same shard leader + shardLeaderID int64 } // NewSegmentTask creates a SegmentTask with actions, @@ -362,9 +364,10 @@ func NewSegmentTask(ctx context.Context, base := newBaseTask(ctx, source, collectionID, replica, shard, fmt.Sprintf("SegmentTask-%s-%d", actions[0].Type().String(), segmentID)) base.actions = actions return &SegmentTask{ - baseTask: base, - segmentID: segmentID, - loadPriority: loadPriority, + baseTask: base, + segmentID: segmentID, + loadPriority: loadPriority, + shardLeaderID: -1, }, nil } @@ -392,6 +395,14 @@ func (task *SegmentTask) MarshalJSON() ([]byte, error) { return marshalJSON(task) } +func (task *SegmentTask) ShardLeaderID() int64 { + return task.shardLeaderID +} + +func (task *SegmentTask) SetShardLeaderID(id int64) { + task.shardLeaderID = id +} + type ChannelTask struct { *baseTask } diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 37e459c4e6..7dabb0128a 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -2032,3 +2032,187 @@ func newReplicaDefaultRG(replicaID int64) *meta.Replica { typeutil.NewUniqueSet(), ) } + +func (suite *TaskSuite) TestSegmentTaskShardLeaderID() { + ctx := context.Background() + timeout := 10 * time.Second + + // Create a segment task + action := NewSegmentActionWithScope(1, ActionTypeGrow, "", 100, querypb.DataScope_Historical, 100) + segmentTask, err := NewSegmentTask( + ctx, + timeout, + WrapIDSource(0), + suite.collection, + suite.replica, + commonpb.LoadPriority_LOW, + action, + ) + suite.NoError(err) + + // Test initial shard leader ID (should be -1) + suite.Equal(int64(-1), segmentTask.ShardLeaderID()) + + // Test setting shard leader ID + expectedLeaderID := int64(123) + segmentTask.SetShardLeaderID(expectedLeaderID) + suite.Equal(expectedLeaderID, segmentTask.ShardLeaderID()) + + // Test setting another value + anotherLeaderID := int64(456) + segmentTask.SetShardLeaderID(anotherLeaderID) + suite.Equal(anotherLeaderID, segmentTask.ShardLeaderID()) + + // Test with zero value + segmentTask.SetShardLeaderID(0) + suite.Equal(int64(0), segmentTask.ShardLeaderID()) +} + +func (suite *TaskSuite) TestExecutor_MoveSegmentTask() { + ctx := context.Background() + timeout := 10 * time.Second + sourceNode := int64(2) + targetNode := int64(3) + channel := &datapb.VchannelInfo{ + CollectionID: suite.collection, + ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test", + } + + suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collection, 1)) + suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{sourceNode, targetNode})) + + // Create move task with both grow and reduce actions to simulate TaskTypeMove + segmentID := suite.loadSegments[0] + growAction := NewSegmentAction(targetNode, ActionTypeGrow, channel.ChannelName, segmentID) + reduceAction := NewSegmentAction(sourceNode, ActionTypeReduce, channel.ChannelName, segmentID) + + // Create a move task that has both actions + moveTask, err := NewSegmentTask( + ctx, + timeout, + WrapIDSource(0), + suite.collection, + suite.replica, + commonpb.LoadPriority_LOW, + growAction, + reduceAction, + ) + suite.NoError(err) + + // Mock cluster expectations for load segment + suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) + suite.cluster.EXPECT().ReleaseSegments(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil) + + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).RunAndReturn(func(ctx context.Context, i int64) (*milvuspb.DescribeCollectionResponse, error) { + return &milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestMoveSegmentTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, + }, + }, nil + }) + suite.broker.EXPECT().ListIndexes(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ + { + CollectionID: suite.collection, + }, + }, nil) + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segmentID).Return([]*datapb.SegmentInfo{ + { + ID: segmentID, + CollectionID: suite.collection, + PartitionID: -1, + InsertChannel: channel.ChannelName, + }, + }, nil) + suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segmentID).Return(nil, nil) + + // Set up distribution with leader view + view := &meta.LeaderView{ + ID: targetNode, + CollectionID: suite.collection, + Channel: channel.ChannelName, + Segments: make(map[int64]*querypb.SegmentDist), + Status: &querypb.LeaderViewStatus{Serviceable: true}, + } + + suite.dist.ChannelDistManager.Update(targetNode, &meta.DmChannel{ + VchannelInfo: channel, + Node: targetNode, + Version: 1, + View: view, + }) + + // Add segments to original node distribution for release + segments := []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: segmentID, + CollectionID: suite.collection, + PartitionID: 1, + InsertChannel: channel.ChannelName, + }, + }, + } + suite.dist.SegmentDistManager.Update(sourceNode, segments...) + + // Set up broker expectations + segmentInfos := []*datapb.SegmentInfo{ + { + ID: segmentID, + CollectionID: suite.collection, + PartitionID: 1, + InsertChannel: channel.ChannelName, + }, + } + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segmentInfos, nil) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) + + // Test that move task sets shard leader ID during load step + suite.Equal(TaskTypeMove, GetTaskType(moveTask)) + suite.Equal(int64(-1), moveTask.ShardLeaderID()) // Initial value + + // Set up task executor + executor := NewExecutor(suite.meta, + suite.dist, + suite.broker, + suite.target, + suite.cluster, + suite.nodeMgr) + + // Verify shard leader ID was set for load action in move task + executor.executeSegmentAction(moveTask, 0) + suite.Equal(targetNode, moveTask.ShardLeaderID()) + suite.NoError(moveTask.Err()) + + // expect release action will execute successfully + executor.executeSegmentAction(moveTask, 1) + suite.Equal(targetNode, moveTask.ShardLeaderID()) + suite.True(moveTask.actions[0].IsFinished(suite.dist)) + suite.NoError(moveTask.Err()) + + // test shard leader change before release action + newLeaderID := sourceNode + view1 := &meta.LeaderView{ + ID: newLeaderID, + CollectionID: suite.collection, + Channel: channel.ChannelName, + Segments: make(map[int64]*querypb.SegmentDist), + Status: &querypb.LeaderViewStatus{Serviceable: true}, + Version: 100, + } + + suite.dist.ChannelDistManager.Update(newLeaderID, &meta.DmChannel{ + VchannelInfo: channel, + Node: newLeaderID, + Version: 100, + View: view1, + }) + + // expect release action will skip and task will fail + suite.broker.ExpectedCalls = nil + executor.executeSegmentAction(moveTask, 1) + suite.True(moveTask.actions[1].IsFinished(suite.dist)) + suite.ErrorContains(moveTask.Err(), "shard leader changed") +} diff --git a/internal/querynodev2/delegator/exclude_info.go b/internal/querynodev2/delegator/exclude_info.go index 9d86d81f33..dcafdc86fd 100644 --- a/internal/querynodev2/delegator/exclude_info.go +++ b/internal/querynodev2/delegator/exclude_info.go @@ -77,7 +77,7 @@ func (s *ExcludedSegments) CleanInvalid(ts uint64) { for _, segmentID := range invalidExcludedInfos { delete(s.segments, segmentID) - log.Ctx(context.TODO()).Info("remove segment from exclude info", zap.Int64("segmentID", segmentID)) + log.Ctx(context.TODO()).Debug("remove segment from exclude info", zap.Int64("segmentID", segmentID)) } s.lastClean.Store(time.Now()) } diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 29ae5be5d7..afdf05592a 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -303,6 +303,16 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm }) delegator.AddExcludedSegments(growingInfo) + flushedInfo := lo.SliceToMap(channel.GetFlushedSegmentIds(), func(id int64) (int64, uint64) { + return id, typeutil.MaxTimestamp + }) + delegator.AddExcludedSegments(flushedInfo) + + droppedInfo := lo.SliceToMap(channel.GetDroppedSegmentIds(), func(id int64) (int64, uint64) { + return id, typeutil.MaxTimestamp + }) + delegator.AddExcludedSegments(droppedInfo) + defer func() { if err != nil { // remove legacy growing diff --git a/tests/integration/balance/balance_test.go b/tests/integration/balance/balance_test.go index 98ec3bf82d..25f5507b87 100644 --- a/tests/integration/balance/balance_test.go +++ b/tests/integration/balance/balance_test.go @@ -21,11 +21,13 @@ import ( "fmt" "strconv" "strings" + "sync" "testing" "time" "github.com/samber/lo" "github.com/stretchr/testify/suite" + "go.uber.org/atomic" "go.uber.org/zap" "google.golang.org/protobuf/proto" @@ -309,6 +311,65 @@ func (s *BalanceTestSuit) TestNodeDown() { }, 30*time.Second, 1*time.Second) } +func (s *BalanceTestSuit) TestConcurrentBalanceChannelAndSegment() { + ctx := context.Background() + + // speed up balance trigger + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "500") + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "500") + + // init collection with 10 channel, each channel has 10 segment, each segment has 2000 row + // and load it with 1 replicas on 2 nodes. + name := "test_balance_" + funcutil.GenRandomStr() + s.initCollection(name, 1, 10, 10, 2000, 500) + + stopSearchCh := make(chan struct{}) + failCounter := atomic.NewInt64(0) + + // keep query during balance + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stopSearchCh: + log.Info("stop search") + return + default: + queryResult, err := s.Cluster.Proxy.Query(ctx, &milvuspb.QueryRequest{ + DbName: "", + CollectionName: name, + Expr: "", + OutputFields: []string{"count(*)"}, + }) + + if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil { + log.Info("query failed", zap.Error(err)) + failCounter.Inc() + } + } + } + }() + + // then we add 1 query node, expected segment and channel will be move to new query node concurrently + qn1 := s.Cluster.AddQueryNode() + + // wait until balance channel finished + s.Eventually(func() bool { + resp, err := qn1.GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{}) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + log.Info("resp", zap.Any("channel", len(resp.Channels)), zap.Any("segments", len(resp.Segments))) + return len(resp.Channels) == 5 + }, 30*time.Second, 1*time.Second) + + // expected concurrent balance will execute successfully, shard serviceable won't be broken + close(stopSearchCh) + wg.Wait() + s.Equal(int64(0), failCounter.Load()) +} + func TestBalance(t *testing.T) { g := integration.WithoutStreamingService() defer g()