enhance: support milvus version when coordinator startup (#46456)

issue: #46451

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Session versioning added to validate coordinator compatibility during
registration and active takeover.

* **Changes**
* Active–standby flow simplified: standby-to-active activation now
always enabled and initialized unconditionally.
* Registration uses version-aware transactions to ensure version
consistency during takeover.
  * Startup/health startup path streamlined.

* **Tests**
* Added version-key integration test; removed test for disabling
active-standby.
  * Updated flush test to assert rate-limiter errors occur.

* **Chores**
  * Removed centralized connection manager and its test suite.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-12-22 20:29:18 +08:00 committed by GitHub
parent 341388479a
commit 2edc9ee236
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 197 additions and 897 deletions

View File

@ -75,9 +75,6 @@ type mixCoordImpl struct {
factory dependency.Factory
enableActiveStandBy bool
activateFunc func() error
metricsRequest *metricsinfo.MetricsRequest
metaKVCreator func() kv.MetaKv
@ -97,13 +94,12 @@ func NewMixCoordServer(c context.Context, factory dependency.Factory) (*mixCoord
dataCoordServer := datacoord.CreateServer(c, factory)
return &mixCoordImpl{
ctx: ctx,
cancel: cancel,
rootcoordServer: rootCoordServer,
queryCoordServer: queryCoordServer,
datacoordServer: dataCoordServer,
enableActiveStandBy: Params.MixCoordCfg.EnableActiveStandby.GetAsBool(),
factory: factory,
ctx: ctx,
cancel: cancel,
rootcoordServer: rootCoordServer,
queryCoordServer: queryCoordServer,
datacoordServer: dataCoordServer,
factory: factory,
}, nil
}
@ -115,21 +111,17 @@ func (s *mixCoordImpl) Register() error {
metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.MixCoordRole).Inc()
log.Info("MixCoord Register Finished")
}
if s.enableActiveStandBy {
go func() {
if err := s.session.ProcessActiveStandBy(s.activateFunc); err != nil {
if s.ctx.Err() == context.Canceled {
log.Info("standby process canceled due to server shutdown")
return
}
log.Error("failed to activate standby server", zap.Error(err))
panic(err)
go func() {
if err := s.session.ProcessActiveStandBy(s.activateFunc); err != nil {
if s.ctx.Err() == context.Canceled {
log.Info("standby process canceled due to server shutdown")
return
}
afterRegister()
}()
} else {
log.Error("failed to activate standby server", zap.Error(err))
panic(err)
}
afterRegister()
}
}()
return nil
}
@ -142,33 +134,25 @@ func (s *mixCoordImpl) Init() error {
s.factory.Init(Params)
s.initKVCreator()
s.initStreamingCoord()
if s.enableActiveStandBy {
s.activateFunc = func() error {
log.Info("mixCoord switch from standby to active, activating")
s.UpdateStateCode(commonpb.StateCode_StandBy)
log.Info("MixCoord enter standby mode successfully")
return nil
}
var err error
s.initOnce.Do(func() {
if err = s.initInternal(); err != nil {
log.Error("mixCoord init failed", zap.Error(err))
}
})
if err != nil {
return err
}
log.Info("mixCoord startup success", zap.String("address", s.session.GetAddress()))
s.startAndUpdateHealthy()
return err
func (s *mixCoordImpl) activateFunc() error {
log.Info("mixCoord switch from standby to active, activating")
var err error
s.initOnce.Do(func() {
if err = s.initInternal(); err != nil {
log.Error("mixCoord init failed", zap.Error(err))
}
s.UpdateStateCode(commonpb.StateCode_StandBy)
log.Info("MixCoord enter standby mode successfully")
} else {
s.initOnce.Do(func() {
if initErr = s.initInternal(); initErr != nil {
log.Error("mixCoord init failed", zap.Error(initErr))
}
})
})
if err != nil {
return err
}
return initErr
log.Info("mixCoord startup success", zap.String("address", s.session.GetAddress()))
s.startAndUpdateHealthy()
return err
}
func (s *mixCoordImpl) initInternal() error {
@ -235,9 +219,6 @@ func (s *mixCoordImpl) initKVCreator() {
}
func (s *mixCoordImpl) Start() error {
if !s.enableActiveStandBy {
s.startAndUpdateHealthy()
}
return nil
}
@ -380,7 +361,7 @@ func (s *mixCoordImpl) initStreamingCoord() {
func (s *mixCoordImpl) initSession() error {
s.session = sessionutil.NewSession(s.ctx)
s.session.Init(typeutil.MixCoordRole, s.address, true, true)
s.session.SetEnableActiveStandBy(s.enableActiveStandBy)
s.session.SetEnableActiveStandBy(true)
s.rootcoordServer.SetSession(s.session)
s.datacoordServer.SetSession(s.session)
s.queryCoordServer.SetSession(s.session)
@ -388,9 +369,6 @@ func (s *mixCoordImpl) initSession() error {
return nil
}
func (s *mixCoordImpl) startHealthCheck() {
}
func (s *mixCoordImpl) SetAddress(address string) {
s.address = address
s.rootcoordServer.SetAddress(address)

View File

@ -103,61 +103,6 @@ func TestMixcoord_EnableActiveStandby(t *testing.T) {
assert.NoError(t, err)
}
// make sure the main functions work well when EnableActiveStandby=false
func TestMixcoord_DisableActiveStandby(t *testing.T) {
randVal := rand.Int()
paramtable.Init()
testutil.ResetEnvironment()
Params.Save("etcd.rootPath", fmt.Sprintf("/%d", randVal))
// Need to reset global etcd to follow new path
kvfactory.CloseEtcdClient()
paramtable.Get().Save(Params.MixCoordCfg.EnableActiveStandby.Key, "false")
paramtable.Get().Save(Params.CommonCfg.RootCoordTimeTick.Key, fmt.Sprintf("rootcoord-time-tick-%d", randVal))
paramtable.Get().Save(Params.CommonCfg.RootCoordStatistics.Key, fmt.Sprintf("rootcoord-statistics-%d", randVal))
paramtable.Get().Save(Params.CommonCfg.RootCoordDml.Key, fmt.Sprintf("rootcoord-dml-test-%d", randVal))
ctx := context.Background()
coreFactory := dependency.NewDefaultFactory(true)
etcdCli, err := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
Params.EtcdCfg.Endpoints.GetAsStrings(),
Params.EtcdCfg.EtcdTLSCert.GetValue(),
Params.EtcdCfg.EtcdTLSKey.GetValue(),
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
assert.NoError(t, err)
defer etcdCli.Close()
core, err := NewMixCoordServer(ctx, coreFactory)
core.SetEtcdClient(etcdCli)
assert.NoError(t, err)
core.SetTiKVClient(tikv.SetupLocalTxn())
err = core.Init()
assert.NoError(t, err)
assert.Equal(t, commonpb.StateCode_Initializing, core.GetStateCode())
err = core.Start()
assert.NoError(t, err)
core.session.TriggerKill = false
err = core.Register()
assert.NoError(t, err)
assert.Equal(t, commonpb.StateCode_Healthy, core.GetStateCode())
resp, err := core.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
CollectionName: "unexist",
})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
err = core.Stop()
assert.NoError(t, err)
}
func TestMixCoord_FlushAll(t *testing.T) {
t.Run("success", func(t *testing.T) {
mockey.PatchConvey("test flush all success", t, func() {

View File

@ -1,455 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package distributed
import (
"context"
"os"
"sync"
"syscall"
"time"
"github.com/cockroachdb/errors"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/samber/lo"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// ConnectionManager handles connection to other components of the system
type ConnectionManager struct {
session *sessionutil.Session
dependencies map[string]struct{}
rootCoord rootcoordpb.RootCoordClient
rootCoordMu sync.RWMutex
queryCoord querypb.QueryCoordClient
queryCoordMu sync.RWMutex
dataCoord datapb.DataCoordClient
dataCoordMu sync.RWMutex
queryNodes map[int64]querypb.QueryNodeClient
queryNodesMu sync.RWMutex
dataNodes map[int64]datapb.DataNodeClient
dataNodesMu sync.RWMutex
indexNodes map[int64]workerpb.IndexNodeClient
indexNodesMu sync.RWMutex
taskMu sync.RWMutex
buildTasks map[int64]*buildClientTask
notify chan int64
connMu sync.RWMutex
connections map[int64]*grpc.ClientConn
closeCh chan struct{}
}
// NewConnectionManager creates a new connection manager.
func NewConnectionManager(session *sessionutil.Session) *ConnectionManager {
return &ConnectionManager{
session: session,
dependencies: make(map[string]struct{}),
queryNodes: make(map[int64]querypb.QueryNodeClient),
dataNodes: make(map[int64]datapb.DataNodeClient),
indexNodes: make(map[int64]workerpb.IndexNodeClient),
buildTasks: make(map[int64]*buildClientTask),
notify: make(chan int64),
connections: make(map[int64]*grpc.ClientConn),
}
}
// AddDependency add a dependency by role name.
func (cm *ConnectionManager) AddDependency(roleName string) error {
if !cm.checkroleName(roleName) {
return errors.New("roleName is illegal")
}
log := log.Ctx(context.TODO())
_, ok := cm.dependencies[roleName]
if ok {
log.Warn("Dependency is already added", zap.String("roleName", roleName))
return nil
}
cm.dependencies[roleName] = struct{}{}
msess, rev, err := cm.session.GetSessions(context.TODO(), roleName)
if err != nil {
log.Debug("ClientManager GetSessions failed", zap.String("roleName", roleName))
return err
}
if len(msess) == 0 {
log.Debug("No nodes are currently alive", zap.String("roleName", roleName))
} else {
for _, value := range msess {
cm.buildConnections(value)
}
}
watcher := cm.session.WatchServices(roleName, rev, nil)
go cm.processEvent(watcher.EventChannel())
return nil
}
func (cm *ConnectionManager) Start() {
go cm.receiveFinishTask()
}
func (cm *ConnectionManager) GetRootCoordClient() (rootcoordpb.RootCoordClient, bool) {
cm.rootCoordMu.RLock()
defer cm.rootCoordMu.RUnlock()
_, ok := cm.dependencies[typeutil.RootCoordRole]
if !ok {
log.Ctx(context.TODO()).Error("RootCoord dependency has not been added yet")
return nil, false
}
return cm.rootCoord, true
}
func (cm *ConnectionManager) GetQueryCoordClient() (querypb.QueryCoordClient, bool) {
cm.queryCoordMu.RLock()
defer cm.queryCoordMu.RUnlock()
_, ok := cm.dependencies[typeutil.QueryCoordRole]
if !ok {
log.Ctx(context.TODO()).Error("QueryCoord dependency has not been added yet")
return nil, false
}
return cm.queryCoord, true
}
func (cm *ConnectionManager) GetDataCoordClient() (datapb.DataCoordClient, bool) {
cm.dataCoordMu.RLock()
defer cm.dataCoordMu.RUnlock()
_, ok := cm.dependencies[typeutil.DataCoordRole]
if !ok {
log.Ctx(context.TODO()).Error("DataCoord dependency has not been added yet")
return nil, false
}
return cm.dataCoord, true
}
func (cm *ConnectionManager) GetQueryNodeClients() ([]lo.Tuple2[int64, querypb.QueryNodeClient], bool) {
cm.queryNodesMu.RLock()
defer cm.queryNodesMu.RUnlock()
_, ok := cm.dependencies[typeutil.QueryNodeRole]
if !ok {
log.Ctx(context.TODO()).Error("QueryNode dependency has not been added yet")
return nil, false
}
nodes := lo.MapToSlice(cm.queryNodes, func(id int64, client querypb.QueryNodeClient) lo.Tuple2[int64, querypb.QueryNodeClient] {
return lo.Tuple2[int64, querypb.QueryNodeClient]{A: id, B: client}
})
return nodes, true
}
func (cm *ConnectionManager) GetDataNodeClients() ([]lo.Tuple2[int64, datapb.DataNodeClient], bool) {
cm.dataNodesMu.RLock()
defer cm.dataNodesMu.RUnlock()
_, ok := cm.dependencies[typeutil.DataNodeRole]
if !ok {
log.Ctx(context.TODO()).Error("DataNode dependency has not been added yet")
return nil, false
}
return lo.MapToSlice(cm.dataNodes, func(id int64, client datapb.DataNodeClient) lo.Tuple2[int64, datapb.DataNodeClient] {
return lo.Tuple2[int64, datapb.DataNodeClient]{A: id, B: client}
}), true
}
func (cm *ConnectionManager) GetIndexNodeClients() ([]lo.Tuple2[int64, workerpb.IndexNodeClient], bool) {
cm.indexNodesMu.RLock()
defer cm.indexNodesMu.RUnlock()
_, ok := cm.dependencies[typeutil.IndexNodeRole]
if !ok {
log.Ctx(context.TODO()).Error("IndexNode dependency has not been added yet")
return nil, false
}
return lo.MapToSlice(cm.indexNodes, func(id int64, client workerpb.IndexNodeClient) lo.Tuple2[int64, workerpb.IndexNodeClient] {
return lo.Tuple2[int64, workerpb.IndexNodeClient]{A: id, B: client}
}), true
}
func (cm *ConnectionManager) Stop() {
for _, task := range cm.buildTasks {
task.Stop()
}
close(cm.closeCh)
for _, conn := range cm.connections {
conn.Close()
}
}
// fix datarace in unittest
// startWatchService will only be invoked at start procedure
// otherwise, remove the annotation and add atomic protection
//
//go:norace
func (cm *ConnectionManager) processEvent(channel <-chan *sessionutil.SessionEvent) {
for {
select {
case _, ok := <-cm.closeCh:
if !ok {
return
}
case ev, ok := <-channel:
if !ok {
log.Ctx(context.TODO()).Error("watch service channel closed", zap.Int64("serverID", cm.session.ServerID))
go cm.Stop()
if cm.session.TriggerKill {
if p, err := os.FindProcess(os.Getpid()); err == nil {
p.Signal(syscall.SIGINT)
}
}
return
}
switch ev.EventType {
case sessionutil.SessionAddEvent:
log.Ctx(context.TODO()).Debug("ConnectionManager", zap.Any("add event", ev.Session))
cm.buildConnections(ev.Session)
case sessionutil.SessionDelEvent:
cm.removeTask(ev.Session.ServerID)
cm.removeConnection(ev.Session.ServerID)
}
}
}
}
func (cm *ConnectionManager) receiveFinishTask() {
log := log.Ctx(context.TODO())
for {
select {
case _, ok := <-cm.closeCh:
if !ok {
return
}
case serverID := <-cm.notify:
cm.taskMu.Lock()
task, ok := cm.buildTasks[serverID]
log.Debug("ConnectionManager", zap.Int64("receive finish", serverID))
if ok {
log.Debug("ConnectionManager", zap.Int64("get task ok", serverID))
log.Debug("ConnectionManager", zap.Any("task state", task.state))
if task.state == buildClientSuccess {
log.Debug("ConnectionManager", zap.Int64("build success", serverID))
cm.addConnection(task.sess.ServerID, task.result)
cm.buildClients(task.sess, task.result)
}
delete(cm.buildTasks, serverID)
}
cm.taskMu.Unlock()
}
}
}
func (cm *ConnectionManager) buildClients(session *sessionutil.Session, connection *grpc.ClientConn) {
switch session.ServerName {
case typeutil.RootCoordRole:
cm.rootCoordMu.Lock()
defer cm.rootCoordMu.Unlock()
cm.rootCoord = rootcoordpb.NewRootCoordClient(connection)
case typeutil.DataCoordRole:
cm.dataCoordMu.Lock()
defer cm.dataCoordMu.Unlock()
cm.dataCoord = datapb.NewDataCoordClient(connection)
case typeutil.QueryCoordRole:
cm.queryCoordMu.Lock()
defer cm.queryCoordMu.Unlock()
cm.queryCoord = querypb.NewQueryCoordClient(connection)
case typeutil.QueryNodeRole:
cm.queryNodesMu.Lock()
defer cm.queryNodesMu.Unlock()
cm.queryNodes[session.ServerID] = querypb.NewQueryNodeClient(connection)
case typeutil.DataNodeRole:
cm.dataNodesMu.Lock()
defer cm.dataNodesMu.Unlock()
cm.dataNodes[session.ServerID] = datapb.NewDataNodeClient(connection)
case typeutil.IndexNodeRole:
cm.indexNodesMu.Lock()
defer cm.indexNodesMu.Unlock()
cm.indexNodes[session.ServerID] = workerpb.NewIndexNodeClient(connection)
}
}
func (cm *ConnectionManager) buildConnections(session *sessionutil.Session) {
task := newBuildClientTask(session, cm.notify)
cm.addTask(session.ServerID, task)
task.Run()
}
func (cm *ConnectionManager) addConnection(id int64, conn *grpc.ClientConn) {
cm.connMu.Lock()
cm.connections[id] = conn
cm.connMu.Unlock()
}
func (cm *ConnectionManager) removeConnection(id int64) {
cm.connMu.Lock()
conn, ok := cm.connections[id]
if ok {
conn.Close()
delete(cm.connections, id)
}
cm.connMu.Unlock()
}
func (cm *ConnectionManager) addTask(id int64, task *buildClientTask) {
cm.taskMu.Lock()
cm.buildTasks[id] = task
cm.taskMu.Unlock()
}
func (cm *ConnectionManager) removeTask(id int64) {
cm.taskMu.Lock()
task, ok := cm.buildTasks[id]
if ok {
task.Stop()
delete(cm.buildTasks, id)
}
cm.taskMu.Unlock()
}
type buildConnectionstate int
const (
buildConnectionstart buildConnectionstate = iota
buildClientRunning
buildClientSuccess
buildClientFailed
)
type buildClientTask struct {
ctx context.Context
cancel context.CancelFunc
sess *sessionutil.Session
state buildConnectionstate
retryOptions []retry.Option
result *grpc.ClientConn
notify chan int64
}
func newBuildClientTask(session *sessionutil.Session, notify chan int64, retryOptions ...retry.Option) *buildClientTask {
ctx, cancel := context.WithCancel(context.Background())
return &buildClientTask{
ctx: ctx,
cancel: cancel,
sess: session,
retryOptions: retryOptions,
notify: notify,
}
}
func (bct *buildClientTask) Run() {
bct.state = buildClientRunning
go func() {
defer bct.finish()
connectGrpcFunc := func() error {
opts := tracer.GetInterceptorOpts()
log.Ctx(bct.ctx).Debug("Grpc connect", zap.String("Address", bct.sess.Address))
ctx, cancel := context.WithTimeout(bct.ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(ctx, bct.sess.Address,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
grpc.WithDisableRetry(),
grpc.WithUnaryInterceptor(
grpc_middleware.ChainUnaryClient(
grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(3),
grpc_retry.WithCodes(codes.Aborted, codes.Unavailable),
),
otelgrpc.UnaryClientInterceptor(opts...),
)),
grpc.WithStreamInterceptor(
grpc_middleware.ChainStreamClient(
grpc_retry.StreamClientInterceptor(
grpc_retry.WithMax(3),
grpc_retry.WithCodes(codes.Aborted, codes.Unavailable),
),
otelgrpc.StreamClientInterceptor(opts...),
)),
)
if err != nil {
return err
}
bct.result = conn
bct.state = buildClientSuccess
return nil
}
err := retry.Do(bct.ctx, connectGrpcFunc, bct.retryOptions...)
log.Ctx(bct.ctx).Debug("ConnectionManager", zap.Int64("build connection finish", bct.sess.ServerID))
if err != nil {
log.Ctx(bct.ctx).Debug("BuildClientTask try connect failed",
zap.String("roleName", bct.sess.ServerName), zap.Error(err))
bct.state = buildClientFailed
return
}
}()
}
func (bct *buildClientTask) Stop() {
bct.cancel()
}
func (bct *buildClientTask) finish() {
log.Ctx(bct.ctx).Debug("ConnectionManager", zap.Int64("notify connection finish", bct.sess.ServerID))
bct.notify <- bct.sess.ServerID
}
var roles = map[string]struct{}{
typeutil.RootCoordRole: {},
typeutil.QueryCoordRole: {},
typeutil.DataCoordRole: {},
typeutil.QueryNodeRole: {},
typeutil.DataNodeRole: {},
typeutil.IndexNodeRole: {},
}
func (cm *ConnectionManager) checkroleName(roleName string) bool {
_, ok := roles[roleName]
return ok
}

View File

@ -1,296 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package distributed
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"strings"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestMain(t *testing.M) {
// init embed etcd
embedetcdServer, tempDir, err := etcd.StartTestEmbedEtcdServer()
if err != nil {
log.Fatal("failed to start embed etcd server for unittest", zap.Error(err))
}
defer os.RemoveAll(tempDir)
defer embedetcdServer.Server.Stop()
addrs := etcd.GetEmbedEtcdEndpoints(embedetcdServer)
paramtable.Init()
paramtable.Get().Save(paramtable.Get().EtcdCfg.Endpoints.Key, strings.Join(addrs, ","))
os.Exit(t.Run())
}
func TestConnectionManager(t *testing.T) {
ctx := context.Background()
testPath := fmt.Sprintf("TestConnectionManager-%d", time.Now().Unix())
paramtable.Get().Save(paramtable.Get().EtcdCfg.RootPath.Key, testPath)
session := initSession(ctx)
cm := NewConnectionManager(session)
cm.AddDependency(typeutil.RootCoordRole)
cm.AddDependency(typeutil.QueryCoordRole)
cm.AddDependency(typeutil.DataCoordRole)
cm.AddDependency(typeutil.QueryNodeRole)
cm.AddDependency(typeutil.DataNodeRole)
cm.AddDependency(typeutil.IndexNodeRole)
cm.Start()
t.Run("rootCoord", func(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(t, err)
defer lis.Close()
rootCoord := &testRootCoord{}
grpcServer := grpc.NewServer()
defer grpcServer.Stop()
rootcoordpb.RegisterRootCoordServer(grpcServer, rootCoord)
go grpcServer.Serve(lis)
session.Init(typeutil.RootCoordRole, lis.Addr().String(), true, false)
session.Register()
assert.Eventually(t, func() bool {
rootCoord, ok := cm.GetRootCoordClient()
return rootCoord != nil && ok
}, 10*time.Second, 100*time.Millisecond)
})
t.Run("queryCoord", func(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(t, err)
defer lis.Close()
queryCoord := &testQueryCoord{}
grpcServer := grpc.NewServer()
defer grpcServer.Stop()
querypb.RegisterQueryCoordServer(grpcServer, queryCoord)
go grpcServer.Serve(lis)
session.Init(typeutil.QueryCoordRole, lis.Addr().String(), true, false)
session.Register()
assert.Eventually(t, func() bool {
queryCoord, ok := cm.GetQueryCoordClient()
return queryCoord != nil && ok
}, 10*time.Second, 100*time.Millisecond)
})
t.Run("dataCoord", func(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(t, err)
defer lis.Close()
dataCoord := &testDataCoord{}
grpcServer := grpc.NewServer()
defer grpcServer.Stop()
datapb.RegisterDataCoordServer(grpcServer, dataCoord)
go grpcServer.Serve(lis)
session.Init(typeutil.DataCoordRole, lis.Addr().String(), true, false)
session.Register()
assert.Eventually(t, func() bool {
dataCoord, ok := cm.GetDataCoordClient()
return dataCoord != nil && ok
}, 10*time.Second, 100*time.Millisecond)
})
t.Run("queryNode", func(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(t, err)
defer lis.Close()
queryNode := &testQueryNode{}
grpcServer := grpc.NewServer()
defer grpcServer.Stop()
querypb.RegisterQueryNodeServer(grpcServer, queryNode)
go grpcServer.Serve(lis)
session.Init(typeutil.QueryNodeRole, lis.Addr().String(), true, false)
session.Register()
assert.Eventually(t, func() bool {
queryNodes, ok := cm.GetQueryNodeClients()
return len(queryNodes) == 1 && ok
}, 10*time.Second, 100*time.Millisecond)
})
t.Run("dataNode", func(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(t, err)
defer lis.Close()
dataNode := &testDataNode{}
grpcServer := grpc.NewServer()
defer grpcServer.Stop()
datapb.RegisterDataNodeServer(grpcServer, dataNode)
go grpcServer.Serve(lis)
session.Init(typeutil.DataNodeRole, lis.Addr().String(), true, false)
session.Register()
assert.Eventually(t, func() bool {
dataNodes, ok := cm.GetDataNodeClients()
return len(dataNodes) == 1 && ok
}, 10*time.Second, 100*time.Millisecond)
})
t.Run("indexNode", func(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(t, err)
defer lis.Close()
indexNode := &testIndexNode{}
grpcServer := grpc.NewServer()
defer grpcServer.Stop()
workerpb.RegisterIndexNodeServer(grpcServer, indexNode)
go grpcServer.Serve(lis)
session.Init(typeutil.IndexNodeRole, lis.Addr().String(), true, false)
session.Register()
assert.Eventually(t, func() bool {
indexNodes, ok := cm.GetIndexNodeClients()
return len(indexNodes) == 1 && ok
}, 10*time.Second, 100*time.Millisecond)
})
}
func TestConnectionManager_processEvent(t *testing.T) {
t.Run("close closeCh", func(t *testing.T) {
cm := &ConnectionManager{
closeCh: make(chan struct{}),
}
ech := make(chan *sessionutil.SessionEvent)
flag := false
signal := make(chan struct{}, 1)
go func() {
assert.Panics(t, func() {
cm.processEvent(ech)
})
flag = true
signal <- struct{}{}
}()
close(ech)
<-signal
assert.True(t, flag)
ech = make(chan *sessionutil.SessionEvent)
flag = false
go func() {
cm.processEvent(ech)
flag = true
signal <- struct{}{}
}()
close(cm.closeCh)
<-signal
assert.True(t, flag)
})
t.Run("close watch chan", func(t *testing.T) {
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT)
defer signal.Reset(syscall.SIGINT)
sigQuit := make(chan struct{}, 1)
cm := &ConnectionManager{
closeCh: make(chan struct{}),
session: &sessionutil.Session{
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
TriggerKill: true,
},
},
}
ech := make(chan *sessionutil.SessionEvent)
go func() {
<-sc
sigQuit <- struct{}{}
}()
go func() {
cm.processEvent(ech)
}()
close(ech)
<-sigQuit
})
}
type testRootCoord struct {
rootcoordpb.RootCoordServer
}
type testQueryCoord struct {
querypb.QueryCoordServer
}
type testDataCoord struct {
datapb.DataCoordServer
}
type testQueryNode struct {
querypb.QueryNodeServer
}
type testDataNode struct {
datapb.DataNodeServer
}
type testIndexNode struct {
workerpb.IndexNodeServer
}
func initSession(ctx context.Context) *sessionutil.Session {
baseTable := paramtable.GetBaseTable()
rootPath, err := baseTable.Load("etcd.rootPath")
if err != nil {
panic(err)
}
subPath, err := baseTable.Load("etcd.metaSubPath")
if err != nil {
panic(err)
}
metaRootPath := rootPath + "/" + subPath
endpoints := baseTable.GetWithDefault("etcd.endpoints", paramtable.DefaultEtcdEndpoints)
etcdEndpoints := strings.Split(endpoints, ",")
log.Ctx(context.TODO()).Debug("metaRootPath", zap.Any("metaRootPath", metaRootPath))
log.Ctx(context.TODO()).Debug("etcdPoints", zap.Any("etcdPoints", etcdEndpoints))
etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints)
if err != nil {
panic(err)
}
session := sessionutil.NewSessionWithEtcd(ctx, metaRootPath, etcdCli)
return session
}

View File

@ -57,8 +57,17 @@ const (
LabelStandalone = "STANDALONE"
MilvusNodeIDForTesting = "MILVUS_NODE_ID_FOR_TESTING"
exitCodeSessionLeaseExpired = 1
serverVersionKey = "version"
)
var errSessionVersionCheckFailure = errors.New("session version check failure")
// isNotSessionVersionCheckFailure checks if the error is not a session version check failure.
func isNotSessionVersionCheckFailure(err error) bool {
return !errors.Is(err, errSessionVersionCheckFailure)
}
// EnableEmbededQueryNodeLabel set server labels for embedded query node.
func EnableEmbededQueryNodeLabel() {
os.Setenv(SupportedLabelPrefix+LabelStreamingNodeEmbeddedQueryNode, "1")
@ -169,6 +178,7 @@ type Session struct {
isStandby atomic.Value
enableActiveStandBy bool
activeKey string
versionKey string
sessionTTL int64
sessionRetryTimes int64
@ -300,6 +310,7 @@ func (s *Session) Init(serverName, address string, exclusive bool, triggerKill b
}
s.ServerID = serverID
s.ServerLabels = GetServerLabelsFromEnv(serverName)
s.versionKey = path.Join(s.metaRoot, DefaultServiceRoot, serverVersionKey)
s.SetLogger(log.With(
log.FieldComponent("service-registration"),
@ -325,6 +336,35 @@ func (s *Session) Register() {
s.startKeepAliveLoop()
}
// isCoordinator checks if the session needs to check the version.
func (s *Session) isCoordinator() bool {
return s.ServerName == typeutil.MixCoordRole ||
s.ServerName == typeutil.QueryCoordRole ||
s.ServerName == typeutil.DataCoordRole ||
s.ServerName == typeutil.RootCoordRole ||
s.ServerName == typeutil.IndexCoordRole
}
// checkVersion checks the version of the session and returns the error if the version is not found or expired.
func (s *Session) checkVersionForCoordinator() (*mvccpb.KeyValue, error) {
resp, err := s.etcdCli.Get(s.ctx, s.versionKey)
if err != nil {
return nil, err
}
if resp.Count <= 0 {
// no version key found.
return nil, nil
}
version, err := semver.Parse(string(resp.Kvs[0].Value))
if err != nil {
return nil, err
}
if common.Version.Major < version.Major || (common.Version.Major == version.Major && common.Version.Minor < version.Minor) {
return nil, errors.Wrapf(errSessionVersionCheckFailure, "current version(%s), session version(%s)", common.Version.String(), version.String())
}
return resp.Kvs[0], nil
}
var serverIDMu sync.Mutex
func (s *Session) getServerID() (int64, error) {
@ -462,54 +502,69 @@ func (s *Session) registerService() error {
return err
}
txnResp, err := s.etcdCli.Txn(s.ctx).If(
clientv3.Compare(
clientv3.Version(completeKey),
"=",
0)).
Then(clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID))).Commit()
compareOps := []clientv3.Cmp{
clientv3.Compare(clientv3.Version(completeKey), "=", 0),
}
ops := []clientv3.Op{
clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID)),
}
// if enable active-standby, we don't need to check the version now,
// only check the version when the standby is activated.
if s.isCoordinator() && !s.enableActiveStandBy {
if ops, compareOps, err = s.getOpsForCoordinator(ops, compareOps, sessionJSON); err != nil {
return err
}
}
txnResp, err := s.etcdCli.Txn(s.ctx).If(compareOps...).Then(ops...).Commit()
if err != nil {
s.Logger().Warn("register on etcd error, check the availability of etcd", zap.Error(err))
return err
}
if txnResp != nil && !txnResp.Succeeded {
s.handleRestart(completeKey)
return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", s.ServerName)
}
s.Logger().Info("put session key into etcd, service registered successfully", zap.String("key", completeKey), zap.String("value", string(sessionJSON)))
return nil
}
return retry.Do(s.ctx, registerFn, retry.Attempts(uint(s.sessionRetryTimes)))
return retry.Do(s.ctx, registerFn, retry.Attempts(uint(s.sessionRetryTimes)), retry.RetryErr(isNotSessionVersionCheckFailure))
}
// Handle restart is fast path to handle node restart.
// This should be only a fast path for coordinator
// If we find previous session have same address as current , simply purge the old one so the recovery can be much faster
func (s *Session) handleRestart(key string) {
resp, err := s.etcdCli.Get(s.ctx, key)
log := log.With(zap.String("key", key))
// getOpsForCoordinator gets the ops and compare ops for coordinator.
func (s *Session) getOpsForCoordinator(ops []clientv3.Op, compareOps []clientv3.Cmp, sessionJSON []byte) ([]clientv3.Op, []clientv3.Cmp, error) {
previousVersion, err := s.checkVersionForCoordinator()
if err != nil {
log.Warn("failed to read old session from etcd, ignore", zap.Error(err))
return
return nil, nil, err
}
for _, kv := range resp.Kvs {
session := &Session{}
err = json.Unmarshal(kv.Value, session)
expectedVersion := int64(0)
if previousVersion != nil {
expectedVersion = previousVersion.Version
}
legacyCoord := []string{
typeutil.QueryCoordRole,
typeutil.DataCoordRole,
typeutil.RootCoordRole,
}
for _, role := range legacyCoord {
key := path.Join(s.metaRoot, DefaultServiceRoot, role)
var newSession SessionRaw
if err := json.Unmarshal(sessionJSON, &newSession); err != nil {
return nil, nil, err
}
newSession.ServerName = role
newSessionJSON, err := json.Marshal(newSession)
if err != nil {
log.Warn("failed to unmarshal old session from etcd, ignore", zap.Error(err))
return
}
if session.Address == s.Address && session.ServerID < s.ServerID {
log.Warn("find old session is same as current node, assume it as restart, purge old session", zap.String("key", key),
zap.String("address", session.Address))
_, err := s.etcdCli.Delete(s.ctx, key)
if err != nil {
log.Warn("failed to unmarshal old session from etcd, ignore", zap.Error(err))
return
}
return nil, nil, err
}
ops = append(ops, clientv3.OpPut(key, string(newSessionJSON), clientv3.WithLease(*s.LeaseID)))
compareOps = append(compareOps, clientv3.Compare(clientv3.Version(key), "=", 0))
}
// promise the legacy coordinator version not available.
compareOps = append(compareOps, clientv3.Compare(clientv3.Version(s.versionKey), "=", expectedVersion))
// setup the version key if is a coordinator.
ops = append(ops, clientv3.OpPut(s.versionKey, common.Version.String()))
return ops, compareOps, nil
}
// processKeepAliveResponse processes the response of etcd keepAlive interface
@ -980,12 +1035,21 @@ func (s *Session) ProcessActiveStandBy(activateFunc func() error) error {
log.Error("json marshal error", zap.Error(err))
return false, -1, err
}
txnResp, err := s.etcdCli.Txn(s.ctx).If(
clientv3.Compare(
clientv3.Version(s.activeKey),
"=",
0)).
Then(clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID))).Commit()
compareOps := []clientv3.Cmp{
clientv3.Compare(clientv3.Version(s.activeKey), "=", 0),
}
ops := []clientv3.Op{
clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID)),
}
if s.isCoordinator() {
if ops, compareOps, err = s.getOpsForCoordinator(ops, compareOps, sessionJSON); err != nil {
return false, -1, err
}
}
txnResp, err := s.etcdCli.Txn(s.ctx).If(compareOps...).Then(ops...).Commit()
if err != nil {
log.Error("register active key to etcd failed", zap.Error(err))
return false, -1, err

View File

@ -771,6 +771,67 @@ func (s *SessionSuite) TestGetSessions() {
assert.Equal(s.T(), "value2", ret["key2"])
}
func (s *SessionSuite) TestVersionKey() {
ctx := context.Background()
session := NewSessionWithEtcd(ctx, s.metaRoot, s.client)
session.Init(typeutil.MixCoordRole, "normal", false, false)
session.Register()
resp, err := s.client.Get(ctx, session.versionKey)
s.Require().NoError(err)
s.Equal(1, len(resp.Kvs))
s.Equal(common.Version.String(), string(resp.Kvs[0].Value))
common.Version = semver.MustParse("2.5.6")
s.Panics(func() {
session2 := NewSessionWithEtcd(ctx, s.metaRoot, s.client)
session2.Init(typeutil.MixCoordRole, "normal", false, false)
session2.Register()
resp, err = s.client.Get(ctx, session2.versionKey)
s.Require().NoError(err)
s.Equal(1, len(resp.Kvs))
s.Equal(common.Version.String(), string(resp.Kvs[0].Value))
})
session.Stop()
common.Version = semver.MustParse("2.6.4")
session = NewSessionWithEtcd(ctx, s.metaRoot, s.client)
session.Init(typeutil.MixCoordRole, "normal", false, false)
session.Register()
resp, err = s.client.Get(ctx, session.versionKey)
s.Require().NoError(err)
s.Equal(1, len(resp.Kvs))
s.Equal(common.Version.String(), string(resp.Kvs[0].Value))
session.Stop()
common.Version = semver.MustParse("2.6.7")
session = NewSessionWithEtcd(ctx, s.metaRoot, s.client)
session.Init(typeutil.MixCoordRole, "normal", false, false)
session.Register()
resp, err = s.client.Get(ctx, session.versionKey)
s.Require().NoError(err)
s.Equal(1, len(resp.Kvs))
s.Equal(common.Version.String(), string(resp.Kvs[0].Value))
session.Stop()
common.Version = semver.MustParse("3.0.0")
session = NewSessionWithEtcd(ctx, s.metaRoot, s.client)
session.Init(typeutil.MixCoordRole, "normal", false, false)
session.Register()
resp, err = s.client.Get(ctx, session.versionKey)
s.Require().NoError(err)
s.Equal(1, len(resp.Kvs))
s.Equal(common.Version.String(), string(resp.Kvs[0].Value))
}
func (s *SessionSuite) TestSessionLifetime() {
ctx := context.Background()
session := NewSessionWithEtcd(ctx, s.metaRoot, s.client)

View File

@ -911,9 +911,12 @@ func TestFlushRate(t *testing.T) {
}
wg.Wait()
errCnt := 0
for _, err := range errs {
if err != nil {
common.CheckErr(t, err, false, "request is rejected by grpc RateLimiter middleware, please retry later: rate limit exceeded")
errCnt++
}
}
require.NotZero(t, errCnt)
}