From 446e0b7bf569855fe91ddb9eaf01010fcd2a3373 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Mon, 24 Nov 2025 20:05:07 +0800 Subject: [PATCH] fix: keep memory state consistent when recovering broadcast task from proto (#45787) issue: #45782 - because the zero value of the repeated field and bytes field in proto is ignored or treated as empty value but not nil pointer, so we need to fix the recovery info of the broadcast task from proto to keep the consistency of memory state. Signed-off-by: chyezh --- .../server/broadcaster/broadcast_task.go | 21 +++++ .../server/broadcaster/broadcaster_test.go | 81 +++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/internal/streamingcoord/server/broadcaster/broadcast_task.go b/internal/streamingcoord/server/broadcaster/broadcast_task.go index 44a374a587..d7dee9e699 100644 --- a/internal/streamingcoord/server/broadcaster/broadcast_task.go +++ b/internal/streamingcoord/server/broadcaster/broadcast_task.go @@ -20,6 +20,9 @@ import ( func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask { msg := message.NewBroadcastMutableMessageBeforeAppend(proto.Message.Payload, proto.Message.Properties) m := metrics.NewBroadcastTask(msg.MessageType(), proto.GetState(), msg.BroadcastHeader().ResourceKeys.Collect()) + + fixAckInfoFromProto(proto, len(msg.BroadcastHeader().VChannels)) + bt := &broadcastTask{ mu: sync.Mutex{}, taskMetricsGuard: m, @@ -40,6 +43,24 @@ func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask, metrics *broadc return bt } +// fixAckInfoFromProto fixes the recovery info of the broadcast task. +// because the zero value of the repeated field and bytes field in proto is ignored or treated as empty value but not nil pointer, +// so we need to fix the recovery info of the broadcast task from proto to keep the consistency of memory state. +func fixAckInfoFromProto(proto *streamingpb.BroadcastTask, vchannelCount int) { + bitmap := make([]byte, vchannelCount) + copy(bitmap, proto.AckedVchannelBitmap) + + checkpoints := make([]*streamingpb.AckedCheckpoint, vchannelCount) + for i, cp := range proto.AckedCheckpoints { + if cp != nil && cp.TimeTick == 0 { + cp = nil + } + checkpoints[i] = cp + } + proto.AckedVchannelBitmap = bitmap + proto.AckedCheckpoints = checkpoints +} + // newBroadcastTaskFromBroadcastMessage creates a new broadcast task from the broadcast message. func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage, metrics *broadcasterMetrics, ackCallbackScheduler *ackCallbackScheduler) *broadcastTask { m := metrics.NewBroadcastTask(msg.MessageType(), streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, msg.BroadcastHeader().ResourceKeys.Collect()) diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_test.go b/internal/streamingcoord/server/broadcaster/broadcaster_test.go index 88f5231092..f6d8e12558 100644 --- a/internal/streamingcoord/server/broadcaster/broadcaster_test.go +++ b/internal/streamingcoord/server/broadcaster/broadcaster_test.go @@ -9,7 +9,9 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "go.uber.org/atomic" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/distributed/streaming" @@ -267,3 +269,82 @@ func createNewWaitAckBroadcastTaskFromMessage( AckedCheckpoints: acks, } } + +func TestRecoverBroadcastTaskFromProto(t *testing.T) { + task := createNewBroadcastTask(8, []string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c1")) + b, err := proto.Marshal(task) + require.NoError(t, err) + + task = unmarshalTask(t, b, 3) + assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x00, 0x00}) + assert.Len(t, task.AckedCheckpoints, 3) + assert.Nil(t, task.AckedCheckpoints[0]) + assert.Nil(t, task.AckedCheckpoints[1]) + assert.Nil(t, task.AckedCheckpoints[2]) + + cp := &streamingpb.AckedCheckpoint{ + MessageId: walimplstest.NewTestMessageID(1).IntoProto(), + LastConfirmedMessageId: walimplstest.NewTestMessageID(1).IntoProto(), + TimeTick: 1, + } + + task.AckedCheckpoints[2] = cp + task.AckedVchannelBitmap[2] = 0x01 + b, err = proto.Marshal(task) + require.NoError(t, err) + task = unmarshalTask(t, b, 3) + assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x00, 0x01}) + assert.Len(t, task.AckedCheckpoints, 3) + assert.Nil(t, task.AckedCheckpoints[0]) + assert.Nil(t, task.AckedCheckpoints[1]) + assert.NotNil(t, task.AckedCheckpoints[2]) + + task.AckedCheckpoints[2] = nil + task.AckedVchannelBitmap[2] = 0x0 + task.AckedCheckpoints[0] = cp + task.AckedVchannelBitmap[0] = 0x01 + b, err = proto.Marshal(task) + require.NoError(t, err) + task = unmarshalTask(t, b, 3) + assert.Equal(t, task.AckedVchannelBitmap, []byte{0x01, 0x00, 0x00}) + assert.Len(t, task.AckedCheckpoints, 3) + assert.NotNil(t, task.AckedCheckpoints[0]) + assert.Nil(t, task.AckedCheckpoints[1]) + assert.Nil(t, task.AckedCheckpoints[2]) + + task.AckedCheckpoints[0] = nil + task.AckedVchannelBitmap[0] = 0x0 + task.AckedCheckpoints[1] = cp + task.AckedVchannelBitmap[1] = 0x01 + b, err = proto.Marshal(task) + require.NoError(t, err) + task = unmarshalTask(t, b, 3) + assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x01, 0x00}) + assert.Len(t, task.AckedCheckpoints, 3) + assert.Nil(t, task.AckedCheckpoints[0]) + assert.NotNil(t, task.AckedCheckpoints[1]) + assert.Nil(t, task.AckedCheckpoints[2]) + + task.AckedVchannelBitmap = []byte{0x01, 0x01, 0x01} + task.AckedCheckpoints = []*streamingpb.AckedCheckpoint{ + cp, + cp, + cp, + } + b, err = proto.Marshal(task) + require.NoError(t, err) + task = unmarshalTask(t, b, 3) + assert.Equal(t, task.AckedVchannelBitmap, []byte{0x01, 0x01, 0x01}) + assert.Len(t, task.AckedCheckpoints, 3) + assert.NotNil(t, task.AckedCheckpoints[0]) + assert.NotNil(t, task.AckedCheckpoints[1]) + assert.NotNil(t, task.AckedCheckpoints[2]) +} + +func unmarshalTask(t *testing.T, b []byte, vchannelCount int) *streamingpb.BroadcastTask { + task := &streamingpb.BroadcastTask{} + err := proto.Unmarshal(b, task) + require.NoError(t, err) + fixAckInfoFromProto(task, vchannelCount) + return task +}