enhance: make immutable message as the param of ack operation for cdc (#43900)

issue: #43897

- The original broadcast ack operation need to recover message from
etcd, which can not support cdc.
- immutable message will set as the ack parameter to fix it.

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-09-01 10:21:52 +08:00 committed by GitHub
parent 90b4571aee
commit 3327df72e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1059 additions and 999 deletions

View File

@ -353,14 +353,12 @@ func (s *Server) initDataCoord() error {
// initMessageCallback initializes the message callback.
// TODO: we should build a ddl framework to handle the message ack callback for ddl messages
func (s *Server) initMessageCallback() {
registry.RegisterMessageAckCallback(message.MessageTypeDropPartition, func(ctx context.Context, msg message.MutableMessage) error {
dropPartitionMsg := message.MustAsMutableDropPartitionMessageV1(msg)
return s.NotifyDropPartition(ctx, msg.VChannel(), []int64{dropPartitionMsg.Header().PartitionId})
registry.RegisterDropPartitionMessageV1AckCallback(func(ctx context.Context, msg message.ImmutableDropPartitionMessageV1) error {
return s.NotifyDropPartition(ctx, msg.VChannel(), []int64{msg.Header().PartitionId})
})
registry.RegisterMessageAckCallback(message.MessageTypeImport, func(ctx context.Context, msg message.MutableMessage) error {
importMsg := message.MustAsMutableImportMessageV1(msg)
body := importMsg.MustBody()
registry.RegisterImportMessageV1AckCallback(func(ctx context.Context, msg message.ImmutableImportMessageV1) error {
body := msg.MustBody()
importResp, err := s.ImportV2(ctx, &internalpb.ImportRequestInternal{
CollectionID: body.GetCollectionID(),
CollectionName: body.GetCollectionName(),
@ -390,14 +388,10 @@ func (s *Server) initMessageCallback() {
return nil
})
registry.RegisterMessageCheckCallback(message.MessageTypeImport, func(ctx context.Context, msg message.BroadcastMutableMessage) error {
importMsg := message.MustAsMutableImportMessageV1(msg)
b, err := importMsg.Body()
if err != nil {
return err
}
registry.RegisterImportMessageV1CheckCallback(func(ctx context.Context, msg message.BroadcastImportMessageV1) error {
b := msg.MustBody()
options := funcutil.Map2KeyValuePair(b.GetOptions())
_, err = importutilv2.GetTimeoutTs(options)
_, err := importutilv2.GetTimeoutTs(options)
if err != nil {
return err
}

View File

@ -61,6 +61,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
@ -2880,7 +2881,7 @@ func TestServer_InitMessageCallback(t *testing.T) {
}).
BuildMutable()
assert.NoError(t, err)
err = registry.CallMessageAckCallback(ctx, dropPartitionMsg)
err = registry.CallMessageAckCallback(ctx, dropPartitionMsg.IntoImmutableMessage(rmq.NewRmqID(1)))
assert.Error(t, err) // server not healthy
// Test Import message check callback
@ -2908,6 +2909,6 @@ func TestServer_InitMessageCallback(t *testing.T) {
}).
BuildMutable()
assert.NoError(t, err)
err = registry.CallMessageAckCallback(ctx, importMsg)
err = registry.CallMessageAckCallback(ctx, importMsg.IntoImmutableMessage(rmq.NewRmqID(1)))
assert.Error(t, err) // server not healthy
}

View File

@ -27,7 +27,7 @@ func (b broadcast) Append(ctx context.Context, msg message.BroadcastMutableMessa
return b.streamingCoordClient.Broadcast().Broadcast(ctx, msg)
}
func (b broadcast) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
func (b broadcast) Ack(ctx context.Context, msg message.ImmutableMessage) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
// should be an unreachable error.
return ErrWALAccesserClosed
@ -35,7 +35,5 @@ func (b broadcast) Ack(ctx context.Context, req types.BroadcastAckRequest) error
defer b.lifetime.Done()
// retry until the ctx is canceled.
return retry.Do(ctx, func() error {
return b.streamingCoordClient.Broadcast().Ack(ctx, req)
}, retry.AttemptAlways())
return retry.Do(ctx, func() error { return b.streamingCoordClient.Broadcast().Ack(ctx, msg) }, retry.AttemptAlways())
}

View File

@ -169,7 +169,7 @@ type Broadcast interface {
// Ack acknowledges a broadcast message at the specified vchannel.
// It must be called after the message is comsumed by the unique-consumer.
// It will only return error when the ctx is canceled.
Ack(ctx context.Context, req types.BroadcastAckRequest) error
Ack(ctx context.Context, msg message.ImmutableMessage) error
}
// Txn is the interface for writing transaction into the wal.

View File

@ -109,7 +109,7 @@ func (n *noopBroadcast) Append(ctx context.Context, msg message.BroadcastMutable
}, nil
}
func (n *noopBroadcast) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
func (n *noopBroadcast) Ack(ctx context.Context, msg message.ImmutableMessage) error {
return nil
}

View File

@ -19,6 +19,7 @@ import (
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
@ -149,7 +150,12 @@ func TestWAL(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, r.AppendResults, 3)
err = w.Broadcast().Ack(ctx, types.BroadcastAckRequest{BroadcastID: 1, VChannel: vChannel1})
err = w.Broadcast().Ack(ctx, message.NewDropCollectionMessageBuilderV1().
WithVChannel(vChannel1).
WithHeader(&message.DropCollectionMessageHeader{}).
WithBody(&msgpb.DropCollectionRequest{}).
MustBuildMutable().
IntoImmutableMessage(rmq.NewRmqID(1)))
assert.NoError(t, err)
cnt := atomic.NewInt32(0)
@ -192,7 +198,12 @@ func TestWAL(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, r)
err = w.Broadcast().Ack(ctx, types.BroadcastAckRequest{BroadcastID: 1, VChannel: vChannel1})
err = w.Broadcast().Ack(ctx, message.NewDropCollectionMessageBuilderV1().
WithVChannel(vChannel1).
WithHeader(&message.DropCollectionMessageHeader{}).
WithBody(&msgpb.DropCollectionRequest{}).
MustBuildMutable().
IntoImmutableMessage(rmq.NewRmqID(1)))
assert.Error(t, err)
}

View File

@ -24,17 +24,17 @@ func (_m *MockBroadcast) EXPECT() *MockBroadcast_Expecter {
return &MockBroadcast_Expecter{mock: &_m.Mock}
}
// Ack provides a mock function with given fields: ctx, req
func (_m *MockBroadcast) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
ret := _m.Called(ctx, req)
// Ack provides a mock function with given fields: ctx, msg
func (_m *MockBroadcast) Ack(ctx context.Context, msg message.ImmutableMessage) error {
ret := _m.Called(ctx, msg)
if len(ret) == 0 {
panic("no return value specified for Ack")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, types.BroadcastAckRequest) error); ok {
r0 = rf(ctx, req)
if rf, ok := ret.Get(0).(func(context.Context, message.ImmutableMessage) error); ok {
r0 = rf(ctx, msg)
} else {
r0 = ret.Error(0)
}
@ -49,14 +49,14 @@ type MockBroadcast_Ack_Call struct {
// Ack is a helper method to define mock.On call
// - ctx context.Context
// - req types.BroadcastAckRequest
func (_e *MockBroadcast_Expecter) Ack(ctx interface{}, req interface{}) *MockBroadcast_Ack_Call {
return &MockBroadcast_Ack_Call{Call: _e.mock.On("Ack", ctx, req)}
// - msg message.ImmutableMessage
func (_e *MockBroadcast_Expecter) Ack(ctx interface{}, msg interface{}) *MockBroadcast_Ack_Call {
return &MockBroadcast_Ack_Call{Call: _e.mock.On("Ack", ctx, msg)}
}
func (_c *MockBroadcast_Ack_Call) Run(run func(ctx context.Context, req types.BroadcastAckRequest)) *MockBroadcast_Ack_Call {
func (_c *MockBroadcast_Ack_Call) Run(run func(ctx context.Context, msg message.ImmutableMessage)) *MockBroadcast_Ack_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(types.BroadcastAckRequest))
run(args[0].(context.Context), args[1].(message.ImmutableMessage))
})
return _c
}
@ -66,7 +66,7 @@ func (_c *MockBroadcast_Ack_Call) Return(_a0 error) *MockBroadcast_Ack_Call {
return _c
}
func (_c *MockBroadcast_Ack_Call) RunAndReturn(run func(context.Context, types.BroadcastAckRequest) error) *MockBroadcast_Ack_Call {
func (_c *MockBroadcast_Ack_Call) RunAndReturn(run func(context.Context, message.ImmutableMessage) error) *MockBroadcast_Ack_Call {
_c.Call.Return(run)
return _c
}

View File

@ -24,17 +24,17 @@ func (_m *MockBroadcastService) EXPECT() *MockBroadcastService_Expecter {
return &MockBroadcastService_Expecter{mock: &_m.Mock}
}
// Ack provides a mock function with given fields: ctx, req
func (_m *MockBroadcastService) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
ret := _m.Called(ctx, req)
// Ack provides a mock function with given fields: ctx, msg
func (_m *MockBroadcastService) Ack(ctx context.Context, msg message.ImmutableMessage) error {
ret := _m.Called(ctx, msg)
if len(ret) == 0 {
panic("no return value specified for Ack")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, types.BroadcastAckRequest) error); ok {
r0 = rf(ctx, req)
if rf, ok := ret.Get(0).(func(context.Context, message.ImmutableMessage) error); ok {
r0 = rf(ctx, msg)
} else {
r0 = ret.Error(0)
}
@ -49,14 +49,14 @@ type MockBroadcastService_Ack_Call struct {
// Ack is a helper method to define mock.On call
// - ctx context.Context
// - req types.BroadcastAckRequest
func (_e *MockBroadcastService_Expecter) Ack(ctx interface{}, req interface{}) *MockBroadcastService_Ack_Call {
return &MockBroadcastService_Ack_Call{Call: _e.mock.On("Ack", ctx, req)}
// - msg message.ImmutableMessage
func (_e *MockBroadcastService_Expecter) Ack(ctx interface{}, msg interface{}) *MockBroadcastService_Ack_Call {
return &MockBroadcastService_Ack_Call{Call: _e.mock.On("Ack", ctx, msg)}
}
func (_c *MockBroadcastService_Ack_Call) Run(run func(ctx context.Context, req types.BroadcastAckRequest)) *MockBroadcastService_Ack_Call {
func (_c *MockBroadcastService_Ack_Call) Run(run func(ctx context.Context, msg message.ImmutableMessage)) *MockBroadcastService_Ack_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(types.BroadcastAckRequest))
run(args[0].(context.Context), args[1].(message.ImmutableMessage))
})
return _c
}
@ -66,7 +66,7 @@ func (_c *MockBroadcastService_Ack_Call) Return(_a0 error) *MockBroadcastService
return _c
}
func (_c *MockBroadcastService_Ack_Call) RunAndReturn(run func(context.Context, types.BroadcastAckRequest) error) *MockBroadcastService_Ack_Call {
func (_c *MockBroadcastService_Ack_Call) RunAndReturn(run func(context.Context, message.ImmutableMessage) error) *MockBroadcastService_Ack_Call {
_c.Call.Return(run)
return _c
}

View File

@ -63,6 +63,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
_ "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/pulsar"
"github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"

View File

@ -54,14 +54,15 @@ func (c *GRPCBroadcastServiceImpl) Broadcast(ctx context.Context, msg message.Br
}, nil
}
func (c *GRPCBroadcastServiceImpl) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
func (c *GRPCBroadcastServiceImpl) Ack(ctx context.Context, msg message.ImmutableMessage) error {
client, err := c.service.GetService(ctx)
if err != nil {
return err
}
_, err = client.Ack(ctx, &streamingpb.BroadcastAckRequest{
BroadcastId: req.BroadcastID,
Vchannel: req.VChannel,
BroadcastId: msg.BroadcastHeader().BroadcastID,
Vchannel: msg.VChannel(),
Message: msg.IntoImmutableMessageProto(),
})
return err
}

View File

@ -15,7 +15,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
)
@ -29,10 +29,9 @@ func TestBroadcast(t *testing.T) {
MustBuildBroadcast()
_, err := bs.Broadcast(context.Background(), msg)
assert.NoError(t, err)
err = bs.Ack(context.Background(), types.BroadcastAckRequest{
VChannel: "v1",
BroadcastID: 1,
})
msg1 := msg.WithBroadcastID(1).SplitIntoMutableMessage()
immutableMsg1 := msg1[0].IntoImmutableMessage(rmq.NewRmqID(1))
err = bs.Ack(context.Background(), immutableMsg1)
assert.NoError(t, err)
}

View File

@ -47,7 +47,7 @@ type BroadcastService interface {
Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)
// Ack sends a broadcast ack to the streaming service.
Ack(ctx context.Context, req types.BroadcastAckRequest) error
Ack(ctx context.Context, msg message.ImmutableMessage) error
}
// Client is the interface of log service client.

View File

@ -5,18 +5,14 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func TestDial(t *testing.T) {
paramtable.Init()
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
assert.NoError(t, err)
defer etcd.StopEtcdServer()
c, err := etcd.GetEmbedEtcdClient()
assert.NoError(t, err)
c, _ := kvfactory.GetEtcdAndPath()
assert.NotNil(t, c)
client := NewClient(c)

View File

@ -242,7 +242,6 @@ func TestBalancer(t *testing.T) {
func TestBalancer_WithRecoveryLag(t *testing.T) {
paramtable.Init()
etcdClient, _ := kvfactory.GetEtcdAndPath()
channel.ResetStaticPChannelStatsManager()
channel.RecoverPChannelStatsManager([]string{})

View File

@ -8,6 +8,7 @@ import (
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
@ -86,17 +87,40 @@ func (bm *broadcastTaskManager) assignID(ctx context.Context, msg message.Broadc
return msg, nil
}
// Ack acknowledges the message at the specified vchannel.
func (bm *broadcastTaskManager) Ack(ctx context.Context, broadcastID uint64, vchannel string) error {
// LegacyAck is the legacy ack function for the broadcast task.
// It will not be used after upgrading to 2.6.1, only used for compatibility.
func (bm *broadcastTaskManager) LegacyAck(ctx context.Context, broadcastID uint64, vchannel string) error {
task, ok := bm.getBroadcastTaskByID(broadcastID)
if !ok {
bm.Logger().Warn("broadcast task not found, it may already acked, ignore the request", zap.Uint64("broadcastID", broadcastID), zap.String("vchannel", vchannel))
return nil
}
if err := task.Ack(ctx, vchannel); err != nil {
msg := task.GetImmutableMessageFromVChannel(vchannel)
if msg == nil {
task.Logger().Warn("vchannel is already acked, ignore the ack request", zap.String("vchannel", vchannel))
return nil
}
return bm.Ack(ctx, msg)
}
// Ack acknowledges the message at the specified vchannel.
func (bm *broadcastTaskManager) Ack(ctx context.Context, msg message.ImmutableMessage) error {
if err := registry.CallMessageAckCallback(ctx, msg); err != nil {
bm.Logger().Warn("message ack callback failed", log.FieldMessage(msg), zap.Error(err))
return err
}
bm.Logger().Warn("message ack callback success", log.FieldMessage(msg))
broadcastID := msg.BroadcastHeader().BroadcastID
vchannel := msg.VChannel()
task, ok := bm.getBroadcastTaskByID(broadcastID)
if !ok {
bm.Logger().Warn("broadcast task not found, it may already acked, ignore the request", zap.Uint64("broadcastID", broadcastID), zap.String("vchannel", vchannel))
return nil
}
if err := task.Ack(ctx, msg); err != nil {
return err
}
if task.State() == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE {
bm.removeBroadcastTask(broadcastID)
}
@ -131,7 +155,6 @@ func (bm *broadcastTaskManager) addBroadcastTask(ctx context.Context, msg messag
}
bm.tasks[header.BroadcastID] = newIncomingTask
bm.cond.L.Unlock()
// TODO: perform a task checker here to make sure the task is vaild to be broadcasted in future.
return newIncomingTask, nil
}

View File

@ -8,7 +8,6 @@ import (
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
@ -116,35 +115,28 @@ func (b *broadcastTask) InitializeRecovery(ctx context.Context) error {
return nil
}
// getImmutableMessageFromVChannel gets the immutable message from the vchannel.
// GetImmutableMessageFromVChannel gets the immutable message from the vchannel.
// If the vchannel is already acked, it returns nil.
func (b *broadcastTask) getImmutableMessageFromVChannel(vchannel string) message.MutableMessage {
msgs := b.PendingBroadcastMessages()
func (b *broadcastTask) GetImmutableMessageFromVChannel(vchannel string) message.ImmutableMessage {
b.mu.Lock()
defer b.mu.Unlock()
msg := message.NewBroadcastMutableMessageBeforeAppend(b.task.Message.Payload, b.task.Message.Properties)
msgs := msg.SplitIntoMutableMessage()
for _, msg := range msgs {
if msg.VChannel() == vchannel {
return msg
// The legacy message don't have timetick, so we need to set it to 0.
return msg.WithTimeTick(0).IntoImmutableMessage(nil)
}
}
return nil
}
// Ack acknowledges the message at the specified vchannel.
func (b *broadcastTask) Ack(ctx context.Context, vchannel string) error {
// TODO: after all status is recovered from wal, we need make a async framework to handle the callback asynchronously.
msg := b.getImmutableMessageFromVChannel(vchannel)
if msg == nil {
b.Logger().Warn("vchannel is already acked, ignore the ack request", zap.String("vchannel", vchannel))
return nil
}
if err := registry.CallMessageAckCallback(ctx, msg); err != nil {
b.Logger().Warn("message ack callback failed", log.FieldMessage(msg), zap.Error(err))
return err
}
b.Logger().Warn("message ack callback success", log.FieldMessage(msg))
func (b *broadcastTask) Ack(ctx context.Context, msg message.ImmutableMessage) error {
b.mu.Lock()
defer b.mu.Unlock()
task, ok := b.copyAndSetVChannelAcked(vchannel)
task, ok := b.copyAndSetVChannelAcked(msg.VChannel())
if !ok {
return nil
}
@ -152,7 +144,7 @@ func (b *broadcastTask) Ack(ctx context.Context, vchannel string) error {
// We should always save the task after acked.
// Even if the task mark as done in memory.
// Because the task is set as done in memory before save the recovery info.
if err := b.saveTask(ctx, task, b.Logger().With(zap.String("ackVChannel", vchannel))); err != nil {
if err := b.saveTask(ctx, task, b.Logger().With(zap.String("ackVChannel", msg.VChannel()))); err != nil {
return err
}
b.task = task
@ -218,6 +210,8 @@ func (b *broadcastTask) BroadcastDone(ctx context.Context) error {
}
// copyAndMarkBroadcastDone copies the task and mark the broadcast task as done.
// !!! The ack state of the task should not be removed, because the task is a lock-hint of resource key held by a broadcast operation.
// It can be removed only after the broadcast message is acked by all the vchannels.
func (b *broadcastTask) copyAndMarkBroadcastDone() *streamingpb.BroadcastTask {
task := proto.Clone(b.task).(*streamingpb.BroadcastTask)
if isAllDone(task) {

View File

@ -11,8 +11,11 @@ type Broadcaster interface {
// Broadcast broadcasts the message to all channels.
Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)
// LegacyAck is the legacy ack interface for the 2.6.0 import message.
LegacyAck(ctx context.Context, broadcastID uint64, vchannel string) error
// Ack acknowledges the message at the specified vchannel.
Ack(ctx context.Context, req types.BroadcastAckRequest) error
Ack(ctx context.Context, msg message.ImmutableMessage) error
// Close closes the broadcaster.
Close()

View File

@ -102,14 +102,23 @@ func (b *broadcasterImpl) Broadcast(ctx context.Context, msg message.BroadcastMu
return r, nil
}
// Ack acknowledges the message at the specified vchannel.
func (b *broadcasterImpl) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
func (b *broadcasterImpl) LegacyAck(ctx context.Context, broadcastID uint64, vchannel string) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("broadcaster is closing")
}
defer b.lifetime.Done()
return b.manager.Ack(ctx, req.BroadcastID, req.VChannel)
return b.manager.LegacyAck(ctx, broadcastID, vchannel)
}
// Ack acknowledges the message at the specified vchannel.
func (b *broadcasterImpl) Ack(ctx context.Context, msg message.ImmutableMessage) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("broadcaster is closing")
}
defer b.lifetime.Done()
return b.manager.Ack(ctx, msg)
}
func (b *broadcasterImpl) Close() {

View File

@ -19,6 +19,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
internaltypes "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/idalloc"
"github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_message"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
@ -87,14 +88,14 @@ func TestBroadcaster(t *testing.T) {
}, 30*time.Second, 10*time.Millisecond)
// only task 7 is not done.
ack(bc, 7, "v1")
ack(bc, 7, "v1") // test already acked, make the idempotent.
ack(t, bc, 7, "v1")
ack(t, bc, 7, "v1") // test already acked, make the idempotent.
assert.Equal(t, len(done.Collect()), 6)
ack(bc, 7, "v2")
ack(bc, 7, "v2")
ack(t, bc, 7, "v2")
ack(t, bc, 7, "v2")
assert.Equal(t, len(done.Collect()), 6)
ack(bc, 7, "v3")
ack(bc, 7, "v3")
ack(t, bc, 7, "v3")
ack(t, bc, 7, "v3")
assert.Eventually(t, func() bool {
return appended.Load() == 9 && len(done.Collect()) == 7
}, 30*time.Second, 10*time.Millisecond)
@ -121,16 +122,20 @@ func TestBroadcaster(t *testing.T) {
bc.Close()
_, err = bc.Broadcast(context.Background(), nil)
assert.Error(t, err)
err = bc.Ack(context.Background(), types.BroadcastAckRequest{})
err = bc.Ack(context.Background(), mock_message.NewMockImmutableMessage(t))
assert.Error(t, err)
}
func ack(broadcaster Broadcaster, broadcastID uint64, vchannel string) {
func ack(t *testing.T, broadcaster Broadcaster, broadcastID uint64, vchannel string) {
for {
if err := broadcaster.Ack(context.Background(), types.BroadcastAckRequest{
msg := mock_message.NewMockImmutableMessage(t)
msg.EXPECT().VChannel().Return(vchannel)
msg.EXPECT().MessageTypeWithVersion().Return(message.MessageTypeTimeTickV1)
msg.EXPECT().BroadcastHeader().Return(&message.BroadcastHeader{
BroadcastID: broadcastID,
VChannel: vchannel,
}); err == nil {
})
msg.EXPECT().MarshalLogObject(mock.Anything).Return(nil).Maybe()
if err := broadcaster.Ack(context.Background(), msg); err == nil {
break
}
}
@ -165,7 +170,7 @@ func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) *at
vchannel := msg.VChannel()
go func() {
time.Sleep(time.Duration(rand.Int31n(100)) * time.Millisecond)
ack(broadcaster.Get(), broadcastID, vchannel)
ack(t, broadcaster.Get(), broadcastID, vchannel)
}()
}
return resps

View File

@ -5,33 +5,24 @@ import (
"fmt"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// init the message ack callbacks
func init() {
resetMessageAckCallbacks()
resetMessageCheckCallbacks()
}
// resetMessageAckCallbacks resets the message ack callbacks.
func resetMessageAckCallbacks() {
messageAckCallbacks = map[message.MessageType]*syncutil.Future[MessageAckCallback]{
message.MessageTypeDropPartition: syncutil.NewFuture[MessageAckCallback](),
message.MessageTypeImport: syncutil.NewFuture[MessageAckCallback](),
}
}
// MessageAckCallback is the callback function for the message type.
type MessageAckCallback = func(ctx context.Context, msg message.MutableMessage) error
type (
MessageAckCallback[H proto.Message, B proto.Message] = func(ctx context.Context, params message.SpecializedImmutableMessage[H, B]) error
messageInnerAckCallback = func(ctx context.Context, msgs message.ImmutableMessage) error
)
// messageAckCallbacks is the map of message type to the callback function.
var messageAckCallbacks map[message.MessageType]*syncutil.Future[MessageAckCallback]
var messageAckCallbacks map[message.MessageTypeWithVersion]*syncutil.Future[messageInnerAckCallback]
// RegisterMessageAckCallback registers the callback function for the message type.
func RegisterMessageAckCallback(typ message.MessageType, callback MessageAckCallback) {
// registerMessageAckCallback registers the callback function for the message type.
func registerMessageAckCallback[H proto.Message, B proto.Message](callback MessageAckCallback[H, B]) {
typ := message.MustGetMessageTypeWithVersion[H, B]()
future, ok := messageAckCallbacks[typ]
if !ok {
panic(fmt.Sprintf("the future of message callback for type %s is not registered", typ))
@ -40,12 +31,15 @@ func RegisterMessageAckCallback(typ message.MessageType, callback MessageAckCall
// only for test, the register callback should be called once and only once
return
}
future.Set(callback)
future.Set(func(ctx context.Context, msgs message.ImmutableMessage) error {
specializedMsg := message.MustAsSpecializedImmutableMessage[H, B](msgs)
return callback(ctx, specializedMsg)
})
}
// CallMessageAckCallback calls the callback function for the message type.
func CallMessageAckCallback(ctx context.Context, msg message.MutableMessage) error {
callbackFuture, ok := messageAckCallbacks[msg.MessageType()]
func CallMessageAckCallback(ctx context.Context, msg message.ImmutableMessage) error {
callbackFuture, ok := messageAckCallbacks[msg.MessageTypeWithVersion()]
if !ok {
// No callback need tobe called, return nil
return nil

View File

@ -8,8 +8,8 @@ import (
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
func TestMessageCallbackRegistration(t *testing.T) {
@ -18,22 +18,26 @@ func TestMessageCallbackRegistration(t *testing.T) {
// Test registering a callback
called := false
callback := func(ctx context.Context, msg message.MutableMessage) error {
callback := func(ctx context.Context, msg message.ImmutableDropPartitionMessageV1) error {
called = true
return nil
}
// Register callback for DropPartition message type
RegisterMessageAckCallback(message.MessageTypeDropPartition, callback)
RegisterDropPartitionMessageV1AckCallback(callback)
// Verify callback was registered
callbackFuture, ok := messageAckCallbacks[message.MessageTypeDropPartition]
callbackFuture, ok := messageAckCallbacks[message.MessageTypeDropPartitionV1]
assert.True(t, ok)
assert.NotNil(t, callbackFuture)
// Create a mock message
msg := mock_message.NewMockMutableMessage(t)
msg.EXPECT().MessageType().Return(message.MessageTypeDropPartition)
msg := message.NewDropPartitionMessageBuilderV1().
WithHeader(&message.DropPartitionMessageHeader{}).
WithBody(&message.DropPartitionRequest{}).
WithVChannel("v1").
MustBuildMutable().
WithTimeTick(1).
IntoImmutableMessage(rmq.NewRmqID(1))
// Call the callback
err := CallMessageAckCallback(context.Background(), msg)

View File

@ -5,26 +5,24 @@ import (
"fmt"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// MessageCheckCallback is the callback function for the message type.
type MessageCheckCallback = func(ctx context.Context, msg message.BroadcastMutableMessage) error
// resetMessageCheckCallbacks resets the message check callbacks.
func resetMessageCheckCallbacks() {
messageCheckCallbacks = map[message.MessageType]*syncutil.Future[MessageCheckCallback]{
message.MessageTypeImport: syncutil.NewFuture[MessageCheckCallback](),
}
}
type (
// MessageCheckCallback is the callback function for the message type.
MessageCheckCallback[H proto.Message, B proto.Message] = func(ctx context.Context, msg message.SpecializedBroadcastMessage[H, B]) error
messageInnerCheckCallback = func(ctx context.Context, msg message.BroadcastMutableMessage) error
)
// messageCheckCallbacks is the map of message type to the callback function.
var messageCheckCallbacks map[message.MessageType]*syncutil.Future[MessageCheckCallback]
var messageCheckCallbacks map[message.MessageTypeWithVersion]*syncutil.Future[messageInnerCheckCallback]
// RegisterMessageCheckCallback registers the callback function for the message type.
func RegisterMessageCheckCallback(typ message.MessageType, callback MessageCheckCallback) {
// registerMessageCheckCallback registers the callback function for the message type.
func registerMessageCheckCallback[H proto.Message, B proto.Message](callback MessageCheckCallback[H, B]) {
typ := message.MustGetMessageTypeWithVersion[H, B]()
future, ok := messageCheckCallbacks[typ]
if !ok {
panic(fmt.Sprintf("the future of check message callback for type %s is not registered", typ))
@ -33,12 +31,15 @@ func RegisterMessageCheckCallback(typ message.MessageType, callback MessageCheck
// only for test, the register callback should be called once and only once
return
}
future.Set(callback)
future.Set(func(ctx context.Context, msg message.BroadcastMutableMessage) error {
specializedMsg := message.MustAsSpecializedBroadcastMessage[H, B](msg)
return callback(ctx, specializedMsg)
})
}
// CallMessageCheckCallback calls the callback function for the message type.
func CallMessageCheckCallback(ctx context.Context, msg message.BroadcastMutableMessage) error {
callbackFuture, ok := messageCheckCallbacks[msg.MessageType()]
callbackFuture, ok := messageCheckCallbacks[msg.MessageTypeWithVersion()]
if !ok {
// No callback need tobe called, return nil
return nil

View File

@ -8,7 +8,6 @@ import (
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
@ -18,22 +17,24 @@ func TestCheckMessageCallbackRegistration(t *testing.T) {
// Test registering a callback
called := false
callback := func(ctx context.Context, msg message.BroadcastMutableMessage) error {
callback := func(ctx context.Context, msg message.BroadcastImportMessageV1) error {
called = true
return nil
}
// Register callback for DropPartition message type
RegisterMessageCheckCallback(message.MessageTypeImport, callback)
RegisterImportMessageV1CheckCallback(callback)
// Verify callback was registered
callbackFuture, ok := messageCheckCallbacks[message.MessageTypeImport]
callbackFuture, ok := messageCheckCallbacks[message.MessageTypeImportV1]
assert.True(t, ok)
assert.NotNil(t, callbackFuture)
// Create a mock message
msg := mock_message.NewMockBroadcastMutableMessage(t)
msg.EXPECT().MessageType().Return(message.MessageTypeImport)
msg := message.NewImportMessageBuilderV1().
WithHeader(&message.ImportMessageHeader{}).
WithBody(&message.ImportMsg{}).
WithBroadcast([]string{"v1"}).MustBuildBroadcast()
// Call the callback
err := CallMessageCheckCallback(context.Background(), msg)

View File

@ -0,0 +1,35 @@
package registry
import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// init the message ack callbacks
func init() {
resetMessageAckCallbacks()
resetMessageCheckCallbacks()
}
// resetMessageCheckCallbacks resets the message check callbacks.
func resetMessageCheckCallbacks() {
messageCheckCallbacks = map[message.MessageTypeWithVersion]*syncutil.Future[messageInnerCheckCallback]{
message.MessageTypeImportV1: syncutil.NewFuture[messageInnerCheckCallback](),
}
}
var (
RegisterDropPartitionMessageV1AckCallback = registerMessageAckCallback[*message.DropPartitionMessageHeader, *msgpb.DropPartitionRequest]
RegisterImportMessageV1AckCallback = registerMessageAckCallback[*message.ImportMessageHeader, *msgpb.ImportMsg]
)
// resetMessageAckCallbacks resets the message ack callbacks.
func resetMessageAckCallbacks() {
messageAckCallbacks = map[message.MessageTypeWithVersion]*syncutil.Future[messageInnerAckCallback]{
message.MessageTypeDropPartitionV1: syncutil.NewFuture[messageInnerAckCallback](),
message.MessageTypeImportV1: syncutil.NewFuture[messageInnerAckCallback](),
}
}
var RegisterImportMessageV1CheckCallback = registerMessageCheckCallback[*message.ImportMessageHeader, *msgpb.ImportMsg]

View File

@ -4,9 +4,9 @@ import (
"context"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
@ -19,12 +19,14 @@ type BroadcastService interface {
func NewBroadcastService(bc *syncutil.Future[broadcaster.Broadcaster]) BroadcastService {
return &broadcastServceImpl{
broadcaster: bc,
walName: util.MustSelectWALName(),
}
}
// broadcastServiceeeeImpl is the implementation of the broadcast service.
type broadcastServceImpl struct {
broadcaster *syncutil.Future[broadcaster.Broadcaster]
walName string
}
// Broadcast broadcasts the message to all channels.
@ -53,10 +55,15 @@ func (s *broadcastServceImpl) Ack(ctx context.Context, req *streamingpb.Broadcas
if err != nil {
return nil, err
}
if err := broadcaster.Ack(ctx, types.BroadcastAckRequest{
BroadcastID: req.BroadcastId,
VChannel: req.Vchannel,
}); err != nil {
if req.Message == nil {
// before 2.6.1, the request don't have the message field, only have the broadcast id and vchannel.
// so we need to use the legacy ack interface.
if err := broadcaster.LegacyAck(ctx, req.BroadcastId, req.Vchannel); err != nil {
return nil, err
}
return &streamingpb.BroadcastAckResponse{}, nil
}
if err := broadcaster.Ack(ctx, message.NewImmutableMessageFromProto(s.walName, req.Message)); err != nil {
return nil, err
}
return &streamingpb.BroadcastAckResponse{}, nil

View File

@ -11,7 +11,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/v2/mocks/proto/mock_streamingpb"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
@ -182,17 +181,11 @@ func newMockedConsumerImpl(t *testing.T, ctx context.Context, h message.Handler)
func newConsumeResponse(id message.MessageID, msg message.MutableMessage) *streamingpb.ConsumeResponse {
msg.WithTimeTick(tsoutil.GetCurrentTime())
msg.WithLastConfirmed(walimplstest.NewTestMessageID(0))
pb := msg.IntoMessageProto()
immutableMsg := msg.IntoImmutableMessage(id)
return &streamingpb.ConsumeResponse{
Response: &streamingpb.ConsumeResponse_Consume{
Consume: &streamingpb.ConsumeMessageReponse{
Message: &messagespb.ImmutableMessage{
Id: &messagespb.MessageID{
Id: id.Marshal(),
},
Payload: pb.Payload,
Properties: pb.Properties,
},
Message: immutableMsg.IntoImmutableMessageProto(),
},
},
}

View File

@ -13,6 +13,7 @@ import (
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_lazygrpc"
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_resolver"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil"
@ -20,7 +21,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/mocks/proto/mock_streamingpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -169,11 +169,7 @@ func newVersionedState(version int64, serverIDs map[uint64]bool) discoverer.Vers
func TestDial(t *testing.T) {
paramtable.Init()
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
assert.NoError(t, err)
defer etcd.StopEtcdServer()
c, err := etcd.GetEmbedEtcdClient()
assert.NoError(t, err)
c, _ := kvfactory.GetEtcdAndPath()
assert.NotNil(t, c)
client := NewManagerClient(c)

View File

@ -12,7 +12,6 @@ import (
"github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
@ -187,17 +186,10 @@ func (c *ConsumeServer) sendImmutableMessage(msg message.ImmutableMessage) (err
metricsGuard.Finish(err)
}()
pb := msg.IntoMessageProto()
// Send Consumed message to client and do metrics.
if err := c.consumeServer.SendConsumeMessage(&streamingpb.ConsumeMessageReponse{
ConsumerId: c.consumerID,
Message: &messagespb.ImmutableMessage{
Id: &messagespb.MessageID{
Id: msg.MessageID().Marshal(),
},
Payload: pb.Payload,
Properties: pb.Properties,
},
Message: msg.IntoImmutableMessageProto(),
}); err != nil {
return status.NewInner("send consume message failed: %s", err.Error())
}

View File

@ -24,7 +24,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
@ -189,9 +188,8 @@ func TestConsumerServeSendArm(t *testing.T) {
// test send.
msg := mock_message.NewMockImmutableMessage(t)
msg.EXPECT().MessageID().Return(walimplstest.NewTestMessageID(1))
msg.EXPECT().EstimateSize().Return(0)
msg.EXPECT().IntoMessageProto().Return(&messagespb.Message{})
msg.EXPECT().IntoImmutableMessageProto().Return(&messagespb.ImmutableMessage{})
scanCh <- msg
// test send txn message.

View File

@ -140,10 +140,7 @@ func (r *recoveryStorageImpl) GetSchema(ctx context.Context, vchannel string, ti
// ObserveMessage is called when a new message is observed.
func (r *recoveryStorageImpl) ObserveMessage(ctx context.Context, msg message.ImmutableMessage) error {
if h := msg.BroadcastHeader(); h != nil {
if err := streaming.WAL().Broadcast().Ack(ctx, types.BroadcastAckRequest{
BroadcastID: h.BroadcastID,
VChannel: msg.VChannel(),
}); err != nil {
if err := streaming.WAL().Broadcast().Ack(ctx, msg); err != nil {
r.Logger().Warn("failed to ack broadcast message", zap.Error(err))
return err
}

View File

@ -11,19 +11,14 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/json"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestSessionDiscoverer(t *testing.T) {
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
assert.NoError(t, err)
defer etcd.StopEtcdServer()
etcdClient, err := etcd.GetEmbedEtcdClient()
assert.NoError(t, err)
etcdClient, _ := kvfactory.GetEtcdAndPath()
targetVersion := "0.1.0"
d := NewSessionDiscoverer(etcdClient, "session/", false, ">="+targetVersion)
@ -58,7 +53,7 @@ func TestSessionDiscoverer(t *testing.T) {
var lastVersion typeutil.Version = typeutil.VersionInt64(-1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err = d.Discover(ctx, func(state VersionedState) error {
err := d.Discover(ctx, func(state VersionedState) error {
sessions := state.Sessions()
expectedSessions := make(map[int64]*sessionutil.SessionRaw, len(expected[idx]))

View File

@ -129,8 +129,9 @@ message BroadcastResponse {
}
message BroadcastAckRequest {
uint64 broadcast_id = 1; // broadcast id.
string vchannel = 2; // the vchannel that acked the message.
uint64 broadcast_id = 1 [deprecated = true]; // broadcast id.
string vchannel = 2 [deprecated = true]; // the vchannel that acked the message.
messages.ImmutableMessage message = 3; // the message that to be acked.
}
message BroadcastAckResponse {

File diff suppressed because it is too large Load Diff

View File

@ -9,12 +9,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
type BroadcastAckRequest struct {
// BroadcastID is the broadcast id of the ack request.
BroadcastID uint64
VChannel string
}
// BroadcastAppendResult is the result of broadcast append operation.
type BroadcastAppendResult struct {
BroadcastID uint64 // the broadcast id of the append operation.