Integration test framework (#21283)

Signed-off-by: wayblink <anyang.wang@zilliz.com>
This commit is contained in:
wayblink 2023-01-12 19:49:40 +08:00 committed by GitHub
parent 76c0292bca
commit 6a722396bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 3102 additions and 527 deletions

View File

@ -121,7 +121,7 @@ func TestServer_CreateIndex(t *testing.T) {
Value: "DISKANN", Value: "DISKANN",
}, },
} }
s.indexNodeManager = NewNodeManager(ctx) s.indexNodeManager = NewNodeManager(ctx, defaultIndexNodeCreatorFunc)
resp, err := s.CreateIndex(ctx, req) resp, err := s.CreateIndex(ctx, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode())

View File

@ -24,7 +24,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
grpcindexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
@ -34,19 +33,21 @@ import (
// IndexNodeManager is used to manage the client of IndexNode. // IndexNodeManager is used to manage the client of IndexNode.
type IndexNodeManager struct { type IndexNodeManager struct {
nodeClients map[UniqueID]types.IndexNode nodeClients map[UniqueID]types.IndexNode
stoppingNodes map[UniqueID]struct{} stoppingNodes map[UniqueID]struct{}
lock sync.RWMutex lock sync.RWMutex
ctx context.Context ctx context.Context
indexNodeCreator indexNodeCreatorFunc
} }
// NewNodeManager is used to create a new IndexNodeManager. // NewNodeManager is used to create a new IndexNodeManager.
func NewNodeManager(ctx context.Context) *IndexNodeManager { func NewNodeManager(ctx context.Context, indexNodeCreator indexNodeCreatorFunc) *IndexNodeManager {
return &IndexNodeManager{ return &IndexNodeManager{
nodeClients: make(map[UniqueID]types.IndexNode), nodeClients: make(map[UniqueID]types.IndexNode),
stoppingNodes: make(map[UniqueID]struct{}), stoppingNodes: make(map[UniqueID]struct{}),
lock: sync.RWMutex{}, lock: sync.RWMutex{},
ctx: ctx, ctx: ctx,
indexNodeCreator: indexNodeCreator,
} }
} }
@ -84,7 +85,7 @@ func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error {
err error err error
) )
nodeClient, err = grpcindexnodeclient.NewClient(context.TODO(), address, Params.DataCoordCfg.WithCredential.GetAsBool()) nodeClient, err = nm.indexNodeCreator(context.TODO(), address)
if err != nil { if err != nil {
log.Error("create IndexNode client fail", zap.Error(err)) log.Error("create IndexNode client fail", zap.Error(err))
return err return err

View File

@ -31,7 +31,7 @@ import (
) )
func TestIndexNodeManager_AddNode(t *testing.T) { func TestIndexNodeManager_AddNode(t *testing.T) {
nm := NewNodeManager(context.Background()) nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc)
nodeID, client := nm.PeekClient(&model.SegmentIndex{}) nodeID, client := nm.PeekClient(&model.SegmentIndex{})
assert.Equal(t, int64(-1), nodeID) assert.Equal(t, int64(-1), nodeID)
assert.Nil(t, client) assert.Nil(t, client)
@ -255,7 +255,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
} }
func TestNodeManager_StoppingNode(t *testing.T) { func TestNodeManager_StoppingNode(t *testing.T) {
nm := NewNodeManager(context.Background()) nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc)
err := nm.AddNode(1, "indexnode-1") err := nm.AddNode(1, "indexnode-1")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(nm.GetAllClients())) assert.Equal(t, 1, len(nm.GetAllClients()))

View File

@ -229,6 +229,18 @@ func newMockDataNodeClient(id int64, ch chan interface{}) (*mockDataNodeClient,
}, nil }, nil
} }
type mockIndexNodeClient struct {
id int64
state commonpb.StateCode
}
func newMockIndexNodeClient(id int64) (*mockIndexNodeClient, error) {
return &mockIndexNodeClient{
id: id,
state: commonpb.StateCode_Initializing,
}, nil
}
func (c *mockDataNodeClient) Init() error { func (c *mockDataNodeClient) Init() error {
return nil return nil
} }
@ -417,7 +429,7 @@ func (m *mockRootCoordService) GetStatisticsChannel(ctx context.Context) (*milvu
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
//DDL request // DDL request
func (m *mockRootCoordService) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { func (m *mockRootCoordService) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
@ -489,7 +501,7 @@ func (m *mockRootCoordService) ShowPartitionsInternal(ctx context.Context, req *
return m.ShowPartitions(ctx, req) return m.ShowPartitions(ctx, req)
} }
//global timestamp allocator // global timestamp allocator
func (m *mockRootCoordService) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { func (m *mockRootCoordService) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) {
if m.state != commonpb.StateCode_Healthy { if m.state != commonpb.StateCode_Healthy {
return &rootcoordpb.AllocTimestampResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil return &rootcoordpb.AllocTimestampResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil
@ -523,7 +535,7 @@ func (m *mockRootCoordService) AllocID(ctx context.Context, req *rootcoordpb.All
}, nil }, nil
} }
//segment // segment
func (m *mockRootCoordService) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { func (m *mockRootCoordService) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client"
indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
@ -79,6 +80,7 @@ type (
) )
type dataNodeCreatorFunc func(ctx context.Context, addr string) (types.DataNode, error) type dataNodeCreatorFunc func(ctx context.Context, addr string) (types.DataNode, error)
type indexNodeCreatorFunc func(ctx context.Context, addr string) (types.IndexNode, error)
type rootCoordCreatorFunc func(ctx context.Context, metaRootPath string, etcdClient *clientv3.Client) (types.RootCoord, error) type rootCoordCreatorFunc func(ctx context.Context, metaRootPath string, etcdClient *clientv3.Client) (types.RootCoord, error)
// makes sure Server implements `DataCoord` // makes sure Server implements `DataCoord`
@ -131,6 +133,7 @@ type Server struct {
activateFunc func() activateFunc func()
dataNodeCreator dataNodeCreatorFunc dataNodeCreator dataNodeCreatorFunc
indexNodeCreator indexNodeCreatorFunc
rootCoordClientCreator rootCoordCreatorFunc rootCoordClientCreator rootCoordCreatorFunc
//indexCoord types.IndexCoord //indexCoord types.IndexCoord
@ -153,36 +156,36 @@ func defaultServerHelper() ServerHelper {
// Option utility function signature to set DataCoord server attributes // Option utility function signature to set DataCoord server attributes
type Option func(svr *Server) type Option func(svr *Server)
// SetRootCoordCreator returns an `Option` setting RootCoord creator with provided parameter // WithRootCoordCreator returns an `Option` setting RootCoord creator with provided parameter
func SetRootCoordCreator(creator rootCoordCreatorFunc) Option { func WithRootCoordCreator(creator rootCoordCreatorFunc) Option {
return func(svr *Server) { return func(svr *Server) {
svr.rootCoordClientCreator = creator svr.rootCoordClientCreator = creator
} }
} }
// SetServerHelper returns an `Option` setting ServerHelp with provided parameter // WithServerHelper returns an `Option` setting ServerHelp with provided parameter
func SetServerHelper(helper ServerHelper) Option { func WithServerHelper(helper ServerHelper) Option {
return func(svr *Server) { return func(svr *Server) {
svr.helper = helper svr.helper = helper
} }
} }
// SetCluster returns an `Option` setting Cluster with provided parameter // WithCluster returns an `Option` setting Cluster with provided parameter
func SetCluster(cluster *Cluster) Option { func WithCluster(cluster *Cluster) Option {
return func(svr *Server) { return func(svr *Server) {
svr.cluster = cluster svr.cluster = cluster
} }
} }
// SetDataNodeCreator returns an `Option` setting DataNode create function // WithDataNodeCreator returns an `Option` setting DataNode create function
func SetDataNodeCreator(creator dataNodeCreatorFunc) Option { func WithDataNodeCreator(creator dataNodeCreatorFunc) Option {
return func(svr *Server) { return func(svr *Server) {
svr.dataNodeCreator = creator svr.dataNodeCreator = creator
} }
} }
// SetSegmentManager returns an Option to set SegmentManager // WithSegmentManager returns an Option to set SegmentManager
func SetSegmentManager(manager Manager) Option { func WithSegmentManager(manager Manager) Option {
return func(svr *Server) { return func(svr *Server) {
svr.segmentManager = manager svr.segmentManager = manager
} }
@ -199,6 +202,7 @@ func CreateServer(ctx context.Context, factory dependency.Factory, opts ...Optio
buildIndexCh: make(chan UniqueID, 1024), buildIndexCh: make(chan UniqueID, 1024),
notifyIndexChan: make(chan UniqueID), notifyIndexChan: make(chan UniqueID),
dataNodeCreator: defaultDataNodeCreatorFunc, dataNodeCreator: defaultDataNodeCreatorFunc,
indexNodeCreator: defaultIndexNodeCreatorFunc,
rootCoordClientCreator: defaultRootCoordCreatorFunc, rootCoordClientCreator: defaultRootCoordCreatorFunc,
helper: defaultServerHelper(), helper: defaultServerHelper(),
metricsCacheManager: metricsinfo.NewMetricsCacheManager(), metricsCacheManager: metricsinfo.NewMetricsCacheManager(),
@ -215,6 +219,10 @@ func defaultDataNodeCreatorFunc(ctx context.Context, addr string) (types.DataNod
return datanodeclient.NewClient(ctx, addr) return datanodeclient.NewClient(ctx, addr)
} }
func defaultIndexNodeCreatorFunc(ctx context.Context, addr string) (types.IndexNode, error) {
return indexnodeclient.NewClient(context.TODO(), addr, Params.DataCoordCfg.WithCredential.GetAsBool())
}
func defaultRootCoordCreatorFunc(ctx context.Context, metaRootPath string, client *clientv3.Client) (types.RootCoord, error) { func defaultRootCoordCreatorFunc(ctx context.Context, metaRootPath string, client *clientv3.Client) (types.RootCoord, error) {
return rootcoordclient.NewClient(ctx, metaRootPath, client) return rootcoordclient.NewClient(ctx, metaRootPath, client)
} }
@ -374,6 +382,18 @@ func (s *Server) SetEtcdClient(client *clientv3.Client) {
s.etcdCli = client s.etcdCli = client
} }
func (s *Server) SetRootCoord(rootCoord types.RootCoord) {
s.rootCoordClient = rootCoord
}
func (s *Server) SetDataNodeCreator(f func(context.Context, string) (types.DataNode, error)) {
s.dataNodeCreator = f
}
func (s *Server) SetIndexNodeCreator(f func(context.Context, string) (types.IndexNode, error)) {
s.indexNodeCreator = f
}
func (s *Server) createCompactionHandler() { func (s *Server) createCompactionHandler() {
s.compactionHandler = newCompactionPlanHandler(s.sessionManager, s.channelManager, s.meta, s.allocator, s.flushCh) s.compactionHandler = newCompactionPlanHandler(s.sessionManager, s.channelManager, s.meta, s.allocator, s.flushCh)
} }
@ -465,6 +485,9 @@ func (s *Server) initSegmentManager() {
} }
func (s *Server) initMeta(chunkManager storage.ChunkManager) error { func (s *Server) initMeta(chunkManager storage.ChunkManager) error {
if s.meta != nil {
return nil
}
etcdKV := etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) etcdKV := etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue())
s.kvClient = etcdKV s.kvClient = etcdKV
@ -488,7 +511,7 @@ func (s *Server) initIndexBuilder(manager storage.ChunkManager) {
func (s *Server) initIndexNodeManager() { func (s *Server) initIndexNodeManager() {
if s.indexNodeManager == nil { if s.indexNodeManager == nil {
s.indexNodeManager = NewNodeManager(s.ctx) s.indexNodeManager = NewNodeManager(s.ctx, s.indexNodeCreator)
} }
} }
@ -834,8 +857,10 @@ func (s *Server) handleFlushingSegments(ctx context.Context) {
func (s *Server) initRootCoordClient() error { func (s *Server) initRootCoordClient() error {
var err error var err error
if s.rootCoordClient, err = s.rootCoordClientCreator(s.ctx, Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli); err != nil { if s.rootCoordClient == nil {
return err if s.rootCoordClient, err = s.rootCoordClientCreator(s.ctx, Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli); err != nil {
return err
}
} }
if err = s.rootCoordClient.Init(); err != nil { if err = s.rootCoordClient.Init(); err != nil {
return err return err

View File

@ -40,6 +40,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
@ -1302,7 +1303,7 @@ func TestSaveBinlogPaths(t *testing.T) {
/* /*
t.Run("test save dropped segment and remove channel", func(t *testing.T) { t.Run("test save dropped segment and remove channel", func(t *testing.T) {
spyCh := make(chan struct{}, 1) spyCh := make(chan struct{}, 1)
svr := newTestServer(t, nil, SetSegmentManager(&spySegmentManager{spyCh: spyCh})) svr := newTestServer(t, nil, WithSegmentManager(&spySegmentManager{spyCh: spyCh}))
defer closeTestServer(t, svr) defer closeTestServer(t, svr)
svr.meta.AddCollection(&collectionInfo{ID: 1}) svr.meta.AddCollection(&collectionInfo{ID: 1})
@ -1333,7 +1334,7 @@ func TestSaveBinlogPaths(t *testing.T) {
func TestDropVirtualChannel(t *testing.T) { func TestDropVirtualChannel(t *testing.T) {
t.Run("normal DropVirtualChannel", func(t *testing.T) { t.Run("normal DropVirtualChannel", func(t *testing.T) {
spyCh := make(chan struct{}, 1) spyCh := make(chan struct{}, 1)
svr := newTestServer(t, nil, SetSegmentManager(&spySegmentManager{spyCh: spyCh})) svr := newTestServer(t, nil, WithSegmentManager(&spySegmentManager{spyCh: spyCh}))
defer closeTestServer(t, svr) defer closeTestServer(t, svr)
@ -1668,7 +1669,7 @@ func TestDataNodeTtChannel(t *testing.T) {
helper := ServerHelper{ helper := ServerHelper{
eventAfterHandleDataNodeTt: func() { ch <- struct{}{} }, eventAfterHandleDataNodeTt: func() { ch <- struct{}{} },
} }
svr := newTestServer(t, nil, SetServerHelper(helper)) svr := newTestServer(t, nil, WithServerHelper(helper))
defer closeTestServer(t, svr) defer closeTestServer(t, svr)
svr.meta.AddCollection(&collectionInfo{ svr.meta.AddCollection(&collectionInfo{
@ -2835,13 +2836,13 @@ func TestOptions(t *testing.T) {
kv.Close() kv.Close()
}() }()
t.Run("SetRootCoordCreator", func(t *testing.T) { t.Run("WithRootCoordCreator", func(t *testing.T) {
svr := newTestServer(t, nil) svr := newTestServer(t, nil)
defer closeTestServer(t, svr) defer closeTestServer(t, svr)
var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoord, error) { var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoord, error) {
return nil, errors.New("dummy") return nil, errors.New("dummy")
} }
opt := SetRootCoordCreator(crt) opt := WithRootCoordCreator(crt)
assert.NotNil(t, opt) assert.NotNil(t, opt)
svr.rootCoordClientCreator = nil svr.rootCoordClientCreator = nil
opt(svr) opt(svr)
@ -2850,7 +2851,7 @@ func TestOptions(t *testing.T) {
assert.NotNil(t, crt) assert.NotNil(t, crt)
assert.NotNil(t, svr.rootCoordClientCreator) assert.NotNil(t, svr.rootCoordClientCreator)
}) })
t.Run("SetCluster", func(t *testing.T) { t.Run("WithCluster", func(t *testing.T) {
defer kv.RemoveWithPrefix("") defer kv.RemoveWithPrefix("")
sessionManager := NewSessionManager() sessionManager := NewSessionManager()
@ -2859,17 +2860,17 @@ func TestOptions(t *testing.T) {
cluster := NewCluster(sessionManager, channelManager) cluster := NewCluster(sessionManager, channelManager)
assert.Nil(t, err) assert.Nil(t, err)
opt := SetCluster(cluster) opt := WithCluster(cluster)
assert.NotNil(t, opt) assert.NotNil(t, opt)
svr := newTestServer(t, nil, opt) svr := newTestServer(t, nil, opt)
defer closeTestServer(t, svr) defer closeTestServer(t, svr)
assert.Same(t, cluster, svr.cluster) assert.Same(t, cluster, svr.cluster)
}) })
t.Run("SetDataNodeCreator", func(t *testing.T) { t.Run("WithDataNodeCreator", func(t *testing.T) {
var target int64 var target int64
var val = rand.Int63() var val = rand.Int63()
opt := SetDataNodeCreator(func(context.Context, string) (types.DataNode, error) { opt := WithDataNodeCreator(func(context.Context, string) (types.DataNode, error) {
target = val target = val
return nil, nil return nil, nil
}) })
@ -2918,7 +2919,7 @@ func TestHandleSessionEvent(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer cluster.Close() defer cluster.Close()
svr := newTestServer(t, nil, SetCluster(cluster)) svr := newTestServer(t, nil, WithCluster(cluster))
defer closeTestServer(t, svr) defer closeTestServer(t, svr)
t.Run("handle events", func(t *testing.T) { t.Run("handle events", func(t *testing.T) {
// None event // None event
@ -3779,6 +3780,7 @@ func testDataCoordBase(t *testing.T, opts ...Option) *Server {
paramtable.Get().Save(Params.CommonCfg.DataCoordTimeTick.Key, Params.CommonCfg.DataCoordTimeTick.GetValue()+strconv.Itoa(rand.Int())) paramtable.Get().Save(Params.CommonCfg.DataCoordTimeTick.Key, Params.CommonCfg.DataCoordTimeTick.GetValue()+strconv.Itoa(rand.Int()))
factory := dependency.NewDefaultFactory(true) factory := dependency.NewDefaultFactory(true)
ctx := context.Background()
etcdCli, err := etcd.GetEtcdClient( etcdCli, err := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
Params.EtcdCfg.EtcdUseSSL.GetAsBool(), Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
@ -3789,17 +3791,18 @@ func testDataCoordBase(t *testing.T, opts ...Option) *Server {
Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
assert.Nil(t, err) assert.Nil(t, err)
sessKey := path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) sessKey := path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot)
_, err = etcdCli.Delete(context.Background(), sessKey, clientv3.WithPrefix()) _, err = etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix())
assert.Nil(t, err) assert.Nil(t, err)
svr := CreateServer(context.TODO(), factory, opts...) svr := CreateServer(ctx, factory, opts...)
svr.SetEtcdClient(etcdCli) svr.SetEtcdClient(etcdCli)
svr.dataNodeCreator = func(ctx context.Context, addr string) (types.DataNode, error) { svr.SetDataNodeCreator(func(ctx context.Context, addr string) (types.DataNode, error) {
return newMockDataNodeClient(0, nil) return newMockDataNodeClient(0, nil)
} })
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { svr.SetIndexNodeCreator(func(ctx context.Context, addr string) (types.IndexNode, error) {
return newMockRootCoordService(), nil return indexnode.NewMockIndexNodeComponent(ctx)
} })
svr.SetRootCoord(newMockRootCoordService())
err = svr.Init() err = svr.Init()
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -117,6 +117,9 @@ type DataNode struct {
rootCoord types.RootCoord rootCoord types.RootCoord
dataCoord types.DataCoord dataCoord types.DataCoord
//call once
initOnce sync.Once
sessionMu sync.Mutex // to fix data race
session *sessionutil.Session session *sessionutil.Session
watchKv kv.MetaKv watchKv kv.MetaKv
chunkManager storage.ChunkManager chunkManager storage.ChunkManager
@ -153,6 +156,10 @@ func (node *DataNode) SetAddress(address string) {
node.address = address node.address = address
} }
func (node *DataNode) GetAddress() string {
return node.address
}
// SetEtcdClient sets etcd client for DataNode // SetEtcdClient sets etcd client for DataNode
func (node *DataNode) SetEtcdClient(etcdCli *clientv3.Client) { func (node *DataNode) SetEtcdClient(etcdCli *clientv3.Client) {
node.etcdCli = etcdCli node.etcdCli = etcdCli
@ -186,7 +193,7 @@ func (node *DataNode) Register() error {
// Start liveness check // Start liveness check
go node.session.LivenessCheck(node.ctx, func() { go node.session.LivenessCheck(node.ctx, func() {
log.Error("Data Node disconnected from etcd, process will exit", zap.Int64("Server Id", node.session.ServerID)) log.Error("Data Node disconnected from etcd, process will exit", zap.Int64("Server Id", node.GetSession().ServerID))
if err := node.Stop(); err != nil { if err := node.Stop(); err != nil {
log.Fatal("failed to stop server", zap.Error(err)) log.Fatal("failed to stop server", zap.Error(err))
} }
@ -222,37 +229,42 @@ func (node *DataNode) initRateCollector() error {
return nil return nil
} }
// Init function does nothing now.
func (node *DataNode) Init() error { func (node *DataNode) Init() error {
log.Info("DataNode server initializing", var initError error
zap.String("TimeTickChannelName", Params.CommonCfg.DataCoordTimeTick.GetValue()), node.initOnce.Do(func() {
) logutil.Logger(node.ctx).Info("DataNode server initializing",
if err := node.initSession(); err != nil { zap.String("TimeTickChannelName", Params.CommonCfg.DataCoordTimeTick.GetValue()),
log.Error("DataNode server init session failed", zap.Error(err)) )
return err if err := node.initSession(); err != nil {
} log.Error("DataNode server init session failed", zap.Error(err))
initError = err
return
}
err := node.initRateCollector() err := node.initRateCollector()
if err != nil { if err != nil {
log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", paramtable.GetNodeID()), zap.Error(err)) log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", paramtable.GetNodeID()), zap.Error(err))
return err initError = err
} return
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID())) }
log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID()))
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID()) idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, paramtable.GetNodeID())
if err != nil { if err != nil {
log.Error("failed to create id allocator", log.Error("failed to create id allocator",
zap.Error(err), zap.Error(err),
zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", paramtable.GetNodeID())) zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", paramtable.GetNodeID()))
return err initError = err
} return
node.rowIDAllocator = idAllocator }
node.rowIDAllocator = idAllocator
node.factory.Init(Params) node.factory.Init(Params)
log.Info("DataNode server init succeeded", log.Info("DataNode server init succeeded",
zap.String("MsgChannelSubName", Params.CommonCfg.DataNodeSubName.GetValue())) zap.String("MsgChannelSubName", Params.CommonCfg.DataNodeSubName.GetValue()))
return nil })
return initError
} }
// StartWatchChannels start loop to watch channel allocation status via kv(etcd for now) // StartWatchChannels start loop to watch channel allocation status via kv(etcd for now)
@ -260,7 +272,8 @@ func (node *DataNode) StartWatchChannels(ctx context.Context) {
defer logutil.LogPanic() defer logutil.LogPanic()
// REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name}
// TODO, this is risky, we'd better watch etcd with revision rather simply a path // TODO, this is risky, we'd better watch etcd with revision rather simply a path
watchPrefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID())) watchPrefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID))
log.Info("Start watch channel", zap.String("prefix", watchPrefix))
evtChan := node.watchKv.WatchWithPrefix(watchPrefix) evtChan := node.watchKv.WatchWithPrefix(watchPrefix)
// after watch, first check all exists nodes first // after watch, first check all exists nodes first
err := node.checkWatchedList() err := node.checkWatchedList()
@ -412,7 +425,7 @@ func (node *DataNode) handlePutEvent(watchInfo *datapb.ChannelWatchInfo, version
return fmt.Errorf("fail to marshal watchInfo with state, vChanName: %s, state: %s ,err: %w", vChanName, watchInfo.State.String(), err) return fmt.Errorf("fail to marshal watchInfo with state, vChanName: %s, state: %s ,err: %w", vChanName, watchInfo.State.String(), err)
} }
key := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID()), vChanName) key := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID), vChanName)
success, err := node.watchKv.CompareVersionAndSwap(key, version, string(v)) success, err := node.watchKv.CompareVersionAndSwap(key, version, string(v))
// etcd error, retrying // etcd error, retrying
@ -558,3 +571,17 @@ func (node *DataNode) Stop() error {
return nil return nil
} }
// to fix data race
func (node *DataNode) SetSession(session *sessionutil.Session) {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
node.session = session
}
// to fix data race
func (node *DataNode) GetSession() *sessionutil.Session {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
return node.session
}

View File

@ -102,6 +102,9 @@ func TestDataNode(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
err = node.Start() err = node.Start()
assert.Nil(t, err) assert.Nil(t, err)
assert.Empty(t, node.GetAddress())
node.SetAddress("address")
assert.Equal(t, "address", node.GetAddress())
defer node.Stop() defer node.Stop()
node.chunkManager = storage.NewLocalChunkManager(storage.RootPath("/tmp/milvus_test/datanode")) node.chunkManager = storage.NewLocalChunkManager(storage.RootPath("/tmp/milvus_test/datanode"))
@ -155,7 +158,7 @@ func TestDataNode(t *testing.T) {
t.Run("Test getSystemInfoMetrics", func(t *testing.T) { t.Run("Test getSystemInfoMetrics", func(t *testing.T) {
emptyNode := &DataNode{} emptyNode := &DataNode{}
emptyNode.session = &sessionutil.Session{ServerID: 1} emptyNode.SetSession(&sessionutil.Session{ServerID: 1})
emptyNode.flowgraphManager = newFlowgraphManager() emptyNode.flowgraphManager = newFlowgraphManager()
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics)
@ -170,7 +173,7 @@ func TestDataNode(t *testing.T) {
t.Run("Test getSystemInfoMetrics with quotaMetric error", func(t *testing.T) { t.Run("Test getSystemInfoMetrics with quotaMetric error", func(t *testing.T) {
emptyNode := &DataNode{} emptyNode := &DataNode{}
emptyNode.session = &sessionutil.Session{ServerID: 1} emptyNode.SetSession(&sessionutil.Session{ServerID: 1})
emptyNode.flowgraphManager = newFlowgraphManager() emptyNode.flowgraphManager = newFlowgraphManager()
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics)

View File

@ -61,6 +61,7 @@ type dataSyncService struct {
chunkManager storage.ChunkManager chunkManager storage.ChunkManager
compactor *compactionExecutor // reference to compaction executor compactor *compactionExecutor // reference to compaction executor
serverID int64
stopOnce sync.Once stopOnce sync.Once
flushListener chan *segmentFlushPack // chan to listen flush event flushListener chan *segmentFlushPack // chan to listen flush event
} }
@ -77,6 +78,7 @@ func newDataSyncService(ctx context.Context,
flushingSegCache *Cache, flushingSegCache *Cache,
chunkManager storage.ChunkManager, chunkManager storage.ChunkManager,
compactor *compactionExecutor, compactor *compactionExecutor,
serverID int64,
) (*dataSyncService, error) { ) (*dataSyncService, error) {
if channel == nil { if channel == nil {
@ -108,6 +110,7 @@ func newDataSyncService(ctx context.Context,
flushingSegCache: flushingSegCache, flushingSegCache: flushingSegCache,
chunkManager: chunkManager, chunkManager: chunkManager,
compactor: compactor, compactor: compactor,
serverID: serverID,
} }
if err := service.initNodes(vchan); err != nil { if err := service.initNodes(vchan); err != nil {
@ -127,7 +130,7 @@ type nodeConfig struct {
vChannelName string vChannelName string
channel Channel // Channel info channel Channel // Channel info
allocator allocatorInterface allocator allocatorInterface
serverID int64
// defaults // defaults
parallelConfig parallelConfig
} }
@ -280,6 +283,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo) erro
allocator: dsService.idAllocator, allocator: dsService.idAllocator,
parallelConfig: newParallelConfig(), parallelConfig: newParallelConfig(),
serverID: dsService.serverID,
} }
var dmStreamNode Node var dmStreamNode Node

View File

@ -172,6 +172,7 @@ func TestDataSyncService_newDataSyncService(te *testing.T) {
newCache(), newCache(),
cm, cm,
newCompactionExecutor(), newCompactionExecutor(),
0,
) )
if !test.isValidCase { if !test.isValidCase {
@ -269,7 +270,7 @@ func TestDataSyncService_Start(t *testing.T) {
}, },
} }
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor()) sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err) assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack) sync.flushListener = make(chan *segmentFlushPack)
@ -424,7 +425,7 @@ func TestDataSyncService_Close(t *testing.T) {
}, },
} }
sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor()) sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, allocFactory, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), 0)
assert.Nil(t, err) assert.Nil(t, err)
sync.flushListener = make(chan *segmentFlushPack, 10) sync.flushListener = make(chan *segmentFlushPack, 10)

View File

@ -681,7 +681,7 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De
commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt),
commonpbutil.WithMsgID(0), commonpbutil.WithMsgID(0),
commonpbutil.WithTimeStamp(ts), commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(config.serverID),
), ),
ChannelName: config.vChannelName, ChannelName: config.vChannelName,
Timestamp: ts, Timestamp: ts,

View File

@ -48,7 +48,7 @@ func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo
var alloc allocatorInterface = newAllocator(dn.rootCoord) var alloc allocatorInterface = newAllocator(dn.rootCoord)
dataSyncService, err := newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel, dataSyncService, err := newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel,
alloc, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor) alloc, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, dn.GetSession().ServerID)
if err != nil { if err != nil {
log.Warn("new data sync service fail", zap.String("vChannelName", vchan.GetChannelName()), zap.Error(err)) log.Warn("new data sync service fail", zap.String("vChannelName", vchan.GetChannelName()), zap.Error(err))
return err return err

View File

@ -801,7 +801,7 @@ func flushNotifyFunc(dsService *dataSyncService, opts ...retry.Option) notifyMet
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(0), commonpbutil.WithMsgType(0),
commonpbutil.WithMsgID(0), commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(dsService.serverID),
), ),
SegmentID: pack.segmentID, SegmentID: pack.segmentID,
CollectionID: dsService.collectionID, CollectionID: dsService.collectionID,

View File

@ -94,7 +94,7 @@ func (node *DataNode) getSystemInfoMetrics(ctx context.Context, req *milvuspb.Ge
CreatedTime: paramtable.GetCreateTime().String(), CreatedTime: paramtable.GetCreateTime().String(),
UpdatedTime: paramtable.GetUpdateTime().String(), UpdatedTime: paramtable.GetUpdateTime().String(),
Type: typeutil.DataNodeRole, Type: typeutil.DataNodeRole,
ID: node.session.ServerID, ID: node.GetSession().ServerID,
}, },
SystemConfigurations: metricsinfo.DataNodeConfiguration{ SystemConfigurations: metricsinfo.DataNodeConfiguration{
FlushInsertBufferSize: Params.DataNodeCfg.FlushInsertBufferSize.GetAsInt64(), FlushInsertBufferSize: Params.DataNodeCfg.FlushInsertBufferSize.GetAsInt64(),

View File

@ -37,6 +37,7 @@ import (
s "github.com/milvus-io/milvus/internal/storage" s "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
@ -79,6 +80,7 @@ var emptyFlushAndDropFunc flushAndDropFunc = func(_ []*segmentFlushPack) {}
func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode { func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode {
factory := dependency.NewDefaultFactory(true) factory := dependency.NewDefaultFactory(true)
node := NewDataNode(ctx, factory) node := NewDataNode(ctx, factory)
node.SetSession(&sessionutil.Session{ServerID: 1})
rc := &RootCoordFactory{ rc := &RootCoordFactory{
ID: 0, ID: 0,

View File

@ -65,8 +65,8 @@ func (node *DataNode) WatchDmChannels(ctx context.Context, in *datapb.WatchDmCha
func (node *DataNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { func (node *DataNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
log.Debug("DataNode current state", zap.Any("State", node.stateCode.Load())) log.Debug("DataNode current state", zap.Any("State", node.stateCode.Load()))
nodeID := common.NotRegisteredID nodeID := common.NotRegisteredID
if node.session != nil && node.session.Registered() { if node.GetSession() != nil && node.session.Registered() {
nodeID = node.session.ServerID nodeID = node.GetSession().ServerID
} }
states := &milvuspb.ComponentStates{ states := &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{ State: &milvuspb.ComponentInfo{
@ -100,14 +100,15 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen
return errStatus, nil return errStatus, nil
} }
if req.GetBase().GetTargetID() != node.session.ServerID { serverID := node.GetSession().ServerID
if req.GetBase().GetTargetID() != serverID {
log.Warn("flush segment target id not matched", log.Warn("flush segment target id not matched",
zap.Int64("targetID", req.GetBase().GetTargetID()), zap.Int64("targetID", req.GetBase().GetTargetID()),
zap.Int64("serverID", node.session.ServerID), zap.Int64("serverID", serverID),
) )
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), node.session.ServerID), Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), serverID),
} }
return status, nil return status, nil
} }
@ -814,7 +815,7 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo
commonpbutil.WithMsgType(0), commonpbutil.WithMsgType(0),
commonpbutil.WithMsgID(0), commonpbutil.WithMsgID(0),
commonpbutil.WithTimeStamp(ts), commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(node.session.ServerID),
), ),
SegmentID: segmentID, SegmentID: segmentID,
CollectionID: req.GetImportTask().GetCollectionId(), CollectionID: req.GetImportTask().GetCollectionId(),

View File

@ -123,7 +123,7 @@ func (s *DataNodeServicesSuite) TestGetComponentStates() {
s.Assert().Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode) s.Assert().Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode)
s.Assert().Equal(common.NotRegisteredID, resp.State.NodeID) s.Assert().Equal(common.NotRegisteredID, resp.State.NodeID)
s.node.session = &sessionutil.Session{} s.node.SetSession(&sessionutil.Session{})
s.node.session.UpdateRegistered(true) s.node.session.UpdateRegistered(true)
resp, err = s.node.GetComponentStates(context.Background()) resp, err = s.node.GetComponentStates(context.Background())
s.Assert().NoError(err) s.Assert().NoError(err)
@ -203,7 +203,7 @@ func (s *DataNodeServicesSuite) TestFlushSegments() {
req := &datapb.FlushSegmentsRequest{ req := &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.node.session.ServerID, TargetID: s.node.GetSession().ServerID,
}, },
DbID: 0, DbID: 0,
CollectionID: 1, CollectionID: 1,
@ -277,7 +277,7 @@ func (s *DataNodeServicesSuite) TestFlushSegments() {
req = &datapb.FlushSegmentsRequest{ req = &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.node.session.ServerID, TargetID: s.node.GetSession().ServerID,
}, },
DbID: 0, DbID: 0,
CollectionID: 1, CollectionID: 1,
@ -290,7 +290,7 @@ func (s *DataNodeServicesSuite) TestFlushSegments() {
req = &datapb.FlushSegmentsRequest{ req = &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.node.session.ServerID, TargetID: s.node.GetSession().ServerID,
}, },
DbID: 0, DbID: 0,
CollectionID: 1, CollectionID: 1,
@ -314,7 +314,7 @@ func (s *DataNodeServicesSuite) TestShowConfigurations() {
//test closed server //test closed server
node := &DataNode{} node := &DataNode{}
node.session = &sessionutil.Session{ServerID: 1} node.SetSession(&sessionutil.Session{ServerID: 1})
node.stateCode.Store(commonpb.StateCode_Abnormal) node.stateCode.Store(commonpb.StateCode_Abnormal)
resp, err := node.ShowConfigurations(s.ctx, req) resp, err := node.ShowConfigurations(s.ctx, req)
@ -331,7 +331,7 @@ func (s *DataNodeServicesSuite) TestShowConfigurations() {
func (s *DataNodeServicesSuite) TestGetMetrics() { func (s *DataNodeServicesSuite) TestGetMetrics() {
node := &DataNode{} node := &DataNode{}
node.session = &sessionutil.Session{ServerID: 1} node.SetSession(&sessionutil.Session{ServerID: 1})
node.flowgraphManager = newFlowgraphManager() node.flowgraphManager = newFlowgraphManager()
// server is closed // server is closed
node.stateCode.Store(commonpb.StateCode_Abnormal) node.stateCode.Store(commonpb.StateCode_Abnormal)

View File

@ -32,7 +32,6 @@ import (
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockDataCoord struct { type MockDataCoord struct {
types.DataCoord types.DataCoord
@ -102,6 +101,15 @@ func (*MockDataCoord) SetAddress(address string) {
func (m *MockDataCoord) SetEtcdClient(etcdClient *clientv3.Client) { func (m *MockDataCoord) SetEtcdClient(etcdClient *clientv3.Client) {
} }
func (m *MockDataCoord) SetRootCoord(rootCoord types.RootCoord) {
}
func (m *MockDataCoord) SetDataNodeCreator(func(context.Context, string) (types.DataNode, error)) {
}
func (m *MockDataCoord) SetIndexNodeCreator(func(context.Context, string) (types.IndexNode, error)) {
}
func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return m.states, m.err return m.states, m.err
} }
@ -264,7 +272,6 @@ func (m *MockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexReq
return m.dropIndexResp, m.err return m.dropIndexResp, m.err
} }
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func Test_NewServer(t *testing.T) { func Test_NewServer(t *testing.T) {
paramtable.Init() paramtable.Init()
ctx := context.Background() ctx := context.Background()

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/tracer" "github.com/milvus-io/milvus/internal/tracer"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
@ -260,7 +261,7 @@ func (s *Server) init() error {
log.Error("failed to start RootCoord client", zap.Error(err)) log.Error("failed to start RootCoord client", zap.Error(err))
panic(err) panic(err)
} }
if err = funcutil.WaitForComponentHealthy(ctx, rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil { if err = componentutil.WaitForComponentHealthy(ctx, rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil {
log.Error("failed to wait for RootCoord client to be ready", zap.Error(err)) log.Error("failed to wait for RootCoord client to be ready", zap.Error(err))
panic(err) panic(err)
} }
@ -286,7 +287,7 @@ func (s *Server) init() error {
log.Error("failed to start DataCoord client", zap.Error(err)) log.Error("failed to start DataCoord client", zap.Error(err))
panic(err) panic(err)
} }
if err = funcutil.WaitForComponentInitOrHealthy(ctx, dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil { if err = componentutil.WaitForComponentInitOrHealthy(ctx, dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil {
log.Error("failed to wait for DataCoord client to be ready", zap.Error(err)) log.Error("failed to wait for DataCoord client to be ready", zap.Error(err))
panic(err) panic(err)
} }

View File

@ -84,6 +84,10 @@ func (m *MockDataNode) GetStateCode() commonpb.StateCode {
func (m *MockDataNode) SetAddress(address string) { func (m *MockDataNode) SetAddress(address string) {
} }
func (m *MockDataNode) GetAddress() string {
return ""
}
func (m *MockDataNode) SetRootCoord(rc types.RootCoord) error { func (m *MockDataNode) SetRootCoord(rc types.RootCoord) error {
return m.err return m.err
} }

View File

@ -259,11 +259,7 @@ func (s *Server) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsReq
// NewServer create a new IndexNode grpc server. // NewServer create a new IndexNode grpc server.
func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) { func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) {
ctx1, cancel := context.WithCancel(ctx) ctx1, cancel := context.WithCancel(ctx)
node, err := indexnode.NewIndexNode(ctx1, factory) node := indexnode.NewIndexNode(ctx1, factory)
if err != nil {
defer cancel()
return nil, err
}
return &Server{ return &Server{
loopCtx: ctx1, loopCtx: ctx1,

View File

@ -49,6 +49,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
@ -372,7 +373,7 @@ func (s *Server) init() error {
log.Debug("init RootCoord client for Proxy done") log.Debug("init RootCoord client for Proxy done")
log.Debug("Proxy wait for RootCoord to be healthy") log.Debug("Proxy wait for RootCoord to be healthy")
if err := funcutil.WaitForComponentHealthy(s.ctx, s.rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil { if err := componentutil.WaitForComponentHealthy(s.ctx, s.rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil {
log.Warn("Proxy failed to wait for RootCoord to be healthy", zap.Error(err)) log.Warn("Proxy failed to wait for RootCoord to be healthy", zap.Error(err))
return err return err
} }
@ -401,7 +402,7 @@ func (s *Server) init() error {
log.Debug("init DataCoord client for Proxy done") log.Debug("init DataCoord client for Proxy done")
log.Debug("Proxy wait for DataCoord to be healthy") log.Debug("Proxy wait for DataCoord to be healthy")
if err := funcutil.WaitForComponentHealthy(s.ctx, s.dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil { if err := componentutil.WaitForComponentHealthy(s.ctx, s.dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil {
log.Warn("Proxy failed to wait for DataCoord to be healthy", zap.Error(err)) log.Warn("Proxy failed to wait for DataCoord to be healthy", zap.Error(err))
return err return err
} }
@ -430,7 +431,7 @@ func (s *Server) init() error {
log.Debug("init QueryCoord client for Proxy done") log.Debug("init QueryCoord client for Proxy done")
log.Debug("Proxy wait for QueryCoord to be healthy") log.Debug("Proxy wait for QueryCoord to be healthy")
if err := funcutil.WaitForComponentHealthy(s.ctx, s.queryCoordClient, "QueryCoord", 1000000, time.Millisecond*200); err != nil { if err := componentutil.WaitForComponentHealthy(s.ctx, s.queryCoordClient, "QueryCoord", 1000000, time.Millisecond*200); err != nil {
log.Warn("Proxy failed to wait for QueryCoord to be healthy", zap.Error(err)) log.Warn("Proxy failed to wait for QueryCoord to be healthy", zap.Error(err))
return err return err
} }

View File

@ -797,6 +797,10 @@ func (m *MockProxy) SetQueryCoordClient(queryCoord types.QueryCoord) {
} }
func (m *MockProxy) SetQueryNodeCreator(func(ctx context.Context, addr string) (types.QueryNode, error)) {
}
func (m *MockProxy) GetRateLimiter() (types.Limiter, error) { func (m *MockProxy) GetRateLimiter() (types.Limiter, error) {
return nil, nil return nil, nil
} }
@ -808,6 +812,10 @@ func (m *MockProxy) UpdateStateCode(stateCode commonpb.StateCode) {
func (m *MockProxy) SetAddress(address string) { func (m *MockProxy) SetAddress(address string) {
} }
func (m *MockProxy) GetAddress() string {
return ""
}
func (m *MockProxy) SetEtcdClient(etcdClient *clientv3.Client) { func (m *MockProxy) SetEtcdClient(etcdClient *clientv3.Client) {
} }

View File

@ -40,6 +40,7 @@ import (
qc "github.com/milvus-io/milvus/internal/querycoordv2" qc "github.com/milvus-io/milvus/internal/querycoordv2"
"github.com/milvus-io/milvus/internal/tracer" "github.com/milvus-io/milvus/internal/tracer"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
@ -132,25 +133,25 @@ func (s *Server) init() error {
if s.rootCoord == nil { if s.rootCoord == nil {
s.rootCoord, err = rcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) s.rootCoord, err = rcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
if err != nil { if err != nil {
log.Debug("QueryCoord try to new RootCoord client failed", zap.Error(err)) log.Error("QueryCoord try to new RootCoord client failed", zap.Error(err))
panic(err) panic(err)
} }
} }
if err = s.rootCoord.Init(); err != nil { if err = s.rootCoord.Init(); err != nil {
log.Debug("QueryCoord RootCoordClient Init failed", zap.Error(err)) log.Error("QueryCoord RootCoordClient Init failed", zap.Error(err))
panic(err) panic(err)
} }
if err = s.rootCoord.Start(); err != nil { if err = s.rootCoord.Start(); err != nil {
log.Debug("QueryCoord RootCoordClient Start failed", zap.Error(err)) log.Error("QueryCoord RootCoordClient Start failed", zap.Error(err))
panic(err) panic(err)
} }
// wait for master init or healthy // wait for master init or healthy
log.Debug("QueryCoord try to wait for RootCoord ready") log.Debug("QueryCoord try to wait for RootCoord ready")
err = funcutil.WaitForComponentHealthy(s.loopCtx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) err = componentutil.WaitForComponentHealthy(s.loopCtx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200)
if err != nil { if err != nil {
log.Debug("QueryCoord wait for RootCoord ready failed", zap.Error(err)) log.Error("QueryCoord wait for RootCoord ready failed", zap.Error(err))
panic(err) panic(err)
} }
@ -163,23 +164,23 @@ func (s *Server) init() error {
if s.dataCoord == nil { if s.dataCoord == nil {
s.dataCoord, err = dcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) s.dataCoord, err = dcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
if err != nil { if err != nil {
log.Debug("QueryCoord try to new DataCoord client failed", zap.Error(err)) log.Error("QueryCoord try to new DataCoord client failed", zap.Error(err))
panic(err) panic(err)
} }
} }
if err = s.dataCoord.Init(); err != nil { if err = s.dataCoord.Init(); err != nil {
log.Debug("QueryCoord DataCoordClient Init failed", zap.Error(err)) log.Error("QueryCoord DataCoordClient Init failed", zap.Error(err))
panic(err) panic(err)
} }
if err = s.dataCoord.Start(); err != nil { if err = s.dataCoord.Start(); err != nil {
log.Debug("QueryCoord DataCoordClient Start failed", zap.Error(err)) log.Error("QueryCoord DataCoordClient Start failed", zap.Error(err))
panic(err) panic(err)
} }
log.Debug("QueryCoord try to wait for DataCoord ready") log.Debug("QueryCoord try to wait for DataCoord ready")
err = funcutil.WaitForComponentHealthy(s.loopCtx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200) err = componentutil.WaitForComponentHealthy(s.loopCtx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200)
if err != nil { if err != nil {
log.Debug("QueryCoord wait for DataCoord ready failed", zap.Error(err)) log.Error("QueryCoord wait for DataCoord ready failed", zap.Error(err))
panic(err) panic(err)
} }
if err := s.SetDataCoord(s.dataCoord); err != nil { if err := s.SetDataCoord(s.dataCoord); err != nil {

View File

@ -34,7 +34,7 @@ import (
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockQueryCoord struct { type MockQueryCoord struct {
states *milvuspb.ComponentStates states *milvuspb.ComponentStates
status *commonpb.Status status *commonpb.Status
@ -88,6 +88,9 @@ func (m *MockQueryCoord) SetDataCoord(types.DataCoord) error {
return nil return nil
} }
func (m *MockQueryCoord) SetQueryNodeCreator(func(ctx context.Context, addr string) (types.QueryNode, error)) {
}
func (m *MockQueryCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { func (m *MockQueryCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
log.Debug("MockQueryCoord::WaitForComponentStates") log.Debug("MockQueryCoord::WaitForComponentStates")
return m.states, m.err return m.states, m.err
@ -159,7 +162,7 @@ func (m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHea
}, m.err }, m.err
} }
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockRootCoord struct { type MockRootCoord struct {
types.RootCoord types.RootCoord
initErr error initErr error
@ -192,7 +195,7 @@ func (m *MockRootCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo
}, nil }, nil
} }
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockDataCoord struct { type MockDataCoord struct {
types.DataCoord types.DataCoord
initErr error initErr error
@ -225,7 +228,7 @@ func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo
}, nil }, nil
} }
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
paramtable.Init() paramtable.Init()
code := m.Run() code := m.Run()

View File

@ -34,7 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
) )
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockQueryNode struct { type MockQueryNode struct {
states *milvuspb.ComponentStates states *milvuspb.ComponentStates
status *commonpb.Status status *commonpb.Status
@ -128,6 +128,10 @@ func (m *MockQueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Sy
func (m *MockQueryNode) SetAddress(address string) { func (m *MockQueryNode) SetAddress(address string) {
} }
func (m *MockQueryNode) GetAddress() string {
return ""
}
func (m *MockQueryNode) SetEtcdClient(client *clientv3.Client) { func (m *MockQueryNode) SetEtcdClient(client *clientv3.Client) {
} }

View File

@ -170,19 +170,35 @@ func (s *Server) init() error {
if s.newDataCoordClient != nil { if s.newDataCoordClient != nil {
log.Debug("RootCoord start to create DataCoord client") log.Debug("RootCoord start to create DataCoord client")
dataCoord := s.newDataCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) dataCoord := s.newDataCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
if err := s.rootCoord.SetDataCoord(s.ctx, dataCoord); err != nil { s.dataCoord = dataCoord
if err = s.dataCoord.Init(); err != nil {
log.Error("RootCoord DataCoordClient Init failed", zap.Error(err))
panic(err)
}
if err = s.dataCoord.Start(); err != nil {
log.Error("RootCoord DataCoordClient Start failed", zap.Error(err))
panic(err)
}
if err := s.rootCoord.SetDataCoord(dataCoord); err != nil {
panic(err) panic(err)
} }
s.dataCoord = dataCoord
} }
if s.newQueryCoordClient != nil { if s.newQueryCoordClient != nil {
log.Debug("RootCoord start to create QueryCoord client") log.Debug("RootCoord start to create QueryCoord client")
queryCoord := s.newQueryCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) queryCoord := s.newQueryCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli)
s.queryCoord = queryCoord
if err := s.queryCoord.Init(); err != nil {
log.Error("RootCoord QueryCoordClient Init failed", zap.Error(err))
panic(err)
}
if err := s.queryCoord.Start(); err != nil {
log.Error("RootCoord QueryCoordClient Start failed", zap.Error(err))
panic(err)
}
if err := s.rootCoord.SetQueryCoord(queryCoord); err != nil { if err := s.rootCoord.SetQueryCoord(queryCoord); err != nil {
panic(err) panic(err)
} }
s.queryCoord = queryCoord
} }
return s.rootCoord.Init() return s.rootCoord.Init()

View File

@ -18,6 +18,7 @@ package grpcrootcoord
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"path" "path"
@ -35,7 +36,6 @@ import (
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
) )
@ -66,7 +66,7 @@ func (m *mockCore) SetAddress(address string) {
func (m *mockCore) SetEtcdClient(etcdClient *clientv3.Client) { func (m *mockCore) SetEtcdClient(etcdClient *clientv3.Client) {
} }
func (m *mockCore) SetDataCoord(context.Context, types.DataCoord) error { func (m *mockCore) SetDataCoord(types.DataCoord) error {
return nil return nil
} }
@ -74,6 +74,9 @@ func (m *mockCore) SetQueryCoord(types.QueryCoord) error {
return nil return nil
} }
func (m *mockCore) SetProxyCreator(func(ctx context.Context, addr string) (types.Proxy, error)) {
}
func (m *mockCore) Register() error { func (m *mockCore) Register() error {
return nil return nil
} }
@ -92,13 +95,15 @@ func (m *mockCore) Stop() error {
type mockDataCoord struct { type mockDataCoord struct {
types.DataCoord types.DataCoord
initErr error
startErr error
} }
func (m *mockDataCoord) Init() error { func (m *mockDataCoord) Init() error {
return nil return m.initErr
} }
func (m *mockDataCoord) Start() error { func (m *mockDataCoord) Start() error {
return nil return m.startErr
} }
func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return &milvuspb.ComponentStates{ return &milvuspb.ComponentStates{
@ -119,19 +124,21 @@ func (m *mockDataCoord) Stop() error {
return fmt.Errorf("stop error") return fmt.Errorf("stop error")
} }
type mockQuery struct { type mockQueryCoord struct {
types.QueryCoord types.QueryCoord
initErr error
startErr error
} }
func (m *mockQuery) Init() error { func (m *mockQueryCoord) Init() error {
return nil return m.initErr
} }
func (m *mockQuery) Start() error { func (m *mockQueryCoord) Start() error {
return nil return m.startErr
} }
func (m *mockQuery) Stop() error { func (m *mockQueryCoord) Stop() error {
return fmt.Errorf("stop error") return fmt.Errorf("stop error")
} }
@ -154,7 +161,7 @@ func TestRun(t *testing.T) {
return &mockDataCoord{} return &mockDataCoord{}
} }
svr.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord { svr.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord {
return &mockQuery{} return &mockQueryCoord{}
} }
paramtable.Get().Save(rcServerConfig.Port.Key, fmt.Sprintf("%d", rand.Int()%100+10000)) paramtable.Get().Save(rcServerConfig.Port.Key, fmt.Sprintf("%d", rand.Int()%100+10000))
@ -192,19 +199,66 @@ func TestRun(t *testing.T) {
} }
func initEtcd(etcdEndpoints []string) (*clientv3.Client, error) { func TestServerRun_DataCoordClientInitErr(t *testing.T) {
var etcdCli *clientv3.Client paramtable.Init()
connectEtcdFn := func() error { ctx := context.Background()
etcd, err := clientv3.New(clientv3.Config{Endpoints: etcdEndpoints, DialTimeout: 5 * time.Second}) server, err := NewServer(ctx, nil)
if err != nil { assert.Nil(t, err)
return err assert.NotNil(t, server)
}
etcdCli = etcd server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoord {
return nil return &mockDataCoord{initErr: errors.New("mock datacoord init error")}
} }
err := retry.Do(context.TODO(), connectEtcdFn, retry.Attempts(100)) assert.Panics(t, func() { server.Run() })
if err != nil {
return nil, err err = server.Stop()
} assert.Nil(t, err)
return etcdCli, nil }
func TestServerRun_DataCoordClientStartErr(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoord {
return &mockDataCoord{startErr: errors.New("mock datacoord start error")}
}
assert.Panics(t, func() { server.Run() })
err = server.Stop()
assert.Nil(t, err)
}
func TestServerRun_QueryCoordClientInitErr(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord {
return &mockQueryCoord{initErr: errors.New("mock querycoord init error")}
}
assert.Panics(t, func() { server.Run() })
err = server.Stop()
assert.Nil(t, err)
}
func TestServer_QueryCoordClientStartErr(t *testing.T) {
paramtable.Init()
ctx := context.Background()
server, err := NewServer(ctx, nil)
assert.Nil(t, err)
assert.NotNil(t, server)
server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord {
return &mockQueryCoord{startErr: errors.New("mock querycoord start error")}
}
assert.Panics(t, func() { server.Run() })
err = server.Stop()
assert.Nil(t, err)
} }

View File

@ -97,7 +97,7 @@ type IndexNode struct {
} }
// NewIndexNode creates a new IndexNode component. // NewIndexNode creates a new IndexNode component.
func NewIndexNode(ctx context.Context, factory dependency.Factory) (*IndexNode, error) { func NewIndexNode(ctx context.Context, factory dependency.Factory) *IndexNode {
log.Debug("New IndexNode ...") log.Debug("New IndexNode ...")
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
ctx1, cancel := context.WithCancel(ctx) ctx1, cancel := context.WithCancel(ctx)
@ -109,13 +109,10 @@ func NewIndexNode(ctx context.Context, factory dependency.Factory) (*IndexNode,
tasks: map[taskKey]*taskInfo{}, tasks: map[taskKey]*taskInfo{},
} }
b.UpdateStateCode(commonpb.StateCode_Abnormal) b.UpdateStateCode(commonpb.StateCode_Abnormal)
sc, err := NewTaskScheduler(b.loopCtx) sc := NewTaskScheduler(b.loopCtx)
if err != nil {
return nil, err
}
b.sched = sc b.sched = sc
return b, nil return b
} }
// Register register index node at etcd. // Register register index node at etcd.
@ -349,3 +346,7 @@ func (i *IndexNode) ShowConfigurations(ctx context.Context, req *internalpb.Show
func (i *IndexNode) SetAddress(address string) { func (i *IndexNode) SetAddress(address string) {
i.address = address i.address = address
} }
func (i *IndexNode) GetAddress() string {
return i.address
}

View File

@ -18,10 +18,7 @@ func NewMockIndexNodeComponent(ctx context.Context) (types.IndexNodeComponent, e
chunkMgr: &mockChunkmgr{}, chunkMgr: &mockChunkmgr{},
} }
node, err := NewIndexNode(ctx, factory) node := NewIndexNode(ctx, factory)
if err != nil {
return nil, err
}
startEmbedEtcd() startEmbedEtcd()
etcdCli := getEtcdClient() etcdCli := getEtcdClient()

View File

@ -183,6 +183,11 @@ func (m *Mock) Register() error {
func (m *Mock) SetAddress(address string) { func (m *Mock) SetAddress(address string) {
m.CallSetAddress(address) m.CallSetAddress(address)
} }
func (m *Mock) GetAddress() string {
return ""
}
func (m *Mock) SetEtcdClient(etcdClient *clientv3.Client) { func (m *Mock) SetEtcdClient(etcdClient *clientv3.Client) {
} }
@ -209,7 +214,7 @@ func (m *Mock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest)
return m.CallGetMetrics(ctx, req) return m.CallGetMetrics(ctx, req)
} }
//ShowConfigurations returns the configurations of Mock indexNode matching req.Pattern // ShowConfigurations returns the configurations of Mock indexNode matching req.Pattern
func (m *Mock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { func (m *Mock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
return m.CallShowConfigurations(ctx, req) return m.CallShowConfigurations(ctx, req)
} }

View File

@ -461,8 +461,7 @@ func TestComponentState(t *testing.T) {
ctx = context.TODO() ctx = context.TODO()
) )
Params.Init() Params.Init()
in, err := NewIndexNode(ctx, factory) in := NewIndexNode(ctx, factory)
assert.Nil(t, err)
in.SetEtcdClient(getEtcdClient()) in.SetEtcdClient(getEtcdClient())
state, err := in.GetComponentStates(ctx) state, err := in.GetComponentStates(ctx)
assert.Nil(t, err) assert.Nil(t, err)
@ -497,8 +496,7 @@ func TestGetTimeTickChannel(t *testing.T) {
ctx = context.TODO() ctx = context.TODO()
) )
Params.Init() Params.Init()
in, err := NewIndexNode(ctx, factory) in := NewIndexNode(ctx, factory)
assert.Nil(t, err)
ret, err := in.GetTimeTickChannel(ctx) ret, err := in.GetTimeTickChannel(ctx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, ret.Status.ErrorCode, commonpb.ErrorCode_Success) assert.Equal(t, ret.Status.ErrorCode, commonpb.ErrorCode_Success)
@ -512,8 +510,7 @@ func TestGetStatisticChannel(t *testing.T) {
ctx = context.TODO() ctx = context.TODO()
) )
Params.Init() Params.Init()
in, err := NewIndexNode(ctx, factory) in := NewIndexNode(ctx, factory)
assert.Nil(t, err)
ret, err := in.GetStatisticsChannel(ctx) ret, err := in.GetStatisticsChannel(ctx)
assert.Nil(t, err) assert.Nil(t, err)
@ -528,8 +525,7 @@ func TestIndexTaskWhenStoppingNode(t *testing.T) {
ctx = context.TODO() ctx = context.TODO()
) )
Params.Init() Params.Init()
in, err := NewIndexNode(ctx, factory) in := NewIndexNode(ctx, factory)
assert.Nil(t, err)
in.loadOrStoreTask("cluster-1", 1, &taskInfo{ in.loadOrStoreTask("cluster-1", 1, &taskInfo{
state: commonpb.IndexState_InProgress, state: commonpb.IndexState_InProgress,
@ -555,6 +551,19 @@ func TestIndexTaskWhenStoppingNode(t *testing.T) {
} }
} }
func TestGetSetAddress(t *testing.T) {
var (
factory = &mockFactory{
chunkMgr: &mockChunkmgr{},
}
ctx = context.TODO()
)
Params.Init()
in := NewIndexNode(ctx, factory)
in.SetAddress("address")
assert.Equal(t, "address", in.GetAddress())
}
func TestInitErr(t *testing.T) { func TestInitErr(t *testing.T) {
// var ( // var (
// factory = &mockFactory{} // factory = &mockFactory{}

View File

@ -172,7 +172,7 @@ type TaskScheduler struct {
} }
// NewTaskScheduler creates a new task scheduler of indexing tasks. // NewTaskScheduler creates a new task scheduler of indexing tasks.
func NewTaskScheduler(ctx context.Context) (*TaskScheduler, error) { func NewTaskScheduler(ctx context.Context) *TaskScheduler {
ctx1, cancel := context.WithCancel(ctx) ctx1, cancel := context.WithCancel(ctx)
s := &TaskScheduler{ s := &TaskScheduler{
ctx: ctx1, ctx: ctx1,
@ -181,7 +181,7 @@ func NewTaskScheduler(ctx context.Context) (*TaskScheduler, error) {
} }
s.IndexBuildQueue = NewIndexBuildTaskQueue(s) s.IndexBuildQueue = NewIndexBuildTaskQueue(s)
return s, nil return s
} }
func (sched *TaskScheduler) scheduleIndexBuildTask() []task { func (sched *TaskScheduler) scheduleIndexBuildTask() []task {

View File

@ -157,8 +157,7 @@ func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expect
func TestIndexTaskScheduler(t *testing.T) { func TestIndexTaskScheduler(t *testing.T) {
Params.Init() Params.Init()
scheduler, err := NewTaskScheduler(context.TODO()) scheduler := NewTaskScheduler(context.TODO())
assert.Nil(t, err)
scheduler.Start() scheduler.Start()
tasks := make([]task, 0) tasks := make([]task, 0)
@ -188,15 +187,14 @@ func TestIndexTaskScheduler(t *testing.T) {
assert.Equal(t, tasks[len(tasks)-1].GetState(), tasks[len(tasks)-1].(*fakeTask).expectedState) assert.Equal(t, tasks[len(tasks)-1].GetState(), tasks[len(tasks)-1].(*fakeTask).expectedState)
assert.Equal(t, tasks[len(tasks)-1].Ctx().(*stagectx).curstate, fakeTaskState(fakeTaskSavedIndexes)) assert.Equal(t, tasks[len(tasks)-1].Ctx().(*stagectx).curstate, fakeTaskState(fakeTaskSavedIndexes))
scheduler, err = NewTaskScheduler(context.TODO()) scheduler = NewTaskScheduler(context.TODO())
assert.Nil(t, err)
tasks = make([]task, 0, 1024) tasks = make([]task, 0, 1024)
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {
tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished)) tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished))
assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(tasks[len(tasks)-1])) assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(tasks[len(tasks)-1]))
} }
failTask := newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished) failTask := newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished)
err = scheduler.IndexBuildQueue.Enqueue(failTask) err := scheduler.IndexBuildQueue.Enqueue(failTask)
assert.Error(t, err) assert.Error(t, err)
failTask.Reset() failTask.Reset()

View File

@ -438,6 +438,10 @@ func (node *Proxy) SetAddress(address string) {
node.address = address node.address = address
} }
func (node *Proxy) GetAddress() string {
return node.address
}
// SetEtcdClient sets etcd client for proxy. // SetEtcdClient sets etcd client for proxy.
func (node *Proxy) SetEtcdClient(client *clientv3.Client) { func (node *Proxy) SetEtcdClient(client *clientv3.Client) {
node.etcdCli = client node.etcdCli = client
@ -458,6 +462,10 @@ func (node *Proxy) SetQueryCoordClient(cli types.QueryCoord) {
node.queryCoord = cli node.queryCoord = cli
} }
func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string) (types.QueryNode, error)) {
node.shardMgr.clientCreator = f
}
// GetRateLimiter returns the rateLimiter in Proxy. // GetRateLimiter returns the rateLimiter in Proxy.
func (node *Proxy) GetRateLimiter() (types.Limiter, error) { func (node *Proxy) GetRateLimiter() (types.Limiter, error) {
if node.multiRateLimiter == nil { if node.multiRateLimiter == nil {

View File

@ -58,6 +58,7 @@ import (
"github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/tracer" "github.com/milvus-io/milvus/internal/tracer"
"github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/crypto" "github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/distance" "github.com/milvus-io/milvus/internal/util/distance"
@ -474,6 +475,7 @@ func TestProxy(t *testing.T) {
var p paramtable.GrpcServerConfig var p paramtable.GrpcServerConfig
p.Init(typeutil.ProxyRole, &base) p.Init(typeutil.ProxyRole, &base)
testServer.Proxy.SetAddress(p.GetAddress()) testServer.Proxy.SetAddress(p.GetAddress())
assert.Equal(t, p.GetAddress(), testServer.Proxy.GetAddress())
go testServer.startGrpc(ctx, &wg, &p) go testServer.startGrpc(ctx, &wg, &p)
assert.NoError(t, testServer.waitForGrpcReady()) assert.NoError(t, testServer.waitForGrpcReady())
@ -482,7 +484,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = rootCoordClient.Init() err = rootCoordClient.Init()
assert.NoError(t, err) assert.NoError(t, err)
err = funcutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration) err = componentutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration)
assert.NoError(t, err) assert.NoError(t, err)
proxy.SetRootCoordClient(rootCoordClient) proxy.SetRootCoordClient(rootCoordClient)
log.Info("Proxy set root coordinator client") log.Info("Proxy set root coordinator client")
@ -491,7 +493,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = dataCoordClient.Init() err = dataCoordClient.Init()
assert.NoError(t, err) assert.NoError(t, err)
err = funcutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration) err = componentutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration)
assert.NoError(t, err) assert.NoError(t, err)
proxy.SetDataCoordClient(dataCoordClient) proxy.SetDataCoordClient(dataCoordClient)
log.Info("Proxy set data coordinator client") log.Info("Proxy set data coordinator client")
@ -500,9 +502,10 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = queryCoordClient.Init() err = queryCoordClient.Init()
assert.NoError(t, err) assert.NoError(t, err)
err = funcutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration) err = componentutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration)
assert.NoError(t, err) assert.NoError(t, err)
proxy.SetQueryCoordClient(queryCoordClient) proxy.SetQueryCoordClient(queryCoordClient)
proxy.SetQueryNodeCreator(defaultQueryNodeClientCreator)
log.Info("Proxy set query coordinator client") log.Info("Proxy set query coordinator client")
proxy.UpdateStateCode(commonpb.StateCode_Initializing) proxy.UpdateStateCode(commonpb.StateCode_Initializing)

View File

@ -312,6 +312,7 @@ func (sa *segIDAssigner) syncSegments() (bool, error) {
sa.segReqs = nil sa.segReqs = nil
log.Debug("syncSegments call dataCoord.AssignSegmentID", zap.String("request", req.String()))
resp, err := sa.dataCoord.AssignSegmentID(context.Background(), req) resp, err := sa.dataCoord.AssignSegmentID(context.Background(), req)
if err != nil { if err != nil {

View File

@ -106,7 +106,7 @@ func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
return func(s *shardClientMgr) { s.clientCreator = creator } return func(s *shardClientMgr) { s.clientCreator = creator }
} }
func defaultShardClientCreator(ctx context.Context, addr string) (types.QueryNode, error) { func defaultQueryNodeClientCreator(ctx context.Context, addr string) (types.QueryNode, error) {
return qnClient.NewClient(ctx, addr) return qnClient.NewClient(ctx, addr)
} }
@ -117,7 +117,7 @@ func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgr {
sync.RWMutex sync.RWMutex
data map[UniqueID]*shardClient data map[UniqueID]*shardClient
}{data: make(map[UniqueID]*shardClient)}, }{data: make(map[UniqueID]*shardClient)},
clientCreator: defaultShardClientCreator, clientCreator: defaultQueryNodeClientCreator,
} }
for _, opt := range options { for _, opt := range options {
opt(s) opt(s)

View File

@ -85,8 +85,9 @@ type Server struct {
broker meta.Broker broker meta.Broker
// Session // Session
cluster session.Cluster cluster session.Cluster
nodeMgr *session.NodeManager nodeMgr *session.NodeManager
queryNodeCreator session.QueryNodeCreator
// Schedulers // Schedulers
jobScheduler *job.Scheduler jobScheduler *job.Scheduler
@ -117,6 +118,7 @@ func NewQueryCoord(ctx context.Context) (*Server, error) {
cancel: cancel, cancel: cancel,
} }
server.UpdateStateCode(commonpb.StateCode_Abnormal) server.UpdateStateCode(commonpb.StateCode_Abnormal)
server.queryNodeCreator = session.DefaultQueryNodeCreator
return server, nil return server, nil
} }
@ -182,7 +184,7 @@ func (s *Server) Init() error {
// Init session // Init session
log.Info("init session") log.Info("init session")
s.nodeMgr = session.NewNodeManager() s.nodeMgr = session.NewNodeManager()
s.cluster = session.NewCluster(s.nodeMgr) s.cluster = session.NewCluster(s.nodeMgr, s.queryNodeCreator)
// Init schedulers // Init schedulers
log.Info("init schedulers") log.Info("init schedulers")
@ -479,6 +481,10 @@ func (s *Server) SetDataCoord(dataCoord types.DataCoord) error {
return nil return nil
} }
func (s *Server) SetQueryNodeCreator(f func(ctx context.Context, addr string) (types.QueryNode, error)) {
s.queryNodeCreator = f
}
func (s *Server) recover() error { func (s *Server) recover() error {
// Recover target managers // Recover target managers
group, ctx := errgroup.WithContext(s.ctx) group, ctx := errgroup.WithContext(s.ctx)

View File

@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/mocks" "github.com/milvus-io/milvus/internal/querycoordv2/mocks"
"github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
@ -429,7 +430,7 @@ func newQueryCoord() (*Server, error) {
return nil, err return nil, err
} }
server.SetEtcdClient(etcdCli) server.SetEtcdClient(etcdCli)
server.SetQueryNodeCreator(session.DefaultQueryNodeCreator)
err = server.Init() err = server.Init()
return server, err return server, err
} }

View File

@ -29,6 +29,7 @@ import (
grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -68,9 +69,15 @@ type QueryCluster struct {
stopOnce sync.Once stopOnce sync.Once
} }
func NewCluster(nodeManager *NodeManager) *QueryCluster { type QueryNodeCreator func(ctx context.Context, addr string) (types.QueryNode, error)
func DefaultQueryNodeCreator(ctx context.Context, addr string) (types.QueryNode, error) {
return grpcquerynodeclient.NewClient(ctx, addr)
}
func NewCluster(nodeManager *NodeManager, queryNodeCreator QueryNodeCreator) *QueryCluster {
c := &QueryCluster{ c := &QueryCluster{
clients: newClients(), clients: newClients(queryNodeCreator),
nodeManager: nodeManager, nodeManager: nodeManager,
ch: make(chan struct{}), ch: make(chan struct{}),
} }
@ -112,7 +119,7 @@ func (c *QueryCluster) updateLoop() {
func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
var status *commonpb.Status var status *commonpb.Status
var err error var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) { err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.LoadSegmentsRequest) req := proto.Clone(req).(*querypb.LoadSegmentsRequest)
req.Base.TargetID = nodeID req.Base.TargetID = nodeID
status, err = cli.LoadSegments(ctx, req) status, err = cli.LoadSegments(ctx, req)
@ -126,7 +133,7 @@ func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *quer
func (c *QueryCluster) WatchDmChannels(ctx context.Context, nodeID int64, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { func (c *QueryCluster) WatchDmChannels(ctx context.Context, nodeID int64, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
var status *commonpb.Status var status *commonpb.Status
var err error var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) { err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.WatchDmChannelsRequest) req := proto.Clone(req).(*querypb.WatchDmChannelsRequest)
req.Base.TargetID = nodeID req.Base.TargetID = nodeID
status, err = cli.WatchDmChannels(ctx, req) status, err = cli.WatchDmChannels(ctx, req)
@ -140,7 +147,7 @@ func (c *QueryCluster) WatchDmChannels(ctx context.Context, nodeID int64, req *q
func (c *QueryCluster) UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { func (c *QueryCluster) UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
var status *commonpb.Status var status *commonpb.Status
var err error var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) { err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.UnsubDmChannelRequest) req := proto.Clone(req).(*querypb.UnsubDmChannelRequest)
req.Base.TargetID = nodeID req.Base.TargetID = nodeID
status, err = cli.UnsubDmChannel(ctx, req) status, err = cli.UnsubDmChannel(ctx, req)
@ -154,7 +161,7 @@ func (c *QueryCluster) UnsubDmChannel(ctx context.Context, nodeID int64, req *qu
func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
var status *commonpb.Status var status *commonpb.Status
var err error var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) { err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.ReleaseSegmentsRequest) req := proto.Clone(req).(*querypb.ReleaseSegmentsRequest)
req.Base.TargetID = nodeID req.Base.TargetID = nodeID
status, err = cli.ReleaseSegments(ctx, req) status, err = cli.ReleaseSegments(ctx, req)
@ -168,7 +175,7 @@ func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *q
func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
var resp *querypb.GetDataDistributionResponse var resp *querypb.GetDataDistributionResponse
var err error var err error
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) { err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.GetDataDistributionRequest) req := proto.Clone(req).(*querypb.GetDataDistributionRequest)
req.Base = &commonpb.MsgBase{ req.Base = &commonpb.MsgBase{
TargetID: nodeID, TargetID: nodeID,
@ -186,7 +193,7 @@ func (c *QueryCluster) GetMetrics(ctx context.Context, nodeID int64, req *milvus
resp *milvuspb.GetMetricsResponse resp *milvuspb.GetMetricsResponse
err error err error
) )
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) { err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
resp, err = cli.GetMetrics(ctx, req) resp, err = cli.GetMetrics(ctx, req)
}) })
if err1 != nil { if err1 != nil {
@ -200,7 +207,7 @@ func (c *QueryCluster) SyncDistribution(ctx context.Context, nodeID int64, req *
resp *commonpb.Status resp *commonpb.Status
err error err error
) )
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) { err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.SyncDistributionRequest) req := proto.Clone(req).(*querypb.SyncDistributionRequest)
req.Base.TargetID = nodeID req.Base.TargetID = nodeID
resp, err = cli.SyncDistribution(ctx, req) resp, err = cli.SyncDistribution(ctx, req)
@ -216,18 +223,16 @@ func (c *QueryCluster) GetComponentStates(ctx context.Context, nodeID int64) (*m
resp *milvuspb.ComponentStates resp *milvuspb.ComponentStates
err error err error
) )
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
err1 := c.send(ctx, nodeID, func(cli *grpcquerynodeclient.Client) {
resp, err = cli.GetComponentStates(ctx) resp, err = cli.GetComponentStates(ctx)
}) })
if err1 != nil { if err1 != nil {
return nil, err1 return nil, err1
} }
return resp, err return resp, err
} }
func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli *grpcquerynodeclient.Client)) error { func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli types.QueryNode)) error {
node := c.nodeManager.Get(nodeID) node := c.nodeManager.Get(nodeID)
if node == nil { if node == nil {
return WrapErrNodeNotFound(nodeID) return WrapErrNodeNotFound(nodeID)
@ -244,7 +249,8 @@ func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli *grpc
type clients struct { type clients struct {
sync.RWMutex sync.RWMutex
clients map[int64]*grpcquerynodeclient.Client // nodeID -> client clients map[int64]types.QueryNode // nodeID -> client
queryNodeCreator QueryNodeCreator
} }
func (c *clients) getAllNodeIDs() []int64 { func (c *clients) getAllNodeIDs() []int64 {
@ -258,15 +264,15 @@ func (c *clients) getAllNodeIDs() []int64 {
return ret return ret
} }
func (c *clients) getOrCreate(ctx context.Context, node *NodeInfo) (*grpcquerynodeclient.Client, error) { func (c *clients) getOrCreate(ctx context.Context, node *NodeInfo) (types.QueryNode, error) {
if cli := c.get(node.ID()); cli != nil { if cli := c.get(node.ID()); cli != nil {
return cli, nil return cli, nil
} }
return c.create(node) return c.create(node)
} }
func createNewClient(ctx context.Context, addr string) (*grpcquerynodeclient.Client, error) { func createNewClient(ctx context.Context, addr string, queryNodeCreator QueryNodeCreator) (types.QueryNode, error) {
newCli, err := grpcquerynodeclient.NewClient(ctx, addr) newCli, err := queryNodeCreator(ctx, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -279,13 +285,13 @@ func createNewClient(ctx context.Context, addr string) (*grpcquerynodeclient.Cli
return newCli, nil return newCli, nil
} }
func (c *clients) create(node *NodeInfo) (*grpcquerynodeclient.Client, error) { func (c *clients) create(node *NodeInfo) (types.QueryNode, error) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if cli, ok := c.clients[node.ID()]; ok { if cli, ok := c.clients[node.ID()]; ok {
return cli, nil return cli, nil
} }
cli, err := createNewClient(context.Background(), node.Addr()) cli, err := createNewClient(context.Background(), node.Addr(), c.queryNodeCreator)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -293,7 +299,7 @@ func (c *clients) create(node *NodeInfo) (*grpcquerynodeclient.Client, error) {
return cli, nil return cli, nil
} }
func (c *clients) get(nodeID int64) *grpcquerynodeclient.Client { func (c *clients) get(nodeID int64) types.QueryNode {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
return c.clients[nodeID] return c.clients[nodeID]
@ -320,6 +326,9 @@ func (c *clients) closeAll() {
} }
} }
func newClients() *clients { func newClients(queryNodeCreator QueryNodeCreator) *clients {
return &clients{clients: make(map[int64]*grpcquerynodeclient.Client)} return &clients{
clients: make(map[int64]types.QueryNode),
queryNodeCreator: queryNodeCreator,
}
} }

View File

@ -90,7 +90,7 @@ func (suite *ClusterTestSuite) setupCluster() {
node := NewNodeInfo(int64(i), lis.Addr().String()) node := NewNodeInfo(int64(i), lis.Addr().String())
suite.nodeManager.Add(node) suite.nodeManager.Add(node)
} }
suite.cluster = NewCluster(suite.nodeManager) suite.cluster = NewCluster(suite.nodeManager, DefaultQueryNodeCreator)
} }
func (suite *ClusterTestSuite) createTestServers() []querypb.QueryNodeServer { func (suite *ClusterTestSuite) createTestServers() []querypb.QueryNodeServer {

View File

@ -41,7 +41,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
@ -75,7 +75,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon
} }
nodeID := common.NotRegisteredID nodeID := common.NotRegisteredID
if node.session != nil && node.session.Registered() { if node.session != nil && node.session.Registered() {
nodeID = paramtable.GetNodeID() nodeID = node.GetSession().ServerID
} }
info := &milvuspb.ComponentInfo{ info := &milvuspb.ComponentInfo{
NodeID: nodeID, NodeID: nodeID,
@ -172,7 +172,7 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
} }
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID()) failRet.Status.Reason = msgQueryNodeIsUnhealthy(node.GetSession().ServerID)
return failRet, nil return failRet, nil
} }
node.wg.Add(1) node.wg.Add(1)
@ -299,9 +299,10 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que
// WatchDmChannels create consumers on dmChannels to receive Incremental datawhich is the important part of real-time query // WatchDmChannels create consumers on dmChannels to receive Incremental datawhich is the important part of real-time query
func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
// check node healthy // check node healthy
if !node.isHealthy() { if !node.isHealthy() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -312,17 +313,17 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
defer node.wg.Done() defer node.wg.Done()
// check target matches // check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() { if in.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), paramtable.GetNodeID()), Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), nodeID),
} }
return status, nil return status, nil
} }
log := log.With( log := log.With(
zap.Int64("collectionID", in.GetCollectionID()), zap.Int64("collectionID", in.GetCollectionID()),
zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("nodeID", nodeID),
zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string { zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string {
return info.GetChannelName() return info.GetChannelName()
})), })),
@ -390,8 +391,9 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) {
// check node healthy // check node healthy
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -402,10 +404,10 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
defer node.wg.Done() defer node.wg.Done()
// check target matches // check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() { if req.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
} }
return status, nil return status, nil
} }
@ -451,9 +453,10 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
// LoadSegments load historical data into query node, historical data can be vector data or index // LoadSegments load historical data into query node, historical data can be vector data or index
func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
// check node healthy // check node healthy
if !node.isHealthy() { if !node.isHealthy() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -464,10 +467,10 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
defer node.wg.Done() defer node.wg.Done()
// check target matches // check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() { if in.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), paramtable.GetNodeID()), Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), nodeID),
} }
return status, nil return status, nil
} }
@ -496,7 +499,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
startTs := time.Now() startTs := time.Now()
log.Info("loadSegmentsTask init", zap.Int64("collectionID", in.CollectionID), log.Info("loadSegmentsTask init", zap.Int64("collectionID", in.CollectionID),
zap.Int64s("segmentIDs", segmentIDs), zap.Int64s("segmentIDs", segmentIDs),
zap.Int64("nodeID", paramtable.GetNodeID())) zap.Int64("nodeID", nodeID))
// TODO remove concurrent load segment for now, unless we solve the memory issue // TODO remove concurrent load segment for now, unless we solve the memory issue
log.Info("loadSegmentsTask start ", zap.Int64("collectionID", in.CollectionID), log.Info("loadSegmentsTask start ", zap.Int64("collectionID", in.CollectionID),
@ -512,7 +515,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
return status, nil return status, nil
} }
log.Info("loadSegmentsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", paramtable.GetNodeID())) log.Info("loadSegmentsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", nodeID))
waitFunc := func() (*commonpb.Status, error) { waitFunc := func() (*commonpb.Status, error) {
err = task.WaitToFinish() err = task.WaitToFinish()
@ -527,7 +530,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
log.Warn(err.Error()) log.Warn(err.Error())
return status, nil return status, nil
} }
log.Info("loadSegmentsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", paramtable.GetNodeID())) log.Info("loadSegmentsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", nodeID))
return &commonpb.Status{ return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
}, nil }, nil
@ -539,7 +542,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment
// ReleaseCollection clears all data related to this collection on the querynode // ReleaseCollection clears all data related to this collection on the querynode
func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -587,7 +590,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
// ReleasePartitions clears all data related to this partition on the querynode // ReleasePartitions clears all data related to this partition on the querynode
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -634,8 +637,9 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
// ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID // ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID
func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -646,10 +650,10 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
defer node.wg.Done() defer node.wg.Done()
// check target matches // check target matches
if in.GetBase().GetTargetID() != paramtable.GetNodeID() { if in.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), paramtable.GetNodeID()), Reason: common.WrapNodeIDNotMatchMsg(in.GetBase().GetTargetID(), nodeID),
} }
return status, nil return status, nil
} }
@ -684,7 +688,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS
// GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ... // GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ...
func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID)
res := &querypb.GetSegmentInfoResponse{ res := &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -732,13 +736,14 @@ func filterSegmentInfo(segmentInfos []*querypb.SegmentInfo, segmentIDs map[int64
// Search performs replica search tasks. // Search performs replica search tasks.
func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
if !node.IsStandAlone && req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() { nodeID := node.GetSession().ServerID
if !node.IsStandAlone && req.GetReq().GetBase().GetTargetID() != nodeID {
return &internalpb.SearchResults{ return &internalpb.SearchResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s", Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s",
paramtable.GetNodeID(), nodeID,
common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID())), common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), nodeID)),
}, },
}, nil }, nil
} }
@ -807,13 +812,14 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
tr.CtxElapse(ctx, "search done in all shards") tr.CtxElapse(ctx, "search done in all shards")
rateCol.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq())) rateCol.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq()))
rateCol.Add(metricsinfo.SearchThroughput, float64(proto.Size(req))) rateCol.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Add(float64(proto.Size(req))) metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(nodeID, 10), metrics.SearchLabel).Add(float64(proto.Size(req)))
} }
return ret, nil return ret, nil
} }
func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.SearchRequest, dmlChannel string) (*internalpb.SearchResults, error) { func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.SearchRequest, dmlChannel string) (*internalpb.SearchResults, error) {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel).Inc() nodeID := node.GetSession().ServerID
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.TotalLabel).Inc()
failRet := &internalpb.SearchResults{ failRet := &internalpb.SearchResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -822,11 +828,11 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
defer func() { defer func() {
if failRet.Status.ErrorCode != commonpb.ErrorCode_Success { if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
} }
}() }()
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID()) failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
return failRet, nil return failRet, nil
} }
node.wg.Add(1) node.wg.Add(1)
@ -876,13 +882,13 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
tr.CtxElapse(ctx, fmt.Sprintf("do subsearch done, vChannel = %s, segmentIDs = %v", dmlChannel, req.GetSegmentIDs())) tr.CtxElapse(ctx, fmt.Sprintf("do subsearch done, vChannel = %s, segmentIDs = %v", dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(historicalTask.queueDur.Milliseconds())) metrics.SearchLabel).Observe(float64(historicalTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(historicalTask.reduceDur.Milliseconds())) metrics.SearchLabel).Observe(float64(historicalTask.reduceDur.Milliseconds()))
latency := tr.ElapseSpan() latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.SuccessLabel).Inc()
return historicalTask.Ret, nil return historicalTask.Ret, nil
} }
@ -923,9 +929,9 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
if err != nil { if err != nil {
return err return err
} }
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds())) metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds())) metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
streamingResult = streamingTask.Ret streamingResult = streamingTask.Ret
return nil return nil
@ -951,16 +957,17 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
tr.CtxElapse(ctx, fmt.Sprintf("do reduce done in shard cluster, vChannel = %s, segmentIDs = %v", dmlChannel, req.GetSegmentIDs())) tr.CtxElapse(ctx, fmt.Sprintf("do reduce done in shard cluster, vChannel = %s, segmentIDs = %v", dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.Leader).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq())) metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(nodeID)).Observe(float64(req.Req.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk())) metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(nodeID)).Observe(float64(req.Req.GetTopk()))
return ret, nil return ret, nil
} }
func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.QueryRequest, dmlChannel string) (*internalpb.RetrieveResults, error) { func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.QueryRequest, dmlChannel string) (*internalpb.RetrieveResults, error) {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel).Inc() nodeID := node.GetSession().ServerID
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.TotalLabel).Inc()
failRet := &internalpb.RetrieveResults{ failRet := &internalpb.RetrieveResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -972,11 +979,11 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
defer func() { defer func() {
if failRet.Status.ErrorCode != commonpb.ErrorCode_Success { if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc()
} }
}() }()
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID()) failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID)
return failRet, nil return failRet, nil
} }
node.wg.Add(1) node.wg.Add(1)
@ -1030,13 +1037,13 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(queryTask.queueDur.Milliseconds())) metrics.QueryLabel).Observe(float64(queryTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(queryTask.reduceDur.Milliseconds())) metrics.QueryLabel).Observe(float64(queryTask.reduceDur.Milliseconds()))
latency := tr.ElapseSpan() latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return queryTask.Ret, nil return queryTask.Ret, nil
} }
@ -1067,9 +1074,9 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
if err != nil { if err != nil {
return err return err
} }
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQLatencyInQueue.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds())) metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds())) metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
streamingResult = streamingTask.Ret streamingResult = streamingTask.Ret
return nil return nil
@ -1101,8 +1108,8 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
latency := tr.ElapseSpan() latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return ret, nil return ret, nil
} }
@ -1115,13 +1122,14 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()), zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp())) zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
if req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() { nodeID := node.GetSession().ServerID
if req.GetReq().GetBase().GetTargetID() != nodeID {
return &internalpb.RetrieveResults{ return &internalpb.RetrieveResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s", Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s",
paramtable.GetNodeID(), nodeID,
common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID())), common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), nodeID)),
}, },
}, nil }, nil
} }
@ -1185,7 +1193,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
if !req.FromShardLeader { if !req.FromShardLeader {
rateCol.Add(metricsinfo.NQPerSecond, 1) rateCol.Add(metricsinfo.NQPerSecond, 1)
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(nodeID, 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
} }
return ret, nil return ret, nil
} }
@ -1195,7 +1203,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
if !node.isHealthy() { if !node.isHealthy() {
return &commonpb.Status{ return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()), Reason: msgQueryNodeIsUnhealthy(node.GetSession().ServerID),
}, nil }, nil
} }
node.wg.Add(1) node.wg.Add(1)
@ -1219,16 +1227,17 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn
// ShowConfigurations returns the configurations of queryNode matching req.Pattern // ShowConfigurations returns the configurations of queryNode matching req.Pattern
func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
log.Warn("QueryNode.ShowConfigurations failed", log.Warn("QueryNode.ShowConfigurations failed",
zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("nodeId", nodeID),
zap.String("req", req.Pattern), zap.String("req", req.Pattern),
zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID()))) zap.Error(errQueryNodeIsUnhealthy(nodeID)))
return &internalpb.ShowConfigurationsResponse{ return &internalpb.ShowConfigurationsResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()), Reason: msgQueryNodeIsUnhealthy(nodeID),
}, },
Configuations: nil, Configuations: nil,
}, nil }, nil
@ -1256,16 +1265,17 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S
// GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ... // GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ...
func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
nodeID := node.GetSession().ServerID
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed", log.Ctx(ctx).Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("nodeId", nodeID),
zap.String("req", req.Request), zap.String("req", req.Request),
zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID()))) zap.Error(errQueryNodeIsUnhealthy(nodeID)))
return &milvuspb.GetMetricsResponse{ return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()), Reason: msgQueryNodeIsUnhealthy(nodeID),
}, },
Response: "", Response: "",
}, nil }, nil
@ -1276,7 +1286,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
metricType, err := metricsinfo.ParseMetricType(req.Request) metricType, err := metricsinfo.ParseMetricType(req.Request)
if err != nil { if err != nil {
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed to parse metric type", log.Ctx(ctx).Warn("QueryNode.GetMetrics failed to parse metric type",
zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("nodeId", nodeID),
zap.String("req", req.Request), zap.String("req", req.Request),
zap.Error(err)) zap.Error(err))
@ -1292,7 +1302,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
queryNodeMetrics, err := getSystemInfoMetrics(ctx, req, node) queryNodeMetrics, err := getSystemInfoMetrics(ctx, req, node)
if err != nil { if err != nil {
log.Ctx(ctx).Warn("QueryNode.GetMetrics failed", log.Ctx(ctx).Warn("QueryNode.GetMetrics failed",
zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("nodeId", nodeID),
zap.String("req", req.Request), zap.String("req", req.Request),
zap.String("metricType", metricType), zap.String("metricType", metricType),
zap.Error(err)) zap.Error(err))
@ -1307,7 +1317,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
} }
log.Ctx(ctx).RatedDebug(60, "QueryNode.GetMetrics failed, request metric type is not implemented yet", log.Ctx(ctx).RatedDebug(60, "QueryNode.GetMetrics failed, request metric type is not implemented yet",
zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("nodeId", nodeID),
zap.String("req", req.Request), zap.String("req", req.Request),
zap.String("metricType", metricType)) zap.String("metricType", metricType))
@ -1321,18 +1331,19 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR
} }
func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
nodeID := node.GetSession().ServerID
log := log.With( log := log.With(
zap.Int64("msg-id", req.GetBase().GetMsgID()), zap.Int64("msg-id", req.GetBase().GetMsgID()),
zap.Int64("node-id", paramtable.GetNodeID()), zap.Int64("node-id", nodeID),
) )
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
log.Warn("QueryNode.GetMetrics failed", log.Warn("QueryNode.GetMetrics failed",
zap.Error(errQueryNodeIsUnhealthy(paramtable.GetNodeID()))) zap.Error(errQueryNodeIsUnhealthy(nodeID)))
return &querypb.GetDataDistributionResponse{ return &querypb.GetDataDistributionResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msgQueryNodeIsUnhealthy(paramtable.GetNodeID()), Reason: msgQueryNodeIsUnhealthy(nodeID),
}, },
}, nil }, nil
} }
@ -1340,12 +1351,12 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
defer node.wg.Done() defer node.wg.Done()
// check target matches // check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() { if req.GetBase().GetTargetID() != nodeID {
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s", Reason: fmt.Sprintf("QueryNode %d can't serve, recovering: %s",
paramtable.GetNodeID(), nodeID,
common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID())), common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID)),
} }
return &querypb.GetDataDistributionResponse{Status: status}, nil return &querypb.GetDataDistributionResponse{Status: status}, nil
} }
@ -1407,7 +1418,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
return &querypb.GetDataDistributionResponse{ return &querypb.GetDataDistributionResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
NodeID: paramtable.GetNodeID(), NodeID: nodeID,
Segments: segmentVersionInfos, Segments: segmentVersionInfos,
Channels: channelVersionInfos, Channels: channelVersionInfos,
LeaderViews: leaderViews, LeaderViews: leaderViews,
@ -1416,9 +1427,10 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) { func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel())) log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel()))
nodeID := node.GetSession().ServerID
// check node healthy // check node healthy
if !node.isHealthyOrStopping() { if !node.isHealthyOrStopping() {
err := fmt.Errorf("query node %d is not ready", paramtable.GetNodeID()) err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -1429,11 +1441,11 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
defer node.wg.Done() defer node.wg.Done()
// check target matches // check target matches
if req.GetBase().GetTargetID() != paramtable.GetNodeID() { if req.GetBase().GetTargetID() != nodeID {
log.Warn("failed to do match target id when sync ", zap.Int64("expect", req.GetBase().GetTargetID()), zap.Int64("actual", node.session.ServerID)) log.Warn("failed to do match target id when sync ", zap.Int64("expect", req.GetBase().GetTargetID()), zap.Int64("actual", nodeID))
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
} }
return status, nil return status, nil
} }
@ -1476,3 +1488,17 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
Reason: "", Reason: "",
}, nil }, nil
} }
// to fix data race
func (node *QueryNode) SetSession(session *sessionutil.Session) {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
node.session = session
}
// to fix data race
func (node *QueryNode) GetSession() *sessionutil.Session {
node.sessionMu.Lock()
defer node.sessionMu.Unlock()
return node.session
}

View File

@ -100,7 +100,7 @@ func TestImpl_WatchDmChannels(t *testing.T) {
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels, MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(), MsgID: rand.Int63(),
TargetID: node.session.ServerID, TargetID: node.GetSession().ServerID,
}, },
NodeID: 0, NodeID: 0,
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
@ -187,7 +187,7 @@ func TestImpl_WatchDmChannels(t *testing.T) {
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels, MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(), MsgID: rand.Int63(),
TargetID: node.session.ServerID, TargetID: node.GetSession().ServerID,
}, },
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID}, PartitionIDs: []UniqueID{defaultPartitionID},
@ -218,7 +218,7 @@ func TestImpl_UnsubDmChannel(t *testing.T) {
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels, MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(), MsgID: rand.Int63(),
TargetID: node.session.ServerID, TargetID: node.GetSession().ServerID,
}, },
NodeID: 0, NodeID: 0,
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
@ -241,7 +241,7 @@ func TestImpl_UnsubDmChannel(t *testing.T) {
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_UnsubDmChannel, MsgType: commonpb.MsgType_UnsubDmChannel,
MsgID: rand.Int63(), MsgID: rand.Int63(),
TargetID: node.session.ServerID, TargetID: node.GetSession().ServerID,
}, },
NodeID: 0, NodeID: 0,
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
@ -299,7 +299,7 @@ func TestImpl_LoadSegments(t *testing.T) {
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels, MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(), MsgID: rand.Int63(),
TargetID: node.session.ServerID, TargetID: node.GetSession().ServerID,
}, },
DstNodeID: 0, DstNodeID: 0,
Schema: schema, Schema: schema,
@ -540,11 +540,11 @@ func TestImpl_ShowConfigurations(t *testing.T) {
t.Run("test ShowConfigurations", func(t *testing.T) { t.Run("test ShowConfigurations", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx) node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err) assert.NoError(t, err)
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
pattern := "Cache" pattern := "Cache"
req := &internalpb.ShowConfigurationsRequest{ req := &internalpb.ShowConfigurationsRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.GetSession().ServerID),
Pattern: pattern, Pattern: pattern,
} }
@ -556,12 +556,12 @@ func TestImpl_ShowConfigurations(t *testing.T) {
t.Run("test ShowConfigurations node failed", func(t *testing.T) { t.Run("test ShowConfigurations node failed", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx) node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err) assert.NoError(t, err)
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
node.UpdateStateCode(commonpb.StateCode_Abnormal) node.UpdateStateCode(commonpb.StateCode_Abnormal)
pattern := "Cache" pattern := "Cache"
req := &internalpb.ShowConfigurationsRequest{ req := &internalpb.ShowConfigurationsRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.GetSession().ServerID),
Pattern: pattern, Pattern: pattern,
} }
@ -592,7 +592,7 @@ func TestImpl_GetMetrics(t *testing.T) {
defer wg.Done() defer wg.Done()
node, err := genSimpleQueryNode(ctx) node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err) assert.NoError(t, err)
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
metricReq := make(map[string]string) metricReq := make(map[string]string)
metricReq[metricsinfo.MetricTypeKey] = "system_info" metricReq[metricsinfo.MetricTypeKey] = "system_info"
@ -644,7 +644,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
defer node.Stop() defer node.Stop()
req := &queryPb.ReleaseSegmentsRequest{ req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID}, PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID}, SegmentIDs: []UniqueID{defaultSegmentID},
@ -669,7 +669,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
defer node.Stop() defer node.Stop()
req := &queryPb.ReleaseSegmentsRequest{ req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID}, PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID}, SegmentIDs: []UniqueID{defaultSegmentID},
@ -704,7 +704,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
defer node.Stop() defer node.Stop()
req := &queryPb.ReleaseSegmentsRequest{ req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID}, PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID}, SegmentIDs: []UniqueID{defaultSegmentID},
@ -725,7 +725,7 @@ func TestImpl_ReleaseSegments(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
req := &queryPb.ReleaseSegmentsRequest{ req := &queryPb.ReleaseSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_ReleaseSegments, node.GetSession().ServerID),
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
} }
@ -1056,7 +1056,7 @@ func TestSyncDistribution(t *testing.T) {
defer node.Stop() defer node.Stop()
resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{ resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{TargetID: node.session.ServerID}, Base: &commonpb.MsgBase{TargetID: node.GetSession().ServerID},
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
Channel: defaultDMLChannel, Channel: defaultDMLChannel,
Actions: []*querypb.SyncAction{ Actions: []*querypb.SyncAction{
@ -1086,7 +1086,7 @@ func TestSyncDistribution(t *testing.T) {
cs.SetupFirstVersion() cs.SetupFirstVersion()
resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{ resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{TargetID: node.session.ServerID}, Base: &commonpb.MsgBase{TargetID: node.GetSession().ServerID},
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
Channel: defaultDMLChannel, Channel: defaultDMLChannel,
Actions: []*querypb.SyncAction{ Actions: []*querypb.SyncAction{
@ -1109,7 +1109,7 @@ func TestSyncDistribution(t *testing.T) {
assert.Equal(t, segmentStateLoaded, segment.state) assert.Equal(t, segmentStateLoaded, segment.state)
assert.EqualValues(t, 1, segment.version) assert.EqualValues(t, 1, segment.version)
resp, err = node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{ resp, err = node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{TargetID: node.session.ServerID}, Base: &commonpb.MsgBase{TargetID: node.GetSession().ServerID},
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
Channel: defaultDMLChannel, Channel: defaultDMLChannel,
Actions: []*querypb.SyncAction{ Actions: []*querypb.SyncAction{

View File

@ -38,7 +38,7 @@ type ImplUtilsSuite struct {
func (s *ImplUtilsSuite) SetupSuite() { func (s *ImplUtilsSuite) SetupSuite() {
s.querynode = newQueryNodeMock() s.querynode = newQueryNodeMock()
client := v3client.New(embedetcdServer.Server) client := v3client.New(embedetcdServer.Server)
s.querynode.session = sessionutil.NewSession(context.Background(), "milvus_ut/sessions", client) s.querynode.SetSession(sessionutil.NewSession(context.Background(), "milvus_ut/sessions", client))
s.querynode.UpdateStateCode(commonpb.StateCode_Healthy) s.querynode.UpdateStateCode(commonpb.StateCode_Healthy)
s.querynode.ShardClusterService = newShardClusterService(client, s.querynode.session, s.querynode) s.querynode.ShardClusterService = newShardClusterService(client, s.querynode.session, s.querynode)
@ -52,8 +52,8 @@ func (s *ImplUtilsSuite) SetupTest() {
nodeEvent := []nodeEvent{ nodeEvent := []nodeEvent{
{ {
nodeID: s.querynode.session.ServerID, nodeID: s.querynode.GetSession().ServerID,
nodeAddr: s.querynode.session.ServerName, nodeAddr: s.querynode.GetSession().ServerName,
isLeader: true, isLeader: true,
}, },
} }
@ -75,9 +75,9 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("normal transfer load", func() { s.Run("normal transfer load", func() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{ status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID, TargetID: s.querynode.GetSession().ServerID,
}, },
DstNodeID: s.querynode.session.ServerID, DstNodeID: s.querynode.GetSession().ServerID,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
{ {
SegmentID: defaultSegmentID, SegmentID: defaultSegmentID,
@ -95,9 +95,9 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("transfer non-exist channel load", func() { s.Run("transfer non-exist channel load", func() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{ status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID, TargetID: s.querynode.GetSession().ServerID,
}, },
DstNodeID: s.querynode.session.ServerID, DstNodeID: s.querynode.GetSession().ServerID,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
{ {
SegmentID: defaultSegmentID, SegmentID: defaultSegmentID,
@ -115,9 +115,9 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
s.Run("transfer empty load segments", func() { s.Run("transfer empty load segments", func() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{ status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID, TargetID: s.querynode.GetSession().ServerID,
}, },
DstNodeID: s.querynode.session.ServerID, DstNodeID: s.querynode.GetSession().ServerID,
Infos: []*querypb.SegmentLoadInfo{}, Infos: []*querypb.SegmentLoadInfo{},
}) })
@ -141,7 +141,7 @@ func (s *ImplUtilsSuite) TestTransferLoad() {
status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{ status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID, TargetID: s.querynode.GetSession().ServerID,
}, },
DstNodeID: 100, DstNodeID: 100,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
@ -197,12 +197,12 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
s.Run("normal transfer release", func() { s.Run("normal transfer release", func() {
status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{ status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID, TargetID: s.querynode.GetSession().ServerID,
}, },
SegmentIDs: []int64{}, SegmentIDs: []int64{},
Scope: querypb.DataScope_All, Scope: querypb.DataScope_All,
Shard: defaultChannelName, Shard: defaultChannelName,
NodeID: s.querynode.session.ServerID, NodeID: s.querynode.GetSession().ServerID,
}) })
s.NoError(err) s.NoError(err)
@ -212,12 +212,12 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
s.Run("transfer non-exist channel release", func() { s.Run("transfer non-exist channel release", func() {
status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{ status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID, TargetID: s.querynode.GetSession().ServerID,
}, },
SegmentIDs: []int64{}, SegmentIDs: []int64{},
Scope: querypb.DataScope_All, Scope: querypb.DataScope_All,
Shard: "invalid_channel", Shard: "invalid_channel",
NodeID: s.querynode.session.ServerID, NodeID: s.querynode.GetSession().ServerID,
}) })
s.NoError(err) s.NoError(err)
@ -239,7 +239,7 @@ func (s *ImplUtilsSuite) TestTransferRelease() {
status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{ status, err := s.querynode.TransferRelease(ctx, &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: s.querynode.session.ServerID, TargetID: s.querynode.GetSession().ServerID,
}, },
SegmentIDs: []int64{}, SegmentIDs: []int64{},
Scope: querypb.DataScope_All, Scope: querypb.DataScope_All,

View File

@ -89,7 +89,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema, Schema: schema,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
{ {
@ -117,7 +117,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema, Schema: schema,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
{ {
@ -210,7 +210,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
} }
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema, Schema: schema,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo}, Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
DeltaPositions: []*internalpb.MsgPosition{ DeltaPositions: []*internalpb.MsgPosition{
@ -305,7 +305,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema, Schema: schema,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
{ {
@ -360,7 +360,7 @@ func TestTask_loadSegmentsTask(t *testing.T) {
segmentID1 := defaultSegmentID segmentID1 := defaultSegmentID
segmentID2 := defaultSegmentID + 1 segmentID2 := defaultSegmentID + 1
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema, Schema: schema,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
{ {
@ -428,7 +428,7 @@ func TestTask_loadSegmentsTaskLoadDelta(t *testing.T) {
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
} }
loadReq := &querypb.LoadSegmentsRequest{ loadReq := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema, Schema: schema,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo}, Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
DeltaPositions: []*internalpb.MsgPosition{ DeltaPositions: []*internalpb.MsgPosition{
@ -458,7 +458,7 @@ func TestTask_loadSegmentsTaskLoadDelta(t *testing.T) {
// load second segments with same channel // load second segments with same channel
loadReq = &querypb.LoadSegmentsRequest{ loadReq = &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.GetSession().ServerID),
Schema: schema, Schema: schema,
Infos: []*querypb.SegmentLoadInfo{ Infos: []*querypb.SegmentLoadInfo{
{ {

View File

@ -81,7 +81,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest,
}, nil }, nil
} }
hardwareInfos := metricsinfo.HardwareMetrics{ hardwareInfos := metricsinfo.HardwareMetrics{
IP: node.session.Address, IP: node.GetSession().Address,
CPUCoreCount: hardware.GetCPUNum(), CPUCoreCount: hardware.GetCPUNum(),
CPUCoreUsage: hardware.GetCPUUsage(), CPUCoreUsage: hardware.GetCPUUsage(),
Memory: totalMem, Memory: totalMem,
@ -99,7 +99,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest,
CreatedTime: paramtable.GetCreateTime().String(), CreatedTime: paramtable.GetCreateTime().String(),
UpdatedTime: paramtable.GetUpdateTime().String(), UpdatedTime: paramtable.GetUpdateTime().String(),
Type: typeutil.QueryNodeRole, Type: typeutil.QueryNodeRole,
ID: node.session.ServerID, ID: node.GetSession().ServerID,
}, },
SystemConfigurations: metricsinfo.QueryNodeConfiguration{ SystemConfigurations: metricsinfo.QueryNodeConfiguration{
SimdType: Params.CommonCfg.SimdType.GetValue(), SimdType: Params.CommonCfg.SimdType.GetValue(),

View File

@ -48,10 +48,10 @@ func TestGetSystemInfoMetrics(t *testing.T) {
Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
assert.NoError(t, err) assert.NoError(t, err)
defer etcdCli.Close() defer etcdCli.Close()
node.session = sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) node.SetSession(sessionutil.NewSession(node.queryNodeLoopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli))
req := &milvuspb.GetMetricsRequest{ req := &milvuspb.GetMetricsRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.session.ServerID), Base: genCommonMsgBase(commonpb.MsgType_WatchQueryChannels, node.GetSession().ServerID),
} }
resp, err := getSystemInfoMetrics(ctx, req, node) resp, err := getSystemInfoMetrics(ctx, req, node)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -1741,7 +1741,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory
} }
// init shard cluster service // init shard cluster service
node.ShardClusterService = newShardClusterService(node.etcdCli, node.session, node) node.ShardClusterService = newShardClusterService(node.etcdCli, node.GetSession(), node)
node.queryShardService, err = newQueryShardService(node.queryNodeLoopCtx, node.queryShardService, err = newQueryShardService(node.queryNodeLoopCtx,
node.metaReplica, node.tSafeReplica, node.metaReplica, node.tSafeReplica,

View File

@ -110,8 +110,9 @@ type QueryNode struct {
factory dependency.Factory factory dependency.Factory
scheduler *taskScheduler scheduler *taskScheduler
session *sessionutil.Session sessionMu sync.Mutex
eventCh <-chan *sessionutil.SessionEvent session *sessionutil.Session
eventCh <-chan *sessionutil.SessionEvent
vectorStorage storage.ChunkManager vectorStorage storage.ChunkManager
etcdKV *etcdkv.EtcdKV etcdKV *etcdkv.EtcdKV
@ -393,3 +394,7 @@ func (node *QueryNode) SetEtcdClient(client *clientv3.Client) {
func (node *QueryNode) SetAddress(address string) { func (node *QueryNode) SetAddress(address string) {
node.address = address node.address = address
} }
func (node *QueryNode) GetAddress() string {
return node.address
}

View File

@ -197,6 +197,9 @@ func TestQueryNode_init(t *testing.T) {
node.SetEtcdClient(etcdcli) node.SetEtcdClient(etcdcli)
err = node.Init() err = node.Init()
assert.Nil(t, err) assert.Nil(t, err)
assert.Empty(t, node.GetAddress())
node.SetAddress("address")
assert.Equal(t, "address", node.GetAddress())
} }
func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, error) { func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, error) {

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
@ -34,7 +35,21 @@ import (
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
) )
type proxyCreator func(sess *sessionutil.Session) (types.Proxy, error) type proxyCreator func(ctx context.Context, addr string) (types.Proxy, error)
func DefaultProxyCreator(ctx context.Context, addr string) (types.Proxy, error) {
cli, err := grpcproxyclient.NewClient(ctx, addr)
if err != nil {
return nil, err
}
if err := cli.Init(); err != nil {
return nil, err
}
if err := cli.Start(); err != nil {
return nil, err
}
return cli, nil
}
type proxyClientManager struct { type proxyClientManager struct {
creator proxyCreator creator proxyCreator
@ -85,7 +100,7 @@ func (p *proxyClientManager) GetProxyCount() int {
} }
func (p *proxyClientManager) connect(session *sessionutil.Session) { func (p *proxyClientManager) connect(session *sessionutil.Session) {
pc, err := p.creator(session) pc, err := p.creator(context.Background(), session.Address)
if err != nil { if err != nil {
log.Warn("failed to create proxy client", zap.String("address", session.Address), zap.Int64("serverID", session.ServerID), zap.Error(err)) log.Warn("failed to create proxy client", zap.String("address", session.Address), zap.Int64("serverID", session.ServerID), zap.Error(err))
return return

View File

@ -117,7 +117,7 @@ func TestProxyClientManager_GetProxyClients(t *testing.T) {
defer cli.Close() defer cli.Close()
assert.Nil(t, err) assert.Nil(t, err)
core.etcdCli = cli core.etcdCli = cli
core.proxyCreator = func(se *sessionutil.Session) (types.Proxy, error) { core.proxyCreator = func(ctx context.Context, addr string) (types.Proxy, error) {
return nil, errors.New("failed") return nil, errors.New("failed")
} }
@ -149,7 +149,7 @@ func TestProxyClientManager_AddProxyClient(t *testing.T) {
defer cli.Close() defer cli.Close()
core.etcdCli = cli core.etcdCli = cli
core.proxyCreator = func(se *sessionutil.Session) (types.Proxy, error) { core.proxyCreator = func(ctx context.Context, addr string) (types.Proxy, error) {
return nil, errors.New("failed") return nil, errors.New("failed")
} }

View File

@ -36,7 +36,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
pnc "github.com/milvus-io/milvus/internal/distributed/proxy/client"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
@ -150,19 +149,7 @@ func NewCore(c context.Context, factory dependency.Factory) (*Core, error) {
} }
core.UpdateStateCode(commonpb.StateCode_Abnormal) core.UpdateStateCode(commonpb.StateCode_Abnormal)
core.proxyCreator = func(se *sessionutil.Session) (types.Proxy, error) { core.SetProxyCreator(DefaultProxyCreator)
cli, err := pnc.NewClient(c, se.Address)
if err != nil {
return nil, err
}
if err := cli.Init(); err != nil {
return nil, err
}
if err := cli.Start(); err != nil {
return nil, err
}
return cli, nil
}
return core, nil return core, nil
} }
@ -263,23 +250,21 @@ func (c *Core) tsLoop() {
} }
} }
func (c *Core) SetDataCoord(ctx context.Context, s types.DataCoord) error { func (c *Core) SetProxyCreator(f func(ctx context.Context, addr string) (types.Proxy, error)) {
if err := s.Init(); err != nil { c.proxyCreator = f
return err }
}
if err := s.Start(); err != nil { func (c *Core) SetDataCoord(s types.DataCoord) error {
return err if s == nil {
return errors.New("null DataCoord interface")
} }
c.dataCoord = s c.dataCoord = s
return nil return nil
} }
func (c *Core) SetQueryCoord(s types.QueryCoord) error { func (c *Core) SetQueryCoord(s types.QueryCoord) error {
if err := s.Init(); err != nil { if s == nil {
return err return errors.New("null QueryCoord interface")
}
if err := s.Start(); err != nil {
return err
} }
c.queryCoord = s c.queryCoord = s
return nil return nil

View File

@ -52,6 +52,8 @@ type Component interface {
GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error)
GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error)
Register() error Register() error
//SetAddress(address string)
//GetAddress() string
} }
// DataNode is the interface `datanode` package implements // DataNode is the interface `datanode` package implements
@ -112,6 +114,7 @@ type DataNodeComponent interface {
GetStateCode() commonpb.StateCode GetStateCode() commonpb.StateCode
SetAddress(address string) SetAddress(address string)
GetAddress() string
// SetEtcdClient set etcd client for DataNode // SetEtcdClient set etcd client for DataNode
SetEtcdClient(etcdClient *clientv3.Client) SetEtcdClient(etcdClient *clientv3.Client)
@ -370,6 +373,14 @@ type DataCoordComponent interface {
// SetEtcdClient set EtcdClient for DataCoord // SetEtcdClient set EtcdClient for DataCoord
// `etcdClient` is a client of etcd // `etcdClient` is a client of etcd
SetEtcdClient(etcdClient *clientv3.Client) SetEtcdClient(etcdClient *clientv3.Client)
SetRootCoord(rootCoord RootCoord)
// SetDataNodeCreator set DataNode client creator func for DataCoord
SetDataNodeCreator(func(context.Context, string) (DataNode, error))
//SetIndexNodeCreator set Index client creator func for DataCoord
SetIndexNodeCreator(func(context.Context, string) (IndexNode, error))
} }
// IndexNode is the interface `indexnode` package implements // IndexNode is the interface `indexnode` package implements
@ -406,7 +417,7 @@ type IndexNodeComponent interface {
IndexNode IndexNode
SetAddress(address string) SetAddress(address string)
GetAddress() string
// SetEtcdClient set etcd client for IndexNodeComponent // SetEtcdClient set etcd client for IndexNodeComponent
SetEtcdClient(etcdClient *clientv3.Client) SetEtcdClient(etcdClient *clientv3.Client)
@ -763,10 +774,9 @@ type RootCoordComponent interface {
// SetDataCoord set DataCoord for RootCoord // SetDataCoord set DataCoord for RootCoord
// `dataCoord` is a client of data coordinator. // `dataCoord` is a client of data coordinator.
// `ctx` is the context pass to DataCoord api.
// //
// Always return nil. // Always return nil.
SetDataCoord(ctx context.Context, dataCoord DataCoord) error SetDataCoord(dataCoord DataCoord) error
// SetQueryCoord set QueryCoord for RootCoord // SetQueryCoord set QueryCoord for RootCoord
// `queryCoord` is a client of query coordinator. // `queryCoord` is a client of query coordinator.
@ -774,6 +784,9 @@ type RootCoordComponent interface {
// Always return nil. // Always return nil.
SetQueryCoord(queryCoord QueryCoord) error SetQueryCoord(queryCoord QueryCoord) error
// SetProxyCreator set Proxy client creator func for RootCoord
SetProxyCreator(func(ctx context.Context, addr string) (Proxy, error))
// GetMetrics notifies RootCoordComponent to collect metrics for specified component // GetMetrics notifies RootCoordComponent to collect metrics for specified component
GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)
} }
@ -826,6 +839,7 @@ type ProxyComponent interface {
Proxy Proxy
SetAddress(address string) SetAddress(address string)
GetAddress() string
// SetEtcdClient set EtcdClient for Proxy // SetEtcdClient set EtcdClient for Proxy
// `etcdClient` is a client of etcd // `etcdClient` is a client of etcd
SetEtcdClient(etcdClient *clientv3.Client) SetEtcdClient(etcdClient *clientv3.Client)
@ -846,6 +860,9 @@ type ProxyComponent interface {
// `queryCoord` is a client of query coordinator. // `queryCoord` is a client of query coordinator.
SetQueryCoordClient(queryCoord QueryCoord) SetQueryCoordClient(queryCoord QueryCoord)
// SetQueryNodeCreator set QueryNode client creator func for Proxy
SetQueryNodeCreator(func(ctx context.Context, addr string) (QueryNode, error))
// GetRateLimiter returns the rateLimiter in Proxy // GetRateLimiter returns the rateLimiter in Proxy
GetRateLimiter() (Limiter, error) GetRateLimiter() (Limiter, error)
@ -1326,6 +1343,7 @@ type QueryNodeComponent interface {
UpdateStateCode(stateCode commonpb.StateCode) UpdateStateCode(stateCode commonpb.StateCode)
SetAddress(address string) SetAddress(address string)
GetAddress() string
// SetEtcdClient set etcd client for QueryNode // SetEtcdClient set etcd client for QueryNode
SetEtcdClient(etcdClient *clientv3.Client) SetEtcdClient(etcdClient *clientv3.Client)
@ -1385,4 +1403,7 @@ type QueryCoordComponent interface {
// Return nil in status: // Return nil in status:
// The rootCoord is not nil. // The rootCoord is not nil.
SetRootCoord(rootCoord RootCoord) error SetRootCoord(rootCoord RootCoord) error
// SetQueryNodeCreator set QueryNode client creator func for QueryCoord
SetQueryNodeCreator(func(ctx context.Context, addr string) (QueryNode, error))
} }

View File

@ -0,0 +1,73 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package componentutil
import (
"context"
"errors"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/retry"
)
// WaitForComponentStates wait for component's state to be one of the specific states
func WaitForComponentStates(ctx context.Context, service types.Component, serviceName string, states []commonpb.StateCode, attempts uint, sleep time.Duration) error {
checkFunc := func() error {
resp, err := service.GetComponentStates(ctx)
if err != nil {
return err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(resp.Status.Reason)
}
meet := false
for _, state := range states {
if resp.State.StateCode == state {
meet = true
break
}
}
if !meet {
return fmt.Errorf(
"WaitForComponentStates, not meet, %s current state: %s",
serviceName,
resp.State.StateCode.String())
}
return nil
}
return retry.Do(ctx, checkFunc, retry.Attempts(attempts), retry.Sleep(sleep))
}
// WaitForComponentInitOrHealthy wait for component's state to be initializing or healthy
func WaitForComponentInitOrHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}, attempts, sleep)
}
// WaitForComponentInit wait for component's state to be initializing
func WaitForComponentInit(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing}, attempts, sleep)
}
// WaitForComponentHealthy wait for component's state to be healthy
func WaitForComponentHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Healthy}, attempts, sleep)
}

View File

@ -0,0 +1,133 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package componentutil
import (
"context"
"errors"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
)
type MockComponent struct {
compState *milvuspb.ComponentStates
strResp *milvuspb.StringResponse
compErr error
}
func (mc *MockComponent) SetCompState(state *milvuspb.ComponentStates) {
mc.compState = state
}
func (mc *MockComponent) SetStrResp(resp *milvuspb.StringResponse) {
mc.strResp = resp
}
func (mc *MockComponent) Init() error {
return nil
}
func (mc *MockComponent) Start() error {
return nil
}
func (mc *MockComponent) Stop() error {
return nil
}
func (mc *MockComponent) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return mc.compState, mc.compErr
}
func (mc *MockComponent) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return mc.strResp, nil
}
func (mc *MockComponent) Register() error {
return nil
}
func buildMockComponent(code commonpb.StateCode) *MockComponent {
mc := &MockComponent{
compState: &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: "role",
StateCode: code,
},
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
},
strResp: nil,
compErr: nil,
}
return mc
}
func Test_WaitForComponentInitOrHealthy(t *testing.T) {
mc := &MockComponent{
compState: nil,
strResp: nil,
compErr: errors.New("error"),
}
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
mc = &MockComponent{
compState: &milvuspb.ComponentStates{
State: nil,
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
},
strResp: nil,
compErr: nil,
}
err = WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if funcutil.SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_WaitForComponentInit(t *testing.T) {
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInit(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if funcutil.SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}

View File

@ -39,7 +39,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/retry"
) )
@ -67,51 +66,6 @@ func GetLocalIP() string {
return "127.0.0.1" return "127.0.0.1"
} }
// WaitForComponentStates wait for component's state to be one of the specific states
func WaitForComponentStates(ctx context.Context, service types.Component, serviceName string, states []commonpb.StateCode, attempts uint, sleep time.Duration) error {
checkFunc := func() error {
resp, err := service.GetComponentStates(ctx)
if err != nil {
return err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(resp.Status.Reason)
}
meet := false
for _, state := range states {
if resp.State.StateCode == state {
meet = true
break
}
}
if !meet {
return fmt.Errorf(
"WaitForComponentStates, not meet, %s current state: %s",
serviceName,
resp.State.StateCode.String())
}
return nil
}
return retry.Do(ctx, checkFunc, retry.Attempts(attempts), retry.Sleep(sleep))
}
// WaitForComponentInitOrHealthy wait for component's state to be initializing or healthy
func WaitForComponentInitOrHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}, attempts, sleep)
}
// WaitForComponentInit wait for component's state to be initializing
func WaitForComponentInit(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing}, attempts, sleep)
}
// WaitForComponentHealthy wait for component's state to be healthy
func WaitForComponentHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error {
return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Healthy}, attempts, sleep)
}
// JSONToMap parse the jsonic index parameters to map // JSONToMap parse the jsonic index parameters to map
func JSONToMap(mStr string) (map[string]string, error) { func JSONToMap(mStr string) (map[string]string, error) {
buffer := make(map[string]any) buffer := make(map[string]any)

View File

@ -35,62 +35,6 @@ import (
grpcStatus "google.golang.org/grpc/status" grpcStatus "google.golang.org/grpc/status"
) )
type MockComponent struct {
compState *milvuspb.ComponentStates
strResp *milvuspb.StringResponse
compErr error
}
func (mc *MockComponent) SetCompState(state *milvuspb.ComponentStates) {
mc.compState = state
}
func (mc *MockComponent) SetStrResp(resp *milvuspb.StringResponse) {
mc.strResp = resp
}
func (mc *MockComponent) Init() error {
return nil
}
func (mc *MockComponent) Start() error {
return nil
}
func (mc *MockComponent) Stop() error {
return nil
}
func (mc *MockComponent) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {
return mc.compState, mc.compErr
}
func (mc *MockComponent) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return mc.strResp, nil
}
func (mc *MockComponent) Register() error {
return nil
}
func buildMockComponent(code commonpb.StateCode) *MockComponent {
mc := &MockComponent{
compState: &milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
NodeID: 0,
Role: "role",
StateCode: code,
},
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
},
strResp: nil,
compErr: nil,
}
return mc
}
func Test_CheckGrpcReady(t *testing.T) { func Test_CheckGrpcReady(t *testing.T) {
errChan := make(chan error) errChan := make(chan error)
@ -112,68 +56,6 @@ func Test_GetLocalIP(t *testing.T) {
assert.NotZero(t, len(ip)) assert.NotZero(t, len(ip))
} }
func Test_WaitForComponentInitOrHealthy(t *testing.T) {
mc := &MockComponent{
compState: nil,
strResp: nil,
compErr: errors.New("error"),
}
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
mc = &MockComponent{
compState: &milvuspb.ComponentStates{
State: nil,
SubcomponentStates: nil,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
},
strResp: nil,
compErr: nil,
}
err = WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
assert.NotNil(t, err)
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInitOrHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_WaitForComponentInit(t *testing.T) {
validCodes := []commonpb.StateCode{commonpb.StateCode_Initializing}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentInit(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_WaitForComponentHealthy(t *testing.T) {
validCodes := []commonpb.StateCode{commonpb.StateCode_Healthy}
testCodes := []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal}
for _, code := range testCodes {
mc := buildMockComponent(code)
err := WaitForComponentHealthy(context.TODO(), mc, "mockService", 1, 10*time.Millisecond)
if SliceContain(validCodes, code) {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
}
}
func Test_ParseIndexParamsMap(t *testing.T) { func Test_ParseIndexParamsMap(t *testing.T) {
num := 10 num := 10
keys := make([]string, 0) keys := make([]string, 0)

View File

@ -83,6 +83,8 @@ type ComponentParam struct {
DataNodeGrpcClientCfg GrpcClientConfig DataNodeGrpcClientCfg GrpcClientConfig
IndexCoordGrpcClientCfg GrpcClientConfig IndexCoordGrpcClientCfg GrpcClientConfig
IndexNodeGrpcClientCfg GrpcClientConfig IndexNodeGrpcClientCfg GrpcClientConfig
IntegrationTestCfg integrationTestConfig
} }
// InitOnce initialize once // InitOnce initialize once
@ -126,6 +128,8 @@ func (p *ComponentParam) Init() {
p.DataCoordGrpcClientCfg.Init(typeutil.DataCoordRole, &p.BaseTable) p.DataCoordGrpcClientCfg.Init(typeutil.DataCoordRole, &p.BaseTable)
p.DataNodeGrpcClientCfg.Init(typeutil.DataNodeRole, &p.BaseTable) p.DataNodeGrpcClientCfg.Init(typeutil.DataNodeRole, &p.BaseTable)
p.IndexNodeGrpcClientCfg.Init(typeutil.IndexNodeRole, &p.BaseTable) p.IndexNodeGrpcClientCfg.Init(typeutil.IndexNodeRole, &p.BaseTable)
p.IntegrationTestCfg.init(&p.BaseTable)
} }
func (p *ComponentParam) RocksmqEnable() bool { func (p *ComponentParam) RocksmqEnable() bool {
@ -1732,3 +1736,17 @@ func (p *indexNodeConfig) init(base *BaseTable) {
} }
p.GracefulStopTimeout.Init(base.mgr) p.GracefulStopTimeout.Init(base.mgr)
} }
type integrationTestConfig struct {
IntegrationMode ParamItem `refreshable:"false"`
}
func (p *integrationTestConfig) init(base *BaseTable) {
p.IntegrationMode = ParamItem{
Key: "integration.test.mode",
Version: "2.2.0",
DefaultValue: "false",
PanicIfEmpty: true,
}
p.IntegrationMode.Init(base.mgr)
}

View File

@ -209,6 +209,11 @@ func NewSession(ctx context.Context, metaRoot string, client *clientv3.Client, o
reuseNodeID: true, reuseNodeID: true,
} }
// integration test create cluster with different nodeId in one process
if paramtable.Get().IntegrationTestCfg.IntegrationMode.GetAsBool() {
session.reuseNodeID = false
}
session.apply(opts...) session.apply(opts...)
session.UpdateRegistered(false) session.UpdateRegistered(false)

View File

@ -709,3 +709,28 @@ func TestSession_apply(t *testing.T) {
assert.Equal(t, int64(100), session.sessionTTL) assert.Equal(t, int64(100), session.sessionTTL)
assert.Equal(t, int64(200), session.sessionRetryTimes) assert.Equal(t, int64(200), session.sessionRetryTimes)
} }
func TestIntegrationMode(t *testing.T) {
ctx := context.Background()
params := paramtable.Get()
params.Init()
params.Save(params.IntegrationTestCfg.IntegrationMode.Key, "true")
endpoints := params.GetWithDefault("etcd.endpoints", paramtable.DefaultEtcdEndpoints)
metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot)
etcdEndpoints := strings.Split(endpoints, ",")
etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints)
require.NoError(t, err)
etcdKV := etcdkv.NewEtcdKV(etcdCli, metaRoot)
err = etcdKV.RemoveWithPrefix("")
assert.NoError(t, err)
s1 := NewSession(ctx, metaRoot, etcdCli)
assert.Equal(t, false, s1.reuseNodeID)
s2 := NewSession(ctx, metaRoot, etcdCli)
assert.Equal(t, false, s2.reuseNodeID)
s1.Init("inittest1", "testAddr1", false, false)
s1.Init("inittest2", "testAddr2", false, false)
assert.NotEqual(t, s1.ServerID, s2.ServerID)
}

View File

@ -0,0 +1,372 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"math/rand"
"strconv"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestHelloMilvus(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
prefix := "TestHelloMilvus"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionStatus, err := c.proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
assert.NoError(t, err)
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
}
assert.Equal(t, createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
assert.NoError(t, err)
assert.Equal(t, showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
insertResult, err := c.proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
assert.Equal(t, insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
assert.NoError(t, err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
segments, err := c.metaWatcher.ShowSegments()
assert.NoError(t, err)
assert.NotEmpty(t, segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// load
loadStatus, err := c.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
for {
select {
case <-ctx.Done():
errors.New("context deadline exceeded")
default:
}
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
// search
expr := fmt.Sprintf("%s > 0", "int64")
nq := 10
topk := 10
roundDecimal := -1
nprobe := 10
searchReq := constructSearchRequest("", collectionName, expr,
floatVecField, nq, dim, nprobe, topk, roundDecimal)
searchResult, err := c.proxy.Search(ctx, searchReq)
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
}
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
log.Info("TestHelloMilvus succeed")
}
const (
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
LimitKey = "limit"
)
func constructSearchRequest(
dbName, collectionName string,
expr string,
floatVecField string,
nq, dim, nprobe, topk, roundDecimal int,
) *milvuspb.SearchRequest {
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
if err != nil {
panic(err)
}
plg := constructPlaceholderGroup(nq, dim)
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: strconv.Itoa(topk),
},
{
Key: RoundDecimalKey,
Value: strconv.Itoa(roundDecimal),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
func constructPlaceholderGroup(
nq, dim int,
) *commonpb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
if err != nil {
panic(err)
}
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(numRows, dim),
},
},
},
},
}
}
func generateFloatVectors(numRows, dim int) []float32 {
total := numRows * dim
ret := make([]float32, 0, total)
for i := 0; i < total; i++ {
ret = append(ret, rand.Float32())
}
return ret
}
func generateHashKeys(numRows int) []uint32 {
ret := make([]uint32, 0, numRows)
for i := 0; i < numRows; i++ {
ret = append(ret, rand.Uint32())
}
return ret
}

View File

@ -0,0 +1,141 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"encoding/json"
"fmt"
"path"
"sort"
"time"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/util/sessionutil"
clientv3 "go.etcd.io/etcd/client/v3"
)
// MetaWatcher to observe meta data of milvus cluster
type MetaWatcher interface {
ShowSessions() ([]*sessionutil.Session, error)
ShowSegments() ([]*datapb.SegmentInfo, error)
ShowReplicas() ([]*milvuspb.ReplicaInfo, error)
}
type EtcdMetaWatcher struct {
MetaWatcher
rootPath string
etcdCli *clientv3.Client
}
func (watcher *EtcdMetaWatcher) ShowSessions() ([]*sessionutil.Session, error) {
metaPath := watcher.rootPath + "/meta/session"
return listSessionsByPrefix(watcher.etcdCli, metaPath)
}
func (watcher *EtcdMetaWatcher) ShowSegments() ([]*datapb.SegmentInfo, error) {
metaBasePath := path.Join(watcher.rootPath, "/meta/datacoord-meta/s/")
return listSegments(watcher.etcdCli, metaBasePath, func(s *datapb.SegmentInfo) bool {
return true
})
}
func (watcher *EtcdMetaWatcher) ShowReplicas() ([]*milvuspb.ReplicaInfo, error) {
metaBasePath := path.Join(watcher.rootPath, "/meta/querycoord-replica/")
return listReplicas(watcher.etcdCli, metaBasePath)
}
//=================== Below largely copied from birdwatcher ========================
// listSessions returns all session
func listSessionsByPrefix(cli *clientv3.Client, prefix string) ([]*sessionutil.Session, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp, err := cli.Get(ctx, prefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
sessions := make([]*sessionutil.Session, 0, len(resp.Kvs))
for _, kv := range resp.Kvs {
session := &sessionutil.Session{}
err := json.Unmarshal(kv.Value, session)
if err != nil {
continue
}
sessions = append(sessions, session)
}
return sessions, nil
}
func listSegments(cli *clientv3.Client, prefix string, filter func(*datapb.SegmentInfo) bool) ([]*datapb.SegmentInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp, err := cli.Get(ctx, prefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
segments := make([]*datapb.SegmentInfo, 0, len(resp.Kvs))
for _, kv := range resp.Kvs {
info := &datapb.SegmentInfo{}
err = proto.Unmarshal(kv.Value, info)
if err != nil {
continue
}
if filter == nil || filter(info) {
segments = append(segments, info)
}
}
sort.Slice(segments, func(i, j int) bool {
return segments[i].GetID() < segments[j].GetID()
})
return segments, nil
}
func listReplicas(cli *clientv3.Client, prefix string) ([]*milvuspb.ReplicaInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp, err := cli.Get(ctx, prefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
replicas := make([]*milvuspb.ReplicaInfo, 0, len(resp.Kvs))
for _, kv := range resp.Kvs {
replica := &milvuspb.ReplicaInfo{}
if err != proto.Unmarshal(kv.Value, replica) {
continue
}
replicas = append(replicas, replica)
}
return replicas, nil
}
func PrettyReplica(replica *milvuspb.ReplicaInfo) string {
res := fmt.Sprintf("ReplicaID: %d CollectionID: %d\n", replica.ReplicaID, replica.CollectionID)
for _, shardReplica := range replica.ShardReplicas {
res = res + fmt.Sprintf("Channel %s leader %d\n", shardReplica.DmChannelName, shardReplica.LeaderID)
}
res = res + fmt.Sprintf("Nodes:%v\n", replica.NodeIds)
return res
}

View File

@ -0,0 +1,334 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"errors"
"strconv"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestShowSessions(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
sessions, err := c.metaWatcher.ShowSessions()
assert.NoError(t, err)
assert.NotEmpty(t, sessions)
for _, session := range sessions {
log.Info("ShowSessions result", zap.String("session", session.String()))
}
}
func TestShowSegments(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
prefix := "TestShowSegments"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionStatus, err := c.proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
assert.NoError(t, err)
assert.Equal(t, createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
assert.NoError(t, err)
assert.Equal(t, showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
insertResult, err := c.proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
assert.Equal(t, insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
segments, err := c.metaWatcher.ShowSegments()
assert.NoError(t, err)
assert.NotEmpty(t, segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
}
func TestShowReplicas(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
prefix := "TestShowReplicas"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionStatus, err := c.proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: 2,
})
assert.NoError(t, err)
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
}
assert.Equal(t, createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
assert.NoError(t, err)
assert.Equal(t, showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
insertResult, err := c.proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
assert.Equal(t, insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
assert.NoError(t, err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
segments, err := c.metaWatcher.ShowSegments()
assert.NoError(t, err)
assert.NotEmpty(t, segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// load
loadStatus, err := c.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
for {
select {
case <-ctx.Done():
errors.New("context deadline exceeded")
default:
}
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
replicas, err := c.metaWatcher.ShowReplicas()
assert.NoError(t, err)
assert.NotEmpty(t, replicas)
for _, replica := range replicas {
log.Info("ShowReplicas result", zap.String("replica", PrettyReplica(replica)))
}
log.Info("TestShowReplicas succeed")
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,187 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/datanode"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/querynode"
"github.com/stretchr/testify/assert"
)
func TestAddRemoveDataNode(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
datanode := datanode.NewDataNode(ctx, c.factory)
datanode.SetEtcdClient(c.etcdCli)
//datanode := c.CreateDefaultDataNode()
err = c.AddDataNode(datanode)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, len(c.dataNodes))
err = c.RemoveDataNode(datanode)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.DataNodeNum)
assert.Equal(t, 1, len(c.dataNodes))
// add default node and remove randomly
err = c.AddDataNode(nil)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, len(c.dataNodes))
err = c.RemoveDataNode(nil)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.DataNodeNum)
assert.Equal(t, 1, len(c.dataNodes))
}
func TestAddRemoveQueryNode(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
queryNode := querynode.NewQueryNode(ctx, c.factory)
queryNode.SetEtcdClient(c.etcdCli)
//queryNode := c.CreateDefaultQueryNode()
err = c.AddQueryNode(queryNode)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 2, len(c.queryNodes))
err = c.RemoveQueryNode(queryNode)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 1, len(c.queryNodes))
// add default node and remove randomly
err = c.AddQueryNode(nil)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 2, len(c.queryNodes))
err = c.RemoveQueryNode(nil)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 1, len(c.queryNodes))
}
func TestAddRemoveIndexNode(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
indexNode := indexnode.NewIndexNode(ctx, c.factory)
indexNode.SetEtcdClient(c.etcdCli)
//indexNode := c.CreateDefaultIndexNode()
err = c.AddIndexNode(indexNode)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 2, len(c.indexNodes))
err = c.RemoveIndexNode(indexNode)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 1, len(c.indexNodes))
// add default node and remove randomly
err = c.AddIndexNode(nil)
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 2, len(c.indexNodes))
err = c.RemoveIndexNode(nil)
assert.NoError(t, err)
assert.Equal(t, 1, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 1, len(c.indexNodes))
}
func TestUpdateClusterSize(t *testing.T) {
ctx := context.Background()
c, err := StartMiniCluster(ctx)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
err = c.UpdateClusterSize(ClusterConfig{
QueryNodeNum: -1,
DataNodeNum: -1,
IndexNodeNum: -1,
})
assert.Error(t, err)
err = c.UpdateClusterSize(ClusterConfig{
QueryNodeNum: 2,
DataNodeNum: 2,
IndexNodeNum: 2,
})
assert.NoError(t, err)
assert.Equal(t, 2, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 2, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 2, len(c.dataNodes))
assert.Equal(t, 2, len(c.queryNodes))
assert.Equal(t, 2, len(c.indexNodes))
err = c.UpdateClusterSize(ClusterConfig{
DataNodeNum: 3,
QueryNodeNum: 2,
IndexNodeNum: 1,
})
assert.NoError(t, err)
assert.Equal(t, 3, c.clusterConfig.DataNodeNum)
assert.Equal(t, 2, c.clusterConfig.QueryNodeNum)
assert.Equal(t, 1, c.clusterConfig.IndexNodeNum)
assert.Equal(t, 3, len(c.dataNodes))
assert.Equal(t, 2, len(c.queryNodes))
assert.Equal(t, 1, len(c.indexNodes))
}