Zhen Ye 446e0b7bf5
fix: keep memory state consistent when recovering broadcast task from proto (#45787)
issue: #45782

- because the zero value of the repeated field and bytes field in proto
is ignored or treated as empty value but not nil pointer, so we need to
fix the recovery info of the broadcast task from proto to keep the
consistency of memory state.

Signed-off-by: chyezh <chyezh@outlook.com>
2025-11-24 20:05:07 +08:00

351 lines
13 KiB
Go

package broadcaster
import (
"context"
"math/rand"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
internaltypes "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/idalloc"
"github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_message"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestBroadcaster(t *testing.T) {
registry.ResetRegistration()
paramtable.Init()
paramtable.Get().StreamingCfg.WALBroadcasterTombstoneCheckInternal.SwapTempValue("10ms")
paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxCount.SwapTempValue("2")
paramtable.Get().StreamingCfg.WALBroadcasterTombstoneMaxLifetime.SwapTempValue("20ms")
mb := mock_balancer.NewMockBalancer(t)
mb.EXPECT().ReplicateRole().Return(replicateutil.RolePrimary)
mb.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
time.Sleep(100 * time.Second)
return nil
})
balance.Register(mb)
registry.RegisterDropCollectionV1AckCallback(func(ctx context.Context, msg message.BroadcastResultDropCollectionMessageV1) error {
return nil
})
meta := mock_metastore.NewMockStreamingCoordCataLog(t)
meta.EXPECT().ListBroadcastTask(mock.Anything).
RunAndReturn(func(ctx context.Context) ([]*streamingpb.BroadcastTask, error) {
return []*streamingpb.BroadcastTask{
createNewBroadcastTask(8, []string{"v1"}, message.NewCollectionNameResourceKey("c1")),
createNewBroadcastTask(9, []string{"v1", "v2"}, message.NewCollectionNameResourceKey("c2")),
createNewBroadcastTask(3, []string{"v1", "v2", "v3"}),
createNewWaitAckBroadcastTaskFromMessage(
createNewBroadcastMsg([]string{"v1", "v2", "v3"}).WithBroadcastID(4),
streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING,
[]byte{0x00, 0x01, 0x00}),
createNewWaitAckBroadcastTaskFromMessage(
createNewBroadcastMsg([]string{"v1", "v2", "v3"}).WithBroadcastID(5),
streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING,
[]byte{0x01, 0x01, 0x00}),
createNewWaitAckBroadcastTaskFromMessage(
createNewBroadcastMsg([]string{"v1", "v2", "v3"}).WithBroadcastID(6), // will be done directly.
streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING,
[]byte{0x01, 0x01, 0x01}),
createNewWaitAckBroadcastTaskFromMessage(
createNewBroadcastMsg([]string{"v1", "v2", "v3"},
message.NewCollectionNameResourceKey("c3"),
message.NewCollectionNameResourceKey("c4")).WithBroadcastID(7),
streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED,
[]byte{0x00, 0x00, 0x00}),
}, nil
}).Times(1)
done := typeutil.NewConcurrentSet[uint64]()
meta.EXPECT().SaveBroadcastTask(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, broadcastID uint64, bt *streamingpb.BroadcastTask) error {
if ctx.Err() != nil {
return ctx.Err()
}
if bt.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_TOMBSTONE {
done.Insert(broadcastID)
}
return nil
})
rc := idalloc.NewMockRootCoordClient(t)
f := syncutil.NewFuture[internaltypes.MixCoordClient]()
f.Set(rc)
resource.InitForTest(resource.OptStreamingCatalog(meta), resource.OptMixCoordClient(f))
fbc := syncutil.NewFuture[Broadcaster]()
appended := createOpeartor(t, fbc)
bc, err := RecoverBroadcaster(context.Background())
fbc.Set(bc)
assert.NoError(t, err)
assert.NotNil(t, bc)
assert.Eventually(t, func() bool {
return appended.Load() == 9 && len(done.Collect()) == 6
}, 30*time.Second, 10*time.Millisecond)
// only task 7 is not done.
ack(t, bc, 7, "v1")
ack(t, bc, 7, "v1") // test already acked, make the idempotent.
assert.Equal(t, len(done.Collect()), 6)
ack(t, bc, 7, "v2")
ack(t, bc, 7, "v2")
assert.Equal(t, len(done.Collect()), 6)
ack(t, bc, 7, "v3")
ack(t, bc, 7, "v3")
assert.Eventually(t, func() bool {
return appended.Load() == 9 && len(done.Collect()) == 7
}, 30*time.Second, 10*time.Millisecond)
// Test broadcast here.
broadcastWithSameRK := func() {
var result *types.BroadcastAppendResult
var err error
b, err := bc.WithResourceKeys(context.Background(), message.NewCollectionNameResourceKey("c7"))
assert.NoError(t, err)
result, err = b.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c7")))
assert.Equal(t, len(result.AppendResults), 3)
assert.NoError(t, err)
}
go broadcastWithSameRK()
go broadcastWithSameRK()
assert.Eventually(t, func() bool {
return appended.Load() == 15 && len(done.Collect()) == 9
}, 30*time.Second, 10*time.Millisecond)
// Test close befor broadcast
broadcastAPI, err := bc.WithResourceKeys(context.Background(), message.NewExclusiveClusterResourceKey())
assert.NoError(t, err)
broadcastAPI.Close()
broadcastAPI, err = bc.WithResourceKeys(context.Background(), message.NewExclusiveClusterResourceKey())
assert.NoError(t, err)
broadcastAPI.Close()
bc.Close()
broadcastAPI, err = bc.WithResourceKeys(context.Background())
assert.NoError(t, err)
_, err = broadcastAPI.Broadcast(context.Background(), nil)
assert.Error(t, err)
err = bc.Ack(context.Background(), mock_message.NewMockImmutableMessage(t))
assert.Error(t, err)
}
func ack(t *testing.T, broadcaster Broadcaster, broadcastID uint64, vchannel string) {
for {
msg := message.NewDropCollectionMessageBuilderV1().
WithHeader(&message.DropCollectionMessageHeader{}).
WithBody(&msgpb.DropCollectionRequest{}).
WithBroadcast([]string{vchannel}).
MustBuildBroadcast().
WithBroadcastID(broadcastID).
SplitIntoMutableMessage()[0].
WithTimeTick(100).
WithLastConfirmed(walimplstest.NewTestMessageID(1)).
IntoImmutableMessage(walimplstest.NewTestMessageID(1))
if err := broadcaster.Ack(context.Background(), msg); err == nil {
break
}
}
}
func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) *atomic.Int64 {
id := atomic.NewInt64(1)
appended := atomic.NewInt64(0)
operator := mock_streaming.NewMockWALAccesser(t)
f := func(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
resps := types.AppendResponses{
Responses: make([]types.AppendResponse, len(msgs)),
}
for idx, msg := range msgs {
newID := walimplstest.NewTestMessageID(id.Inc())
if rand.Int31n(10) < 3 {
resps.Responses[idx] = types.AppendResponse{
Error: errors.New("append failed"),
}
continue
}
resps.Responses[idx] = types.AppendResponse{
AppendResult: &types.AppendResult{
MessageID: newID,
TimeTick: uint64(time.Now().UnixMilli()),
},
Error: nil,
}
appended.Inc()
broadcastID := msg.BroadcastHeader().BroadcastID
vchannel := msg.VChannel()
go func() {
time.Sleep(time.Duration(rand.Int31n(100)) * time.Millisecond)
ack(t, broadcaster.Get(), broadcastID, vchannel)
}()
}
return resps
}
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything).RunAndReturn(f)
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
streaming.SetWALForTest(operator)
return appended
}
func createNewBroadcastMsg(vchannels []string, rks ...message.ResourceKey) message.BroadcastMutableMessage {
msg, err := message.NewDropCollectionMessageBuilderV1().
WithHeader(&messagespb.DropCollectionMessageHeader{}).
WithBody(&msgpb.DropCollectionRequest{}).
WithBroadcast(vchannels, rks...).
BuildBroadcast()
if err != nil {
panic(err)
}
return msg
}
func createNewBroadcastTask(broadcastID uint64, vchannels []string, rks ...message.ResourceKey) *streamingpb.BroadcastTask {
msg := createNewBroadcastMsg(vchannels, rks...).WithBroadcastID(broadcastID)
pb := msg.IntoMessageProto()
return &streamingpb.BroadcastTask{
Message: &messagespb.Message{
Payload: pb.Payload,
Properties: pb.Properties,
},
State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING,
AckedVchannelBitmap: make([]byte, len(vchannels)),
}
}
func createNewWaitAckBroadcastTaskFromMessage(
msg message.BroadcastMutableMessage,
state streamingpb.BroadcastTaskState,
bitmap []byte,
) *streamingpb.BroadcastTask {
pb := msg.IntoMessageProto()
acks := make([]*streamingpb.AckedCheckpoint, len(bitmap))
for i := 0; i < len(bitmap); i++ {
if bitmap[i] != 0 {
messageID := walimplstest.NewTestMessageID(int64(i))
lastConfirmedMessageID := walimplstest.NewTestMessageID(int64(i))
acks[i] = &streamingpb.AckedCheckpoint{
MessageId: messageID.IntoProto(),
LastConfirmedMessageId: lastConfirmedMessageID.IntoProto(),
TimeTick: 1,
}
}
}
return &streamingpb.BroadcastTask{
Message: &messagespb.Message{
Payload: pb.Payload,
Properties: pb.Properties,
},
State: state,
AckedVchannelBitmap: bitmap,
AckedCheckpoints: acks,
}
}
func TestRecoverBroadcastTaskFromProto(t *testing.T) {
task := createNewBroadcastTask(8, []string{"v1", "v2", "v3"}, message.NewCollectionNameResourceKey("c1"))
b, err := proto.Marshal(task)
require.NoError(t, err)
task = unmarshalTask(t, b, 3)
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x00, 0x00})
assert.Len(t, task.AckedCheckpoints, 3)
assert.Nil(t, task.AckedCheckpoints[0])
assert.Nil(t, task.AckedCheckpoints[1])
assert.Nil(t, task.AckedCheckpoints[2])
cp := &streamingpb.AckedCheckpoint{
MessageId: walimplstest.NewTestMessageID(1).IntoProto(),
LastConfirmedMessageId: walimplstest.NewTestMessageID(1).IntoProto(),
TimeTick: 1,
}
task.AckedCheckpoints[2] = cp
task.AckedVchannelBitmap[2] = 0x01
b, err = proto.Marshal(task)
require.NoError(t, err)
task = unmarshalTask(t, b, 3)
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x00, 0x01})
assert.Len(t, task.AckedCheckpoints, 3)
assert.Nil(t, task.AckedCheckpoints[0])
assert.Nil(t, task.AckedCheckpoints[1])
assert.NotNil(t, task.AckedCheckpoints[2])
task.AckedCheckpoints[2] = nil
task.AckedVchannelBitmap[2] = 0x0
task.AckedCheckpoints[0] = cp
task.AckedVchannelBitmap[0] = 0x01
b, err = proto.Marshal(task)
require.NoError(t, err)
task = unmarshalTask(t, b, 3)
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x01, 0x00, 0x00})
assert.Len(t, task.AckedCheckpoints, 3)
assert.NotNil(t, task.AckedCheckpoints[0])
assert.Nil(t, task.AckedCheckpoints[1])
assert.Nil(t, task.AckedCheckpoints[2])
task.AckedCheckpoints[0] = nil
task.AckedVchannelBitmap[0] = 0x0
task.AckedCheckpoints[1] = cp
task.AckedVchannelBitmap[1] = 0x01
b, err = proto.Marshal(task)
require.NoError(t, err)
task = unmarshalTask(t, b, 3)
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x00, 0x01, 0x00})
assert.Len(t, task.AckedCheckpoints, 3)
assert.Nil(t, task.AckedCheckpoints[0])
assert.NotNil(t, task.AckedCheckpoints[1])
assert.Nil(t, task.AckedCheckpoints[2])
task.AckedVchannelBitmap = []byte{0x01, 0x01, 0x01}
task.AckedCheckpoints = []*streamingpb.AckedCheckpoint{
cp,
cp,
cp,
}
b, err = proto.Marshal(task)
require.NoError(t, err)
task = unmarshalTask(t, b, 3)
assert.Equal(t, task.AckedVchannelBitmap, []byte{0x01, 0x01, 0x01})
assert.Len(t, task.AckedCheckpoints, 3)
assert.NotNil(t, task.AckedCheckpoints[0])
assert.NotNil(t, task.AckedCheckpoints[1])
assert.NotNil(t, task.AckedCheckpoints[2])
}
func unmarshalTask(t *testing.T, b []byte, vchannelCount int) *streamingpb.BroadcastTask {
task := &streamingpb.BroadcastTask{}
err := proto.Unmarshal(b, task)
require.NoError(t, err)
fixAckInfoFromProto(task, vchannelCount)
return task
}