enhance: refactor update replicate config operation using wal-broadcast-based DDL/DCL framework (#44560)

issue: #43897

- UpdateReplicateConfig operation will broadcast AlterReplicateConfig
message into all pchannels with cluster-exclusive-lock.
- Begin txn message will use commit message timetick now (to avoid
timetick rollback when CDC with txn message).
- If current cluster is secondary, the UpdateReplicateConfig will wait
until the replicate configuration is consistent with the config
replicated from primary.

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-10-15 15:26:01 +08:00 committed by GitHub
parent 822588302a
commit 8bf7d6ae72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 692 additions and 476 deletions

View File

@ -6,16 +6,20 @@ import (
"testing"
"time"
"github.com/apache/pulsar-client-go/pulsar"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/options"
pulsar2 "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/pulsar"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
@ -67,8 +71,8 @@ func TestReplicate(t *testing.T) {
},
CrossClusterTopology: []*commonpb.CrossClusterTopology{
{
SourceClusterId: "by-dev",
TargetClusterId: "primary",
SourceClusterId: "primary",
TargetClusterId: "by-dev",
},
},
})
@ -82,6 +86,53 @@ func TestReplicate(t *testing.T) {
t.Logf("cfg: %+v\n", cfg)
}
func TestReplicateCreateCollection(t *testing.T) {
t.Skip("cat not running without streaming service at background")
streaming.Init()
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "ID", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "Vector", DataType: schemapb.DataType_FloatVector},
},
}
schemaBytes, err := proto.Marshal(schema)
if err != nil {
panic(err)
}
msg := message.NewCreateCollectionMessageBuilderV1().
WithHeader(&message.CreateCollectionMessageHeader{
CollectionId: 1,
PartitionIds: []int64{2},
}).
WithBody(&msgpb.CreateCollectionRequest{
CollectionID: 1,
CollectionName: collectionName,
PartitionName: "partition",
PhysicalChannelNames: []string{
"primary-rootcoord-dml_0",
"primary-rootcoord-dml_1",
},
VirtualChannelNames: []string{
"primary-rootcoord-dml_0_1v0",
"primary-rootcoord-dml_1_1v1",
},
Schema: schemaBytes,
}).
WithBroadcast([]string{"primary-rootcoord-dml_0_1v0", "primary-rootcoord-dml_1_1v1"}).
MustBuildBroadcast()
msgs := msg.WithBroadcastID(100).SplitIntoMutableMessage()
for _, msg := range msgs {
immutableMsg := msg.WithLastConfirmedUseMessageID().WithTimeTick(1).IntoImmutableMessage(pulsar2.NewPulsarID(
pulsar.NewMessageID(1, 2, 3, 4),
))
_, err := streaming.WAL().Replicate().Append(context.Background(), message.NewReplicateMessage("primary", immutableMsg.IntoImmutableMessageProto()))
if err != nil {
panic(err)
}
}
}
func TestStreamingBroadcast(t *testing.T) {
t.Skip("cat not running without streaming service at background")
streamingutil.SetStreamingServiceEnabled()

View File

@ -221,6 +221,8 @@ type ReplicationCatalog interface {
}
// StreamingCoordCataLog is the interface for streamingcoord catalog
// All write operation of catalog is reliable, the error will only be returned if the ctx is canceled,
// otherwise it will retry until success.
type StreamingCoordCataLog interface {
ReplicationCatalog
@ -228,12 +230,14 @@ type StreamingCoordCataLog interface {
GetCChannel(ctx context.Context) (*streamingpb.CChannelMeta, error)
// SaveCChannel save the control channel to metastore.
// Only return error if the ctx is canceled, otherwise it will retry until success.
SaveCChannel(ctx context.Context, info *streamingpb.CChannelMeta) error
// GetVersion get the streaming version from metastore.
GetVersion(ctx context.Context) (*streamingpb.StreamingVersion, error)
// SaveVersion save the streaming version to metastore.
// Only return error if the ctx is canceled, otherwise it will retry until success.
SaveVersion(ctx context.Context, version *streamingpb.StreamingVersion) error
// physical channel watch related
@ -242,6 +246,7 @@ type StreamingCoordCataLog interface {
ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error)
// SavePChannel save a pchannel info to metastore.
// Only return error if the ctx is canceled, otherwise it will retry until success.
SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error
// ListBroadcastTask list all broadcast tasks.
@ -251,9 +256,11 @@ type StreamingCoordCataLog interface {
// SaveBroadcastTask save the broadcast task to metastore.
// Make the task recoverable after restart.
// When broadcast task is done, it will be removed from metastore.
// Only return error if the ctx is canceled, otherwise it will retry until success.
SaveBroadcastTask(ctx context.Context, broadcastID uint64, task *streamingpb.BroadcastTask) error
// SaveReplicateConfiguration saves the replicate configuration to metastore.
// Only return error if the ctx is canceled, otherwise it will retry until success.
SaveReplicateConfiguration(ctx context.Context, config *streamingpb.ReplicateConfigurationMeta, replicatingTasks []*streamingpb.ReplicatePChannelMeta) error
// GetReplicateConfiguration gets the replicate configuration from metastore.

View File

@ -36,6 +36,7 @@ import (
// │   └── cluster-2-pchannel-2
func NewCataLog(metaKV kv.MetaKv) metastore.StreamingCoordCataLog {
return &catalog{
// catalog should be reliable to write, ensure the data is consistent in memory and underlying meta storage.
metaKV: kv.NewReliableWriteMetaKv(metaKV),
}
}

View File

@ -467,24 +467,17 @@ func (_c *MockBalancer_UpdateBalancePolicy_Call) RunAndReturn(run func(context.C
return _c
}
// UpdateReplicateConfiguration provides a mock function with given fields: ctx, msgs
func (_m *MockBalancer) UpdateReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error {
_va := make([]interface{}, len(msgs))
for _i := range msgs {
_va[_i] = msgs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// UpdateReplicateConfiguration provides a mock function with given fields: ctx, result
func (_m *MockBalancer) UpdateReplicateConfiguration(ctx context.Context, result message.BroadcastResultAlterReplicateConfigMessageV2) error {
ret := _m.Called(ctx, result)
if len(ret) == 0 {
panic("no return value specified for UpdateReplicateConfiguration")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, ...message.ImmutableAlterReplicateConfigMessageV2) error); ok {
r0 = rf(ctx, msgs...)
if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastResultAlterReplicateConfigMessageV2) error); ok {
r0 = rf(ctx, result)
} else {
r0 = ret.Error(0)
}
@ -499,21 +492,14 @@ type MockBalancer_UpdateReplicateConfiguration_Call struct {
// UpdateReplicateConfiguration is a helper method to define mock.On call
// - ctx context.Context
// - msgs ...message.ImmutableAlterReplicateConfigMessageV2
func (_e *MockBalancer_Expecter) UpdateReplicateConfiguration(ctx interface{}, msgs ...interface{}) *MockBalancer_UpdateReplicateConfiguration_Call {
return &MockBalancer_UpdateReplicateConfiguration_Call{Call: _e.mock.On("UpdateReplicateConfiguration",
append([]interface{}{ctx}, msgs...)...)}
// - result message.BroadcastResultAlterReplicateConfigMessageV2
func (_e *MockBalancer_Expecter) UpdateReplicateConfiguration(ctx interface{}, result interface{}) *MockBalancer_UpdateReplicateConfiguration_Call {
return &MockBalancer_UpdateReplicateConfiguration_Call{Call: _e.mock.On("UpdateReplicateConfiguration", ctx, result)}
}
func (_c *MockBalancer_UpdateReplicateConfiguration_Call) Run(run func(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2)) *MockBalancer_UpdateReplicateConfiguration_Call {
func (_c *MockBalancer_UpdateReplicateConfiguration_Call) Run(run func(ctx context.Context, result message.BroadcastResultAlterReplicateConfigMessageV2)) *MockBalancer_UpdateReplicateConfiguration_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]message.ImmutableAlterReplicateConfigMessageV2, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(message.ImmutableAlterReplicateConfigMessageV2)
}
}
run(args[0].(context.Context), variadicArgs...)
run(args[0].(context.Context), args[1].(message.BroadcastResultAlterReplicateConfigMessageV2))
})
return _c
}
@ -523,7 +509,7 @@ func (_c *MockBalancer_UpdateReplicateConfiguration_Call) Return(_a0 error) *Moc
return _c
}
func (_c *MockBalancer_UpdateReplicateConfiguration_Call) RunAndReturn(run func(context.Context, ...message.ImmutableAlterReplicateConfigMessageV2) error) *MockBalancer_UpdateReplicateConfiguration_Call {
func (_c *MockBalancer_UpdateReplicateConfiguration_Call) RunAndReturn(run func(context.Context, message.BroadcastResultAlterReplicateConfigMessageV2) error) *MockBalancer_UpdateReplicateConfiguration_Call {
_c.Call.Return(run)
return _c
}

View File

@ -60,7 +60,7 @@ type Balancer interface {
MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error
// UpdateReplicateConfiguration updates the replicate configuration.
UpdateReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error
UpdateReplicateConfiguration(ctx context.Context, result message.BroadcastResultAlterReplicateConfigMessageV2) error
// Trigger is a hint to trigger a balance.
Trigger(ctx context.Context) error

View File

@ -46,6 +46,7 @@ func RecoverBalancer(
if err != nil {
return nil, errors.Wrap(err, "fail to recover channel manager")
}
manager.SetLogger(resource.Resource().Logger().With(log.FieldComponent("channel-manager")))
ctx, cancel := context.WithCancelCause(context.Background())
b := &balancerImpl{
ctx: ctx,
@ -122,7 +123,7 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb WatchChan
}
// UpdateReplicateConfiguration updates the replicate configuration.
func (b *balancerImpl) UpdateReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error {
func (b *balancerImpl) UpdateReplicateConfiguration(ctx context.Context, result message.BroadcastResultAlterReplicateConfigMessageV2) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
return status.NewOnShutdownError("balancer is closing")
}
@ -131,7 +132,7 @@ func (b *balancerImpl) UpdateReplicateConfiguration(ctx context.Context, msgs ..
ctx, cancel := contextutil.MergeContext(ctx, b.ctx)
defer cancel()
if err := b.channelMetaManager.UpdateReplicateConfiguration(ctx, msgs...); err != nil {
if err := b.channelMetaManager.UpdateReplicateConfiguration(ctx, result); err != nil {
return err
}
return nil

View File

@ -5,17 +5,17 @@ import (
"sync"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -114,10 +114,10 @@ func recoverFromConfigurationAndMeta(ctx context.Context, streamingVersion *stre
var c *PChannelMeta
if streamingVersion == nil {
// if streaming service has never been enabled, we treat all channels as read-only.
c = newPChannelMeta(newChannel, types.AccessModeRO)
c = NewPChannelMeta(newChannel, types.AccessModeRO)
} else {
// once the streaming service is enabled, we treat all channels as read-write.
c = newPChannelMeta(newChannel, types.AccessModeRW)
c = NewPChannelMeta(newChannel, types.AccessModeRW)
}
if _, ok := channels[c.ChannelID()]; !ok {
channels[c.ChannelID()] = c
@ -126,18 +126,23 @@ func recoverFromConfigurationAndMeta(ctx context.Context, streamingVersion *stre
return channels, metrics, nil
}
func recoverReplicateConfiguration(ctx context.Context) (*replicateConfigHelper, error) {
func recoverReplicateConfiguration(ctx context.Context) (*replicateutil.ConfigHelper, error) {
config, err := resource.Resource().StreamingCatalog().GetReplicateConfiguration(ctx)
if err != nil {
return nil, err
}
return newReplicateConfigHelper(config), nil
return replicateutil.MustNewConfigHelper(
paramtable.Get().CommonCfg.ClusterPrefix.GetValue(),
config.GetReplicateConfiguration(),
), nil
}
// ChannelManager manages the channels.
// ChannelManager is the `wal` of channel assignment and unassignment.
// Every operation applied to the streaming node should be recorded in ChannelManager first.
type ChannelManager struct {
log.Binder
cond *syncutil.ContextCond
channels map[ChannelID]*PChannelMeta
version typeutil.VersionInt64Pair
@ -147,7 +152,7 @@ type ChannelManager struct {
// null if no streaming service has been run.
// 1 if streaming service has been run once.
streamingEnableNotifiers []*syncutil.AsyncTaskNotifier[struct{}]
replicateConfig *replicateConfigHelper
replicateConfig *replicateutil.ConfigHelper
}
// RegisterStreamingEnabledNotifier registers a notifier into the balancer.
@ -202,9 +207,8 @@ func (cm *ChannelManager) MarkStreamingHasEnabled(ctx context.Context) error {
Version: 1,
}
if err := retry.Do(ctx, func() error {
return resource.Resource().StreamingCatalog().SaveVersion(ctx, cm.streamingVersion)
}, retry.AttemptAlways()); err != nil {
if err := resource.Resource().StreamingCatalog().SaveVersion(ctx, cm.streamingVersion); err != nil {
cm.Logger().Error("failed to save streaming version", zap.Error(err))
return err
}
@ -329,9 +333,8 @@ func (cm *ChannelManager) updatePChannelMeta(ctx context.Context, pChannelMetas
return nil
}
if err := retry.Do(ctx, func() error {
return resource.Resource().StreamingCatalog().SavePChannels(ctx, pChannelMetas)
}, retry.AttemptAlways()); err != nil {
if err := resource.Resource().StreamingCatalog().SavePChannels(ctx, pChannelMetas); err != nil {
cm.Logger().Error("failed to save pchannels", zap.Error(err))
return err
}
@ -391,51 +394,71 @@ func (cm *ChannelManager) WatchAssignmentResult(ctx context.Context, cb WatchCha
}
// UpdateReplicateConfiguration updates the in-memory replicate configuration.
func (cm *ChannelManager) UpdateReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error {
config := replicateutil.MustNewConfigHelper(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), msgs[0].Header().ReplicateConfiguration)
pchannels := make([]types.AckedCheckpoint, 0, len(msgs))
for _, msg := range msgs {
pchannels = append(pchannels, types.AckedCheckpoint{
Channel: funcutil.ToPhysicalChannel(msg.VChannel()),
MessageID: msg.LastConfirmedMessageID(),
LastConfirmedMessageID: msg.LastConfirmedMessageID(),
TimeTick: msg.TimeTick(),
})
}
func (cm *ChannelManager) UpdateReplicateConfiguration(ctx context.Context, result message.BroadcastResultAlterReplicateConfigMessageV2) error {
msg := result.Message
config := replicateutil.MustNewConfigHelper(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), msg.Header().ReplicateConfiguration)
cm.cond.L.Lock()
defer cm.cond.L.Unlock()
if cm.replicateConfig == nil {
cm.replicateConfig = newReplicateConfigHelperFromMessage(msgs[0])
} else {
// StartUpdating starts the updating process.
if !cm.replicateConfig.StartUpdating(config.GetReplicateConfiguration(), msgs[0].BroadcastHeader().VChannels) {
return nil
}
}
cm.replicateConfig.Apply(config.GetReplicateConfiguration(), pchannels)
dirtyConfig, dirtyCDCTasks, dirty := cm.replicateConfig.ConsumeIfDirty(config.GetReplicateConfiguration())
if !dirty {
// the meta is not dirty, so nothing updated, return it directly.
if cm.replicateConfig != nil && proto.Equal(config.GetReplicateConfiguration(), cm.replicateConfig.GetReplicateConfiguration()) {
// check if the replicate configuration is changed.
// if not changed, return it directly.
return nil
}
if err := resource.Resource().StreamingCatalog().SaveReplicateConfiguration(ctx, dirtyConfig, dirtyCDCTasks); err != nil {
newIncomingCDCTasks := cm.getNewIncomingTask(config, result.Results)
if err := resource.Resource().StreamingCatalog().SaveReplicateConfiguration(ctx,
&streamingpb.ReplicateConfigurationMeta{ReplicateConfiguration: config.GetReplicateConfiguration()},
newIncomingCDCTasks); err != nil {
cm.Logger().Error("failed to save replicate configuration", zap.Error(err))
return err
}
// If the acked result is nil, it means the all the channels are acked,
// so we can update the version and push the new replicate configuration into client.
if dirtyConfig.AckedResult == nil {
// update metrics.
cm.cond.UnsafeBroadcast()
cm.version.Local++
cm.metrics.UpdateAssignmentVersion(cm.version.Local)
}
cm.replicateConfig = config
cm.cond.UnsafeBroadcast()
cm.version.Local++
cm.metrics.UpdateAssignmentVersion(cm.version.Local)
return nil
}
// getNewIncomingTask gets the new incoming task from replicatingTasks.
func (cm *ChannelManager) getNewIncomingTask(newConfig *replicateutil.ConfigHelper, appendResults map[string]*message.AppendResult) []*streamingpb.ReplicatePChannelMeta {
incoming := newConfig.GetCurrentCluster()
var current *replicateutil.MilvusCluster
if cm.replicateConfig != nil {
current = cm.replicateConfig.GetCurrentCluster()
}
incomingReplicatingTasks := make([]*streamingpb.ReplicatePChannelMeta, 0, len(incoming.TargetClusters()))
for _, targetCluster := range incoming.TargetClusters() {
if current != nil && current.TargetCluster(targetCluster.GetClusterId()) != nil {
// target already exists, skip it.
continue
}
// TODO: support add new pchannels into existing clusters.
for _, pchannel := range targetCluster.GetPchannels() {
sourceClusterID := targetCluster.SourceCluster().ClusterId
sourcePChannel := targetCluster.MustGetSourceChannel(pchannel)
incomingReplicatingTasks = append(incomingReplicatingTasks, &streamingpb.ReplicatePChannelMeta{
SourceChannelName: sourcePChannel,
TargetChannelName: pchannel,
TargetCluster: targetCluster.MilvusCluster,
// The checkpoint is set as the initialized checkpoint for one cdc-task,
// when the startup of one cdc-task, the checkpoint returned from the target cluster is nil,
// so we set the initialized checkpoint here to start operation from here.
// the InitializedCheckpoint is always keep same semantic with the checkpoint at target cluster.
// so the cluster id is the source cluster id (aka. current cluster id)
InitializedCheckpoint: &commonpb.ReplicateCheckpoint{
ClusterId: sourceClusterID,
Pchannel: sourcePChannel,
MessageId: appendResults[sourcePChannel].LastConfirmedMessageID.IntoProto(),
TimeTick: appendResults[sourcePChannel].TimeTick,
},
})
}
}
return incomingReplicatingTasks
}
// applyAssignments applies the assignments.
func (cm *ChannelManager) applyAssignments(cb WatchChannelAssignmentsCallback) (typeutil.VersionInt64Pair, error) {
cm.cond.L.Lock()

View File

@ -2,17 +2,22 @@ package channel
import (
"context"
"math/rand"
"strings"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
@ -77,9 +82,6 @@ func TestChannelManager(t *testing.T) {
// Test success.
catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Unset()
catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pm []*streamingpb.PChannelMeta) error {
if rand.Int31n(3) == 0 {
return errors.New("save meta failure")
}
return nil
})
modified, err = m.AssignPChannels(ctx, map[ChannelID]types.PChannelInfoAssigned{newChannelID("test-channel"): {
@ -116,6 +118,242 @@ func TestChannelManager(t *testing.T) {
nodeID, ok = m.GetLatestWALLocated(ctx, "test-channel")
assert.False(t, ok)
assert.Zero(t, nodeID)
t.Run("UpdateReplicateConfiguration", func(t *testing.T) {
param, err := m.GetLatestChannelAssignment()
oldLocalVersion := param.Version.Local
assert.NoError(t, err)
assert.Equal(t, m.ReplicateRole(), replicateutil.RolePrimary)
// Test update replicate configurations
cfg := &commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev", Pchannels: []string{"by-dev-test-channel-1", "by-dev-test-channel-2"}},
{ClusterId: "by-dev2", Pchannels: []string{"by-dev2-test-channel-1", "by-dev2-test-channel-2"}},
},
CrossClusterTopology: []*commonpb.CrossClusterTopology{
{SourceClusterId: "by-dev", TargetClusterId: "by-dev2"},
},
}
msg := message.NewAlterReplicateConfigMessageBuilderV2().
WithHeader(&message.AlterReplicateConfigMessageHeader{
ReplicateConfiguration: cfg,
}).
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithBroadcast([]string{"by-dev-test-channel-1", "by-dev-test-channel-2"}).
MustBuildBroadcast()
result := message.BroadcastResultAlterReplicateConfigMessageV2{
Message: message.MustAsBroadcastAlterReplicateConfigMessageV2(msg),
Results: map[string]*message.AppendResult{
"by-dev-test-channel-1": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(2),
TimeTick: 1,
},
"by-dev-test-channel-2": {
MessageID: walimplstest.NewTestMessageID(3),
LastConfirmedMessageID: walimplstest.NewTestMessageID(4),
TimeTick: 1,
},
},
}
catalog.EXPECT().SaveReplicateConfiguration(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, config *streamingpb.ReplicateConfigurationMeta, replicatingTasks []*streamingpb.ReplicatePChannelMeta) error {
assert.True(t, proto.Equal(config.ReplicateConfiguration, cfg))
assert.Len(t, replicatingTasks, 2)
for _, task := range replicatingTasks {
result := result.Results[task.GetSourceChannelName()]
assert.True(t, result.LastConfirmedMessageID.EQ(message.MustUnmarshalMessageID(task.InitializedCheckpoint.MessageId)))
assert.Equal(t, result.TimeTick, task.InitializedCheckpoint.TimeTick)
assert.Equal(t, task.GetTargetChannelName(), strings.Replace(task.GetSourceChannelName(), "by-dev", "by-dev2", 1))
assert.Equal(t, task.GetTargetCluster().GetClusterId(), "by-dev2")
}
return nil
})
err = m.UpdateReplicateConfiguration(ctx, result)
assert.NoError(t, err)
param, err = m.GetLatestChannelAssignment()
assert.Equal(t, param.Version.Local, oldLocalVersion+1)
assert.NoError(t, err)
assert.Equal(t, m.ReplicateRole(), replicateutil.RolePrimary)
// test idempotency
err = m.UpdateReplicateConfiguration(ctx, result)
assert.NoError(t, err)
param, err = m.GetLatestChannelAssignment()
assert.Equal(t, param.Version.Local, oldLocalVersion+1)
assert.NoError(t, err)
assert.Equal(t, m.ReplicateRole(), replicateutil.RolePrimary)
// TODO: support add new pchannels into existing clusters.
// Add more pchannels into existing clusters.
// Clusters: []*commonpb.MilvusCluster{
// {ClusterId: "by-dev", Pchannels: []string{"by-dev-test-channel-1", "by-dev-test-channel-2", "by-dev-test-channel-3"}},
// {ClusterId: "by-dev2", Pchannels: []string{"by-dev2-test-channel-1", "by-dev2-test-channel-2", "by-dev2-test-channel-3"}},
// },
// CrossClusterTopology: []*commonpb.CrossClusterTopology{
// {SourceClusterId: "by-dev", TargetClusterId: "by-dev2"},
// },
// }
// msg = message.NewAlterReplicateConfigMessageBuilderV2().
// WithHeader(&message.AlterReplicateConfigMessageHeader{
// ReplicateConfiguration: cfg,
// }).
// WithBody(&message.AlterReplicateConfigMessageBody{}).
// WithBroadcast([]string{"by-dev-test-channel-1", "by-dev-test-channel-2", "by-dev-test-channel-3"}).
// MustBuildBroadcast()
// result = message.BroadcastResultAlterReplicateConfigMessageV2{
// Message: message.MustAsBroadcastAlterReplicateConfigMessageV2(msg),
// Results: map[string]*message.AppendResult{
// "by-dev-test-channel-1": {
// MessageID: walimplstest.NewTestMessageID(1),
// LastConfirmedMessageID: walimplstest.NewTestMessageID(2),
// TimeTick: 1,
// },
// "by-dev-test-channel-2": {
// MessageID: walimplstest.NewTestMessageID(3),
// LastConfirmedMessageID: walimplstest.NewTestMessageID(4),
// TimeTick: 1,
// },
// "by-dev-test-channel-3": {
// MessageID: walimplstest.NewTestMessageID(5),
// LastConfirmedMessageID: walimplstest.NewTestMessageID(6),
// TimeTick: 1,
// },
// },
// }
// catalog.EXPECT().SaveReplicateConfiguration(mock.Anything, mock.Anything, mock.Anything).Unset()
// catalog.EXPECT().SaveReplicateConfiguration(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
// func(ctx context.Context, config *streamingpb.ReplicateConfigurationMeta, replicatingTasks []*streamingpb.ReplicatePChannelMeta) error {
// assert.True(t, proto.Equal(config.ReplicateConfiguration, cfg))
// assert.Len(t, replicatingTasks, 1) // here should be two new incoming tasks.
// for _, task := range replicatingTasks {
// assert.Equal(t, task.GetSourceChannelName(), "by-dev-test-channel-3")
// result := result.Results[task.GetSourceChannelName()]
// assert.True(t, result.LastConfirmedMessageID.EQ(message.MustUnmarshalMessageID(task.InitializedCheckpoint.MessageId)))
// assert.Equal(t, result.TimeTick, task.InitializedCheckpoint.TimeTick)
// assert.Equal(t, task.GetTargetChannelName(), strings.Replace(task.GetSourceChannelName(), "by-dev", "by-dev2", 1))
// assert.Equal(t, task.GetTargetCluster().GetClusterId(), "by-dev2")
// }
// return nil
// })
// err = m.UpdateReplicateConfiguration(ctx, result)
// assert.NoError(t, err)
// Add new cluster into existing config.
cfg = &commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev", Pchannels: []string{"by-dev-test-channel-1", "by-dev-test-channel-2"}},
{ClusterId: "by-dev2", Pchannels: []string{"by-dev2-test-channel-1", "by-dev2-test-channel-2"}},
{ClusterId: "by-dev3", Pchannels: []string{"by-dev3-test-channel-1", "by-dev3-test-channel-2"}},
},
CrossClusterTopology: []*commonpb.CrossClusterTopology{
{SourceClusterId: "by-dev", TargetClusterId: "by-dev2"},
{SourceClusterId: "by-dev", TargetClusterId: "by-dev3"},
},
}
msg = message.NewAlterReplicateConfigMessageBuilderV2().
WithHeader(&message.AlterReplicateConfigMessageHeader{
ReplicateConfiguration: cfg,
}).
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithBroadcast([]string{"by-dev-test-channel-1", "by-dev-test-channel-2"}).
MustBuildBroadcast()
result = message.BroadcastResultAlterReplicateConfigMessageV2{
Message: message.MustAsBroadcastAlterReplicateConfigMessageV2(msg),
Results: map[string]*message.AppendResult{
"by-dev-test-channel-1": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(2),
TimeTick: 1,
},
"by-dev-test-channel-2": {
MessageID: walimplstest.NewTestMessageID(3),
LastConfirmedMessageID: walimplstest.NewTestMessageID(4),
TimeTick: 1,
},
},
}
catalog.EXPECT().SaveReplicateConfiguration(mock.Anything, mock.Anything, mock.Anything).Unset()
catalog.EXPECT().SaveReplicateConfiguration(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, config *streamingpb.ReplicateConfigurationMeta, replicatingTasks []*streamingpb.ReplicatePChannelMeta) error {
assert.True(t, proto.Equal(config.ReplicateConfiguration, cfg))
assert.Len(t, replicatingTasks, 2) // here should be two new incoming tasks.
for _, task := range replicatingTasks {
assert.Equal(t, task.GetTargetCluster().GetClusterId(), "by-dev3")
result := result.Results[task.GetSourceChannelName()]
assert.True(t, result.LastConfirmedMessageID.EQ(message.MustUnmarshalMessageID(task.InitializedCheckpoint.MessageId)))
assert.Equal(t, result.TimeTick, task.InitializedCheckpoint.TimeTick)
assert.Equal(t, task.GetTargetChannelName(), strings.Replace(task.GetSourceChannelName(), "by-dev", "by-dev3", 1))
assert.Equal(t, task.GetTargetCluster().GetClusterId(), "by-dev3")
}
return nil
})
err = m.UpdateReplicateConfiguration(ctx, result)
assert.NoError(t, err)
param, err = m.GetLatestChannelAssignment()
assert.NoError(t, err)
assert.Equal(t, param.Version.Local, oldLocalVersion+2)
assert.True(t, proto.Equal(param.ReplicateConfiguration, cfg))
assert.Equal(t, m.ReplicateRole(), replicateutil.RolePrimary)
// switch into secondary
cfg = &commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev", Pchannels: []string{"by-dev-test-channel-1", "by-dev-test-channel-2"}},
{ClusterId: "by-dev2", Pchannels: []string{"by-dev2-test-channel-1", "by-dev2-test-channel-2"}},
{ClusterId: "by-dev3", Pchannels: []string{"by-dev3-test-channel-1", "by-dev3-test-channel-2"}},
},
CrossClusterTopology: []*commonpb.CrossClusterTopology{
{SourceClusterId: "by-dev2", TargetClusterId: "by-dev"},
{SourceClusterId: "by-dev2", TargetClusterId: "by-dev3"},
},
}
msg = message.NewAlterReplicateConfigMessageBuilderV2().
WithHeader(&message.AlterReplicateConfigMessageHeader{
ReplicateConfiguration: cfg,
}).
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithBroadcast([]string{"by-dev-test-channel-1", "by-dev-test-channel-2"}).
MustBuildBroadcast()
result = message.BroadcastResultAlterReplicateConfigMessageV2{
Message: message.MustAsBroadcastAlterReplicateConfigMessageV2(msg),
Results: map[string]*message.AppendResult{
"by-dev-test-channel-1": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(2),
TimeTick: 1,
},
"by-dev-test-channel-2": {
MessageID: walimplstest.NewTestMessageID(3),
LastConfirmedMessageID: walimplstest.NewTestMessageID(4),
TimeTick: 1,
},
},
}
catalog.EXPECT().SaveReplicateConfiguration(mock.Anything, mock.Anything, mock.Anything).Unset()
catalog.EXPECT().SaveReplicateConfiguration(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, config *streamingpb.ReplicateConfigurationMeta, replicatingTasks []*streamingpb.ReplicatePChannelMeta) error {
assert.True(t, proto.Equal(config.ReplicateConfiguration, cfg))
assert.Len(t, replicatingTasks, 0) // here should be two new incoming tasks.
return nil
})
err = m.UpdateReplicateConfiguration(ctx, result)
assert.NoError(t, err)
err = m.UpdateReplicateConfiguration(ctx, result)
assert.NoError(t, err)
param, err = m.GetLatestChannelAssignment()
assert.NoError(t, err)
assert.Equal(t, param.Version.Local, oldLocalVersion+3)
assert.True(t, proto.Equal(param.ReplicateConfiguration, cfg))
assert.Equal(t, m.ReplicateRole(), replicateutil.RoleSecondary)
})
}
func TestStreamingEnableChecker(t *testing.T) {

View File

@ -9,8 +9,8 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
// newPChannelMeta creates a new PChannelMeta.
func newPChannelMeta(name string, accessMode types.AccessMode) *PChannelMeta {
// NewPChannelMeta creates a new PChannelMeta.
func NewPChannelMeta(name string, accessMode types.AccessMode) *PChannelMeta {
return &PChannelMeta{
inner: &streamingpb.PChannelMeta{
Channel: &streamingpb.PChannelInfo{

View File

@ -39,7 +39,7 @@ func TestPChannel(t *testing.T) {
},
}, pchannel.CurrentAssignment())
pchannel = newPChannelMeta("test-channel", types.AccessModeRW)
pchannel = NewPChannelMeta("test-channel", types.AccessModeRW)
assert.Equal(t, "test-channel", pchannel.Name())
assert.Equal(t, int64(1), pchannel.CurrentTerm())
assert.Empty(t, pchannel.AssignHistories())

View File

@ -1,117 +0,0 @@
package channel
import (
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
)
// replicateConfigHelper is a helper to manage the replicate configuration.
type replicateConfigHelper struct {
*replicateutil.ConfigHelper
ackedPendings *types.AckedResult
dirty bool
}
// newReplicateConfigHelperFromMessage creates a new replicate config helper from message.
func newReplicateConfigHelperFromMessage(replicateConfig message.ImmutableAlterReplicateConfigMessageV2) *replicateConfigHelper {
return newReplicateConfigHelper(&streamingpb.ReplicateConfigurationMeta{
ReplicateConfiguration: nil,
AckedResult: types.NewAckedPendings(replicateConfig.BroadcastHeader().VChannels).AckedResult,
})
}
// newReplicateConfigHelper creates a new replicate config helper from proto.
func newReplicateConfigHelper(replicateConfig *streamingpb.ReplicateConfigurationMeta) *replicateConfigHelper {
if replicateConfig == nil {
return nil
}
return &replicateConfigHelper{
ConfigHelper: replicateutil.MustNewConfigHelper(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), replicateConfig.GetReplicateConfiguration()),
ackedPendings: types.NewAckedPendingsFromProto(replicateConfig.GetAckedResult()),
dirty: false,
}
}
// StartUpdating starts the updating process.
// return true if the replicate configuration is changed, false otherwise.
func (rc *replicateConfigHelper) StartUpdating(config *commonpb.ReplicateConfiguration, pchannels []string) bool {
if rc.ConfigHelper != nil && proto.Equal(config, rc.GetReplicateConfiguration()) {
return false
}
if rc.ackedPendings == nil {
rc.ackedPendings = types.NewAckedPendings(pchannels)
}
return true
}
// Apply applies the replicate configuration to the wal.
func (rc *replicateConfigHelper) Apply(config *commonpb.ReplicateConfiguration, cp []types.AckedCheckpoint) {
if rc.ackedPendings == nil {
panic("ackedPendings is nil when applying replicate configuration")
}
for _, cp := range cp {
if rc.ackedPendings.Ack(cp) {
rc.dirty = true
}
}
}
// ConsumeIfDirty consumes the dirty part of the replicate configuration.
func (rc *replicateConfigHelper) ConsumeIfDirty(incoming *commonpb.ReplicateConfiguration) (config *streamingpb.ReplicateConfigurationMeta, replicatingTasks []*streamingpb.ReplicatePChannelMeta, dirty bool) {
if !rc.dirty {
return nil, nil, false
}
rc.dirty = false
if !rc.ackedPendings.IsAllAcked() {
// not all the channels are acked, return the current replicate configuration and acked result.
var cfg *commonpb.ReplicateConfiguration
if rc.ConfigHelper != nil {
cfg = rc.ConfigHelper.GetReplicateConfiguration()
}
return &streamingpb.ReplicateConfigurationMeta{
ReplicateConfiguration: cfg,
AckedResult: rc.ackedPendings.AckedResult,
}, nil, true
}
// all the channels are acked, return the new replicate configuration and acked result.
newConfig := replicateutil.MustNewConfigHelper(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), incoming)
newIncomingCDCTasks := rc.getNewIncomingTask(newConfig)
rc.ConfigHelper = newConfig
rc.ackedPendings = nil
return &streamingpb.ReplicateConfigurationMeta{
ReplicateConfiguration: incoming,
AckedResult: nil,
}, newIncomingCDCTasks, true
}
// getNewIncomingTask gets the new incoming task from replicatingTasks.
func (cm *replicateConfigHelper) getNewIncomingTask(newConfig *replicateutil.ConfigHelper) []*streamingpb.ReplicatePChannelMeta {
incoming := newConfig.GetCurrentCluster()
var current *replicateutil.MilvusCluster
if cm.ConfigHelper != nil {
current = cm.ConfigHelper.GetCurrentCluster()
}
incomingReplicatingTasks := make([]*streamingpb.ReplicatePChannelMeta, 0, len(incoming.TargetClusters()))
for _, targetCluster := range incoming.TargetClusters() {
if current != nil && current.TargetCluster(targetCluster.GetClusterId()) != nil {
// target already exists, skip it.
continue
}
for _, pchannel := range targetCluster.GetPchannels() {
incomingReplicatingTasks = append(incomingReplicatingTasks, &streamingpb.ReplicatePChannelMeta{
SourceChannelName: targetCluster.MustGetSourceChannel(pchannel),
TargetChannelName: pchannel,
TargetCluster: targetCluster.MilvusCluster,
})
}
}
return incomingReplicatingTasks
}

View File

@ -1,129 +0,0 @@
package channel
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
)
type ReplicateConfigHelperSuite struct {
suite.Suite
helper *replicateConfigHelper
}
func TestReplicateConfigHelperSuite(t *testing.T) {
suite.Run(t, new(ReplicateConfigHelperSuite))
}
func (s *ReplicateConfigHelperSuite) SetupTest() {
s.helper = nil
}
func (s *ReplicateConfigHelperSuite) TestNewReplicateConfigHelper() {
// Test nil input
helper := newReplicateConfigHelper(nil)
s.Nil(helper)
// Test valid input
meta := &streamingpb.ReplicateConfigurationMeta{
ReplicateConfiguration: &commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev"},
},
},
AckedResult: types.NewAckedPendings([]string{"p1", "p2"}).AckedResult,
}
helper = newReplicateConfigHelper(meta)
s.NotNil(helper)
s.NotNil(helper.ConfigHelper)
s.NotNil(helper.ackedPendings)
s.False(helper.dirty)
}
func (s *ReplicateConfigHelperSuite) TestStartUpdating() {
s.helper = &replicateConfigHelper{
ConfigHelper: nil,
ackedPendings: nil,
dirty: false,
}
config := &commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev"},
},
}
pchannels := []string{"p1", "p2"}
// First update should return true
changed := s.helper.StartUpdating(config, pchannels)
s.True(changed)
s.NotNil(s.helper.ackedPendings)
s.helper.Apply(config, []types.AckedCheckpoint{
{Channel: "p1", MessageID: walimplstest.NewTestMessageID(1), LastConfirmedMessageID: walimplstest.NewTestMessageID(1), TimeTick: 1},
{Channel: "p2", MessageID: walimplstest.NewTestMessageID(1), LastConfirmedMessageID: walimplstest.NewTestMessageID(1), TimeTick: 1},
})
s.helper.ConsumeIfDirty(config)
// Same config should return false
changed = s.helper.StartUpdating(config, pchannels)
s.False(changed)
}
func (s *ReplicateConfigHelperSuite) TestApply() {
s.helper = &replicateConfigHelper{
ConfigHelper: nil,
ackedPendings: types.NewAckedPendings([]string{"p1", "p2"}),
dirty: false,
}
config := &commonpb.ReplicateConfiguration{}
checkpoints := []types.AckedCheckpoint{
{Channel: "p1", MessageID: walimplstest.NewTestMessageID(1), LastConfirmedMessageID: walimplstest.NewTestMessageID(1), TimeTick: 1},
{Channel: "p2", MessageID: walimplstest.NewTestMessageID(1), LastConfirmedMessageID: walimplstest.NewTestMessageID(1), TimeTick: 1},
}
s.helper.Apply(config, checkpoints)
s.True(s.helper.dirty)
s.True(s.helper.ackedPendings.IsAllAcked())
}
func (s *ReplicateConfigHelperSuite) TestConsumeIfDirty() {
s.helper = &replicateConfigHelper{
ConfigHelper: nil,
ackedPendings: types.NewAckedPendings([]string{"p1", "p2"}),
dirty: true,
}
// Not all acked case
config, tasks, dirty := s.helper.ConsumeIfDirty(&commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev"},
},
})
s.NotNil(config)
s.Nil(tasks)
s.True(dirty)
s.False(s.helper.dirty)
// All acked case
s.helper.dirty = true
s.helper.ackedPendings.Ack(types.AckedCheckpoint{Channel: "p1", MessageID: walimplstest.NewTestMessageID(1), LastConfirmedMessageID: walimplstest.NewTestMessageID(1), TimeTick: 1})
s.helper.ackedPendings.Ack(types.AckedCheckpoint{Channel: "p2", MessageID: walimplstest.NewTestMessageID(1), LastConfirmedMessageID: walimplstest.NewTestMessageID(1), TimeTick: 1})
config, tasks, dirty = s.helper.ConsumeIfDirty(&commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev"},
},
})
s.NotNil(config)
s.NotNil(tasks)
s.True(dirty)
s.False(s.helper.dirty)
s.Nil(s.helper.ackedPendings)
}

View File

@ -8,7 +8,10 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
var singleton = syncutil.NewFuture[broadcaster.Broadcaster]()
var (
singleton = syncutil.NewFuture[broadcaster.Broadcaster]()
ErrNotPrimary = broadcaster.ErrNotPrimary
)
// Register registers the broadcaster.
func Register(broadcaster broadcaster.Broadcaster) {
@ -21,6 +24,7 @@ func GetWithContext(ctx context.Context) (broadcaster.Broadcaster, error) {
}
// StartBroadcastWithResourceKeys starts a broadcast with resource keys.
// Return ErrNotPrimary if the cluster is not primary, so no DDL message can be broadcasted.
func StartBroadcastWithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (broadcaster.BroadcastAPI, error) {
broadcaster, err := singleton.GetWithContext(ctx)
if err != nil {

View File

@ -60,7 +60,8 @@ func newBroadcastTaskManager(protos []*streamingpb.BroadcastTask) *broadcastTask
// if there's some pending messages that is not appended, it should be continued to be appended.
pendingTasks = append(pendingTasks, newPending)
} else {
// if there's no pending messages, it should be added to the pending ack callback tasks.
// if there's no pending messages, it should be added to the pending ack callback tasks
// to call the ack callback function.
pendingAckCallbackTasks = append(pendingAckCallbackTasks, task)
}
case streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_REPLICATED:
@ -98,7 +99,6 @@ type broadcastTaskManager struct {
lifetime *typeutil.Lifetime
mu *sync.Mutex
tasks map[uint64]*broadcastTask // map the broadcastID to the broadcastTaskState
tombstoneTasks []uint64 // the broadcastID of the tombstone tasks
resourceKeyLocker *resourceKeyLocker
metrics *broadcasterMetrics
broadcastScheduler *broadcasterScheduler // the scheduler of the broadcast task
@ -113,10 +113,7 @@ func (bm *broadcastTaskManager) WithResourceKeys(ctx context.Context, resourceKe
}
resourceKeys = bm.appendSharedClusterRK(resourceKeys...)
guards, err := bm.resourceKeyLocker.Lock(resourceKeys...)
if err != nil {
return nil, err
}
guards := bm.resourceKeyLocker.Lock(resourceKeys...)
if err := bm.checkClusterRole(ctx); err != nil {
// unlock the guards if the cluster role is not primary.
@ -138,7 +135,8 @@ func (bm *broadcastTaskManager) checkClusterRole(ctx context.Context) error {
return err
}
if b.ReplicateRole() != replicateutil.RolePrimary {
return status.NewReplicateViolation("cluster is not primary, cannot do any DDL/DCL")
// a non-primary cluster cannot do any broadcast operation.
return ErrNotPrimary
}
return nil
}

View File

@ -3,14 +3,19 @@ package broadcaster
import (
"context"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
var ErrNotPrimary = errors.New("cluster is not primary, cannot do any DDL/DCL")
type Broadcaster interface {
// WithResourceKeys sets the resource keys of the broadcast operation.
// It will acquire locks of the resource keys and return the broadcast api.
// Once the broadcast api is returned, the Close() method of the broadcast api should be called to release the resource safely.
// Return ErrNotPrimary if the cluster is not primary, so no DDL message can be broadcasted.
WithResourceKeys(ctx context.Context, resourceKeys ...message.ResourceKey) (BroadcastAPI, error)
// LegacyAck is the legacy ack interface for the 2.6.0 import message.

View File

@ -104,7 +104,7 @@ func (r *resourceKeyLocker) FastLock(keys ...message.ResourceKey) (*lockGuards,
}
// Lock locks the resource keys.
func (r *resourceKeyLocker) Lock(keys ...message.ResourceKey) (*lockGuards, error) {
func (r *resourceKeyLocker) Lock(keys ...message.ResourceKey) *lockGuards {
// lock the keys in order to avoid deadlock.
sortResourceKeys(keys)
g := &lockGuards{}
@ -116,7 +116,7 @@ func (r *resourceKeyLocker) Lock(keys ...message.ResourceKey) (*lockGuards, erro
}
g.append(&lockGuard{locker: r, key: key})
}
return g, nil
return g
}
// unlockWithKey unlocks the resource key.

View File

@ -47,21 +47,14 @@ func TestResourceKeyLocker(t *testing.T) {
n := rand.Intn(10)
if n < 3 {
// Lock the keys
guards, err := locker.Lock(keysToLock...)
if err != nil {
t.Errorf("Failed to lock keys: %v", err)
return
}
guards := locker.Lock(keysToLock...)
// Hold lock briefly
time.Sleep(time.Millisecond)
// Unlock the keys
guards.Unlock()
} else {
guards, err := locker.Lock(keysToLock...)
if err == nil {
guards.Unlock()
}
guards := locker.Lock(keysToLock...)
guards.Unlock()
}
}
done <- true
@ -84,11 +77,7 @@ func TestResourceKeyLocker(t *testing.T) {
go func() {
for i := 0; i < 100; i++ {
// Lock key1 then key2
guards, err := locker.Lock(key1, key2)
if err != nil {
t.Errorf("Failed to lock keys in order 1->2: %v", err)
return
}
guards := locker.Lock(key1, key2)
time.Sleep(time.Millisecond)
guards.Unlock()
}
@ -98,11 +87,7 @@ func TestResourceKeyLocker(t *testing.T) {
go func() {
for i := 0; i < 100; i++ {
// Lock key2 then key1
guards, err := locker.Lock(key2, key1)
if err != nil {
t.Errorf("Failed to lock keys in order 2->1: %v", err)
return
}
guards := locker.Lock(key2, key1)
time.Sleep(time.Millisecond)
guards.Unlock()
}

View File

@ -3,32 +3,37 @@ package service
import (
"context"
"github.com/cockroachdb/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/samber/lo"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/service/discover"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
)
var _ streamingpb.StreamingCoordAssignmentServiceServer = (*assignmentServiceImpl)(nil)
var errReplicateConfigurationSame = errors.New("same replicate configuration")
// NewAssignmentService returns a new assignment service.
func NewAssignmentService() streamingpb.StreamingCoordAssignmentServiceServer {
assignmentService := &assignmentServiceImpl{
listenerTotal: metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()),
}
// TODO: after recovering from wal, add it to here.
// registry.RegisterAlterReplicateConfigV2AckCallback(assignmentService.AlterReplicateConfiguration)
registry.RegisterAlterReplicateConfigV2AckCallback(assignmentService.alterReplicateConfiguration)
return assignmentService
}
@ -61,29 +66,63 @@ func (s *assignmentServiceImpl) UpdateReplicateConfiguration(ctx context.Context
log.Ctx(ctx).Info("UpdateReplicateConfiguration received", replicateutil.ConfigLogFields(config)...)
// TODO: after recovering from wal, do a broadcast operation here.
// check if the configuration is same.
// so even if current cluster is not primary, we can still make a idempotent success result.
if _, err := s.validateReplicateConfiguration(ctx, config); err != nil {
if errors.Is(err, errReplicateConfigurationSame) {
log.Ctx(ctx).Info("configuration is same, ignored")
return &streamingpb.UpdateReplicateConfigurationResponse{}, nil
}
return nil, err
}
broadcaster, err := broadcast.StartBroadcastWithResourceKeys(ctx, message.NewExclusiveClusterResourceKey())
if err != nil {
if errors.Is(err, broadcast.ErrNotPrimary) {
// current cluster is not primary, but we support an idempotent broadcast cross replication cluster.
// For example, we have A/B/C three clusters, and A is primary in the replicating topology.
// The milvus client can broadcast the UpdateReplicateConfiguration to all clusters,
// if all clusters returne success, we can consider the UpdateReplicateConfiguration is successful and sync up between A/B/C.
// so if current cluster is not primary, its UpdateReplicateConfiguration will be replicated by CDC,
// so we should wait until the replication configuration is changed into the same one.
return &streamingpb.UpdateReplicateConfigurationResponse{}, s.waitUntilPrimaryChangeOrConfigurationSame(ctx, config)
}
return nil, err
}
msg, err := s.validateReplicateConfiguration(ctx, config)
if err != nil {
if errors.Is(err, errReplicateConfigurationSame) {
log.Ctx(ctx).Info("configuration is same after cluster resource key is acquired, ignored")
return &streamingpb.UpdateReplicateConfigurationResponse{}, nil
}
return nil, err
}
_, err = broadcaster.Broadcast(ctx, msg)
if err != nil {
return nil, err
}
// TODO: After recovering from wal, we can get the immutable message from wal system.
// Now, we just mock the immutable message here.
mutableMsg := msg.SplitIntoMutableMessage()
mockMessages := make([]message.ImmutableAlterReplicateConfigMessageV2, 0)
for _, msg := range mutableMsg {
mockMessages = append(mockMessages,
message.MustAsImmutableAlterReplicateConfigMessageV2(msg.WithTimeTick(0).WithLastConfirmedUseMessageID().IntoImmutableMessage(rmq.NewRmqID(1))),
)
}
// TODO: After recovering from wal, remove the operation here.
if err := s.AlterReplicateConfiguration(ctx, mockMessages...); err != nil {
return nil, err
}
return &streamingpb.UpdateReplicateConfigurationResponse{}, nil
}
// waitUntilPrimaryChangeOrConfigurationSame waits until the primary changes or the configuration is same.
func (s *assignmentServiceImpl) waitUntilPrimaryChangeOrConfigurationSame(ctx context.Context, config *commonpb.ReplicateConfiguration) error {
b, err := balance.GetWithContext(ctx)
if err != nil {
return err
}
errDone := errors.New("done")
err = b.WatchChannelAssignments(ctx, func(param balancer.WatchChannelAssignmentsCallbackParam) error {
if proto.Equal(config, param.ReplicateConfiguration) {
return errDone
}
return nil
})
if errors.Is(err, errDone) {
return nil
}
return err
}
// validateReplicateConfiguration validates the replicate configuration.
func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Context, config *commonpb.ReplicateConfiguration) (message.BroadcastMutableMessage, error) {
balancer, err := balance.GetWithContext(ctx)
@ -96,6 +135,12 @@ func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Conte
if err != nil {
return nil, err
}
// double check if the configuration is same after resource key is acquired.
if proto.Equal(config, latestAssignment.ReplicateConfiguration) {
return nil, errReplicateConfigurationSame
}
pchannels := lo.MapToSlice(latestAssignment.PChannelView.Channels, func(_ channel.ChannelID, channel *channel.PChannelMeta) string {
return channel.Name()
})
@ -121,20 +166,17 @@ func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Conte
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithBroadcast(pchannels).
MustBuildBroadcast()
// TODO: After recovering from wal, remove the operation here.
b.WithBroadcastID(1)
return b, nil
}
// AlterReplicateConfiguration puts the replicate configuration into the balancer.
// alterReplicateConfiguration puts the replicate configuration into the balancer.
// It's a callback function of the broadcast service.
func (s *assignmentServiceImpl) AlterReplicateConfiguration(ctx context.Context, msgs ...message.ImmutableAlterReplicateConfigMessageV2) error {
func (s *assignmentServiceImpl) alterReplicateConfiguration(ctx context.Context, result message.BroadcastResultAlterReplicateConfigMessageV2) error {
balancer, err := balance.GetWithContext(ctx)
if err != nil {
return err
}
return balancer.UpdateReplicateConfiguration(ctx, msgs...)
return balancer.UpdateReplicateConfiguration(ctx, result)
}
// UpdateWALBalancePolicy is used to update the WAL balance policy.

View File

@ -0,0 +1,178 @@
package service
import (
"context"
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/pkg/v2/mocks/proto/mock_streamingpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func TestAssignmentService(t *testing.T) {
resource.InitForTest()
broadcast.ResetBroadcaster()
// Set up the balancer
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
})
balance.Register(b)
// Set up the broadcaster
fb := syncutil.NewFuture[broadcaster.Broadcaster]()
mba := mock_broadcaster.NewMockBroadcastAPI(t)
mb := mock_broadcaster.NewMockBroadcaster(t)
fb.Set(mb)
mba.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(mba, nil).Maybe()
mb.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil).Maybe()
mb.EXPECT().LegacyAck(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
broadcast.Register(mb)
// Test assignment discover
as := NewAssignmentService()
ss := mock_streamingpb.NewMockStreamingCoordAssignmentService_AssignmentDiscoverServer(t)
ss.EXPECT().Context().Return(context.Background()).Maybe()
ss.EXPECT().Recv().Return(nil, io.EOF).Maybe()
ss.EXPECT().Send(mock.Anything).Return(io.EOF).Maybe()
err := as.AssignmentDiscover(ss)
assert.Error(t, err)
// Test update WAL balance policy
b.EXPECT().UpdateBalancePolicy(context.Background(), mock.Anything).Return(&streamingpb.UpdateWALBalancePolicyResponse{}, nil).Maybe()
as.UpdateWALBalancePolicy(context.Background(), &streamingpb.UpdateWALBalancePolicyRequest{})
// Test update replicate configuration
// Test illegal replicate configuration
cfg := &commonpb.ReplicateConfiguration{}
b.EXPECT().GetLatestChannelAssignment().Return(&balancer.WatchChannelAssignmentsCallbackParam{
PChannelView: &channel.PChannelView{
Channels: map[channel.ChannelID]*channel.PChannelMeta{
{Name: "by-dev-1"}: channel.NewPChannelMeta("by-dev-1", types.AccessModeRW),
},
},
}, nil).Maybe()
_, err = as.UpdateReplicateConfiguration(context.Background(), &streamingpb.UpdateReplicateConfigurationRequest{
Configuration: cfg,
})
assert.Error(t, err)
//
cfg = &commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev", Pchannels: []string{"by-dev-1"}, ConnectionParam: &commonpb.ConnectionParam{Uri: "http://test:19530", Token: "by-dev"}},
{ClusterId: "test2", Pchannels: []string{"test2"}, ConnectionParam: &commonpb.ConnectionParam{Uri: "http://test2:19530", Token: "test2"}},
},
CrossClusterTopology: []*commonpb.CrossClusterTopology{
{SourceClusterId: "by-dev", TargetClusterId: "test2"},
},
}
// Test update pass.
_, err = as.UpdateReplicateConfiguration(context.Background(), &streamingpb.UpdateReplicateConfigurationRequest{
Configuration: cfg,
})
assert.NoError(t, err)
// Test idempotent
b.EXPECT().GetLatestChannelAssignment().Unset()
b.EXPECT().GetLatestChannelAssignment().Return(&balancer.WatchChannelAssignmentsCallbackParam{
PChannelView: &channel.PChannelView{
Channels: map[channel.ChannelID]*channel.PChannelMeta{
{Name: "by-dev-1"}: channel.NewPChannelMeta("by-dev-1", types.AccessModeRW),
},
},
ReplicateConfiguration: cfg,
}, nil).Maybe()
_, err = as.UpdateReplicateConfiguration(context.Background(), &streamingpb.UpdateReplicateConfigurationRequest{
Configuration: cfg,
})
assert.NoError(t, err)
// Test secondary path.
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Unset()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, rk ...message.ResourceKey) (broadcaster.BroadcastAPI, error) {
return nil, broadcaster.ErrNotPrimary
})
// Still idempotent.
_, err = as.UpdateReplicateConfiguration(context.Background(), &streamingpb.UpdateReplicateConfigurationRequest{
Configuration: cfg,
})
assert.NoError(t, err)
// Test update on secondary path, it should be block until the replicate configuration is changed.
b.EXPECT().GetLatestChannelAssignment().Unset()
b.EXPECT().GetLatestChannelAssignment().Return(&balancer.WatchChannelAssignmentsCallbackParam{
PChannelView: &channel.PChannelView{
Channels: map[channel.ChannelID]*channel.PChannelMeta{
{Name: "by-dev-1"}: channel.NewPChannelMeta("by-dev-1", types.AccessModeRW),
},
},
ReplicateConfiguration: &commonpb.ReplicateConfiguration{
Clusters: []*commonpb.MilvusCluster{
{ClusterId: "by-dev", Pchannels: []string{"by-dev-1"}, ConnectionParam: &commonpb.ConnectionParam{Uri: "http://test:19530", Token: "by-dev"}},
{ClusterId: "test2", Pchannels: []string{"test2"}, ConnectionParam: &commonpb.ConnectionParam{Uri: "http://test2:19530", Token: "test2"}},
},
CrossClusterTopology: []*commonpb.CrossClusterTopology{
{SourceClusterId: "test2", TargetClusterId: "by-dev"},
},
},
}, nil).Maybe()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).Unset()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
select {
case <-time.After(500 * time.Millisecond):
return cb(balancer.WatchChannelAssignmentsCallbackParam{
ReplicateConfiguration: cfg,
})
case <-ctx.Done():
return ctx.Err()
}
})
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = as.UpdateReplicateConfiguration(ctx, &streamingpb.UpdateReplicateConfigurationRequest{
Configuration: cfg,
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
_, err = as.UpdateReplicateConfiguration(context.Background(), &streamingpb.UpdateReplicateConfigurationRequest{
Configuration: cfg,
})
assert.NoError(t, err)
// Test callback
b.EXPECT().UpdateReplicateConfiguration(mock.Anything, mock.Anything).Return(nil)
msg := message.NewAlterReplicateConfigMessageBuilderV2().
WithHeader(&message.AlterReplicateConfigMessageHeader{
ReplicateConfiguration: cfg,
}).
WithBody(&message.AlterReplicateConfigMessageBody{}).
WithBroadcast([]string{"v1"}).
MustBuildBroadcast()
assert.NoError(t, registry.CallMessageAckCallback(context.Background(), msg, map[string]*message.AppendResult{}))
}

View File

@ -19,6 +19,8 @@ import (
)
func TestBroadcastService(t *testing.T) {
broadcast.ResetBroadcaster()
fb := syncutil.NewFuture[broadcaster.Broadcaster]()
mba := mock_broadcaster.NewMockBroadcastAPI(t)
mb := mock_broadcaster.NewMockBroadcaster(t)

View File

@ -374,7 +374,10 @@ func newImmutableTxnMesasgeFromWAL(
if err != nil {
return nil, err
}
// we don't need to modify the begin message's timetick, but set all the timetick of body messages.
// begin message will be used to replicate, so we also need to set it timetick and last confirmed message id into committed message.
var beginImmutable ImmutableMessage = begin
beginImmutable = beginImmutable.(*specializedImmutableMessageImpl[*BeginTxnMessageHeader, *BeginTxnMessageBody]).cloneForTxnBody(commit.TimeTick(), commit.LastConfirmedMessageID())
for idx, m := range body {
body[idx] = m.(*immutableMessageImpl).cloneForTxnBody(commit.TimeTick(), commit.LastConfirmedMessageID())
}
@ -385,7 +388,7 @@ func newImmutableTxnMesasgeFromWAL(
IntoImmutableMessage(commit.MessageID())
return &immutableTxnMessageImpl{
immutableMessageImpl: *immutableMessage.(*immutableMessageImpl),
begin: begin,
begin: MustAsImmutableBeginTxnMessageV2(beginImmutable),
messages: body,
commit: commit,
}, nil

View File

@ -83,73 +83,22 @@ func NewKeyLock[K comparable]() *KeyLock[K] {
return &keyLock
}
// Lock acquires a write lock for a given key.
func (k *KeyLock[K]) Lock(key K) {
k.keyLocksMutex.Lock()
// update the key map
if keyLock, ok := k.refLocks[key]; ok {
keyLock.ref()
k.keyLocksMutex.Unlock()
keyLock.mutex.Lock()
} else {
obj, err := refLockPoolPool.BorrowObject(ctx)
if err != nil {
log.Ctx(ctx).Error("BorrowObject failed", zap.Error(err))
k.keyLocksMutex.Unlock()
return
}
newKLock := obj.(*RefLock)
// newKLock := newRefLock()
newKLock.mutex.Lock()
k.refLocks[key] = newKLock
newKLock.ref()
k.keyLocksMutex.Unlock()
return
}
}
func (k *KeyLock[K]) TryLock(key K) bool {
k.keyLocksMutex.Lock()
// update the key map
if keyLock, ok := k.refLocks[key]; ok {
keyLock.ref()
k.keyLocksMutex.Unlock()
locked := keyLock.mutex.TryLock()
if !locked {
k.keyLocksMutex.Lock()
keyLock.unref()
if keyLock.refCounter == 0 {
_ = refLockPoolPool.ReturnObject(ctx, keyLock)
delete(k.refLocks, key)
}
k.keyLocksMutex.Unlock()
}
return locked
} else {
obj, err := refLockPoolPool.BorrowObject(ctx)
if err != nil {
log.Ctx(ctx).Error("BorrowObject failed", zap.Error(err))
k.keyLocksMutex.Unlock()
return false
}
newKLock := obj.(*RefLock)
// newKLock := newRefLock()
locked := newKLock.mutex.TryLock()
if !locked {
_ = refLockPoolPool.ReturnObject(ctx, newKLock)
k.keyLocksMutex.Unlock()
return false
}
k.refLocks[key] = newKLock
newKLock.ref()
k.keyLocksMutex.Unlock()
_ = k.tryLockInternal(key, func(mutex *sync.RWMutex) bool {
mutex.Lock()
return true
}
})
}
// TryLock attempts to acquire a write lock for a given key without blocking.
func (k *KeyLock[K]) TryLock(key K) bool {
return k.tryLockInternal(key, func(mutex *sync.RWMutex) bool {
return mutex.TryLock()
})
}
// Unlock releases a lock for a given key.
func (k *KeyLock[K]) Unlock(lockedKey K) {
k.keyLocksMutex.Lock()
defer k.keyLocksMutex.Unlock()
@ -166,40 +115,30 @@ func (k *KeyLock[K]) Unlock(lockedKey K) {
keyLock.mutex.Unlock()
}
// RLock acquires a read lock for a given key.
func (k *KeyLock[K]) RLock(key K) {
k.keyLocksMutex.Lock()
// update the key map
if keyLock, ok := k.refLocks[key]; ok {
keyLock.ref()
k.keyLocksMutex.Unlock()
keyLock.mutex.RLock()
} else {
obj, err := refLockPoolPool.BorrowObject(ctx)
if err != nil {
log.Ctx(ctx).Error("BorrowObject failed", zap.Error(err))
k.keyLocksMutex.Unlock()
return
}
newKLock := obj.(*RefLock)
// newKLock := newRefLock()
newKLock.mutex.RLock()
k.refLocks[key] = newKLock
newKLock.ref()
k.keyLocksMutex.Unlock()
return
}
_ = k.tryLockInternal(key, func(mutex *sync.RWMutex) bool {
mutex.RLock()
return true
})
}
// TryRLock attempts to acquire a read lock for a given key without blocking.
func (k *KeyLock[K]) TryRLock(key K) bool {
return k.tryLockInternal(key, func(mutex *sync.RWMutex) bool {
return mutex.TryRLock()
})
}
// tryLockInternal is the internal function to try lock the key.
func (k *KeyLock[K]) tryLockInternal(key K, tryLocker func(mutex *sync.RWMutex) bool) bool {
k.keyLocksMutex.Lock()
// update the key map
if keyLock, ok := k.refLocks[key]; ok {
keyLock.ref()
k.keyLocksMutex.Unlock()
locked := keyLock.mutex.TryRLock()
locked := tryLocker(&keyLock.mutex)
if !locked {
k.keyLocksMutex.Lock()
keyLock.unref()
@ -218,8 +157,7 @@ func (k *KeyLock[K]) TryRLock(key K) bool {
return false
}
newKLock := obj.(*RefLock)
// newKLock := newRefLock()
locked := newKLock.mutex.TryRLock()
locked := tryLocker(&newKLock.mutex)
if !locked {
_ = refLockPoolPool.ReturnObject(ctx, newKLock)
k.keyLocksMutex.Unlock()