diff --git a/internal/cdc/replication/replicatemanager/replicate_manager.go b/internal/cdc/replication/replicatemanager/replicate_manager.go index 226dcf6ba5..d10de54eac 100644 --- a/internal/cdc/replication/replicatemanager/replicate_manager.go +++ b/internal/cdc/replication/replicatemanager/replicate_manager.go @@ -18,12 +18,12 @@ package replicatemanager import ( "context" - "fmt" "strings" "github.com/samber/lo" "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/metastore/kv/streamingcoord" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" @@ -46,10 +46,6 @@ func NewReplicateManager() *replicateManager { } } -func bindReplicatorKey(replicateInfo *streamingpb.ReplicatePChannelMeta) string { - return fmt.Sprintf("%s_%s", replicateInfo.GetSourceChannelName(), replicateInfo.GetTargetChannelName()) -} - func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.ReplicatePChannelMeta) { logger := log.With( zap.String("sourceChannel", replicateInfo.GetSourceChannelName()), @@ -60,7 +56,7 @@ func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.Replicate // current cluster is not source cluster, skip create replicator return } - replicatorKey := bindReplicatorKey(replicateInfo) + replicatorKey := streamingcoord.BuildReplicatePChannelMetaKey(replicateInfo) _, ok := r.replicators[replicatorKey] if ok { logger.Debug("replicator already exists, skip create replicator") @@ -74,7 +70,7 @@ func (r *replicateManager) CreateReplicator(replicateInfo *streamingpb.Replicate } func (r *replicateManager) RemoveOutOfTargetReplicators(targetReplicatePChannels []*streamingpb.ReplicatePChannelMeta) { - targets := lo.KeyBy(targetReplicatePChannels, bindReplicatorKey) + targets := lo.KeyBy(targetReplicatePChannels, streamingcoord.BuildReplicatePChannelMetaKey) for replicatorKey, replicator := range r.replicators { if pchannelMeta, ok := targets[replicatorKey]; !ok { replicator.StopReplicate() diff --git a/internal/cdc/replication/replicatemanager/replicate_manager_test.go b/internal/cdc/replication/replicatemanager/replicate_manager_test.go index 8157ab6e49..cb0fe447fd 100644 --- a/internal/cdc/replication/replicatemanager/replicate_manager_test.go +++ b/internal/cdc/replication/replicatemanager/replicate_manager_test.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/cdc/cluster" "github.com/milvus-io/milvus/internal/cdc/resource" + "github.com/milvus-io/milvus/internal/metastore/kv/streamingcoord" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) @@ -60,7 +61,8 @@ func TestReplicateManager_CreateReplicator(t *testing.T) { // Verify replicator was created assert.Equal(t, 1, len(manager.replicators)) - replicator, exists := manager.replicators["test-source-channel-1_test-target-channel-1"] + key := streamingcoord.BuildReplicatePChannelMetaKey(replicateInfo) + replicator, exists := manager.replicators[key] assert.True(t, exists) assert.NotNil(t, replicator) @@ -77,12 +79,13 @@ func TestReplicateManager_CreateReplicator(t *testing.T) { // Verify second replicator was created assert.Equal(t, 2, len(manager.replicators)) - replicator2, exists := manager.replicators["test-source-channel-2_test-target-channel-2"] + key2 := streamingcoord.BuildReplicatePChannelMetaKey(replicateInfo2) + replicator2, exists := manager.replicators[key2] assert.True(t, exists) assert.NotNil(t, replicator2) // Verify first replicator still exists - replicator1, exists := manager.replicators["test-source-channel-1_test-target-channel-1"] + replicator1, exists := manager.replicators[key] assert.True(t, exists) assert.NotNil(t, replicator1) } diff --git a/internal/cdc/replication/replicatestream/msg_queue.go b/internal/cdc/replication/replicatestream/msg_queue.go index cecdd03985..786ba47412 100644 --- a/internal/cdc/replication/replicatestream/msg_queue.go +++ b/internal/cdc/replication/replicatestream/msg_queue.go @@ -31,11 +31,11 @@ type MsgQueue interface { // (via CleanupConfirmedMessages) or ctx is canceled. Enqueue(ctx context.Context, msg message.ImmutableMessage) error - // Dequeue returns the next message from the current read cursor and advances + // ReadNext returns the next message from the current read cursor and advances // the cursor by one. It does NOT delete the message from the queue storage. // Blocks when there are no readable messages (i.e., cursor is at tail) until // a new message is Enqueued or ctx is canceled. - Dequeue(ctx context.Context) (message.ImmutableMessage, error) + ReadNext(ctx context.Context) (message.ImmutableMessage, error) // SeekToHead moves the read cursor to the first not-yet-deleted message. SeekToHead() @@ -116,8 +116,8 @@ func (q *msgQueue) Enqueue(ctx context.Context, msg message.ImmutableMessage) er return nil } -// Dequeue returns the next message at the read cursor. Does not delete it. -func (q *msgQueue) Dequeue(ctx context.Context) (message.ImmutableMessage, error) { +// ReadNext returns the next message at the read cursor. Does not delete it. +func (q *msgQueue) ReadNext(ctx context.Context) (message.ImmutableMessage, error) { q.mu.Lock() defer q.mu.Unlock() diff --git a/internal/cdc/replication/replicatestream/msg_queue_test.go b/internal/cdc/replication/replicatestream/msg_queue_test.go index f08a1a3f91..8c10251a97 100644 --- a/internal/cdc/replication/replicatestream/msg_queue_test.go +++ b/internal/cdc/replication/replicatestream/msg_queue_test.go @@ -43,7 +43,7 @@ func TestMsgQueue_BasicOperations(t *testing.T) { assert.Equal(t, 1, queue.Len()) // Test dequeue - dequeuedMsg, err := queue.Dequeue(ctx) + dequeuedMsg, err := queue.ReadNext(ctx) assert.NoError(t, err) assert.Equal(t, msg1, dequeuedMsg) assert.Equal(t, 1, queue.Len()) // Length doesn't change after dequeue @@ -89,7 +89,7 @@ func TestMsgQueue_DequeueBlocking(t *testing.T) { ctxWithTimeout, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() - _, err := queue.Dequeue(ctxWithTimeout) + _, err := queue.ReadNext(ctxWithTimeout) assert.Error(t, err) // Context timeout will cause context.Canceled error, not DeadlineExceeded assert.Equal(t, context.Canceled, err) @@ -112,7 +112,7 @@ func TestMsgQueue_SeekToHead(t *testing.T) { assert.NoError(t, err) // Dequeue first message - dequeuedMsg, err := queue.Dequeue(ctx) + dequeuedMsg, err := queue.ReadNext(ctx) assert.NoError(t, err) assert.Equal(t, msg1, dequeuedMsg) @@ -120,7 +120,7 @@ func TestMsgQueue_SeekToHead(t *testing.T) { queue.SeekToHead() // Should be able to dequeue first message again - dequeuedMsg, err = queue.Dequeue(ctx) + dequeuedMsg, err = queue.ReadNext(ctx) assert.NoError(t, err) assert.Equal(t, msg1, dequeuedMsg) } @@ -155,7 +155,7 @@ func TestMsgQueue_CleanupConfirmedMessages(t *testing.T) { assert.Equal(t, msg2, cleanedMessages[1]) // First two messages should be removed - dequeuedMsg, err := queue.Dequeue(ctx) + dequeuedMsg, err := queue.ReadNext(ctx) assert.NoError(t, err) assert.Equal(t, msg3, dequeuedMsg) // Only msg3 remains } @@ -181,7 +181,7 @@ func TestMsgQueue_CleanupWithReadCursor(t *testing.T) { assert.NoError(t, err) // Dequeue first message (advance read cursor) - dequeuedMsg, err := queue.Dequeue(ctx) + dequeuedMsg, err := queue.ReadNext(ctx) assert.NoError(t, err) assert.Equal(t, msg1, dequeuedMsg) assert.Equal(t, 1, queue.readIdx) @@ -256,7 +256,7 @@ func TestMsgQueue_ConcurrentOperations(t *testing.T) { go func() { defer wg.Done() for i := 0; i < numMessages; i++ { - dequeuedMsg, err := queue.Dequeue(ctx) + dequeuedMsg, err := queue.ReadNext(ctx) assert.NoError(t, err) cleanedMessages := queue.CleanupConfirmedMessages(dequeuedMsg.TimeTick()) assert.Equal(t, 1, len(cleanedMessages)) diff --git a/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go b/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go index 60ada98061..e639155a3f 100644 --- a/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go +++ b/internal/cdc/replication/replicatestream/replicate_stream_client_impl.go @@ -19,7 +19,6 @@ package replicatestream import ( "context" "fmt" - "sync" "time" "github.com/cenkalti/backoff/v4" @@ -47,9 +46,9 @@ type replicateStreamClient struct { pendingMessages MsgQueue metrics ReplicateMetrics - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + finishedCh chan struct{} } // NewReplicateStreamClient creates a new ReplicateStreamClient. @@ -64,6 +63,7 @@ func NewReplicateStreamClient(ctx context.Context, replicateInfo *streamingpb.Re metrics: NewReplicateMetrics(replicateInfo), ctx: ctx1, cancel: cancel, + finishedCh: make(chan struct{}), } rs.metrics.OnConnect() @@ -80,6 +80,7 @@ func (r *replicateStreamClient) startInternal() { defer func() { r.metrics.OnDisconnect() logger.Info("replicate stream client closed") + close(r.finishedCh) }() backoff := backoff.NewExponentialBackOff() @@ -88,62 +89,49 @@ func (r *replicateStreamClient) startInternal() { backoff.MaxElapsedTime = 0 backoff.Reset() - disconnect := func(stopCh chan struct{}, err error) (reconnect bool) { - r.metrics.OnDisconnect() - close(stopCh) - r.client.CloseSend() - r.wg.Wait() - time.Sleep(backoff.NextBackOff()) - log.Warn("restart replicate stream client", zap.Error(err)) - return err != nil - } - for { + // Create a local context for this connection that can be canceled + // when we need to stop the send/recv loops + connCtx, connCancel := context.WithCancel(r.ctx) + + milvusClient, err := resource.Resource().ClusterClient().CreateMilvusClient(connCtx, r.replicateInfo.GetTargetCluster()) + if err != nil { + logger.Warn("create milvus client failed, retry...", zap.Error(err)) + time.Sleep(backoff.NextBackOff()) + continue + } + client, err := milvusClient.CreateReplicateStream(connCtx) + if err != nil { + logger.Warn("create milvus replicate stream failed, retry...", zap.Error(err)) + time.Sleep(backoff.NextBackOff()) + continue + } + logger.Info("replicate stream client service started") + + // reset client and pending messages + r.client = client + r.pendingMessages.SeekToHead() + + sendCh := r.startSendLoop(connCtx) + recvCh := r.startRecvLoop(connCtx) + select { case <-r.ctx.Done(): + case <-sendCh: + case <-recvCh: + } + + connCancel() // Cancel the connection context + <-sendCh + <-recvCh // wait for send/recv loops to exit + + if r.ctx.Err() != nil { + logger.Info("close replicate stream client by ctx done") return - default: - milvusClient, err := resource.Resource().ClusterClient().CreateMilvusClient(r.ctx, r.replicateInfo.GetTargetCluster()) - if err != nil { - logger.Warn("create milvus client failed, retry...", zap.Error(err)) - time.Sleep(backoff.NextBackOff()) - continue - } - client, err := milvusClient.CreateReplicateStream(r.ctx) - if err != nil { - logger.Warn("create milvus replicate stream failed, retry...", zap.Error(err)) - time.Sleep(backoff.NextBackOff()) - continue - } - logger.Info("replicate stream client service started") - - // reset client and pending messages - if oldClient := r.client; oldClient != nil { - r.metrics.OnReconnect() - } - r.client = client - r.pendingMessages.SeekToHead() - - stopCh := make(chan struct{}) - sendErrCh := r.startSendLoop(stopCh) - recvErrCh := r.startRecvLoop(stopCh) - - select { - case <-r.ctx.Done(): - r.client.CloseSend() - r.wg.Wait() - return - case err := <-sendErrCh: - reconnect := disconnect(stopCh, err) - if !reconnect { - return - } - case err := <-recvErrCh: - reconnect := disconnect(stopCh, err) - if !reconnect { - return - } - } + } else { + logger.Warn("restart replicate stream client") + r.metrics.OnDisconnect() + time.Sleep(backoff.NextBackOff()) } } } @@ -160,41 +148,43 @@ func (r *replicateStreamClient) Replicate(msg message.ImmutableMessage) error { } } -func (r *replicateStreamClient) startSendLoop(stopCh <-chan struct{}) <-chan error { - errCh := make(chan error, 1) - r.wg.Add(1) +func (r *replicateStreamClient) startSendLoop(ctx context.Context) <-chan struct{} { + ch := make(chan struct{}) go func() { - defer r.wg.Done() - errCh <- r.sendLoop(stopCh) + _ = r.sendLoop(ctx) + close(ch) }() - return errCh + return ch } -func (r *replicateStreamClient) startRecvLoop(stopCh <-chan struct{}) <-chan error { - errCh := make(chan error, 1) - r.wg.Add(1) +func (r *replicateStreamClient) startRecvLoop(ctx context.Context) <-chan struct{} { + ch := make(chan struct{}) go func() { - defer r.wg.Done() - errCh <- r.recvLoop(stopCh) + _ = r.recvLoop(ctx) + close(ch) }() - return errCh + return ch } -func (r *replicateStreamClient) sendLoop(stopCh <-chan struct{}) error { +func (r *replicateStreamClient) sendLoop(ctx context.Context) (err error) { logger := log.With( zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), ) + defer func() { + if err != nil { + logger.Warn("send loop closed by unexpected error", zap.Error(err)) + } else { + logger.Info("send loop closed") + } + r.client.CloseSend() + }() for { select { - case <-r.ctx.Done(): - logger.Info("send loop closed by ctx done") - return nil - case <-stopCh: - logger.Info("send loop closed by stopCh") + case <-ctx.Done(): return nil default: - msg, err := r.pendingMessages.Dequeue(r.ctx) + msg, err := r.pendingMessages.ReadNext(ctx) if err != nil { // context canceled, return nil return nil @@ -211,11 +201,7 @@ func (r *replicateStreamClient) sendLoop(stopCh <-chan struct{}) error { // send txn messages err = txnMsg.RangeOver(func(msg message.ImmutableMessage) error { - err = r.sendMessage(msg) - if err != nil { - return err - } - return nil + return r.sendMessage(msg) }) if err != nil { return err @@ -227,11 +213,11 @@ func (r *replicateStreamClient) sendLoop(stopCh <-chan struct{}) error { if err != nil { return err } - continue - } - err = r.sendMessage(msg) - if err != nil { - return err + } else { + err = r.sendMessage(msg) + if err != nil { + return err + } } } } @@ -266,18 +252,21 @@ func (r *replicateStreamClient) sendMessage(msg message.ImmutableMessage) (err e return r.client.Send(req) } -func (r *replicateStreamClient) recvLoop(stopCh <-chan struct{}) error { +func (r *replicateStreamClient) recvLoop(ctx context.Context) (err error) { logger := log.With( zap.String("sourceChannel", r.replicateInfo.GetSourceChannelName()), zap.String("targetChannel", r.replicateInfo.GetTargetChannelName()), ) + defer func() { + if err != nil { + logger.Warn("recv loop closed by unexpected error", zap.Error(err)) + } else { + logger.Info("recv loop closed") + } + }() for { select { - case <-r.ctx.Done(): - logger.Info("recv loop closed by ctx done") - return nil - case <-stopCh: - logger.Info("recv loop closed by stopCh") + case <-ctx.Done(): return nil default: resp, err := r.client.Recv() @@ -331,5 +320,5 @@ func (r *replicateStreamClient) handleAlterReplicateConfigMessage(msg message.Im func (r *replicateStreamClient) Close() { r.cancel() - r.wg.Wait() + <-r.finishedCh } diff --git a/internal/metastore/kv/streamingcoord/kv_catalog.go b/internal/metastore/kv/streamingcoord/kv_catalog.go index 2ac153bd48..e38198b0ae 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog.go @@ -238,6 +238,12 @@ func (c *catalog) ListReplicatePChannels(ctx context.Context) ([]*streamingpb.Re return infos, nil } +func BuildReplicatePChannelMetaKey(meta *streamingpb.ReplicatePChannelMeta) string { + targetClusterID := meta.GetTargetCluster().GetClusterId() + sourceChannelName := meta.GetSourceChannelName() + return buildReplicatePChannelPath(targetClusterID, sourceChannelName) +} + func buildReplicatePChannelPath(targetClusterID, sourceChannelName string) string { return fmt.Sprintf("%s%s-%s", ReplicatePChannelMetaPrefix, targetClusterID, sourceChannelName) } diff --git a/internal/streamingcoord/server/service/assignment.go b/internal/streamingcoord/server/service/assignment.go index c067238c4a..5fd33f5c27 100644 --- a/internal/streamingcoord/server/service/assignment.go +++ b/internal/streamingcoord/server/service/assignment.go @@ -107,7 +107,9 @@ func (s *assignmentServiceImpl) validateReplicateConfiguration(ctx context.Conte // validate the configuration itself currentClusterID := paramtable.Get().CommonCfg.ClusterPrefix.GetValue() - validator := replicateutil.NewReplicateConfigValidator(config, currentClusterID, pchannels) + currentConfig := latestAssignment.ReplicateConfiguration + incomingConfig := config + validator := replicateutil.NewReplicateConfigValidator(incomingConfig, currentConfig, currentClusterID, pchannels) if err := validator.Validate(); err != nil { log.Ctx(ctx).Warn("UpdateReplicateConfiguration fail", zap.Error(err)) return nil, err diff --git a/pkg/util/replicateutil/config_validator.go b/pkg/util/replicateutil/config_validator.go index 64786b9d77..94632e8332 100644 --- a/pkg/util/replicateutil/config_validator.go +++ b/pkg/util/replicateutil/config_validator.go @@ -19,6 +19,7 @@ package replicateutil import ( "fmt" "net/url" + "slices" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -29,26 +30,28 @@ type ReplicateConfigValidator struct { currentClusterID string currentPChannels []string clusterMap map[string]*commonpb.MilvusCluster - config *commonpb.ReplicateConfiguration + incomingConfig *commonpb.ReplicateConfiguration + currentConfig *commonpb.ReplicateConfiguration } // NewReplicateConfigValidator creates a new validator instance with the given configuration -func NewReplicateConfigValidator(config *commonpb.ReplicateConfiguration, currentClusterID string, currentPChannels []string) *ReplicateConfigValidator { +func NewReplicateConfigValidator(incomingConfig, currentConfig *commonpb.ReplicateConfiguration, currentClusterID string, currentPChannels []string) *ReplicateConfigValidator { validator := &ReplicateConfigValidator{ currentClusterID: currentClusterID, currentPChannels: currentPChannels, clusterMap: make(map[string]*commonpb.MilvusCluster), - config: config, + incomingConfig: incomingConfig, + currentConfig: currentConfig, } return validator } // Validate performs all validation checks on the configuration func (v *ReplicateConfigValidator) Validate() error { - if v.config == nil { + if v.incomingConfig == nil { return fmt.Errorf("config cannot be nil") } - clusters := v.config.GetClusters() + clusters := v.incomingConfig.GetClusters() if len(clusters) == 0 { return fmt.Errorf("clusters list cannot be empty") } @@ -59,13 +62,19 @@ func (v *ReplicateConfigValidator) Validate() error { if err := v.validateRelevance(); err != nil { return err } - topologies := v.config.GetCrossClusterTopology() + topologies := v.incomingConfig.GetCrossClusterTopology() if err := v.validateTopologyEdgeUniqueness(topologies); err != nil { return err } if err := v.validateTopologyTypeConstraint(topologies); err != nil { return err } + // If currentConfig is provided, perform comparison validation + if v.currentConfig != nil { + if err := v.validateConfigComparison(); err != nil { + return err + } + } return nil } @@ -73,6 +82,7 @@ func (v *ReplicateConfigValidator) Validate() error { func (v *ReplicateConfigValidator) validateClusterBasic(clusters []*commonpb.MilvusCluster) error { var expectedPchannelCount int var firstClusterID string + uriSet := make(map[string]string) for i, cluster := range clusters { if cluster == nil { return fmt.Errorf("cluster at index %d is nil", i) @@ -98,6 +108,11 @@ func (v *ReplicateConfigValidator) validateClusterBasic(clusters []*commonpb.Mil if err != nil { return fmt.Errorf("cluster '%s' has invalid URI format: '%s'", clusterID, uri) } + // Check URI uniqueness + if existingClusterID, exists := uriSet[uri]; exists { + return fmt.Errorf("duplicate URI found: '%s' is used by both cluster '%s' and cluster '%s'", uri, existingClusterID, clusterID) + } + uriSet[uri] = clusterID // pchannels validation: non-empty pchannels := cluster.GetPchannels() if len(pchannels) == 0 { @@ -112,10 +127,6 @@ func (v *ReplicateConfigValidator) validateClusterBasic(clusters []*commonpb.Mil if pchannelSet[pchannel] { return fmt.Errorf("cluster '%s' has duplicate pchannel: '%s'", clusterID, pchannel) } - // Validate that pchannel starts with clusterID as prefix - if !strings.HasPrefix(pchannel, clusterID) { - return fmt.Errorf("cluster '%s' has pchannel '%s' that does not start with clusterID as prefix", clusterID, pchannel) - } pchannelSet[pchannel] = true } // pchannels count consistency across all clusters @@ -225,6 +236,59 @@ func (v *ReplicateConfigValidator) validateTopologyTypeConstraint(topologies []* return nil } +// validateConfigComparison validates that for clusters with the same ClusterID, +// no cluster attributes can be changed +func (v *ReplicateConfigValidator) validateConfigComparison() error { + currentClusters := v.currentConfig.GetClusters() + currentClusterMap := make(map[string]*commonpb.MilvusCluster) + + // Build current cluster map + for _, cluster := range currentClusters { + if cluster != nil { + currentClusterMap[cluster.GetClusterId()] = cluster + } + } + + // Compare each incoming cluster with current cluster + for _, incomingCluster := range v.incomingConfig.GetClusters() { + clusterID := incomingCluster.GetClusterId() + currentCluster, exists := currentClusterMap[clusterID] + if exists { + // Cluster exists in current config, validate that only ConnectionParam can change + if err := v.validateClusterConsistency(currentCluster, incomingCluster); err != nil { + return err + } + } + // If cluster doesn't exist in current config, it's a new cluster, which is allowed + } + + return nil +} + +// validateClusterConsistency validates that no cluster attributes can be changed between current and incoming cluster +func (v *ReplicateConfigValidator) validateClusterConsistency(current, incoming *commonpb.MilvusCluster) error { + // Check Pchannels consistency + if !slices.Equal(current.GetPchannels(), incoming.GetPchannels()) { + return fmt.Errorf("cluster '%s' pchannels cannot be changed: current=%v, incoming=%v", + current.GetClusterId(), current.GetPchannels(), incoming.GetPchannels()) + } + + // Check ConnectionParam consistency + currentConn := current.GetConnectionParam() + incomingConn := incoming.GetConnectionParam() + + if currentConn.GetUri() != incomingConn.GetUri() { + return fmt.Errorf("cluster '%s' connection_param.uri cannot be changed: current=%s, incoming=%s", + current.GetClusterId(), currentConn.GetUri(), incomingConn.GetUri()) + } + if currentConn.GetToken() != incomingConn.GetToken() { + return fmt.Errorf("cluster '%s' connection_param.token cannot be changed", + current.GetClusterId()) + } + + return nil +} + func equalIgnoreOrder(a, b []string) bool { if len(a) != len(b) { return false diff --git a/pkg/util/replicateutil/config_validator_test.go b/pkg/util/replicateutil/config_validator_test.go index bec9afb396..d1848cf886 100644 --- a/pkg/util/replicateutil/config_validator_test.go +++ b/pkg/util/replicateutil/config_validator_test.go @@ -34,7 +34,7 @@ func createValidValidatorConfig() *commonpb.ReplicateConfiguration { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1", "cluster-1-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, { ClusterId: "cluster-2", @@ -42,7 +42,7 @@ func createValidValidatorConfig() *commonpb.ReplicateConfiguration { Uri: "localhost:19531", Token: "test-token", }, - Pchannels: []string{"cluster-2-channel-1", "cluster-2-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, }, CrossClusterTopology: []*commonpb.CrossClusterTopology{ @@ -64,7 +64,7 @@ func createStarTopologyConfig() *commonpb.ReplicateConfiguration { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"center-cluster-channel-1", "center-cluster-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, { ClusterId: "leaf-cluster-1", @@ -72,7 +72,7 @@ func createStarTopologyConfig() *commonpb.ReplicateConfiguration { Uri: "localhost:19531", Token: "test-token", }, - Pchannels: []string{"leaf-cluster-1-channel-1", "leaf-cluster-1-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, { ClusterId: "leaf-cluster-2", @@ -80,7 +80,7 @@ func createStarTopologyConfig() *commonpb.ReplicateConfiguration { Uri: "localhost:19532", Token: "test-token", }, - Pchannels: []string{"leaf-cluster-2-channel-1", "leaf-cluster-2-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, }, CrossClusterTopology: []*commonpb.CrossClusterTopology{ @@ -98,12 +98,24 @@ func createStarTopologyConfig() *commonpb.ReplicateConfiguration { func TestNewReplicateConfigValidator(t *testing.T) { config := createValidValidatorConfig() - currentPChannels := []string{"cluster-1-channel-1", "cluster-1-channel-2"} + currentPChannels := []string{"channel-1", "channel-2"} - t.Run("success - create validator", func(t *testing.T) { - validator := NewReplicateConfigValidator(config, "cluster-1", currentPChannels) + t.Run("success - create validator without current config", func(t *testing.T) { + validator := NewReplicateConfigValidator(config, nil, "cluster-1", currentPChannels) assert.NotNil(t, validator) - assert.Equal(t, config, validator.config) + assert.Equal(t, config, validator.incomingConfig) + assert.Equal(t, currentPChannels, validator.currentPChannels) + assert.NotNil(t, validator.clusterMap) + assert.Equal(t, 0, len(validator.clusterMap)) // clusterMap is built during validation + assert.Nil(t, validator.currentConfig) + }) + + t.Run("success - create validator with current config", func(t *testing.T) { + currentConfig := createValidValidatorConfig() + validator := NewReplicateConfigValidator(config, currentConfig, "cluster-1", currentPChannels) + assert.NotNil(t, validator) + assert.Equal(t, config, validator.incomingConfig) + assert.Equal(t, currentConfig, validator.currentConfig) assert.Equal(t, currentPChannels, validator.currentPChannels) assert.NotNil(t, validator.clusterMap) assert.Equal(t, 0, len(validator.clusterMap)) // clusterMap is built during validation @@ -111,17 +123,27 @@ func TestNewReplicateConfigValidator(t *testing.T) { } func TestReplicateConfigValidator_Validate(t *testing.T) { - t.Run("success - valid configuration", func(t *testing.T) { + t.Run("success - valid configuration without current config", func(t *testing.T) { config := createValidValidatorConfig() - currentPChannels := []string{"cluster-1-channel-1", "cluster-1-channel-2"} - validator := NewReplicateConfigValidator(config, "cluster-1", currentPChannels) + currentPChannels := []string{"channel-1", "channel-2"} + validator := NewReplicateConfigValidator(config, nil, "cluster-1", currentPChannels) err := validator.Validate() assert.NoError(t, err) }) - t.Run("error - nil config", func(t *testing.T) { - validator := NewReplicateConfigValidator(nil, "cluster-1", []string{}) + t.Run("success - valid configuration with current config", func(t *testing.T) { + config := createValidValidatorConfig() + currentConfig := createValidValidatorConfig() + currentPChannels := []string{"channel-1", "channel-2"} + validator := NewReplicateConfigValidator(config, currentConfig, "cluster-1", currentPChannels) + + err := validator.Validate() + assert.NoError(t, err) + }) + + t.Run("error - nil incoming config", func(t *testing.T) { + validator := NewReplicateConfigValidator(nil, nil, "cluster-1", []string{}) err := validator.Validate() assert.Error(t, err) assert.Contains(t, err.Error(), "config cannot be nil") @@ -132,7 +154,7 @@ func TestReplicateConfigValidator_Validate(t *testing.T) { Clusters: []*commonpb.MilvusCluster{}, CrossClusterTopology: []*commonpb.CrossClusterTopology{}, } - validator := NewReplicateConfigValidator(config, "cluster-1", []string{}) + validator := NewReplicateConfigValidator(config, nil, "cluster-1", []string{}) err := validator.Validate() assert.Error(t, err) assert.Contains(t, err.Error(), "clusters list cannot be empty") @@ -148,7 +170,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1", "cluster-1-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, { ClusterId: "cluster-2", @@ -156,7 +178,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19531", Token: "test-token", }, - Pchannels: []string{"cluster-2-channel-1", "cluster-2-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, } @@ -180,7 +202,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1"}, + Pchannels: []string{"channel-1"}, }, } @@ -222,7 +244,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1"}, + Pchannels: []string{"channel-1"}, }, } @@ -240,7 +262,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { { ClusterId: "cluster-1", ConnectionParam: nil, - Pchannels: []string{"cluster-1-channel-1"}, + Pchannels: []string{"channel-1"}, }, } @@ -261,7 +283,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1"}, + Pchannels: []string{"channel-1"}, }, } @@ -282,7 +304,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "invalid-uri-format", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1"}, + Pchannels: []string{"channel-1"}, }, } @@ -324,7 +346,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"", "cluster-1-channel-2"}, + Pchannels: []string{"", "channel-2"}, }, } @@ -345,7 +367,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1", "cluster-1-channel-1"}, + Pchannels: []string{"channel-1", "channel-1"}, }, } @@ -358,27 +380,6 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { assert.Contains(t, err.Error(), "has duplicate pchannel") }) - t.Run("error - pchannel doesn't start with cluster ID", func(t *testing.T) { - clusters := []*commonpb.MilvusCluster{ - { - ClusterId: "cluster-1", - ConnectionParam: &commonpb.ConnectionParam{ - Uri: "localhost:19530", - Token: "test-token", - }, - Pchannels: []string{"wrong-prefix-channel"}, - }, - } - - validator := &ReplicateConfigValidator{ - clusterMap: make(map[string]*commonpb.MilvusCluster), - } - - err := validator.validateClusterBasic(clusters) - assert.Error(t, err) - assert.Contains(t, err.Error(), "does not start with clusterID as prefix") - }) - t.Run("error - inconsistent pchannel count", func(t *testing.T) { clusters := []*commonpb.MilvusCluster{ { @@ -387,7 +388,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1", "cluster-1-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, { ClusterId: "cluster-2", @@ -395,7 +396,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19531", Token: "test-token", }, - Pchannels: []string{"cluster-2-channel-1"}, // Only 1 channel instead of 2 + Pchannels: []string{"channel-1"}, // Only 1 channel instead of 2 }, } @@ -416,7 +417,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19530", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1"}, + Pchannels: []string{"channel-1"}, }, { ClusterId: "cluster-1", // Duplicate cluster ID @@ -424,7 +425,7 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { Uri: "localhost:19531", Token: "test-token", }, - Pchannels: []string{"cluster-1-channel-1"}, + Pchannels: []string{"channel-1"}, }, } @@ -436,17 +437,46 @@ func TestReplicateConfigValidator_validateClusterBasic(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "duplicate clusterID found") }) + + t.Run("error - duplicate URI across clusters", func(t *testing.T) { + clusters := []*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1"}, + }, + { + ClusterId: "cluster-2", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", // Same URI as cluster-1 + Token: "test-token", + }, + Pchannels: []string{"channel-1"}, + }, + } + + validator := &ReplicateConfigValidator{ + clusterMap: make(map[string]*commonpb.MilvusCluster), + } + + err := validator.validateClusterBasic(clusters) + assert.Error(t, err) + assert.Contains(t, err.Error(), "duplicate URI found") + }) } func TestReplicateConfigValidator_validateRelevance(t *testing.T) { t.Run("success - current cluster included and pchannels match", func(t *testing.T) { validator := &ReplicateConfigValidator{ currentClusterID: "cluster-1", - currentPChannels: []string{"cluster-1-channel-1", "cluster-1-channel-2"}, + currentPChannels: []string{"channel-1", "channel-2"}, clusterMap: map[string]*commonpb.MilvusCluster{ "cluster-1": { ClusterId: "cluster-1", - Pchannels: []string{"cluster-1-channel-1", "cluster-1-channel-2"}, + Pchannels: []string{"channel-1", "channel-2"}, }, }, } @@ -458,11 +488,11 @@ func TestReplicateConfigValidator_validateRelevance(t *testing.T) { t.Run("error - current cluster not included", func(t *testing.T) { validator := &ReplicateConfigValidator{ currentClusterID: "cluster-1", - currentPChannels: []string{"cluster-1-channel-1"}, + currentPChannels: []string{"channel-1"}, clusterMap: map[string]*commonpb.MilvusCluster{ "cluster-2": { ClusterId: "cluster-2", - Pchannels: []string{"cluster-2-channel-1"}, + Pchannels: []string{"channel-1"}, }, }, } @@ -475,11 +505,11 @@ func TestReplicateConfigValidator_validateRelevance(t *testing.T) { t.Run("error - pchannels don't match", func(t *testing.T) { validator := &ReplicateConfigValidator{ currentClusterID: "cluster-1", - currentPChannels: []string{"cluster-1-channel-1", "cluster-1-channel-2"}, + currentPChannels: []string{"channel-1", "channel-2"}, clusterMap: map[string]*commonpb.MilvusCluster{ "cluster-1": { ClusterId: "cluster-1", - Pchannels: []string{"cluster-1-channel-1", "cluster-1-channel-3"}, // Different channels + Pchannels: []string{"channel-1", "channel-3"}, // Different channels }, }, } @@ -715,3 +745,189 @@ func TestEqualIgnoreOrder(t *testing.T) { assert.False(t, result) }) } + +func TestReplicateConfigValidator_validateConfigComparison(t *testing.T) { + // Helper function to create a config with specific clusters + createConfigWithClusters := func(clusters []*commonpb.MilvusCluster) *commonpb.ReplicateConfiguration { + return &commonpb.ReplicateConfiguration{ + Clusters: clusters, + CrossClusterTopology: []*commonpb.CrossClusterTopology{}, + } + } + + t.Run("success - no current config", func(t *testing.T) { + config := createValidValidatorConfig() + currentPChannels := []string{"channel-1", "channel-2"} + validator := NewReplicateConfigValidator(config, nil, "cluster-1", currentPChannels) + + err := validator.Validate() + assert.NoError(t, err) + }) + + t.Run("success - new cluster added", func(t *testing.T) { + currentConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + incomingConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + { + ClusterId: "cluster-2", // New cluster + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19531", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + validator := NewReplicateConfigValidator(incomingConfig, currentConfig, "cluster-1", []string{"channel-1", "channel-2"}) + err := validator.Validate() + assert.NoError(t, err) + }) + + t.Run("error - ConnectionParam changed", func(t *testing.T) { + currentConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "old-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + incomingConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "new-token", // Token changed - should fail + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + // Test the config comparison validation directly + validator := &ReplicateConfigValidator{ + incomingConfig: incomingConfig, + currentConfig: currentConfig, + } + err := validator.validateConfigComparison() + assert.Error(t, err) + assert.Contains(t, err.Error(), "connection_param.token cannot be changed") + }) + + t.Run("error - pchannels changed", func(t *testing.T) { + currentConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + incomingConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-3"}, // Different pchannels + }, + }) + + // Test the config comparison validation directly + validator := &ReplicateConfigValidator{ + incomingConfig: incomingConfig, + currentConfig: currentConfig, + } + err := validator.validateConfigComparison() + assert.Error(t, err) + assert.Contains(t, err.Error(), "pchannels cannot be changed") + }) + + t.Run("error - ConnectionParam URI changed", func(t *testing.T) { + currentConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + incomingConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19531", // URI changed - should fail + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + // Test the config comparison validation directly + validator := &ReplicateConfigValidator{ + incomingConfig: incomingConfig, + currentConfig: currentConfig, + } + err := validator.validateConfigComparison() + assert.Error(t, err) + assert.Contains(t, err.Error(), "connection_param.uri cannot be changed") + }) + + t.Run("success - same cluster with no changes", func(t *testing.T) { + currentConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + incomingConfig := createConfigWithClusters([]*commonpb.MilvusCluster{ + { + ClusterId: "cluster-1", // Same cluster ID + ConnectionParam: &commonpb.ConnectionParam{ + Uri: "localhost:19530", + Token: "test-token", + }, + Pchannels: []string{"channel-1", "channel-2"}, + }, + }) + + // Test the config comparison validation directly + validator := &ReplicateConfigValidator{ + incomingConfig: incomingConfig, + currentConfig: currentConfig, + } + err := validator.validateConfigComparison() + assert.NoError(t, err) // This should pass since it's the same cluster + }) +}