fix: make ack of broadcaster cannot canceled by client (#45145)

issue: #45141

- make ack of broadcaster cannot canceled by rpc.
- make clone for assignment snapshot of wal balancer.
- add server id for GetReplicateCheckpoint to avoid failure.

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-10-29 20:34:11 +08:00 committed by GitHub
parent 653dfcca41
commit 6e5189fe19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 97 additions and 36 deletions

View File

@ -45,6 +45,18 @@ type assignmentSnapshot struct {
GlobalUnbalancedScore float64
}
// Clone will clone the assignment snapshot.
func (s *assignmentSnapshot) Clone() assignmentSnapshot {
assignments := make(map[types.ChannelID]types.PChannelInfoAssigned, len(s.Assignments))
for channelID, assignment := range s.Assignments {
assignments[channelID] = assignment
}
return assignmentSnapshot{
Assignments: assignments,
GlobalUnbalancedScore: s.GlobalUnbalancedScore,
}
}
// streamingNodeInfo is the streaming node info for vchannel fair policy.
type streamingNodeInfo struct {
AssignedVChannelCount int

View File

@ -256,3 +256,20 @@ func newLayout(channels map[string]int, vchannels map[string]map[string]int64, s
}
return layout
}
func TestAssignmentClone(t *testing.T) {
snapshot := assignmentSnapshot{
Assignments: map[types.ChannelID]types.PChannelInfoAssigned{
newChannelID("c1"): {
Channel: types.PChannelInfo{
Name: "c1",
},
},
},
}
clonedSnapshot := snapshot.Clone()
clonedSnapshot.Assignments[newChannelID("c2")] = types.PChannelInfoAssigned{}
assert.Len(t, snapshot.Assignments, 1)
assert.Equal(t, snapshot.Assignments[newChannelID("c1")], clonedSnapshot.Assignments[newChannelID("c1")])
assert.Len(t, clonedSnapshot.Assignments, 2)
}

View File

@ -94,7 +94,7 @@ func (p *policy) Balance(currentLayout balancer.CurrentLayout) (layout balancer.
// 4. Do a DFS to make a greatest snapshot.
// The DFS will find the unbalance score minimized assignment based on current layout.
greatestSnapshot := snapshot
greatestSnapshot := snapshot.Clone()
p.assignChannels(expectedLayout, reassignChannelIDs, &greatestSnapshot)
if greatestSnapshot.GlobalUnbalancedScore < snapshot.GlobalUnbalancedScore-p.cfg.RebalanceTolerance {
if p.Logger().Level().Enabled(zap.DebugLevel) {

View File

@ -416,7 +416,7 @@ func (b *broadcastTask) saveTaskIfDirty(ctx context.Context, logger *log.MLogger
logger = logger.With(zap.String("state", b.task.State.String()), zap.Int("ackedVChannelCount", ackedCount(b.task)))
if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.Header().BroadcastID, b.task); err != nil {
logger.Warn("save broadcast task failed", zap.Error(err))
if ctx.Err() != nil {
if ctx.Err() == nil {
panic("critical error: the save broadcast task is failed before the context is done")
}
return err

View File

@ -52,6 +52,8 @@ func (s *broadcastServceImpl) Ack(ctx context.Context, req *streamingpb.Broadcas
if err != nil {
return nil, err
}
// Once the ack is reached at streamingcoord, the ack operation should not be cancelable.
ctx = context.WithoutCancel(ctx)
if req.Message == nil {
// before 2.6.1, the request don't have the message field, only have the broadcast id and vchannel.
// so we need to use the legacy ack interface.

View File

@ -3,7 +3,9 @@ package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -54,4 +56,35 @@ func TestBroadcastService(t *testing.T) {
Properties: map[string]string{"key": "value"},
},
})
ctx, cancel := context.WithCancel(context.Background())
reached := make(chan struct{})
done := make(chan struct{})
mb.EXPECT().Ack(mock.Anything, mock.Anything).Unset()
mb.EXPECT().Ack(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, msg message.ImmutableMessage) error {
close(reached)
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
})
go func() {
<-reached
cancel()
time.Sleep(10 * time.Millisecond)
close(done)
}()
_, err := service.Ack(ctx, &streamingpb.BroadcastAckRequest{
BroadcastId: 1,
Vchannel: "v1",
Message: &commonpb.ImmutableMessage{
Id: walimplstest.NewTestMessageID(1).IntoProto(),
Payload: []byte("payload"),
Properties: map[string]string{"key": "value"},
},
})
assert.NoError(t, err)
}

View File

@ -43,15 +43,14 @@ func CreateConsumer(
opts *ConsumerOptions,
handlerClient streamingpb.StreamingNodeHandlerServiceClient,
) (Consumer, error) {
ctxWithReq, err := createConsumeRequest(ctx, opts)
if err != nil {
return nil, err
}
ctx = contextutil.WithCreateConsumer(ctx, &streamingpb.CreateConsumerRequest{
Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel),
})
// TODO: configurable or auto adjust grpc.MaxCallRecvMsgSize
// The messages are always managed by milvus cluster, so the size of message shouldn't be controlled here
// to avoid infinitely blocks.
streamClient, err := handlerClient.Consume(ctxWithReq, grpc.MaxCallRecvMsgSize(math.MaxInt32))
streamClient, err := handlerClient.Consume(ctx, grpc.MaxCallRecvMsgSize(math.MaxInt32))
if err != nil {
return nil, err
}
@ -84,16 +83,6 @@ func CreateConsumer(
return cli, nil
}
// createConsumeRequest creates the consume request.
func createConsumeRequest(ctx context.Context, opts *ConsumerOptions) (context.Context, error) {
// select server to consume.
ctx = contextutil.WithPickServerID(ctx, opts.Assignment.Node.ServerID)
// create the consumer request.
return contextutil.WithCreateConsumer(ctx, &streamingpb.CreateConsumerRequest{
Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel),
}), nil
}
type consumerImpl struct {
ctx context.Context // TODO: the cancel method of consumer should be managed by consumerImpl, fix it in future.
walName string

View File

@ -15,6 +15,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
@ -218,6 +219,8 @@ func (hc *handlerClientImpl) createHandlerAfterStreamingNodeReady(ctx context.Co
assign := hc.watcher.Get(ctx, pchannel)
if assign != nil {
// Find assignment, try to create producer on this assignment.
// pick the target streaming node to serve.
ctx = contextutil.WithPickServerID(ctx, assign.Node.ServerID)
createResult, err := create(ctx, assign)
if err == nil {
logger.Info("create handler success", zap.Any("assignment", assign), zap.Bool("isLocal", registry.IsLocal(createResult)))

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_assignment"
@ -16,6 +17,7 @@ import (
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_resolver"
"github.com/milvus-io/milvus/internal/streamingnode/client/handler/consumer"
"github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/mocks/proto/mock_streamingpb"
"github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_types"
@ -35,14 +37,19 @@ func TestHandlerClient(t *testing.T) {
service := mock_lazygrpc.NewMockService[streamingpb.StreamingNodeHandlerServiceClient](t)
handlerServiceClient := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t)
handlerServiceClient.EXPECT().GetReplicateCheckpoint(mock.Anything, mock.Anything).Return(&streamingpb.GetReplicateCheckpointResponse{
Checkpoint: &commonpb.ReplicateCheckpoint{
ClusterId: "pchannel",
Pchannel: "pchannel",
MessageId: nil,
TimeTick: 0,
},
}, nil)
handlerServiceClient.EXPECT().GetReplicateCheckpoint(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, grcr *streamingpb.GetReplicateCheckpointRequest, co ...grpc.CallOption) (*streamingpb.GetReplicateCheckpointResponse, error) {
serverID, ok := contextutil.GetPickServerID(ctx)
assert.True(t, ok)
assert.Equal(t, serverID, assignment.Node.ServerID)
return &streamingpb.GetReplicateCheckpointResponse{
Checkpoint: &commonpb.ReplicateCheckpoint{
ClusterId: "pchannel",
Pchannel: "pchannel",
MessageId: nil,
TimeTick: 0,
},
}, nil
})
service.EXPECT().GetService(mock.Anything).Return(handlerServiceClient, nil)
rb := mock_resolver.NewMockBuilder(t)
rb.EXPECT().Close().Run(func() {})
@ -65,6 +72,9 @@ func TestHandlerClient(t *testing.T) {
watcher: w,
rebalanceTrigger: rebalanceTrigger,
newProducer: func(ctx context.Context, opts *producer.ProducerOptions, handler streamingpb.StreamingNodeHandlerServiceClient) (Producer, error) {
serverID, ok := contextutil.GetPickServerID(ctx)
assert.True(t, ok)
assert.Equal(t, serverID, assignment.Node.ServerID)
if pK == 0 {
pK++
return nil, status.NewUnmatchedChannelTerm("pchannel", 1, 2)
@ -72,6 +82,9 @@ func TestHandlerClient(t *testing.T) {
return p, nil
},
newConsumer: func(ctx context.Context, opts *consumer.ConsumerOptions, handlerClient streamingpb.StreamingNodeHandlerServiceClient) (Consumer, error) {
serverID, ok := contextutil.GetPickServerID(ctx)
assert.True(t, ok)
assert.Equal(t, serverID, assignment.Node.ServerID)
return c, nil
},
}

View File

@ -32,7 +32,9 @@ func CreateProducer(
opts *ProducerOptions,
handler streamingpb.StreamingNodeHandlerServiceClient,
) (Producer, error) {
ctx = createProduceRequest(ctx, opts)
ctx = contextutil.WithCreateProducer(ctx, &streamingpb.CreateProducerRequest{
Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel),
})
streamClient, err := handler.Produce(ctx)
if err != nil {
return nil, err
@ -80,16 +82,6 @@ func CreateProducer(
return cli, nil
}
// createProduceRequest creates the produce request.
func createProduceRequest(ctx context.Context, opts *ProducerOptions) context.Context {
// select server to consume.
ctx = contextutil.WithPickServerID(ctx, opts.Assignment.Node.ServerID)
// select channel to consume.
return contextutil.WithCreateProducer(ctx, &streamingpb.CreateProducerRequest{
Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel),
})
}
// Expected message sequence:
// CreateProducer
// ProduceRequest 1 -> ProduceResponse Or Error 1