From 6e5189fe1962d9327b71cab05bd2ace87bdb1bc0 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Wed, 29 Oct 2025 20:34:11 +0800 Subject: [PATCH] fix: make ack of broadcaster cannot canceled by client (#45145) issue: #45141 - make ack of broadcaster cannot canceled by rpc. - make clone for assignment snapshot of wal balancer. - add server id for GetReplicateCheckpoint to avoid failure. Signed-off-by: chyezh --- .../policy/vchannelfair/expected_layout.go | 12 +++++++ .../vchannelfair/pchannel_count_fair_test.go | 17 ++++++++++ .../vchannelfair/vchannel_fair_policy.go | 2 +- .../server/broadcaster/broadcast_task.go | 2 +- .../server/service/broadcast.go | 2 ++ .../server/service/broadcast_test.go | 33 +++++++++++++++++++ .../client/handler/consumer/consumer_impl.go | 19 +++-------- .../client/handler/handler_client_impl.go | 3 ++ .../client/handler/handler_client_test.go | 29 +++++++++++----- .../client/handler/producer/producer_impl.go | 14 ++------ 10 files changed, 97 insertions(+), 36 deletions(-) diff --git a/internal/streamingcoord/server/balancer/policy/vchannelfair/expected_layout.go b/internal/streamingcoord/server/balancer/policy/vchannelfair/expected_layout.go index e91c86f49c..a90ae8492d 100644 --- a/internal/streamingcoord/server/balancer/policy/vchannelfair/expected_layout.go +++ b/internal/streamingcoord/server/balancer/policy/vchannelfair/expected_layout.go @@ -45,6 +45,18 @@ type assignmentSnapshot struct { GlobalUnbalancedScore float64 } +// Clone will clone the assignment snapshot. +func (s *assignmentSnapshot) Clone() assignmentSnapshot { + assignments := make(map[types.ChannelID]types.PChannelInfoAssigned, len(s.Assignments)) + for channelID, assignment := range s.Assignments { + assignments[channelID] = assignment + } + return assignmentSnapshot{ + Assignments: assignments, + GlobalUnbalancedScore: s.GlobalUnbalancedScore, + } +} + // streamingNodeInfo is the streaming node info for vchannel fair policy. type streamingNodeInfo struct { AssignedVChannelCount int diff --git a/internal/streamingcoord/server/balancer/policy/vchannelfair/pchannel_count_fair_test.go b/internal/streamingcoord/server/balancer/policy/vchannelfair/pchannel_count_fair_test.go index c40acb463c..7995d30f72 100644 --- a/internal/streamingcoord/server/balancer/policy/vchannelfair/pchannel_count_fair_test.go +++ b/internal/streamingcoord/server/balancer/policy/vchannelfair/pchannel_count_fair_test.go @@ -256,3 +256,20 @@ func newLayout(channels map[string]int, vchannels map[string]map[string]int64, s } return layout } + +func TestAssignmentClone(t *testing.T) { + snapshot := assignmentSnapshot{ + Assignments: map[types.ChannelID]types.PChannelInfoAssigned{ + newChannelID("c1"): { + Channel: types.PChannelInfo{ + Name: "c1", + }, + }, + }, + } + clonedSnapshot := snapshot.Clone() + clonedSnapshot.Assignments[newChannelID("c2")] = types.PChannelInfoAssigned{} + assert.Len(t, snapshot.Assignments, 1) + assert.Equal(t, snapshot.Assignments[newChannelID("c1")], clonedSnapshot.Assignments[newChannelID("c1")]) + assert.Len(t, clonedSnapshot.Assignments, 2) +} diff --git a/internal/streamingcoord/server/balancer/policy/vchannelfair/vchannel_fair_policy.go b/internal/streamingcoord/server/balancer/policy/vchannelfair/vchannel_fair_policy.go index 5d0af7a528..70ac89c3f7 100644 --- a/internal/streamingcoord/server/balancer/policy/vchannelfair/vchannel_fair_policy.go +++ b/internal/streamingcoord/server/balancer/policy/vchannelfair/vchannel_fair_policy.go @@ -94,7 +94,7 @@ func (p *policy) Balance(currentLayout balancer.CurrentLayout) (layout balancer. // 4. Do a DFS to make a greatest snapshot. // The DFS will find the unbalance score minimized assignment based on current layout. - greatestSnapshot := snapshot + greatestSnapshot := snapshot.Clone() p.assignChannels(expectedLayout, reassignChannelIDs, &greatestSnapshot) if greatestSnapshot.GlobalUnbalancedScore < snapshot.GlobalUnbalancedScore-p.cfg.RebalanceTolerance { if p.Logger().Level().Enabled(zap.DebugLevel) { diff --git a/internal/streamingcoord/server/broadcaster/broadcast_task.go b/internal/streamingcoord/server/broadcaster/broadcast_task.go index cf6a0c3c4b..154e926212 100644 --- a/internal/streamingcoord/server/broadcaster/broadcast_task.go +++ b/internal/streamingcoord/server/broadcaster/broadcast_task.go @@ -416,7 +416,7 @@ func (b *broadcastTask) saveTaskIfDirty(ctx context.Context, logger *log.MLogger logger = logger.With(zap.String("state", b.task.State.String()), zap.Int("ackedVChannelCount", ackedCount(b.task))) if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.Header().BroadcastID, b.task); err != nil { logger.Warn("save broadcast task failed", zap.Error(err)) - if ctx.Err() != nil { + if ctx.Err() == nil { panic("critical error: the save broadcast task is failed before the context is done") } return err diff --git a/internal/streamingcoord/server/service/broadcast.go b/internal/streamingcoord/server/service/broadcast.go index 2b75fbd03e..845310b115 100644 --- a/internal/streamingcoord/server/service/broadcast.go +++ b/internal/streamingcoord/server/service/broadcast.go @@ -52,6 +52,8 @@ func (s *broadcastServceImpl) Ack(ctx context.Context, req *streamingpb.Broadcas if err != nil { return nil, err } + // Once the ack is reached at streamingcoord, the ack operation should not be cancelable. + ctx = context.WithoutCancel(ctx) if req.Message == nil { // before 2.6.1, the request don't have the message field, only have the broadcast id and vchannel. // so we need to use the legacy ack interface. diff --git a/internal/streamingcoord/server/service/broadcast_test.go b/internal/streamingcoord/server/service/broadcast_test.go index c298cbfca2..13cfea2129 100644 --- a/internal/streamingcoord/server/service/broadcast_test.go +++ b/internal/streamingcoord/server/service/broadcast_test.go @@ -3,7 +3,9 @@ package service import ( "context" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -54,4 +56,35 @@ func TestBroadcastService(t *testing.T) { Properties: map[string]string{"key": "value"}, }, }) + + ctx, cancel := context.WithCancel(context.Background()) + reached := make(chan struct{}) + done := make(chan struct{}) + mb.EXPECT().Ack(mock.Anything, mock.Anything).Unset() + mb.EXPECT().Ack(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, msg message.ImmutableMessage) error { + close(reached) + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } + }) + go func() { + <-reached + cancel() + time.Sleep(10 * time.Millisecond) + close(done) + }() + _, err := service.Ack(ctx, &streamingpb.BroadcastAckRequest{ + BroadcastId: 1, + Vchannel: "v1", + Message: &commonpb.ImmutableMessage{ + Id: walimplstest.NewTestMessageID(1).IntoProto(), + Payload: []byte("payload"), + Properties: map[string]string{"key": "value"}, + }, + }) + + assert.NoError(t, err) } diff --git a/internal/streamingnode/client/handler/consumer/consumer_impl.go b/internal/streamingnode/client/handler/consumer/consumer_impl.go index bc9d76b9db..78a046c882 100644 --- a/internal/streamingnode/client/handler/consumer/consumer_impl.go +++ b/internal/streamingnode/client/handler/consumer/consumer_impl.go @@ -43,15 +43,14 @@ func CreateConsumer( opts *ConsumerOptions, handlerClient streamingpb.StreamingNodeHandlerServiceClient, ) (Consumer, error) { - ctxWithReq, err := createConsumeRequest(ctx, opts) - if err != nil { - return nil, err - } + ctx = contextutil.WithCreateConsumer(ctx, &streamingpb.CreateConsumerRequest{ + Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel), + }) // TODO: configurable or auto adjust grpc.MaxCallRecvMsgSize // The messages are always managed by milvus cluster, so the size of message shouldn't be controlled here // to avoid infinitely blocks. - streamClient, err := handlerClient.Consume(ctxWithReq, grpc.MaxCallRecvMsgSize(math.MaxInt32)) + streamClient, err := handlerClient.Consume(ctx, grpc.MaxCallRecvMsgSize(math.MaxInt32)) if err != nil { return nil, err } @@ -84,16 +83,6 @@ func CreateConsumer( return cli, nil } -// createConsumeRequest creates the consume request. -func createConsumeRequest(ctx context.Context, opts *ConsumerOptions) (context.Context, error) { - // select server to consume. - ctx = contextutil.WithPickServerID(ctx, opts.Assignment.Node.ServerID) - // create the consumer request. - return contextutil.WithCreateConsumer(ctx, &streamingpb.CreateConsumerRequest{ - Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel), - }), nil -} - type consumerImpl struct { ctx context.Context // TODO: the cancel method of consumer should be managed by consumerImpl, fix it in future. walName string diff --git a/internal/streamingnode/client/handler/handler_client_impl.go b/internal/streamingnode/client/handler/handler_client_impl.go index 4786e834c9..dc51506a62 100644 --- a/internal/streamingnode/client/handler/handler_client_impl.go +++ b/internal/streamingnode/client/handler/handler_client_impl.go @@ -15,6 +15,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" "github.com/milvus-io/milvus/internal/util/streamingutil/status" @@ -218,6 +219,8 @@ func (hc *handlerClientImpl) createHandlerAfterStreamingNodeReady(ctx context.Co assign := hc.watcher.Get(ctx, pchannel) if assign != nil { // Find assignment, try to create producer on this assignment. + // pick the target streaming node to serve. + ctx = contextutil.WithPickServerID(ctx, assign.Node.ServerID) createResult, err := create(ctx, assign) if err == nil { logger.Info("create handler success", zap.Any("assignment", assign), zap.Bool("isLocal", registry.IsLocal(createResult))) diff --git a/internal/streamingnode/client/handler/handler_client_test.go b/internal/streamingnode/client/handler/handler_client_test.go index ab8f6fce66..e9d995375c 100644 --- a/internal/streamingnode/client/handler/handler_client_test.go +++ b/internal/streamingnode/client/handler/handler_client_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_assignment" @@ -16,6 +17,7 @@ import ( "github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_resolver" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/consumer" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/v2/mocks/proto/mock_streamingpb" "github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_types" @@ -35,14 +37,19 @@ func TestHandlerClient(t *testing.T) { service := mock_lazygrpc.NewMockService[streamingpb.StreamingNodeHandlerServiceClient](t) handlerServiceClient := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t) - handlerServiceClient.EXPECT().GetReplicateCheckpoint(mock.Anything, mock.Anything).Return(&streamingpb.GetReplicateCheckpointResponse{ - Checkpoint: &commonpb.ReplicateCheckpoint{ - ClusterId: "pchannel", - Pchannel: "pchannel", - MessageId: nil, - TimeTick: 0, - }, - }, nil) + handlerServiceClient.EXPECT().GetReplicateCheckpoint(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, grcr *streamingpb.GetReplicateCheckpointRequest, co ...grpc.CallOption) (*streamingpb.GetReplicateCheckpointResponse, error) { + serverID, ok := contextutil.GetPickServerID(ctx) + assert.True(t, ok) + assert.Equal(t, serverID, assignment.Node.ServerID) + return &streamingpb.GetReplicateCheckpointResponse{ + Checkpoint: &commonpb.ReplicateCheckpoint{ + ClusterId: "pchannel", + Pchannel: "pchannel", + MessageId: nil, + TimeTick: 0, + }, + }, nil + }) service.EXPECT().GetService(mock.Anything).Return(handlerServiceClient, nil) rb := mock_resolver.NewMockBuilder(t) rb.EXPECT().Close().Run(func() {}) @@ -65,6 +72,9 @@ func TestHandlerClient(t *testing.T) { watcher: w, rebalanceTrigger: rebalanceTrigger, newProducer: func(ctx context.Context, opts *producer.ProducerOptions, handler streamingpb.StreamingNodeHandlerServiceClient) (Producer, error) { + serverID, ok := contextutil.GetPickServerID(ctx) + assert.True(t, ok) + assert.Equal(t, serverID, assignment.Node.ServerID) if pK == 0 { pK++ return nil, status.NewUnmatchedChannelTerm("pchannel", 1, 2) @@ -72,6 +82,9 @@ func TestHandlerClient(t *testing.T) { return p, nil }, newConsumer: func(ctx context.Context, opts *consumer.ConsumerOptions, handlerClient streamingpb.StreamingNodeHandlerServiceClient) (Consumer, error) { + serverID, ok := contextutil.GetPickServerID(ctx) + assert.True(t, ok) + assert.Equal(t, serverID, assignment.Node.ServerID) return c, nil }, } diff --git a/internal/streamingnode/client/handler/producer/producer_impl.go b/internal/streamingnode/client/handler/producer/producer_impl.go index fcd6cbe632..6fb79b0bbb 100644 --- a/internal/streamingnode/client/handler/producer/producer_impl.go +++ b/internal/streamingnode/client/handler/producer/producer_impl.go @@ -32,7 +32,9 @@ func CreateProducer( opts *ProducerOptions, handler streamingpb.StreamingNodeHandlerServiceClient, ) (Producer, error) { - ctx = createProduceRequest(ctx, opts) + ctx = contextutil.WithCreateProducer(ctx, &streamingpb.CreateProducerRequest{ + Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel), + }) streamClient, err := handler.Produce(ctx) if err != nil { return nil, err @@ -80,16 +82,6 @@ func CreateProducer( return cli, nil } -// createProduceRequest creates the produce request. -func createProduceRequest(ctx context.Context, opts *ProducerOptions) context.Context { - // select server to consume. - ctx = contextutil.WithPickServerID(ctx, opts.Assignment.Node.ServerID) - // select channel to consume. - return contextutil.WithCreateProducer(ctx, &streamingpb.CreateProducerRequest{ - Pchannel: types.NewProtoFromPChannelInfo(opts.Assignment.Channel), - }) -} - // Expected message sequence: // CreateProducer // ProduceRequest 1 -> ProduceResponse Or Error 1