fix: Prevent delegator unserviceable due to shard leader change (#42689)

issue: #42098 #42404
Fix critical issue where concurrent balance segment and balance channel
operations cause delegator view inconsistency. When shard leader
switches between load and release phases of segment balance, it results
in loading segments on old delegator but releasing on new delegator,
making the new delegator unserviceable.

The root cause is that balance segment modifies delegator views, and if
these modifications happen on different delegators due to leader change,
it corrupts the delegator state and affects query availability.

Changes include:
- Add shardLeaderID field to SegmentTask to track delegator for load
- Record shard leader ID during segment loading in move operations
- Skip release if shard leader changed from the one used for loading
- Add comprehensive unit tests for leader change scenarios

This ensures balance segment operations are atomic on single delegator,
preventing view corruption and maintaining delegator serviceability.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-06-19 12:10:38 +08:00 committed by GitHub
parent 2cb296ff3b
commit bf5fde1431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 290 additions and 6 deletions

View File

@ -144,7 +144,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
// if all available delegator has been excluded even after refresh shard leader cache // if all available delegator has been excluded even after refresh shard leader cache
// we should clear excludeNodes and try to select node again instead of failing the request at selectNode // we should clear excludeNodes and try to select node again instead of failing the request at selectNode
if len(shardLeaders) > 0 && len(shardLeaders) == excludeNodes.Len() { if len(shardLeaders) > 0 && len(shardLeaders) <= excludeNodes.Len() {
allReplicaExcluded := true allReplicaExcluded := true
for _, node := range shardLeaders { for _, node := range shardLeaders {
if !excludeNodes.Contain(node.nodeID) { if !excludeNodes.Contain(node.nodeID) {

View File

@ -239,6 +239,11 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
} }
log = log.With(zap.Int64("shardLeader", view.Node)) log = log.With(zap.Int64("shardLeader", view.Node))
// NOTE: for balance segment task, expected load and release execution on the same shard leader
if GetTaskType(task) == TaskTypeMove {
task.SetShardLeaderID(view.Node)
}
startTs := time.Now() startTs := time.Now()
log.Info("load segments...") log.Info("load segments...")
status, err := ex.cluster.LoadSegments(task.Context(), view.Node, req) status, err := ex.cluster.LoadSegments(task.Context(), view.Node, req)
@ -270,6 +275,12 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) {
) )
ctx := task.Context() ctx := task.Context()
var err error
defer func() {
if err != nil {
task.Fail(err)
}
}()
dstNode := action.Node() dstNode := action.Node()
@ -300,7 +311,14 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) {
view := ex.dist.ChannelDistManager.GetShardLeader(task.Shard(), replica) view := ex.dist.ChannelDistManager.GetShardLeader(task.Shard(), replica)
if view == nil { if view == nil {
msg := "no shard leader for the segment to execute releasing" msg := "no shard leader for the segment to execute releasing"
err := merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found") err = merr.WrapErrChannelNotFound(task.Shard(), "shard delegator not found")
log.Warn(msg, zap.Error(err))
return
}
// NOTE: for balance segment task, expected load and release execution on the same shard leader
if GetTaskType(task) == TaskTypeMove && task.ShardLeaderID() != view.Node {
msg := "shard leader changed, skip release"
err = merr.WrapErrServiceInternal(fmt.Sprintf("shard leader changed from %d to %d", task.ShardLeaderID(), view.Node))
log.Warn(msg, zap.Error(err)) log.Warn(msg, zap.Error(err))
return return
} }

View File

@ -327,6 +327,8 @@ type SegmentTask struct {
segmentID typeutil.UniqueID segmentID typeutil.UniqueID
loadPriority commonpb.LoadPriority loadPriority commonpb.LoadPriority
// for balance segment task, expected load and release execution on the same shard leader
shardLeaderID int64
} }
// NewSegmentTask creates a SegmentTask with actions, // NewSegmentTask creates a SegmentTask with actions,
@ -365,6 +367,7 @@ func NewSegmentTask(ctx context.Context,
baseTask: base, baseTask: base,
segmentID: segmentID, segmentID: segmentID,
loadPriority: loadPriority, loadPriority: loadPriority,
shardLeaderID: -1,
}, nil }, nil
} }
@ -392,6 +395,14 @@ func (task *SegmentTask) MarshalJSON() ([]byte, error) {
return marshalJSON(task) return marshalJSON(task)
} }
func (task *SegmentTask) ShardLeaderID() int64 {
return task.shardLeaderID
}
func (task *SegmentTask) SetShardLeaderID(id int64) {
task.shardLeaderID = id
}
type ChannelTask struct { type ChannelTask struct {
*baseTask *baseTask
} }

View File

@ -2032,3 +2032,187 @@ func newReplicaDefaultRG(replicaID int64) *meta.Replica {
typeutil.NewUniqueSet(), typeutil.NewUniqueSet(),
) )
} }
func (suite *TaskSuite) TestSegmentTaskShardLeaderID() {
ctx := context.Background()
timeout := 10 * time.Second
// Create a segment task
action := NewSegmentActionWithScope(1, ActionTypeGrow, "", 100, querypb.DataScope_Historical, 100)
segmentTask, err := NewSegmentTask(
ctx,
timeout,
WrapIDSource(0),
suite.collection,
suite.replica,
commonpb.LoadPriority_LOW,
action,
)
suite.NoError(err)
// Test initial shard leader ID (should be -1)
suite.Equal(int64(-1), segmentTask.ShardLeaderID())
// Test setting shard leader ID
expectedLeaderID := int64(123)
segmentTask.SetShardLeaderID(expectedLeaderID)
suite.Equal(expectedLeaderID, segmentTask.ShardLeaderID())
// Test setting another value
anotherLeaderID := int64(456)
segmentTask.SetShardLeaderID(anotherLeaderID)
suite.Equal(anotherLeaderID, segmentTask.ShardLeaderID())
// Test with zero value
segmentTask.SetShardLeaderID(0)
suite.Equal(int64(0), segmentTask.ShardLeaderID())
}
func (suite *TaskSuite) TestExecutor_MoveSegmentTask() {
ctx := context.Background()
timeout := 10 * time.Second
sourceNode := int64(2)
targetNode := int64(3)
channel := &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test",
}
suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collection, 1))
suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{sourceNode, targetNode}))
// Create move task with both grow and reduce actions to simulate TaskTypeMove
segmentID := suite.loadSegments[0]
growAction := NewSegmentAction(targetNode, ActionTypeGrow, channel.ChannelName, segmentID)
reduceAction := NewSegmentAction(sourceNode, ActionTypeReduce, channel.ChannelName, segmentID)
// Create a move task that has both actions
moveTask, err := NewSegmentTask(
ctx,
timeout,
WrapIDSource(0),
suite.collection,
suite.replica,
commonpb.LoadPriority_LOW,
growAction,
reduceAction,
)
suite.NoError(err)
// Mock cluster expectations for load segment
suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil)
suite.cluster.EXPECT().ReleaseSegments(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil)
suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).RunAndReturn(func(ctx context.Context, i int64) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{
Schema: &schemapb.CollectionSchema{
Name: "TestMoveSegmentTask",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector},
},
},
}, nil
})
suite.broker.EXPECT().ListIndexes(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{
{
CollectionID: suite.collection,
},
}, nil)
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segmentID).Return([]*datapb.SegmentInfo{
{
ID: segmentID,
CollectionID: suite.collection,
PartitionID: -1,
InsertChannel: channel.ChannelName,
},
}, nil)
suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segmentID).Return(nil, nil)
// Set up distribution with leader view
view := &meta.LeaderView{
ID: targetNode,
CollectionID: suite.collection,
Channel: channel.ChannelName,
Segments: make(map[int64]*querypb.SegmentDist),
Status: &querypb.LeaderViewStatus{Serviceable: true},
}
suite.dist.ChannelDistManager.Update(targetNode, &meta.DmChannel{
VchannelInfo: channel,
Node: targetNode,
Version: 1,
View: view,
})
// Add segments to original node distribution for release
segments := []*meta.Segment{
{
SegmentInfo: &datapb.SegmentInfo{
ID: segmentID,
CollectionID: suite.collection,
PartitionID: 1,
InsertChannel: channel.ChannelName,
},
},
}
suite.dist.SegmentDistManager.Update(sourceNode, segments...)
// Set up broker expectations
segmentInfos := []*datapb.SegmentInfo{
{
ID: segmentID,
CollectionID: suite.collection,
PartitionID: 1,
InsertChannel: channel.ChannelName,
},
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segmentInfos, nil)
suite.target.UpdateCollectionNextTarget(ctx, suite.collection)
// Test that move task sets shard leader ID during load step
suite.Equal(TaskTypeMove, GetTaskType(moveTask))
suite.Equal(int64(-1), moveTask.ShardLeaderID()) // Initial value
// Set up task executor
executor := NewExecutor(suite.meta,
suite.dist,
suite.broker,
suite.target,
suite.cluster,
suite.nodeMgr)
// Verify shard leader ID was set for load action in move task
executor.executeSegmentAction(moveTask, 0)
suite.Equal(targetNode, moveTask.ShardLeaderID())
suite.NoError(moveTask.Err())
// expect release action will execute successfully
executor.executeSegmentAction(moveTask, 1)
suite.Equal(targetNode, moveTask.ShardLeaderID())
suite.True(moveTask.actions[0].IsFinished(suite.dist))
suite.NoError(moveTask.Err())
// test shard leader change before release action
newLeaderID := sourceNode
view1 := &meta.LeaderView{
ID: newLeaderID,
CollectionID: suite.collection,
Channel: channel.ChannelName,
Segments: make(map[int64]*querypb.SegmentDist),
Status: &querypb.LeaderViewStatus{Serviceable: true},
Version: 100,
}
suite.dist.ChannelDistManager.Update(newLeaderID, &meta.DmChannel{
VchannelInfo: channel,
Node: newLeaderID,
Version: 100,
View: view1,
})
// expect release action will skip and task will fail
suite.broker.ExpectedCalls = nil
executor.executeSegmentAction(moveTask, 1)
suite.True(moveTask.actions[1].IsFinished(suite.dist))
suite.ErrorContains(moveTask.Err(), "shard leader changed")
}

View File

@ -77,7 +77,7 @@ func (s *ExcludedSegments) CleanInvalid(ts uint64) {
for _, segmentID := range invalidExcludedInfos { for _, segmentID := range invalidExcludedInfos {
delete(s.segments, segmentID) delete(s.segments, segmentID)
log.Ctx(context.TODO()).Info("remove segment from exclude info", zap.Int64("segmentID", segmentID)) log.Ctx(context.TODO()).Debug("remove segment from exclude info", zap.Int64("segmentID", segmentID))
} }
s.lastClean.Store(time.Now()) s.lastClean.Store(time.Now())
} }

View File

@ -303,6 +303,16 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
}) })
delegator.AddExcludedSegments(growingInfo) delegator.AddExcludedSegments(growingInfo)
flushedInfo := lo.SliceToMap(channel.GetFlushedSegmentIds(), func(id int64) (int64, uint64) {
return id, typeutil.MaxTimestamp
})
delegator.AddExcludedSegments(flushedInfo)
droppedInfo := lo.SliceToMap(channel.GetDroppedSegmentIds(), func(id int64) (int64, uint64) {
return id, typeutil.MaxTimestamp
})
delegator.AddExcludedSegments(droppedInfo)
defer func() { defer func() {
if err != nil { if err != nil {
// remove legacy growing // remove legacy growing

View File

@ -21,11 +21,13 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -309,6 +311,65 @@ func (s *BalanceTestSuit) TestNodeDown() {
}, 30*time.Second, 1*time.Second) }, 30*time.Second, 1*time.Second)
} }
func (s *BalanceTestSuit) TestConcurrentBalanceChannelAndSegment() {
ctx := context.Background()
// speed up balance trigger
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "500")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "500")
// init collection with 10 channel, each channel has 10 segment, each segment has 2000 row
// and load it with 1 replicas on 2 nodes.
name := "test_balance_" + funcutil.GenRandomStr()
s.initCollection(name, 1, 10, 10, 2000, 500)
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
// keep query during balance
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
queryResult, err := s.Cluster.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: "",
CollectionName: name,
Expr: "",
OutputFields: []string{"count(*)"},
})
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
}
}
}
}()
// then we add 1 query node, expected segment and channel will be move to new query node concurrently
qn1 := s.Cluster.AddQueryNode()
// wait until balance channel finished
s.Eventually(func() bool {
resp, err := qn1.GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{})
s.NoError(err)
s.True(merr.Ok(resp.GetStatus()))
log.Info("resp", zap.Any("channel", len(resp.Channels)), zap.Any("segments", len(resp.Segments)))
return len(resp.Channels) == 5
}, 30*time.Second, 1*time.Second)
// expected concurrent balance will execute successfully, shard serviceable won't be broken
close(stopSearchCh)
wg.Wait()
s.Equal(int64(0), failCounter.Load())
}
func TestBalance(t *testing.T) { func TestBalance(t *testing.T) {
g := integration.WithoutStreamingService() g := integration.WithoutStreamingService()
defer g() defer g()