enhance: Move datanode/indexnode manager to session pkg (#35634)

Related to #28861

Move session manager, worker manager to session package. Also renaming
each manager to corresponding node name(datanode, indexnode).

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2024-08-22 16:02:56 +08:00 committed by GitHub
parent 3107701fe8
commit 582d2eec79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 1370 additions and 1289 deletions

View File

@ -481,18 +481,18 @@ generate-mockery-querynode: getdeps build-cpp
generate-mockery-datacoord: getdeps generate-mockery-datacoord: getdeps
$(INSTALL_PATH)/mockery --name=compactionPlanContext --dir=internal/datacoord --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=compactionPlanContext --dir=internal/datacoord --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Handler --dir=internal/datacoord --filename=mock_handler.go --output=internal/datacoord --structname=NMockHandler --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Handler --dir=internal/datacoord --filename=mock_handler.go --output=internal/datacoord --structname=NMockHandler --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=allocator --dir=internal/datacoord --filename=mock_allocator_test.go --output=internal/datacoord --structname=NMockAllocator --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Allocator --dir=internal/datacoord/allocator --filename=mock_allocator.go --output=internal/datacoord/allocator --structname=MockAllocator --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=DataNodeManager --dir=internal/datacoord/session --filename=mock_datanode_manager.go --output=internal/datacoord/session --structname=MockDataNodeManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=RWChannelStore --dir=internal/datacoord --filename=mock_channel_store.go --output=internal/datacoord --structname=MockRWChannelStore --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=RWChannelStore --dir=internal/datacoord --filename=mock_channel_store.go --output=internal/datacoord --structname=MockRWChannelStore --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=IndexEngineVersionManager --dir=internal/datacoord --filename=mock_index_engine_version_manager.go --output=internal/datacoord --structname=MockVersionManager --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=IndexEngineVersionManager --dir=internal/datacoord --filename=mock_index_engine_version_manager.go --output=internal/datacoord --structname=MockVersionManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=TriggerManager --dir=internal/datacoord --filename=mock_trigger_manager.go --output=internal/datacoord --structname=MockTriggerManager --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=TriggerManager --dir=internal/datacoord --filename=mock_trigger_manager.go --output=internal/datacoord --structname=MockTriggerManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Cluster --dir=internal/datacoord --filename=mock_cluster.go --output=internal/datacoord --structname=MockCluster --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Cluster --dir=internal/datacoord --filename=mock_cluster.go --output=internal/datacoord --structname=MockCluster --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=SessionManager --dir=internal/datacoord --filename=mock_session_manager.go --output=internal/datacoord --structname=MockSessionManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=compactionPlanContext --dir=internal/datacoord --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=compactionPlanContext --dir=internal/datacoord --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=CompactionMeta --dir=internal/datacoord --filename=mock_compaction_meta.go --output=internal/datacoord --structname=MockCompactionMeta --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=CompactionMeta --dir=internal/datacoord --filename=mock_compaction_meta.go --output=internal/datacoord --structname=MockCompactionMeta --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=SubCluster --dir=internal/datacoord --filename=mock_subcluster.go --output=internal/datacoord --structname=MockSubCluster --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=SubCluster --dir=internal/datacoord --filename=mock_subcluster.go --output=internal/datacoord --structname=MockSubCluster --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=WorkerManager --dir=internal/datacoord --filename=mock_worker_manager.go --output=internal/datacoord --structname=MockWorkerManager --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=WorkerManager --dir=internal/datacoord/session --filename=mock_worker_manager.go --output=internal/datacoord/session --structname=MockWorkerManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Manager --dir=internal/datacoord --filename=mock_segment_manager.go --output=internal/datacoord --structname=MockManager --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Manager --dir=internal/datacoord --filename=mock_segment_manager.go --output=internal/datacoord --structname=MockManager --with-expecter --inpackage
generate-mockery-datanode: getdeps generate-mockery-datanode: getdeps

View File

@ -25,6 +25,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/commonpbutil"
@ -35,9 +36,9 @@ import (
// //
//go:generate mockery --name=Cluster --structname=MockCluster --output=./ --filename=mock_cluster.go --with-expecter --inpackage //go:generate mockery --name=Cluster --structname=MockCluster --output=./ --filename=mock_cluster.go --with-expecter --inpackage
type Cluster interface { type Cluster interface {
Startup(ctx context.Context, nodes []*NodeInfo) error Startup(ctx context.Context, nodes []*session.NodeInfo) error
Register(node *NodeInfo) error Register(node *session.NodeInfo) error
UnRegister(node *NodeInfo) error UnRegister(node *session.NodeInfo) error
Watch(ctx context.Context, ch RWChannel) error Watch(ctx context.Context, ch RWChannel) error
Flush(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error Flush(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error
FlushChannels(ctx context.Context, nodeID int64, flushTs Timestamp, channels []string) error FlushChannels(ctx context.Context, nodeID int64, flushTs Timestamp, channels []string) error
@ -47,19 +48,19 @@ type Cluster interface {
QueryImport(nodeID int64, in *datapb.QueryImportRequest) (*datapb.QueryImportResponse, error) QueryImport(nodeID int64, in *datapb.QueryImportRequest) (*datapb.QueryImportResponse, error)
DropImport(nodeID int64, in *datapb.DropImportRequest) error DropImport(nodeID int64, in *datapb.DropImportRequest) error
QuerySlots() map[int64]int64 QuerySlots() map[int64]int64
GetSessions() []*Session GetSessions() []*session.Session
Close() Close()
} }
var _ Cluster = (*ClusterImpl)(nil) var _ Cluster = (*ClusterImpl)(nil)
type ClusterImpl struct { type ClusterImpl struct {
sessionManager SessionManager sessionManager session.DataNodeManager
channelManager ChannelManager channelManager ChannelManager
} }
// NewClusterImpl creates a new cluster // NewClusterImpl creates a new cluster
func NewClusterImpl(sessionManager SessionManager, channelManager ChannelManager) *ClusterImpl { func NewClusterImpl(sessionManager session.DataNodeManager, channelManager ChannelManager) *ClusterImpl {
c := &ClusterImpl{ c := &ClusterImpl{
sessionManager: sessionManager, sessionManager: sessionManager,
channelManager: channelManager, channelManager: channelManager,
@ -69,7 +70,7 @@ func NewClusterImpl(sessionManager SessionManager, channelManager ChannelManager
} }
// Startup inits the cluster with the given data nodes. // Startup inits the cluster with the given data nodes.
func (c *ClusterImpl) Startup(ctx context.Context, nodes []*NodeInfo) error { func (c *ClusterImpl) Startup(ctx context.Context, nodes []*session.NodeInfo) error {
for _, node := range nodes { for _, node := range nodes {
c.sessionManager.AddSession(node) c.sessionManager.AddSession(node)
} }
@ -79,7 +80,7 @@ func (c *ClusterImpl) Startup(ctx context.Context, nodes []*NodeInfo) error {
allNodes []int64 allNodes []int64
) )
lo.ForEach(nodes, func(info *NodeInfo, _ int) { lo.ForEach(nodes, func(info *session.NodeInfo, _ int) {
if info.IsLegacy { if info.IsLegacy {
legacyNodes = append(legacyNodes, info.NodeID) legacyNodes = append(legacyNodes, info.NodeID)
} }
@ -89,13 +90,13 @@ func (c *ClusterImpl) Startup(ctx context.Context, nodes []*NodeInfo) error {
} }
// Register registers a new node in cluster // Register registers a new node in cluster
func (c *ClusterImpl) Register(node *NodeInfo) error { func (c *ClusterImpl) Register(node *session.NodeInfo) error {
c.sessionManager.AddSession(node) c.sessionManager.AddSession(node)
return c.channelManager.AddNode(node.NodeID) return c.channelManager.AddNode(node.NodeID)
} }
// UnRegister removes a node from cluster // UnRegister removes a node from cluster
func (c *ClusterImpl) UnRegister(node *NodeInfo) error { func (c *ClusterImpl) UnRegister(node *session.NodeInfo) error {
c.sessionManager.DeleteSession(node) c.sessionManager.DeleteSession(node)
return c.channelManager.DeleteNode(node.NodeID) return c.channelManager.DeleteNode(node.NodeID)
} }
@ -204,7 +205,7 @@ func (c *ClusterImpl) QuerySlots() map[int64]int64 {
} }
// GetSessions returns all sessions // GetSessions returns all sessions
func (c *ClusterImpl) GetSessions() []*Session { func (c *ClusterImpl) GetSessions() []*session.Session {
return c.sessionManager.GetSessions() return c.sessionManager.GetSessions()
} }

View File

@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/datacoord/session"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/internal/kv/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
@ -50,19 +51,19 @@ type ClusterSuite struct {
mockKv *mocks.WatchKV mockKv *mocks.WatchKV
mockChManager *MockChannelManager mockChManager *MockChannelManager
mockSession *MockSessionManager mockSession *session.MockDataNodeManager
} }
func (suite *ClusterSuite) SetupTest() { func (suite *ClusterSuite) SetupTest() {
suite.mockKv = mocks.NewWatchKV(suite.T()) suite.mockKv = mocks.NewWatchKV(suite.T())
suite.mockChManager = NewMockChannelManager(suite.T()) suite.mockChManager = NewMockChannelManager(suite.T())
suite.mockSession = NewMockSessionManager(suite.T()) suite.mockSession = session.NewMockDataNodeManager(suite.T())
} }
func (suite *ClusterSuite) TearDownTest() {} func (suite *ClusterSuite) TearDownTest() {}
func (suite *ClusterSuite) TestStartup() { func (suite *ClusterSuite) TestStartup() {
nodes := []*NodeInfo{ nodes := []*session.NodeInfo{
{NodeID: 1, Address: "addr1"}, {NodeID: 1, Address: "addr1"},
{NodeID: 2, Address: "addr2"}, {NodeID: 2, Address: "addr2"},
{NodeID: 3, Address: "addr3"}, {NodeID: 3, Address: "addr3"},
@ -71,7 +72,7 @@ func (suite *ClusterSuite) TestStartup() {
suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Times(len(nodes)) suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Times(len(nodes))
suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything). suite.mockChManager.EXPECT().Startup(mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, legacys []int64, nodeIDs []int64) error { RunAndReturn(func(ctx context.Context, legacys []int64, nodeIDs []int64) error {
suite.ElementsMatch(lo.Map(nodes, func(info *NodeInfo, _ int) int64 { return info.NodeID }), nodeIDs) suite.ElementsMatch(lo.Map(nodes, func(info *session.NodeInfo, _ int) int64 { return info.NodeID }), nodeIDs)
return nil return nil
}).Once() }).Once()
@ -81,7 +82,7 @@ func (suite *ClusterSuite) TestStartup() {
} }
func (suite *ClusterSuite) TestRegister() { func (suite *ClusterSuite) TestRegister() {
info := &NodeInfo{NodeID: 1, Address: "addr1"} info := &session.NodeInfo{NodeID: 1, Address: "addr1"}
suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Once() suite.mockSession.EXPECT().AddSession(mock.Anything).Return().Once()
suite.mockChManager.EXPECT().AddNode(mock.Anything). suite.mockChManager.EXPECT().AddNode(mock.Anything).
@ -96,7 +97,7 @@ func (suite *ClusterSuite) TestRegister() {
} }
func (suite *ClusterSuite) TestUnregister() { func (suite *ClusterSuite) TestUnregister() {
info := &NodeInfo{NodeID: 1, Address: "addr1"} info := &session.NodeInfo{NodeID: 1, Address: "addr1"}
suite.mockSession.EXPECT().DeleteSession(mock.Anything).Return().Once() suite.mockSession.EXPECT().DeleteSession(mock.Anything).Return().Once()
suite.mockChManager.EXPECT().DeleteNode(mock.Anything). suite.mockChManager.EXPECT().DeleteNode(mock.Anything).

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
@ -82,7 +83,7 @@ type compactionPlanHandler struct {
meta CompactionMeta meta CompactionMeta
allocator allocator.Allocator allocator allocator.Allocator
chManager ChannelManager chManager ChannelManager
sessions SessionManager sessions session.DataNodeManager
cluster Cluster cluster Cluster
analyzeScheduler *taskScheduler analyzeScheduler *taskScheduler
handler Handler handler Handler
@ -177,7 +178,7 @@ func (c *compactionPlanHandler) getCompactionTasksNumBySignalID(triggerID int64)
return cnt return cnt
} }
func newCompactionPlanHandler(cluster Cluster, sessions SessionManager, cm ChannelManager, meta CompactionMeta, allocator allocator.Allocator, analyzeScheduler *taskScheduler, handler Handler, func newCompactionPlanHandler(cluster Cluster, sessions session.DataNodeManager, cm ChannelManager, meta CompactionMeta, allocator allocator.Allocator, analyzeScheduler *taskScheduler, handler Handler,
) *compactionPlanHandler { ) *compactionPlanHandler {
return &compactionPlanHandler{ return &compactionPlanHandler{
queueTasks: make(map[int64]CompactionTask), queueTasks: make(map[int64]CompactionTask),

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
@ -49,14 +50,14 @@ type clusteringCompactionTask struct {
span trace.Span span trace.Span
allocator allocator.Allocator allocator allocator.Allocator
meta CompactionMeta meta CompactionMeta
sessions SessionManager sessions session.DataNodeManager
handler Handler handler Handler
analyzeScheduler *taskScheduler analyzeScheduler *taskScheduler
maxRetryTimes int32 maxRetryTimes int32
} }
func newClusteringCompactionTask(t *datapb.CompactionTask, allocator allocator.Allocator, meta CompactionMeta, session SessionManager, handler Handler, analyzeScheduler *taskScheduler) *clusteringCompactionTask { func newClusteringCompactionTask(t *datapb.CompactionTask, allocator allocator.Allocator, meta CompactionMeta, session session.DataNodeManager, handler Handler, analyzeScheduler *taskScheduler) *clusteringCompactionTask {
return &clusteringCompactionTask{ return &clusteringCompactionTask{
CompactionTask: t, CompactionTask: t,
allocator: allocator, allocator: allocator,

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
@ -48,9 +49,8 @@ type ClusteringCompactionTaskSuite struct {
mockID atomic.Int64 mockID atomic.Int64
mockAlloc *allocator.MockAllocator mockAlloc *allocator.MockAllocator
meta *meta meta *meta
mockSessMgr *MockSessionManager
handler *NMockHandler handler *NMockHandler
session *MockSessionManager mockSessionMgr *session.MockDataNodeManager
analyzeScheduler *taskScheduler analyzeScheduler *taskScheduler
} }
@ -62,8 +62,6 @@ func (s *ClusteringCompactionTaskSuite) SetupTest() {
s.NoError(err) s.NoError(err)
s.meta = meta s.meta = meta
s.mockSessMgr = NewMockSessionManager(s.T())
s.mockID.Store(time.Now().UnixMilli()) s.mockID.Store(time.Now().UnixMilli())
s.mockAlloc = allocator.NewMockAllocator(s.T()) s.mockAlloc = allocator.NewMockAllocator(s.T())
s.mockAlloc.EXPECT().AllocN(mock.Anything).RunAndReturn(func(x int64) (int64, int64, error) { s.mockAlloc.EXPECT().AllocN(mock.Anything).RunAndReturn(func(x int64) (int64, int64, error) {
@ -79,7 +77,7 @@ func (s *ClusteringCompactionTaskSuite) SetupTest() {
s.handler = NewNMockHandler(s.T()) s.handler = NewNMockHandler(s.T())
s.handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe() s.handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
s.session = NewMockSessionManager(s.T()) s.mockSessionMgr = session.NewMockDataNodeManager(s.T())
scheduler := newTaskScheduler(ctx, s.meta, nil, cm, newIndexEngineVersionManager(), nil) scheduler := newTaskScheduler(ctx, s.meta, nil, cm, newIndexEngineVersionManager(), nil)
s.analyzeScheduler = scheduler s.analyzeScheduler = scheduler
@ -105,7 +103,7 @@ func (s *ClusteringCompactionTaskSuite) TestClusteringCompactionSegmentMetaChang
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mockSessionMgr.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(nil)
task := s.generateBasicTask(false) task := s.generateBasicTask(false)
@ -190,7 +188,7 @@ func (s *ClusteringCompactionTaskSuite) generateBasicTask(vectorClusteringKey bo
ResultSegments: []int64{1000, 1100}, ResultSegments: []int64{1000, 1100},
} }
task := newClusteringCompactionTask(compactionTask, s.mockAlloc, s.meta, s.session, s.handler, s.analyzeScheduler) task := newClusteringCompactionTask(compactionTask, s.mockAlloc, s.meta, s.mockSessionMgr, s.handler, s.analyzeScheduler)
task.maxRetryTimes = 0 task.maxRetryTimes = 0
return task return task
} }
@ -236,7 +234,7 @@ func (s *ClusteringCompactionTaskSuite) TestProcessPipelining() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(merr.WrapErrDataNodeSlotExhausted()) s.mockSessionMgr.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(merr.WrapErrDataNodeSlotExhausted())
task.State = datapb.CompactionTaskState_pipelining task.State = datapb.CompactionTaskState_pipelining
s.False(task.Process()) s.False(task.Process())
s.Equal(int64(NullNodeID), task.GetNodeID()) s.Equal(int64(NullNodeID), task.GetNodeID())
@ -260,7 +258,7 @@ func (s *ClusteringCompactionTaskSuite) TestProcessPipelining() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mockSessionMgr.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(nil)
task.State = datapb.CompactionTaskState_pipelining task.State = datapb.CompactionTaskState_pipelining
s.Equal(false, task.Process()) s.Equal(false, task.Process())
s.Equal(datapb.CompactionTaskState_executing, task.GetState()) s.Equal(datapb.CompactionTaskState_executing, task.GetState())
@ -309,7 +307,7 @@ func (s *ClusteringCompactionTaskSuite) TestProcessExecuting() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(nil, merr.WrapErrNodeNotFound(1)).Once() s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(nil, merr.WrapErrNodeNotFound(1)).Once()
s.Equal(false, task.Process()) s.Equal(false, task.Process())
s.Equal(datapb.CompactionTaskState_pipelining, task.GetState()) s.Equal(datapb.CompactionTaskState_pipelining, task.GetState())
}) })
@ -332,10 +330,10 @@ func (s *ClusteringCompactionTaskSuite) TestProcessExecuting() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(nil, nil).Once() s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(nil, nil).Once()
s.Equal(false, task.Process()) s.Equal(false, task.Process())
s.Equal(datapb.CompactionTaskState_executing, task.GetState()) s.Equal(datapb.CompactionTaskState_executing, task.GetState())
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_executing, State: datapb.CompactionTaskState_executing,
}, nil).Once() }, nil).Once()
s.Equal(false, task.Process()) s.Equal(false, task.Process())
@ -360,7 +358,7 @@ func (s *ClusteringCompactionTaskSuite) TestProcessExecuting() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_completed, State: datapb.CompactionTaskState_completed,
Segments: []*datapb.CompactionSegment{ Segments: []*datapb.CompactionSegment{
{ {
@ -393,7 +391,7 @@ func (s *ClusteringCompactionTaskSuite) TestProcessExecuting() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_completed, State: datapb.CompactionTaskState_completed,
Segments: []*datapb.CompactionSegment{ Segments: []*datapb.CompactionSegment{
{ {
@ -428,7 +426,7 @@ func (s *ClusteringCompactionTaskSuite) TestProcessExecuting() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_executing, State: datapb.CompactionTaskState_executing,
Segments: []*datapb.CompactionSegment{ Segments: []*datapb.CompactionSegment{
{ {
@ -447,31 +445,31 @@ func (s *ClusteringCompactionTaskSuite) TestProcessExecuting() {
func (s *ClusteringCompactionTaskSuite) TestProcessExecutingState() { func (s *ClusteringCompactionTaskSuite) TestProcessExecutingState() {
task := s.generateBasicTask(false) task := s.generateBasicTask(false)
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_failed, State: datapb.CompactionTaskState_failed,
}, nil).Once() }, nil).Once()
s.NoError(task.processExecuting()) s.NoError(task.processExecuting())
s.Equal(datapb.CompactionTaskState_failed, task.GetState()) s.Equal(datapb.CompactionTaskState_failed, task.GetState())
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_failed, State: datapb.CompactionTaskState_failed,
}, nil).Once() }, nil).Once()
s.NoError(task.processExecuting()) s.NoError(task.processExecuting())
s.Equal(datapb.CompactionTaskState_failed, task.GetState()) s.Equal(datapb.CompactionTaskState_failed, task.GetState())
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_pipelining, State: datapb.CompactionTaskState_pipelining,
}, nil).Once() }, nil).Once()
s.NoError(task.processExecuting()) s.NoError(task.processExecuting())
s.Equal(datapb.CompactionTaskState_failed, task.GetState()) s.Equal(datapb.CompactionTaskState_failed, task.GetState())
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_completed, State: datapb.CompactionTaskState_completed,
}, nil).Once() }, nil).Once()
s.Error(task.processExecuting()) s.Error(task.processExecuting())
s.Equal(datapb.CompactionTaskState_failed, task.GetState()) s.Equal(datapb.CompactionTaskState_failed, task.GetState())
s.session.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{ s.mockSessionMgr.EXPECT().GetCompactionPlanResult(mock.Anything, mock.Anything).Return(&datapb.CompactionPlanResult{
State: datapb.CompactionTaskState_completed, State: datapb.CompactionTaskState_completed,
Segments: []*datapb.CompactionSegment{ Segments: []*datapb.CompactionSegment{
{ {
@ -608,7 +606,7 @@ func (s *ClusteringCompactionTaskSuite) TestProcessAnalyzingState() {
PartitionStatsVersion: 10000, PartitionStatsVersion: 10000,
}, },
}) })
s.session.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mockSessionMgr.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.False(task.Process()) s.False(task.Process())
s.Equal(datapb.CompactionTaskState_executing, task.GetState()) s.Equal(datapb.CompactionTaskState_executing, task.GetState())

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -43,7 +44,7 @@ type l0CompactionTask struct {
span trace.Span span trace.Span
allocator allocator.Allocator allocator allocator.Allocator
sessions SessionManager sessions session.DataNodeManager
meta CompactionMeta meta CompactionMeta
} }

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
) )
@ -42,12 +43,12 @@ type L0CompactionTaskSuite struct {
mockAlloc *allocator.MockAllocator mockAlloc *allocator.MockAllocator
mockMeta *MockCompactionMeta mockMeta *MockCompactionMeta
mockSessMgr *MockSessionManager mockSessMgr *session.MockDataNodeManager
} }
func (s *L0CompactionTaskSuite) SetupTest() { func (s *L0CompactionTaskSuite) SetupTest() {
s.mockMeta = NewMockCompactionMeta(s.T()) s.mockMeta = NewMockCompactionMeta(s.T())
s.mockSessMgr = NewMockSessionManager(s.T()) s.mockSessMgr = session.NewMockDataNodeManager(s.T())
s.mockAlloc = allocator.NewMockAllocator(s.T()) s.mockAlloc = allocator.NewMockAllocator(s.T())
} }

View File

@ -10,6 +10,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
@ -24,7 +25,7 @@ type mixCompactionTask struct {
span trace.Span span trace.Span
allocator allocator.Allocator allocator allocator.Allocator
sessions SessionManager sessions session.DataNodeManager
meta CompactionMeta meta CompactionMeta
newSegment *SegmentInfo newSegment *SegmentInfo
} }

View File

@ -4,6 +4,8 @@ import (
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/datacoord/session"
) )
func TestCompactionTaskSuite(t *testing.T) { func TestCompactionTaskSuite(t *testing.T) {
@ -14,10 +16,10 @@ type CompactionTaskSuite struct {
suite.Suite suite.Suite
mockMeta *MockCompactionMeta mockMeta *MockCompactionMeta
mockSessMgr *MockSessionManager mockSessMgr *session.MockDataNodeManager
} }
func (s *CompactionTaskSuite) SetupTest() { func (s *CompactionTaskSuite) SetupTest() {
s.mockMeta = NewMockCompactionMeta(s.T()) s.mockMeta = NewMockCompactionMeta(s.T())
s.mockSessMgr = NewMockSessionManager(s.T()) s.mockSessMgr = session.NewMockDataNodeManager(s.T())
} }

View File

@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/metastore/kv/binlog" "github.com/milvus-io/milvus/internal/metastore/kv/binlog"
"github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
@ -44,7 +45,7 @@ type CompactionPlanHandlerSuite struct {
mockMeta *MockCompactionMeta mockMeta *MockCompactionMeta
mockAlloc *allocator.MockAllocator mockAlloc *allocator.MockAllocator
mockCm *MockChannelManager mockCm *MockChannelManager
mockSessMgr *MockSessionManager mockSessMgr *session.MockDataNodeManager
handler *compactionPlanHandler handler *compactionPlanHandler
cluster Cluster cluster Cluster
} }
@ -53,7 +54,7 @@ func (s *CompactionPlanHandlerSuite) SetupTest() {
s.mockMeta = NewMockCompactionMeta(s.T()) s.mockMeta = NewMockCompactionMeta(s.T())
s.mockAlloc = allocator.NewMockAllocator(s.T()) s.mockAlloc = allocator.NewMockAllocator(s.T())
s.mockCm = NewMockChannelManager(s.T()) s.mockCm = NewMockChannelManager(s.T())
s.mockSessMgr = NewMockSessionManager(s.T()) s.mockSessMgr = session.NewMockDataNodeManager(s.T())
s.cluster = NewMockCluster(s.T()) s.cluster = NewMockCluster(s.T())
s.handler = newCompactionPlanHandler(s.cluster, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc, nil, nil) s.handler = newCompactionPlanHandler(s.cluster, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc, nil, nil)
} }

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/metastore/kv/binlog" "github.com/milvus-io/milvus/internal/metastore/kv/binlog"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
@ -145,8 +146,8 @@ func (s *importScheduler) process() {
} }
func (s *importScheduler) peekSlots() map[int64]int64 { func (s *importScheduler) peekSlots() map[int64]int64 {
nodeIDs := lo.Map(s.cluster.GetSessions(), func(s *Session, _ int) int64 { nodeIDs := lo.Map(s.cluster.GetSessions(), func(s *session.Session, _ int) int64 {
return s.info.NodeID return s.NodeID()
}) })
nodeSlots := make(map[int64]int64) nodeSlots := make(map[int64]int64)
mu := &lock.Mutex{} mu := &lock.Mutex{}

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/metastore/mocks" "github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
) )
@ -106,12 +107,11 @@ func (s *ImportSchedulerSuite) TestProcessPreImport() {
Slots: 1, Slots: 1,
}, nil) }, nil)
s.cluster.EXPECT().PreImport(mock.Anything, mock.Anything).Return(nil) s.cluster.EXPECT().PreImport(mock.Anything, mock.Anything).Return(nil)
s.cluster.EXPECT().GetSessions().Return([]*Session{ s.cluster.EXPECT().GetSessions().RunAndReturn(func() []*session.Session {
{ sess := session.NewSession(&session.NodeInfo{
info: &NodeInfo{
NodeID: nodeID, NodeID: nodeID,
}, }, nil)
}, return []*session.Session{sess}
}) })
s.scheduler.process() s.scheduler.process()
task = s.imeta.GetTask(task.GetTaskID()) task = s.imeta.GetTask(task.GetTaskID())
@ -181,12 +181,11 @@ func (s *ImportSchedulerSuite) TestProcessImport() {
Slots: 1, Slots: 1,
}, nil) }, nil)
s.cluster.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(nil) s.cluster.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(nil)
s.cluster.EXPECT().GetSessions().Return([]*Session{ s.cluster.EXPECT().GetSessions().RunAndReturn(func() []*session.Session {
{ sess := session.NewSession(&session.NodeInfo{
info: &NodeInfo{
NodeID: nodeID, NodeID: nodeID,
}, }, nil)
}, return []*session.Session{sess}
}) })
s.scheduler.process() s.scheduler.process()
task = s.imeta.GetTask(task.GetTaskID()) task = s.imeta.GetTask(task.GetTaskID())
@ -243,12 +242,11 @@ func (s *ImportSchedulerSuite) TestProcessFailed() {
s.cluster.EXPECT().QueryImport(mock.Anything, mock.Anything).Return(&datapb.QueryImportResponse{ s.cluster.EXPECT().QueryImport(mock.Anything, mock.Anything).Return(&datapb.QueryImportResponse{
Slots: 1, Slots: 1,
}, nil) }, nil)
s.cluster.EXPECT().GetSessions().Return([]*Session{ s.cluster.EXPECT().GetSessions().RunAndReturn(func() []*session.Session {
{ sess := session.NewSession(&session.NodeInfo{
info: &NodeInfo{
NodeID: 6, NodeID: 6,
}, }, nil)
}, return []*session.Session{sess}
}) })
for _, id := range task.(*importTask).GetSegmentIDs() { for _, id := range task.(*importTask).GetSegmentIDs() {
segment := &SegmentInfo{ segment := &SegmentInfo{

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/broker" "github.com/milvus-io/milvus/internal/datacoord/broker"
"github.com/milvus-io/milvus/internal/datacoord/session"
mockkv "github.com/milvus-io/milvus/internal/kv/mocks" mockkv "github.com/milvus-io/milvus/internal/kv/mocks"
"github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord"
catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks" catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks"
@ -214,7 +215,7 @@ func TestServer_CreateIndex(t *testing.T) {
Value: "DISKANN", Value: "DISKANN",
}, },
} }
s.indexNodeManager = NewNodeManager(ctx, defaultIndexNodeCreatorFunc) s.indexNodeManager = session.NewNodeManager(ctx, defaultIndexNodeCreatorFunc)
resp, err := s.CreateIndex(ctx, req) resp, err := s.CreateIndex(ctx, req)
assert.Error(t, merr.CheckRPCCall(resp, err)) assert.Error(t, merr.CheckRPCCall(resp, err))
}) })
@ -232,12 +233,10 @@ func TestServer_CreateIndex(t *testing.T) {
Value: "true", Value: "true",
}, },
} }
nodeManager := NewNodeManager(ctx, defaultIndexNodeCreatorFunc) nodeManager := session.NewNodeManager(ctx, defaultIndexNodeCreatorFunc)
s.indexNodeManager = nodeManager s.indexNodeManager = nodeManager
mockNode := mocks.NewMockIndexNodeClient(t) mockNode := mocks.NewMockIndexNodeClient(t)
s.indexNodeManager.lock.Lock() nodeManager.SetClient(1001, mockNode)
s.indexNodeManager.nodeClients[1001] = mockNode
s.indexNodeManager.lock.Unlock()
mockNode.EXPECT().GetJobStats(mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{ mockNode.EXPECT().GetJobStats(mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{
Status: merr.Success(), Status: merr.Success(),
EnableDisk: true, EnableDisk: true,

View File

@ -24,6 +24,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -172,7 +173,7 @@ func (s *Server) getDataCoordMetrics(ctx context.Context) metricsinfo.DataCoordI
// getDataNodeMetrics composes DataNode infos // getDataNodeMetrics composes DataNode infos
// this function will invoke GetMetrics with DataNode specified in NodeInfo // this function will invoke GetMetrics with DataNode specified in NodeInfo
func (s *Server) getDataNodeMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, node *Session) (metricsinfo.DataNodeInfos, error) { func (s *Server) getDataNodeMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, node *session.Session) (metricsinfo.DataNodeInfos, error) {
infos := metricsinfo.DataNodeInfos{ infos := metricsinfo.DataNodeInfos{
BaseComponentInfos: metricsinfo.BaseComponentInfos{ BaseComponentInfos: metricsinfo.BaseComponentInfos{
HasError: true, HasError: true,

View File

@ -25,6 +25,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/metricsinfo"
@ -66,7 +67,7 @@ func TestGetDataNodeMetrics(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
// nil client node // nil client node
_, err = svr.getDataNodeMetrics(ctx, req, NewSession(&NodeInfo{}, nil)) _, err = svr.getDataNodeMetrics(ctx, req, session.NewSession(&session.NodeInfo{}, nil))
assert.Error(t, err) assert.Error(t, err)
creator := func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { creator := func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
@ -74,13 +75,13 @@ func TestGetDataNodeMetrics(t *testing.T) {
} }
// mock datanode client // mock datanode client
session := NewSession(&NodeInfo{}, creator) sess := session.NewSession(&session.NodeInfo{}, creator)
info, err := svr.getDataNodeMetrics(ctx, req, session) info, err := svr.getDataNodeMetrics(ctx, req, sess)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, info.HasError) assert.False(t, info.HasError)
assert.Equal(t, metricsinfo.ConstructComponentName(typeutil.DataNodeRole, 100), info.BaseComponentInfos.Name) assert.Equal(t, metricsinfo.ConstructComponentName(typeutil.DataNodeRole, 100), info.BaseComponentInfos.Name)
getMockFailedClientCreator := func(mockFunc func() (*milvuspb.GetMetricsResponse, error)) dataNodeCreatorFunc { getMockFailedClientCreator := func(mockFunc func() (*milvuspb.GetMetricsResponse, error)) session.DataNodeCreatorFunc {
return func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
cli, err := creator(ctx, addr, nodeID) cli, err := creator(ctx, addr, nodeID)
assert.NoError(t, err) assert.NoError(t, err)
@ -92,7 +93,7 @@ func TestGetDataNodeMetrics(t *testing.T) {
return nil, errors.New("mocked fail") return nil, errors.New("mocked fail")
}) })
info, err = svr.getDataNodeMetrics(ctx, req, NewSession(&NodeInfo{}, mockFailClientCreator)) info, err = svr.getDataNodeMetrics(ctx, req, session.NewSession(&session.NodeInfo{}, mockFailClientCreator))
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, info.HasError) assert.True(t, info.HasError)
@ -104,7 +105,7 @@ func TestGetDataNodeMetrics(t *testing.T) {
}, nil }, nil
}) })
info, err = svr.getDataNodeMetrics(ctx, req, NewSession(&NodeInfo{}, mockFailClientCreator)) info, err = svr.getDataNodeMetrics(ctx, req, session.NewSession(&session.NodeInfo{}, mockFailClientCreator))
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, info.HasError) assert.True(t, info.HasError)
assert.Equal(t, "mocked error", info.ErrorReason) assert.Equal(t, "mocked error", info.ErrorReason)
@ -117,7 +118,7 @@ func TestGetDataNodeMetrics(t *testing.T) {
}, nil }, nil
}) })
info, err = svr.getDataNodeMetrics(ctx, req, NewSession(&NodeInfo{}, mockFailClientCreator)) info, err = svr.getDataNodeMetrics(ctx, req, session.NewSession(&session.NodeInfo{}, mockFailClientCreator))
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, info.HasError) assert.True(t, info.HasError)
} }

View File

@ -7,6 +7,8 @@ import (
datapb "github.com/milvus-io/milvus/internal/proto/datapb" datapb "github.com/milvus-io/milvus/internal/proto/datapb"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
session "github.com/milvus-io/milvus/internal/datacoord/session"
) )
// MockCluster is an autogenerated mock type for the Cluster type // MockCluster is an autogenerated mock type for the Cluster type
@ -188,15 +190,15 @@ func (_c *MockCluster_FlushChannels_Call) RunAndReturn(run func(context.Context,
} }
// GetSessions provides a mock function with given fields: // GetSessions provides a mock function with given fields:
func (_m *MockCluster) GetSessions() []*Session { func (_m *MockCluster) GetSessions() []*session.Session {
ret := _m.Called() ret := _m.Called()
var r0 []*Session var r0 []*session.Session
if rf, ok := ret.Get(0).(func() []*Session); ok { if rf, ok := ret.Get(0).(func() []*session.Session); ok {
r0 = rf() r0 = rf()
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Session) r0 = ret.Get(0).([]*session.Session)
} }
} }
@ -220,12 +222,12 @@ func (_c *MockCluster_GetSessions_Call) Run(run func()) *MockCluster_GetSessions
return _c return _c
} }
func (_c *MockCluster_GetSessions_Call) Return(_a0 []*Session) *MockCluster_GetSessions_Call { func (_c *MockCluster_GetSessions_Call) Return(_a0 []*session.Session) *MockCluster_GetSessions_Call {
_c.Call.Return(_a0) _c.Call.Return(_a0)
return _c return _c
} }
func (_c *MockCluster_GetSessions_Call) RunAndReturn(run func() []*Session) *MockCluster_GetSessions_Call { func (_c *MockCluster_GetSessions_Call) RunAndReturn(run func() []*session.Session) *MockCluster_GetSessions_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
@ -470,11 +472,11 @@ func (_c *MockCluster_QuerySlots_Call) RunAndReturn(run func() map[int64]int64)
} }
// Register provides a mock function with given fields: node // Register provides a mock function with given fields: node
func (_m *MockCluster) Register(node *NodeInfo) error { func (_m *MockCluster) Register(node *session.NodeInfo) error {
ret := _m.Called(node) ret := _m.Called(node)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(*NodeInfo) error); ok { if rf, ok := ret.Get(0).(func(*session.NodeInfo) error); ok {
r0 = rf(node) r0 = rf(node)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
@ -489,14 +491,14 @@ type MockCluster_Register_Call struct {
} }
// Register is a helper method to define mock.On call // Register is a helper method to define mock.On call
// - node *NodeInfo // - node *session.NodeInfo
func (_e *MockCluster_Expecter) Register(node interface{}) *MockCluster_Register_Call { func (_e *MockCluster_Expecter) Register(node interface{}) *MockCluster_Register_Call {
return &MockCluster_Register_Call{Call: _e.mock.On("Register", node)} return &MockCluster_Register_Call{Call: _e.mock.On("Register", node)}
} }
func (_c *MockCluster_Register_Call) Run(run func(node *NodeInfo)) *MockCluster_Register_Call { func (_c *MockCluster_Register_Call) Run(run func(node *session.NodeInfo)) *MockCluster_Register_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*NodeInfo)) run(args[0].(*session.NodeInfo))
}) })
return _c return _c
} }
@ -506,17 +508,17 @@ func (_c *MockCluster_Register_Call) Return(_a0 error) *MockCluster_Register_Cal
return _c return _c
} }
func (_c *MockCluster_Register_Call) RunAndReturn(run func(*NodeInfo) error) *MockCluster_Register_Call { func (_c *MockCluster_Register_Call) RunAndReturn(run func(*session.NodeInfo) error) *MockCluster_Register_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
// Startup provides a mock function with given fields: ctx, nodes // Startup provides a mock function with given fields: ctx, nodes
func (_m *MockCluster) Startup(ctx context.Context, nodes []*NodeInfo) error { func (_m *MockCluster) Startup(ctx context.Context, nodes []*session.NodeInfo) error {
ret := _m.Called(ctx, nodes) ret := _m.Called(ctx, nodes)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, []*NodeInfo) error); ok { if rf, ok := ret.Get(0).(func(context.Context, []*session.NodeInfo) error); ok {
r0 = rf(ctx, nodes) r0 = rf(ctx, nodes)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
@ -532,14 +534,14 @@ type MockCluster_Startup_Call struct {
// Startup is a helper method to define mock.On call // Startup is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodes []*NodeInfo // - nodes []*session.NodeInfo
func (_e *MockCluster_Expecter) Startup(ctx interface{}, nodes interface{}) *MockCluster_Startup_Call { func (_e *MockCluster_Expecter) Startup(ctx interface{}, nodes interface{}) *MockCluster_Startup_Call {
return &MockCluster_Startup_Call{Call: _e.mock.On("Startup", ctx, nodes)} return &MockCluster_Startup_Call{Call: _e.mock.On("Startup", ctx, nodes)}
} }
func (_c *MockCluster_Startup_Call) Run(run func(ctx context.Context, nodes []*NodeInfo)) *MockCluster_Startup_Call { func (_c *MockCluster_Startup_Call) Run(run func(ctx context.Context, nodes []*session.NodeInfo)) *MockCluster_Startup_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]*NodeInfo)) run(args[0].(context.Context), args[1].([]*session.NodeInfo))
}) })
return _c return _c
} }
@ -549,17 +551,17 @@ func (_c *MockCluster_Startup_Call) Return(_a0 error) *MockCluster_Startup_Call
return _c return _c
} }
func (_c *MockCluster_Startup_Call) RunAndReturn(run func(context.Context, []*NodeInfo) error) *MockCluster_Startup_Call { func (_c *MockCluster_Startup_Call) RunAndReturn(run func(context.Context, []*session.NodeInfo) error) *MockCluster_Startup_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
// UnRegister provides a mock function with given fields: node // UnRegister provides a mock function with given fields: node
func (_m *MockCluster) UnRegister(node *NodeInfo) error { func (_m *MockCluster) UnRegister(node *session.NodeInfo) error {
ret := _m.Called(node) ret := _m.Called(node)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(*NodeInfo) error); ok { if rf, ok := ret.Get(0).(func(*session.NodeInfo) error); ok {
r0 = rf(node) r0 = rf(node)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
@ -574,14 +576,14 @@ type MockCluster_UnRegister_Call struct {
} }
// UnRegister is a helper method to define mock.On call // UnRegister is a helper method to define mock.On call
// - node *NodeInfo // - node *session.NodeInfo
func (_e *MockCluster_Expecter) UnRegister(node interface{}) *MockCluster_UnRegister_Call { func (_e *MockCluster_Expecter) UnRegister(node interface{}) *MockCluster_UnRegister_Call {
return &MockCluster_UnRegister_Call{Call: _e.mock.On("UnRegister", node)} return &MockCluster_UnRegister_Call{Call: _e.mock.On("UnRegister", node)}
} }
func (_c *MockCluster_UnRegister_Call) Run(run func(node *NodeInfo)) *MockCluster_UnRegister_Call { func (_c *MockCluster_UnRegister_Call) Run(run func(node *session.NodeInfo)) *MockCluster_UnRegister_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*NodeInfo)) run(args[0].(*session.NodeInfo))
}) })
return _c return _c
} }
@ -591,7 +593,7 @@ func (_c *MockCluster_UnRegister_Call) Return(_a0 error) *MockCluster_UnRegister
return _c return _c
} }
func (_c *MockCluster_UnRegister_Call) RunAndReturn(run func(*NodeInfo) error) *MockCluster_UnRegister_Call { func (_c *MockCluster_UnRegister_Call) RunAndReturn(run func(*session.NodeInfo) error) *MockCluster_UnRegister_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/broker" "github.com/milvus-io/milvus/internal/datacoord/broker"
"github.com/milvus-io/milvus/internal/datacoord/session"
datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client"
indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
@ -84,10 +85,6 @@ type (
Timestamp = typeutil.Timestamp Timestamp = typeutil.Timestamp
) )
type dataNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error)
type indexNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error)
type rootCoordCreatorFunc func(ctx context.Context) (types.RootCoordClient, error) type rootCoordCreatorFunc func(ctx context.Context) (types.RootCoordClient, error)
// makes sure Server implements `DataCoord` // makes sure Server implements `DataCoord`
@ -114,7 +111,7 @@ type Server struct {
segmentManager Manager segmentManager Manager
allocator allocator.Allocator allocator allocator.Allocator
cluster Cluster cluster Cluster
sessionManager SessionManager sessionManager session.DataNodeManager
channelManager ChannelManager channelManager ChannelManager
rootCoordClient types.RootCoordClient rootCoordClient types.RootCoordClient
garbageCollector *garbageCollector garbageCollector *garbageCollector
@ -146,13 +143,13 @@ type Server struct {
enableActiveStandBy bool enableActiveStandBy bool
activateFunc func() error activateFunc func() error
dataNodeCreator dataNodeCreatorFunc dataNodeCreator session.DataNodeCreatorFunc
indexNodeCreator indexNodeCreatorFunc indexNodeCreator session.IndexNodeCreatorFunc
rootCoordClientCreator rootCoordCreatorFunc rootCoordClientCreator rootCoordCreatorFunc
// indexCoord types.IndexCoord // indexCoord types.IndexCoord
// segReferManager *SegmentReferenceManager // segReferManager *SegmentReferenceManager
indexNodeManager *IndexNodeManager indexNodeManager *session.IndexNodeManager
indexEngineVersionManager IndexEngineVersionManager indexEngineVersionManager IndexEngineVersionManager
taskScheduler *taskScheduler taskScheduler *taskScheduler
@ -187,7 +184,7 @@ func WithCluster(cluster Cluster) Option {
} }
// WithDataNodeCreator returns an `Option` setting DataNode create function // WithDataNodeCreator returns an `Option` setting DataNode create function
func WithDataNodeCreator(creator dataNodeCreatorFunc) Option { func WithDataNodeCreator(creator session.DataNodeCreatorFunc) Option {
return func(svr *Server) { return func(svr *Server) {
svr.dataNodeCreator = creator svr.dataNodeCreator = creator
} }
@ -487,7 +484,7 @@ func (s *Server) initCluster() error {
return nil return nil
} }
s.sessionManager = NewSessionManagerImpl(withSessionCreator(s.dataNodeCreator)) s.sessionManager = session.NewDataNodeManagerImpl(session.WithDataNodeCreator(s.dataNodeCreator))
var err error var err error
s.channelManager, err = NewChannelManager(s.watchClient, s.handler, s.sessionManager, s.allocator, withCheckerV2()) s.channelManager, err = NewChannelManager(s.watchClient, s.handler, s.sessionManager, s.allocator, withCheckerV2())
@ -553,19 +550,19 @@ func (s *Server) initServiceDiscovery() error {
} }
log.Info("DataCoord success to get DataNode sessions", zap.Any("sessions", sessions)) log.Info("DataCoord success to get DataNode sessions", zap.Any("sessions", sessions))
datanodes := make([]*NodeInfo, 0, len(sessions)) datanodes := make([]*session.NodeInfo, 0, len(sessions))
legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue()) legacyVersion, err := semver.Parse(paramtable.Get().DataCoordCfg.LegacyVersionWithoutRPCWatch.GetValue())
if err != nil { if err != nil {
log.Warn("DataCoord failed to init service discovery", zap.Error(err)) log.Warn("DataCoord failed to init service discovery", zap.Error(err))
} }
for _, session := range sessions { for _, s := range sessions {
info := &NodeInfo{ info := &session.NodeInfo{
NodeID: session.ServerID, NodeID: s.ServerID,
Address: session.Address, Address: s.Address,
} }
if session.Version.LTE(legacyVersion) { if s.Version.LTE(legacyVersion) {
info.IsLegacy = true info.IsLegacy = true
} }
@ -677,7 +674,7 @@ func (s *Server) initTaskScheduler(manager storage.ChunkManager) {
func (s *Server) initIndexNodeManager() { func (s *Server) initIndexNodeManager() {
if s.indexNodeManager == nil { if s.indexNodeManager == nil {
s.indexNodeManager = NewNodeManager(s.ctx, s.indexNodeCreator) s.indexNodeManager = session.NewNodeManager(s.ctx, s.indexNodeCreator)
} }
} }
@ -858,7 +855,7 @@ func (s *Server) handleSessionEvent(ctx context.Context, role string, event *ses
Version: event.Session.ServerID, Version: event.Session.ServerID,
Channels: []*datapb.ChannelStatus{}, Channels: []*datapb.ChannelStatus{},
} }
node := &NodeInfo{ node := &session.NodeInfo{
NodeID: event.Session.ServerID, NodeID: event.Session.ServerID,
Address: event.Session.Address, Address: event.Session.Address,
} }

View File

@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/broker" "github.com/milvus-io/milvus/internal/datacoord/broker"
"github.com/milvus-io/milvus/internal/datacoord/session"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
@ -57,7 +58,6 @@ import (
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/lock"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
@ -2463,7 +2463,7 @@ func TestOptions(t *testing.T) {
t.Run("WithCluster", func(t *testing.T) { t.Run("WithCluster", func(t *testing.T) {
defer kv.RemoveWithPrefix("") defer kv.RemoveWithPrefix("")
sessionManager := NewSessionManagerImpl() sessionManager := session.NewDataNodeManagerImpl()
channelManager, err := NewChannelManager(kv, newMockHandler(), sessionManager, allocator.NewMockAllocator(t)) channelManager, err := NewChannelManager(kv, newMockHandler(), sessionManager, allocator.NewMockAllocator(t))
assert.NoError(t, err) assert.NoError(t, err)
@ -2505,7 +2505,7 @@ func TestHandleSessionEvent(t *testing.T) {
defer cancel() defer cancel()
alloc := allocator.NewMockAllocator(t) alloc := allocator.NewMockAllocator(t)
sessionManager := NewSessionManagerImpl() sessionManager := session.NewDataNodeManagerImpl()
channelManager, err := NewChannelManager(kv, newMockHandler(), sessionManager, alloc) channelManager, err := NewChannelManager(kv, newMockHandler(), sessionManager, alloc)
assert.NoError(t, err) assert.NoError(t, err)
@ -2549,7 +2549,7 @@ func TestHandleSessionEvent(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
dataNodes := svr.cluster.GetSessions() dataNodes := svr.cluster.GetSessions()
assert.EqualValues(t, 1, len(dataNodes)) assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "DN127.0.0.101", dataNodes[0].info.Address) assert.EqualValues(t, "DN127.0.0.101", dataNodes[0].Address())
evt = &sessionutil.SessionEvent{ evt = &sessionutil.SessionEvent{
EventType: sessionutil.SessionDelEvent, EventType: sessionutil.SessionDelEvent,
@ -3126,7 +3126,7 @@ func closeTestServer(t *testing.T, svr *Server) {
} }
func Test_CheckHealth(t *testing.T) { func Test_CheckHealth(t *testing.T) {
getSessionManager := func(isHealthy bool) *SessionManagerImpl { getSessionManager := func(isHealthy bool) *session.DataNodeManagerImpl {
var client *mockDataNodeClient var client *mockDataNodeClient
if isHealthy { if isHealthy {
client = &mockDataNodeClient{ client = &mockDataNodeClient{
@ -3140,16 +3140,12 @@ func Test_CheckHealth(t *testing.T) {
} }
} }
sm := NewSessionManagerImpl() sm := session.NewDataNodeManagerImpl(session.WithDataNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
sm.sessions = struct {
lock.RWMutex
data map[int64]*Session
}{data: map[int64]*Session{1: {
client: client,
clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
return client, nil return client, nil
}, }))
}}} sm.AddSession(&session.NodeInfo{
NodeID: 1,
})
return sm return sm
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/metastore/mocks" "github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
mocks2 "github.com/milvus-io/milvus/internal/mocks" mocks2 "github.com/milvus-io/milvus/internal/mocks"
@ -43,7 +44,7 @@ type ServerSuite struct {
func WithChannelManager(cm ChannelManager) Option { func WithChannelManager(cm ChannelManager) Option {
return func(svr *Server) { return func(svr *Server) {
svr.sessionManager = NewSessionManagerImpl(withSessionCreator(svr.dataNodeCreator)) svr.sessionManager = session.NewDataNodeManagerImpl(session.WithDataNodeCreator(svr.dataNodeCreator))
svr.channelManager = cm svr.channelManager = cm
svr.cluster = NewClusterImpl(svr.sessionManager, svr.channelManager) svr.cluster = NewClusterImpl(svr.sessionManager, svr.channelManager)
} }

View File

@ -0,0 +1,7 @@
reviewers:
- sunby
- xiaocai2333
- congqixia
approvers:
- maintainers

View File

@ -0,0 +1,3 @@
# Session Package
`session` package contains the worker manager/nodes abstraction for datanodes and indexnodes.

View File

@ -0,0 +1,11 @@
package session
import (
"context"
"github.com/milvus-io/milvus/internal/types"
)
type DataNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error)
type IndexNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error)

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package datacoord package session
import ( import (
"context" "context"
@ -48,8 +48,8 @@ const (
querySlotTimeout = 10 * time.Second querySlotTimeout = 10 * time.Second
) )
//go:generate mockery --name=SessionManager --structname=MockSessionManager --output=./ --filename=mock_session_manager.go --with-expecter --inpackage // DataNodeManager is the interface for datanode session manager.
type SessionManager interface { type DataNodeManager interface {
AddSession(node *NodeInfo) AddSession(node *NodeInfo)
DeleteSession(node *NodeInfo) DeleteSession(node *NodeInfo)
GetSessionIDs() []int64 GetSessionIDs() []int64
@ -75,33 +75,33 @@ type SessionManager interface {
Close() Close()
} }
var _ SessionManager = (*SessionManagerImpl)(nil) var _ DataNodeManager = (*DataNodeManagerImpl)(nil)
// SessionManagerImpl provides the grpc interfaces of cluster // DataNodeManagerImpl provides the grpc interfaces of cluster
type SessionManagerImpl struct { type DataNodeManagerImpl struct {
sessions struct { sessions struct {
lock.RWMutex lock.RWMutex
data map[int64]*Session data map[int64]*Session
} }
sessionCreator dataNodeCreatorFunc sessionCreator DataNodeCreatorFunc
} }
// SessionOpt provides a way to set params in SessionManagerImpl // SessionOpt provides a way to set params in SessionManagerImpl
type SessionOpt func(c *SessionManagerImpl) type SessionOpt func(c *DataNodeManagerImpl)
func withSessionCreator(creator dataNodeCreatorFunc) SessionOpt { func WithDataNodeCreator(creator DataNodeCreatorFunc) SessionOpt {
return func(c *SessionManagerImpl) { c.sessionCreator = creator } return func(c *DataNodeManagerImpl) { c.sessionCreator = creator }
} }
func defaultSessionCreator() dataNodeCreatorFunc { func defaultSessionCreator() DataNodeCreatorFunc {
return func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
return grpcdatanodeclient.NewClient(ctx, addr, nodeID) return grpcdatanodeclient.NewClient(ctx, addr, nodeID)
} }
} }
// NewSessionManagerImpl creates a new SessionManagerImpl // NewDataNodeManagerImpl creates a new NewDataNodeManagerImpl
func NewSessionManagerImpl(options ...SessionOpt) *SessionManagerImpl { func NewDataNodeManagerImpl(options ...SessionOpt) *DataNodeManagerImpl {
m := &SessionManagerImpl{ m := &DataNodeManagerImpl{
sessions: struct { sessions: struct {
lock.RWMutex lock.RWMutex
data map[int64]*Session data map[int64]*Session
@ -115,7 +115,7 @@ func NewSessionManagerImpl(options ...SessionOpt) *SessionManagerImpl {
} }
// AddSession creates a new session // AddSession creates a new session
func (c *SessionManagerImpl) AddSession(node *NodeInfo) { func (c *DataNodeManagerImpl) AddSession(node *NodeInfo) {
c.sessions.Lock() c.sessions.Lock()
defer c.sessions.Unlock() defer c.sessions.Unlock()
@ -125,7 +125,7 @@ func (c *SessionManagerImpl) AddSession(node *NodeInfo) {
} }
// GetSession return a Session related to nodeID // GetSession return a Session related to nodeID
func (c *SessionManagerImpl) GetSession(nodeID int64) (*Session, bool) { func (c *DataNodeManagerImpl) GetSession(nodeID int64) (*Session, bool) {
c.sessions.RLock() c.sessions.RLock()
defer c.sessions.RUnlock() defer c.sessions.RUnlock()
s, ok := c.sessions.data[nodeID] s, ok := c.sessions.data[nodeID]
@ -133,7 +133,7 @@ func (c *SessionManagerImpl) GetSession(nodeID int64) (*Session, bool) {
} }
// DeleteSession removes the node session // DeleteSession removes the node session
func (c *SessionManagerImpl) DeleteSession(node *NodeInfo) { func (c *DataNodeManagerImpl) DeleteSession(node *NodeInfo) {
c.sessions.Lock() c.sessions.Lock()
defer c.sessions.Unlock() defer c.sessions.Unlock()
@ -145,7 +145,7 @@ func (c *SessionManagerImpl) DeleteSession(node *NodeInfo) {
} }
// GetSessionIDs returns IDs of all live DataNodes. // GetSessionIDs returns IDs of all live DataNodes.
func (c *SessionManagerImpl) GetSessionIDs() []int64 { func (c *DataNodeManagerImpl) GetSessionIDs() []int64 {
c.sessions.RLock() c.sessions.RLock()
defer c.sessions.RUnlock() defer c.sessions.RUnlock()
@ -157,7 +157,7 @@ func (c *SessionManagerImpl) GetSessionIDs() []int64 {
} }
// GetSessions gets all node sessions // GetSessions gets all node sessions
func (c *SessionManagerImpl) GetSessions() []*Session { func (c *DataNodeManagerImpl) GetSessions() []*Session {
c.sessions.RLock() c.sessions.RLock()
defer c.sessions.RUnlock() defer c.sessions.RUnlock()
@ -168,7 +168,7 @@ func (c *SessionManagerImpl) GetSessions() []*Session {
return ret return ret
} }
func (c *SessionManagerImpl) getClient(ctx context.Context, nodeID int64) (types.DataNodeClient, error) { func (c *DataNodeManagerImpl) getClient(ctx context.Context, nodeID int64) (types.DataNodeClient, error) {
c.sessions.RLock() c.sessions.RLock()
session, ok := c.sessions.data[nodeID] session, ok := c.sessions.data[nodeID]
c.sessions.RUnlock() c.sessions.RUnlock()
@ -181,11 +181,11 @@ func (c *SessionManagerImpl) getClient(ctx context.Context, nodeID int64) (types
} }
// Flush is a grpc interface. It will send req to nodeID asynchronously // Flush is a grpc interface. It will send req to nodeID asynchronously
func (c *SessionManagerImpl) Flush(ctx context.Context, nodeID int64, req *datapb.FlushSegmentsRequest) { func (c *DataNodeManagerImpl) Flush(ctx context.Context, nodeID int64, req *datapb.FlushSegmentsRequest) {
go c.execFlush(ctx, nodeID, req) go c.execFlush(ctx, nodeID, req)
} }
func (c *SessionManagerImpl) execFlush(ctx context.Context, nodeID int64, req *datapb.FlushSegmentsRequest) { func (c *DataNodeManagerImpl) execFlush(ctx context.Context, nodeID int64, req *datapb.FlushSegmentsRequest) {
log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID), zap.String("channel", req.GetChannelName())) log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID), zap.String("channel", req.GetChannelName()))
cli, err := c.getClient(ctx, nodeID) cli, err := c.getClient(ctx, nodeID)
if err != nil { if err != nil {
@ -196,7 +196,7 @@ func (c *SessionManagerImpl) execFlush(ctx context.Context, nodeID int64, req *d
defer cancel() defer cancel()
resp, err := cli.FlushSegments(ctx, req) resp, err := cli.FlushSegments(ctx, req)
if err := VerifyResponse(resp, err); err != nil { if err := merr.CheckRPCCall(resp, err); err != nil {
log.Error("flush call (perhaps partially) failed", zap.Error(err)) log.Error("flush call (perhaps partially) failed", zap.Error(err))
} else { } else {
log.Info("flush call succeeded") log.Info("flush call succeeded")
@ -204,8 +204,8 @@ func (c *SessionManagerImpl) execFlush(ctx context.Context, nodeID int64, req *d
} }
// Compaction is a grpc interface. It will send request to DataNode with provided `nodeID` synchronously. // Compaction is a grpc interface. It will send request to DataNode with provided `nodeID` synchronously.
func (c *SessionManagerImpl) Compaction(ctx context.Context, nodeID int64, plan *datapb.CompactionPlan) error { func (c *DataNodeManagerImpl) Compaction(ctx context.Context, nodeID int64, plan *datapb.CompactionPlan) error {
ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(ctx, paramtable.Get().DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second))
defer cancel() defer cancel()
cli, err := c.getClient(ctx, nodeID) cli, err := c.getClient(ctx, nodeID)
if err != nil { if err != nil {
@ -214,7 +214,7 @@ func (c *SessionManagerImpl) Compaction(ctx context.Context, nodeID int64, plan
} }
resp, err := cli.CompactionV2(ctx, plan) resp, err := cli.CompactionV2(ctx, plan)
if err := VerifyResponse(resp, err); err != nil { if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to execute compaction", zap.Int64("node", nodeID), zap.Error(err), zap.Int64("planID", plan.GetPlanID())) log.Warn("failed to execute compaction", zap.Int64("node", nodeID), zap.Error(err), zap.Int64("planID", plan.GetPlanID()))
return err return err
} }
@ -224,12 +224,12 @@ func (c *SessionManagerImpl) Compaction(ctx context.Context, nodeID int64, plan
} }
// SyncSegments is a grpc interface. It will send request to DataNode with provided `nodeID` synchronously. // SyncSegments is a grpc interface. It will send request to DataNode with provided `nodeID` synchronously.
func (c *SessionManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegmentsRequest) error { func (c *DataNodeManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegmentsRequest) error {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Int64("planID", req.GetPlanID()), zap.Int64("planID", req.GetPlanID()),
) )
ctx, cancel := context.WithTimeout(context.Background(), Params.DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(context.Background(), paramtable.Get().DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second))
cli, err := c.getClient(ctx, nodeID) cli, err := c.getClient(ctx, nodeID)
cancel() cancel()
if err != nil { if err != nil {
@ -240,7 +240,7 @@ func (c *SessionManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegments
err = retry.Do(context.Background(), func() error { err = retry.Do(context.Background(), func() error {
// doesn't set timeout // doesn't set timeout
resp, err := cli.SyncSegments(context.Background(), req) resp, err := cli.SyncSegments(context.Background(), req)
if err := VerifyResponse(resp, err); err != nil { if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to sync segments", zap.Error(err)) log.Warn("failed to sync segments", zap.Error(err))
return err return err
} }
@ -256,7 +256,7 @@ func (c *SessionManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegments
} }
// GetCompactionPlansResults returns map[planID]*pair[nodeID, *CompactionPlanResults] // GetCompactionPlansResults returns map[planID]*pair[nodeID, *CompactionPlanResults]
func (c *SessionManagerImpl) GetCompactionPlansResults() (map[int64]*typeutil.Pair[int64, *datapb.CompactionPlanResult], error) { func (c *DataNodeManagerImpl) GetCompactionPlansResults() (map[int64]*typeutil.Pair[int64, *datapb.CompactionPlanResult], error) {
ctx := context.Background() ctx := context.Background()
errorGroup, ctx := errgroup.WithContext(ctx) errorGroup, ctx := errgroup.WithContext(ctx)
@ -270,7 +270,7 @@ func (c *SessionManagerImpl) GetCompactionPlansResults() (map[int64]*typeutil.Pa
log.Info("Cannot Create Client", zap.Int64("NodeID", nodeID)) log.Info("Cannot Create Client", zap.Int64("NodeID", nodeID))
return err return err
} }
ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(ctx, paramtable.Get().DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second))
defer cancel() defer cancel()
resp, err := cli.GetCompactionState(ctx, &datapb.CompactionStateRequest{ resp, err := cli.GetCompactionState(ctx, &datapb.CompactionStateRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
@ -308,7 +308,7 @@ func (c *SessionManagerImpl) GetCompactionPlansResults() (map[int64]*typeutil.Pa
return rst, nil return rst, nil
} }
func (c *SessionManagerImpl) GetCompactionPlanResult(nodeID int64, planID int64) (*datapb.CompactionPlanResult, error) { func (c *DataNodeManagerImpl) GetCompactionPlanResult(nodeID int64, planID int64) (*datapb.CompactionPlanResult, error) {
ctx := context.Background() ctx := context.Background()
c.sessions.RLock() c.sessions.RLock()
s, ok := c.sessions.data[nodeID] s, ok := c.sessions.data[nodeID]
@ -322,7 +322,7 @@ func (c *SessionManagerImpl) GetCompactionPlanResult(nodeID int64, planID int64)
log.Info("Cannot Create Client", zap.Int64("NodeID", nodeID)) log.Info("Cannot Create Client", zap.Int64("NodeID", nodeID))
return nil, err return nil, err
} }
ctx, cancel := context.WithTimeout(context.Background(), Params.DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(context.Background(), paramtable.Get().DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second))
defer cancel() defer cancel()
resp, err2 := cli.GetCompactionState(ctx, &datapb.CompactionStateRequest{ resp, err2 := cli.GetCompactionState(ctx, &datapb.CompactionStateRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
@ -352,7 +352,7 @@ func (c *SessionManagerImpl) GetCompactionPlanResult(nodeID int64, planID int64)
return result, nil return result, nil
} }
func (c *SessionManagerImpl) FlushChannels(ctx context.Context, nodeID int64, req *datapb.FlushChannelsRequest) error { func (c *DataNodeManagerImpl) FlushChannels(ctx context.Context, nodeID int64, req *datapb.FlushChannelsRequest) error {
log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID), log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID),
zap.Time("flushTs", tsoutil.PhysicalTime(req.GetFlushTs())), zap.Time("flushTs", tsoutil.PhysicalTime(req.GetFlushTs())),
zap.Strings("channels", req.GetChannels())) zap.Strings("channels", req.GetChannels()))
@ -364,7 +364,7 @@ func (c *SessionManagerImpl) FlushChannels(ctx context.Context, nodeID int64, re
log.Info("SessionManagerImpl.FlushChannels start") log.Info("SessionManagerImpl.FlushChannels start")
resp, err := cli.FlushChannels(ctx, req) resp, err := cli.FlushChannels(ctx, req)
err = VerifyResponse(resp, err) err = merr.CheckRPCCall(resp, err)
if err != nil { if err != nil {
log.Warn("SessionManagerImpl.FlushChannels failed", zap.Error(err)) log.Warn("SessionManagerImpl.FlushChannels failed", zap.Error(err))
return err return err
@ -373,14 +373,14 @@ func (c *SessionManagerImpl) FlushChannels(ctx context.Context, nodeID int64, re
return nil return nil
} }
func (c *SessionManagerImpl) NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error { func (c *DataNodeManagerImpl) NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error {
log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID)) log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID))
cli, err := c.getClient(ctx, nodeID) cli, err := c.getClient(ctx, nodeID)
if err != nil { if err != nil {
log.Info("failed to get dataNode client", zap.Error(err)) log.Info("failed to get dataNode client", zap.Error(err))
return err return err
} }
ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(ctx, paramtable.Get().DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second))
defer cancel() defer cancel()
resp, err := cli.NotifyChannelOperation(ctx, req) resp, err := cli.NotifyChannelOperation(ctx, req)
if err := merr.CheckRPCCall(resp, err); err != nil { if err := merr.CheckRPCCall(resp, err); err != nil {
@ -390,7 +390,7 @@ func (c *SessionManagerImpl) NotifyChannelOperation(ctx context.Context, nodeID
return nil return nil
} }
func (c *SessionManagerImpl) CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) { func (c *DataNodeManagerImpl) CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.String("channel", info.GetVchan().GetChannelName()), zap.String("channel", info.GetVchan().GetChannelName()),
@ -402,7 +402,7 @@ func (c *SessionManagerImpl) CheckChannelOperationProgress(ctx context.Context,
return nil, err return nil, err
} }
ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(ctx, paramtable.Get().DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second))
defer cancel() defer cancel()
resp, err := cli.CheckChannelOperationProgress(ctx, info) resp, err := cli.CheckChannelOperationProgress(ctx, info)
if err := merr.CheckRPCCall(resp, err); err != nil { if err := merr.CheckRPCCall(resp, err); err != nil {
@ -413,7 +413,7 @@ func (c *SessionManagerImpl) CheckChannelOperationProgress(ctx context.Context,
return resp, nil return resp, nil
} }
func (c *SessionManagerImpl) PreImport(nodeID int64, in *datapb.PreImportRequest) error { func (c *DataNodeManagerImpl) PreImport(nodeID int64, in *datapb.PreImportRequest) error {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Int64("jobID", in.GetJobID()), zap.Int64("jobID", in.GetJobID()),
@ -429,10 +429,10 @@ func (c *SessionManagerImpl) PreImport(nodeID int64, in *datapb.PreImportRequest
return err return err
} }
status, err := cli.PreImport(ctx, in) status, err := cli.PreImport(ctx, in)
return VerifyResponse(status, err) return merr.CheckRPCCall(status, err)
} }
func (c *SessionManagerImpl) ImportV2(nodeID int64, in *datapb.ImportRequest) error { func (c *DataNodeManagerImpl) ImportV2(nodeID int64, in *datapb.ImportRequest) error {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Int64("jobID", in.GetJobID()), zap.Int64("jobID", in.GetJobID()),
@ -447,10 +447,10 @@ func (c *SessionManagerImpl) ImportV2(nodeID int64, in *datapb.ImportRequest) er
return err return err
} }
status, err := cli.ImportV2(ctx, in) status, err := cli.ImportV2(ctx, in)
return VerifyResponse(status, err) return merr.CheckRPCCall(status, err)
} }
func (c *SessionManagerImpl) QueryPreImport(nodeID int64, in *datapb.QueryPreImportRequest) (*datapb.QueryPreImportResponse, error) { func (c *DataNodeManagerImpl) QueryPreImport(nodeID int64, in *datapb.QueryPreImportRequest) (*datapb.QueryPreImportResponse, error) {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Int64("jobID", in.GetJobID()), zap.Int64("jobID", in.GetJobID()),
@ -464,13 +464,13 @@ func (c *SessionManagerImpl) QueryPreImport(nodeID int64, in *datapb.QueryPreImp
return nil, err return nil, err
} }
resp, err := cli.QueryPreImport(ctx, in) resp, err := cli.QueryPreImport(ctx, in)
if err = VerifyResponse(resp.GetStatus(), err); err != nil { if err = merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
return nil, err return nil, err
} }
return resp, nil return resp, nil
} }
func (c *SessionManagerImpl) QueryImport(nodeID int64, in *datapb.QueryImportRequest) (*datapb.QueryImportResponse, error) { func (c *DataNodeManagerImpl) QueryImport(nodeID int64, in *datapb.QueryImportRequest) (*datapb.QueryImportResponse, error) {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Int64("jobID", in.GetJobID()), zap.Int64("jobID", in.GetJobID()),
@ -484,13 +484,13 @@ func (c *SessionManagerImpl) QueryImport(nodeID int64, in *datapb.QueryImportReq
return nil, err return nil, err
} }
resp, err := cli.QueryImport(ctx, in) resp, err := cli.QueryImport(ctx, in)
if err = VerifyResponse(resp.GetStatus(), err); err != nil { if err = merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
return nil, err return nil, err
} }
return resp, nil return resp, nil
} }
func (c *SessionManagerImpl) DropImport(nodeID int64, in *datapb.DropImportRequest) error { func (c *DataNodeManagerImpl) DropImport(nodeID int64, in *datapb.DropImportRequest) error {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Int64("jobID", in.GetJobID()), zap.Int64("jobID", in.GetJobID()),
@ -504,10 +504,10 @@ func (c *SessionManagerImpl) DropImport(nodeID int64, in *datapb.DropImportReque
return err return err
} }
status, err := cli.DropImport(ctx, in) status, err := cli.DropImport(ctx, in)
return VerifyResponse(status, err) return merr.CheckRPCCall(status, err)
} }
func (c *SessionManagerImpl) CheckHealth(ctx context.Context) error { func (c *DataNodeManagerImpl) CheckHealth(ctx context.Context) error {
group, ctx := errgroup.WithContext(ctx) group, ctx := errgroup.WithContext(ctx)
ids := c.GetSessionIDs() ids := c.GetSessionIDs()
@ -531,7 +531,7 @@ func (c *SessionManagerImpl) CheckHealth(ctx context.Context) error {
return group.Wait() return group.Wait()
} }
func (c *SessionManagerImpl) QuerySlot(nodeID int64) (*datapb.QuerySlotResponse, error) { func (c *DataNodeManagerImpl) QuerySlot(nodeID int64) (*datapb.QuerySlotResponse, error) {
log := log.With(zap.Int64("nodeID", nodeID)) log := log.With(zap.Int64("nodeID", nodeID))
ctx, cancel := context.WithTimeout(context.Background(), querySlotTimeout) ctx, cancel := context.WithTimeout(context.Background(), querySlotTimeout)
defer cancel() defer cancel()
@ -541,18 +541,18 @@ func (c *SessionManagerImpl) QuerySlot(nodeID int64) (*datapb.QuerySlotResponse,
return nil, err return nil, err
} }
resp, err := cli.QuerySlot(ctx, &datapb.QuerySlotRequest{}) resp, err := cli.QuerySlot(ctx, &datapb.QuerySlotRequest{})
if err = VerifyResponse(resp.GetStatus(), err); err != nil { if err = merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
return nil, err return nil, err
} }
return resp, nil return resp, nil
} }
func (c *SessionManagerImpl) DropCompactionPlan(nodeID int64, req *datapb.DropCompactionPlanRequest) error { func (c *DataNodeManagerImpl) DropCompactionPlan(nodeID int64, req *datapb.DropCompactionPlanRequest) error {
log := log.With( log := log.With(
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.Int64("planID", req.GetPlanID()), zap.Int64("planID", req.GetPlanID()),
) )
ctx, cancel := context.WithTimeout(context.Background(), Params.DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(context.Background(), paramtable.Get().DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second))
defer cancel() defer cancel()
cli, err := c.getClient(ctx, nodeID) cli, err := c.getClient(ctx, nodeID)
if err != nil { if err != nil {
@ -565,11 +565,11 @@ func (c *SessionManagerImpl) DropCompactionPlan(nodeID int64, req *datapb.DropCo
} }
err = retry.Do(context.Background(), func() error { err = retry.Do(context.Background(), func() error {
ctx, cancel := context.WithTimeout(context.Background(), Params.DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second)) ctx, cancel := context.WithTimeout(context.Background(), paramtable.Get().DataCoordCfg.CompactionRPCTimeout.GetAsDuration(time.Second))
defer cancel() defer cancel()
resp, err := cli.DropCompactionPlan(ctx, req) resp, err := cli.DropCompactionPlan(ctx, req)
if err := VerifyResponse(resp, err); err != nil { if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to drop compaction plan", zap.Error(err)) log.Warn("failed to drop compaction plan", zap.Error(err))
return err return err
} }
@ -585,7 +585,7 @@ func (c *SessionManagerImpl) DropCompactionPlan(nodeID int64, req *datapb.DropCo
} }
// Close release sessions // Close release sessions
func (c *SessionManagerImpl) Close() { func (c *DataNodeManagerImpl) Close() {
c.sessions.Lock() c.sessions.Lock()
defer c.sessions.Unlock() defer c.sessions.Unlock()

View File

@ -1,4 +1,20 @@
package datacoord // Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import ( import (
"context" "context"
@ -13,25 +29,30 @@ import (
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/testutils" "github.com/milvus-io/milvus/pkg/util/testutils"
) )
func TestSessionManagerSuite(t *testing.T) { func TestDataNodeManagerSuite(t *testing.T) {
suite.Run(t, new(SessionManagerSuite)) suite.Run(t, new(DataNodeManagerSuite))
} }
type SessionManagerSuite struct { type DataNodeManagerSuite struct {
testutils.PromMetricsSuite testutils.PromMetricsSuite
dn *mocks.MockDataNodeClient dn *mocks.MockDataNodeClient
m *SessionManagerImpl m *DataNodeManagerImpl
} }
func (s *SessionManagerSuite) SetupTest() { func (s *DataNodeManagerSuite) SetupSuite() {
paramtable.Init()
}
func (s *DataNodeManagerSuite) SetupTest() {
s.dn = mocks.NewMockDataNodeClient(s.T()) s.dn = mocks.NewMockDataNodeClient(s.T())
s.m = NewSessionManagerImpl(withSessionCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { s.m = NewDataNodeManagerImpl(WithDataNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
return s.dn, nil return s.dn, nil
})) }))
@ -39,11 +60,11 @@ func (s *SessionManagerSuite) SetupTest() {
s.MetricsEqual(metrics.DataCoordNumDataNodes, 1) s.MetricsEqual(metrics.DataCoordNumDataNodes, 1)
} }
func (s *SessionManagerSuite) SetupSubTest() { func (s *DataNodeManagerSuite) SetupSubTest() {
s.SetupTest() s.SetupTest()
} }
func (s *SessionManagerSuite) TestExecFlush() { func (s *DataNodeManagerSuite) TestExecFlush() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -68,7 +89,7 @@ func (s *SessionManagerSuite) TestExecFlush() {
}) })
} }
func (s *SessionManagerSuite) TestNotifyChannelOperation() { func (s *DataNodeManagerSuite) TestNotifyChannelOperation() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -101,7 +122,7 @@ func (s *SessionManagerSuite) TestNotifyChannelOperation() {
}) })
} }
func (s *SessionManagerSuite) TestCheckCHannelOperationProgress() { func (s *DataNodeManagerSuite) TestCheckCHannelOperationProgress() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -142,7 +163,7 @@ func (s *SessionManagerSuite) TestCheckCHannelOperationProgress() {
}) })
} }
func (s *SessionManagerSuite) TestImportV2() { func (s *DataNodeManagerSuite) TestImportV2() {
mockErr := errors.New("mock error") mockErr := errors.New("mock error")
s.Run("PreImport", func() { s.Run("PreImport", func() {

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package datacoord package session
import ( import (
"context" "context"
@ -24,46 +24,53 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/lock" "github.com/milvus-io/milvus/pkg/util/lock"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) {
return indexnodeclient.NewClient(ctx, addr, nodeID, paramtable.Get().DataCoordCfg.WithCredential.GetAsBool())
}
type WorkerManager interface { type WorkerManager interface {
AddNode(nodeID UniqueID, address string) error AddNode(nodeID typeutil.UniqueID, address string) error
RemoveNode(nodeID UniqueID) RemoveNode(nodeID typeutil.UniqueID)
StoppingNode(nodeID UniqueID) StoppingNode(nodeID typeutil.UniqueID)
PickClient() (UniqueID, types.IndexNodeClient) PickClient() (typeutil.UniqueID, types.IndexNodeClient)
ClientSupportDisk() bool ClientSupportDisk() bool
GetAllClients() map[UniqueID]types.IndexNodeClient GetAllClients() map[typeutil.UniqueID]types.IndexNodeClient
GetClientByID(nodeID UniqueID) (types.IndexNodeClient, bool) GetClientByID(nodeID typeutil.UniqueID) (types.IndexNodeClient, bool)
} }
// IndexNodeManager is used to manage the client of IndexNode. // IndexNodeManager is used to manage the client of IndexNode.
type IndexNodeManager struct { type IndexNodeManager struct {
nodeClients map[UniqueID]types.IndexNodeClient nodeClients map[typeutil.UniqueID]types.IndexNodeClient
stoppingNodes map[UniqueID]struct{} stoppingNodes map[typeutil.UniqueID]struct{}
lock lock.RWMutex lock lock.RWMutex
ctx context.Context ctx context.Context
indexNodeCreator indexNodeCreatorFunc indexNodeCreator IndexNodeCreatorFunc
} }
// NewNodeManager is used to create a new IndexNodeManager. // NewNodeManager is used to create a new IndexNodeManager.
func NewNodeManager(ctx context.Context, indexNodeCreator indexNodeCreatorFunc) *IndexNodeManager { func NewNodeManager(ctx context.Context, indexNodeCreator IndexNodeCreatorFunc) *IndexNodeManager {
return &IndexNodeManager{ return &IndexNodeManager{
nodeClients: make(map[UniqueID]types.IndexNodeClient), nodeClients: make(map[typeutil.UniqueID]types.IndexNodeClient),
stoppingNodes: make(map[UniqueID]struct{}), stoppingNodes: make(map[typeutil.UniqueID]struct{}),
lock: lock.RWMutex{}, lock: lock.RWMutex{},
ctx: ctx, ctx: ctx,
indexNodeCreator: indexNodeCreator, indexNodeCreator: indexNodeCreator,
} }
} }
// setClient sets IndexNode client to node manager. // SetClient sets IndexNode client to node manager.
func (nm *IndexNodeManager) setClient(nodeID UniqueID, client types.IndexNodeClient) { func (nm *IndexNodeManager) SetClient(nodeID typeutil.UniqueID, client types.IndexNodeClient) {
log.Debug("set IndexNode client", zap.Int64("nodeID", nodeID)) log.Debug("set IndexNode client", zap.Int64("nodeID", nodeID))
nm.lock.Lock() nm.lock.Lock()
defer nm.lock.Unlock() defer nm.lock.Unlock()
@ -73,7 +80,7 @@ func (nm *IndexNodeManager) setClient(nodeID UniqueID, client types.IndexNodeCli
} }
// RemoveNode removes the unused client of IndexNode. // RemoveNode removes the unused client of IndexNode.
func (nm *IndexNodeManager) RemoveNode(nodeID UniqueID) { func (nm *IndexNodeManager) RemoveNode(nodeID typeutil.UniqueID) {
log.Debug("remove IndexNode", zap.Int64("nodeID", nodeID)) log.Debug("remove IndexNode", zap.Int64("nodeID", nodeID))
nm.lock.Lock() nm.lock.Lock()
defer nm.lock.Unlock() defer nm.lock.Unlock()
@ -82,7 +89,7 @@ func (nm *IndexNodeManager) RemoveNode(nodeID UniqueID) {
metrics.IndexNodeNum.WithLabelValues().Set(float64(len(nm.nodeClients))) metrics.IndexNodeNum.WithLabelValues().Set(float64(len(nm.nodeClients)))
} }
func (nm *IndexNodeManager) StoppingNode(nodeID UniqueID) { func (nm *IndexNodeManager) StoppingNode(nodeID typeutil.UniqueID) {
log.Debug("IndexCoord", zap.Int64("Stopping node with ID", nodeID)) log.Debug("IndexCoord", zap.Int64("Stopping node with ID", nodeID))
nm.lock.Lock() nm.lock.Lock()
defer nm.lock.Unlock() defer nm.lock.Unlock()
@ -90,7 +97,7 @@ func (nm *IndexNodeManager) StoppingNode(nodeID UniqueID) {
} }
// AddNode adds the client of IndexNode. // AddNode adds the client of IndexNode.
func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error { func (nm *IndexNodeManager) AddNode(nodeID typeutil.UniqueID, address string) error {
log.Debug("add IndexNode", zap.Int64("nodeID", nodeID), zap.String("node address", address)) log.Debug("add IndexNode", zap.Int64("nodeID", nodeID), zap.String("node address", address))
var ( var (
nodeClient types.IndexNodeClient nodeClient types.IndexNodeClient
@ -103,18 +110,18 @@ func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error {
return err return err
} }
nm.setClient(nodeID, nodeClient) nm.SetClient(nodeID, nodeClient)
return nil return nil
} }
func (nm *IndexNodeManager) PickClient() (UniqueID, types.IndexNodeClient) { func (nm *IndexNodeManager) PickClient() (typeutil.UniqueID, types.IndexNodeClient) {
nm.lock.Lock() nm.lock.Lock()
defer nm.lock.Unlock() defer nm.lock.Unlock()
// Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected // Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected
ctx, cancel := context.WithCancel(nm.ctx) ctx, cancel := context.WithCancel(nm.ctx)
var ( var (
pickNodeID = UniqueID(0) pickNodeID = typeutil.UniqueID(0)
nodeMutex = sync.Mutex{} nodeMutex = sync.Mutex{}
wg = sync.WaitGroup{} wg = sync.WaitGroup{}
) )
@ -209,11 +216,11 @@ func (nm *IndexNodeManager) ClientSupportDisk() bool {
return false return false
} }
func (nm *IndexNodeManager) GetAllClients() map[UniqueID]types.IndexNodeClient { func (nm *IndexNodeManager) GetAllClients() map[typeutil.UniqueID]types.IndexNodeClient {
nm.lock.RLock() nm.lock.RLock()
defer nm.lock.RUnlock() defer nm.lock.RUnlock()
allClients := make(map[UniqueID]types.IndexNodeClient, len(nm.nodeClients)) allClients := make(map[typeutil.UniqueID]types.IndexNodeClient, len(nm.nodeClients))
for nodeID, client := range nm.nodeClients { for nodeID, client := range nm.nodeClients {
if _, ok := nm.stoppingNodes[nodeID]; !ok { if _, ok := nm.stoppingNodes[nodeID]; !ok {
allClients[nodeID] = client allClients[nodeID] = client
@ -223,7 +230,7 @@ func (nm *IndexNodeManager) GetAllClients() map[UniqueID]types.IndexNodeClient {
return allClients return allClients
} }
func (nm *IndexNodeManager) GetClientByID(nodeID UniqueID) (types.IndexNodeClient, bool) { func (nm *IndexNodeManager) GetClientByID(nodeID typeutil.UniqueID) (types.IndexNodeClient, bool) {
nm.lock.RLock() nm.lock.RLock()
defer nm.lock.RUnlock() defer nm.lock.RUnlock()

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package datacoord package session
import ( import (
"context" "context"
@ -29,9 +29,12 @@ import (
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/lock" "github.com/milvus-io/milvus/pkg/util/lock"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
func TestIndexNodeManager_AddNode(t *testing.T) { func TestIndexNodeManager_AddNode(t *testing.T) {
paramtable.Init()
nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc) nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc)
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
@ -46,6 +49,7 @@ func TestIndexNodeManager_AddNode(t *testing.T) {
} }
func TestIndexNodeManager_PickClient(t *testing.T) { func TestIndexNodeManager_PickClient(t *testing.T) {
paramtable.Init()
getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient { getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient {
ic := mocks.NewMockIndexNodeClient(t) ic := mocks.NewMockIndexNodeClient(t)
ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err) ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err)
@ -57,7 +61,7 @@ func TestIndexNodeManager_PickClient(t *testing.T) {
t.Run("multiple unavailable IndexNode", func(t *testing.T) { t.Run("multiple unavailable IndexNode", func(t *testing.T) {
nm := &IndexNodeManager{ nm := &IndexNodeManager{
ctx: context.TODO(), ctx: context.TODO(),
nodeClients: map[UniqueID]types.IndexNodeClient{ nodeClients: map[typeutil.UniqueID]types.IndexNodeClient{
1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{
Status: merr.Status(err), Status: merr.Status(err),
}, err), }, err),
@ -92,11 +96,12 @@ func TestIndexNodeManager_PickClient(t *testing.T) {
selectNodeID, client := nm.PickClient() selectNodeID, client := nm.PickClient()
assert.NotNil(t, client) assert.NotNil(t, client)
assert.Contains(t, []UniqueID{8, 9}, selectNodeID) assert.Contains(t, []typeutil.UniqueID{8, 9}, selectNodeID)
}) })
} }
func TestIndexNodeManager_ClientSupportDisk(t *testing.T) { func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
paramtable.Init()
getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient { getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient {
ic := mocks.NewMockIndexNodeClient(t) ic := mocks.NewMockIndexNodeClient(t)
ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err) ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err)
@ -109,7 +114,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
nm := &IndexNodeManager{ nm := &IndexNodeManager{
ctx: context.Background(), ctx: context.Background(),
lock: lock.RWMutex{}, lock: lock.RWMutex{},
nodeClients: map[UniqueID]types.IndexNodeClient{ nodeClients: map[typeutil.UniqueID]types.IndexNodeClient{
1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{
Status: merr.Success(), Status: merr.Success(),
TaskSlots: 1, TaskSlots: 1,
@ -127,7 +132,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
nm := &IndexNodeManager{ nm := &IndexNodeManager{
ctx: context.Background(), ctx: context.Background(),
lock: lock.RWMutex{}, lock: lock.RWMutex{},
nodeClients: map[UniqueID]types.IndexNodeClient{ nodeClients: map[typeutil.UniqueID]types.IndexNodeClient{
1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{
Status: merr.Success(), Status: merr.Success(),
TaskSlots: 1, TaskSlots: 1,
@ -145,7 +150,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
nm := &IndexNodeManager{ nm := &IndexNodeManager{
ctx: context.Background(), ctx: context.Background(),
lock: lock.RWMutex{}, lock: lock.RWMutex{},
nodeClients: map[UniqueID]types.IndexNodeClient{}, nodeClients: map[typeutil.UniqueID]types.IndexNodeClient{},
} }
support := nm.ClientSupportDisk() support := nm.ClientSupportDisk()
@ -156,7 +161,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
nm := &IndexNodeManager{ nm := &IndexNodeManager{
ctx: context.Background(), ctx: context.Background(),
lock: lock.RWMutex{}, lock: lock.RWMutex{},
nodeClients: map[UniqueID]types.IndexNodeClient{ nodeClients: map[typeutil.UniqueID]types.IndexNodeClient{
1: getMockedGetJobStatsClient(nil, err), 1: getMockedGetJobStatsClient(nil, err),
}, },
} }
@ -169,7 +174,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
nm := &IndexNodeManager{ nm := &IndexNodeManager{
ctx: context.Background(), ctx: context.Background(),
lock: lock.RWMutex{}, lock: lock.RWMutex{},
nodeClients: map[UniqueID]types.IndexNodeClient{ nodeClients: map[typeutil.UniqueID]types.IndexNodeClient{
1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{
Status: merr.Status(err), Status: merr.Status(err),
TaskSlots: 0, TaskSlots: 0,
@ -185,6 +190,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) {
} }
func TestNodeManager_StoppingNode(t *testing.T) { func TestNodeManager_StoppingNode(t *testing.T) {
paramtable.Init()
nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc) nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc)
err := nm.AddNode(1, "indexnode-1") err := nm.AddNode(1, "indexnode-1")
assert.NoError(t, err) assert.NoError(t, err)

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
// Code generated by mockery v2.32.4. DO NOT EDIT. // Code generated by mockery v2.32.4. DO NOT EDIT.
package datacoord package session
import ( import (
types "github.com/milvus-io/milvus/internal/types" types "github.com/milvus-io/milvus/internal/types"

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package datacoord package session
import ( import (
"context" "context"
@ -40,18 +40,36 @@ type Session struct {
lock.Mutex lock.Mutex
info *NodeInfo info *NodeInfo
client types.DataNodeClient client types.DataNodeClient
clientCreator dataNodeCreatorFunc clientCreator DataNodeCreatorFunc
isDisposed bool isDisposed bool
} }
// NewSession creates a new session // NewSession creates a new session
func NewSession(info *NodeInfo, creator dataNodeCreatorFunc) *Session { func NewSession(info *NodeInfo, creator DataNodeCreatorFunc) *Session {
return &Session{ return &Session{
info: info, info: info,
clientCreator: creator, clientCreator: creator,
} }
} }
// NodeID returns node id for session.
// If internal info is nil, return -1 instead.
func (n *Session) NodeID() int64 {
if n.info == nil {
return -1
}
return n.info.NodeID
}
// Address returns address of session internal node info.
// If internal info is nil, return empty string instead.
func (n *Session) Address() string {
if n.info == nil {
return ""
}
return n.info.Address
}
// GetOrCreateClient gets or creates a new client for session // GetOrCreateClient gets or creates a new client for session
func (n *Session) GetOrCreateClient(ctx context.Context) (types.DataNodeClient, error) { func (n *Session) GetOrCreateClient(ctx context.Context) (types.DataNodeClient, error) {
n.Lock() n.Lock()

View File

@ -23,6 +23,7 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/logutil"
@ -35,10 +36,10 @@ type SyncSegmentsScheduler struct {
meta *meta meta *meta
channelManager ChannelManager channelManager ChannelManager
sessions SessionManager sessions session.DataNodeManager
} }
func newSyncSegmentsScheduler(m *meta, channelManager ChannelManager, sessions SessionManager) *SyncSegmentsScheduler { func newSyncSegmentsScheduler(m *meta, channelManager ChannelManager, sessions session.DataNodeManager) *SyncSegmentsScheduler {
return &SyncSegmentsScheduler{ return &SyncSegmentsScheduler{
quit: make(chan struct{}), quit: make(chan struct{}),
wg: sync.WaitGroup{}, wg: sync.WaitGroup{},

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/util/lock" "github.com/milvus-io/milvus/pkg/util/lock"
) )
@ -321,7 +322,7 @@ func (s *SyncSegmentsSchedulerSuite) Test_newSyncSegmentsScheduler() {
cm := NewMockChannelManager(s.T()) cm := NewMockChannelManager(s.T())
cm.EXPECT().FindWatcher(mock.Anything).Return(100, nil) cm.EXPECT().FindWatcher(mock.Anything).Return(100, nil)
sm := NewMockSessionManager(s.T()) sm := session.NewMockDataNodeManager(s.T())
sm.EXPECT().SyncSegments(mock.Anything, mock.Anything).RunAndReturn(func(i int64, request *datapb.SyncSegmentsRequest) error { sm.EXPECT().SyncSegments(mock.Anything, mock.Anything).RunAndReturn(func(i int64, request *datapb.SyncSegmentsRequest) error {
for _, seg := range request.GetSegmentInfos() { for _, seg := range request.GetSegmentInfos() {
if seg.GetState() == commonpb.SegmentState_Flushed { if seg.GetState() == commonpb.SegmentState_Flushed {
@ -348,7 +349,7 @@ func (s *SyncSegmentsSchedulerSuite) Test_newSyncSegmentsScheduler() {
func (s *SyncSegmentsSchedulerSuite) Test_SyncSegmentsFail() { func (s *SyncSegmentsSchedulerSuite) Test_SyncSegmentsFail() {
cm := NewMockChannelManager(s.T()) cm := NewMockChannelManager(s.T())
sm := NewMockSessionManager(s.T()) sm := session.NewMockDataNodeManager(s.T())
sss := newSyncSegmentsScheduler(s.m, cm, sm) sss := newSyncSegmentsScheduler(s.m, cm, sm)

View File

@ -24,6 +24,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -53,7 +54,7 @@ type taskScheduler struct {
meta *meta meta *meta
policy buildIndexPolicy policy buildIndexPolicy
nodeManager WorkerManager nodeManager session.WorkerManager
chunkManager storage.ChunkManager chunkManager storage.ChunkManager
indexEngineVersionManager IndexEngineVersionManager indexEngineVersionManager IndexEngineVersionManager
handler Handler handler Handler
@ -61,7 +62,7 @@ type taskScheduler struct {
func newTaskScheduler( func newTaskScheduler(
ctx context.Context, ctx context.Context,
metaTable *meta, nodeManager WorkerManager, metaTable *meta, nodeManager session.WorkerManager,
chunkManager storage.ChunkManager, chunkManager storage.ChunkManager,
indexEngineVersionManager IndexEngineVersionManager, indexEngineVersionManager IndexEngineVersionManager,
handler Handler, handler Handler,

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/metastore"
catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks" catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
@ -803,7 +804,7 @@ func (s *taskSchedulerSuite) scheduler(handler Handler) {
}) })
in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil) in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil)
workerManager := NewMockWorkerManager(s.T()) workerManager := session.NewMockWorkerManager(s.T())
workerManager.EXPECT().PickClient().Return(s.nodeID, in) workerManager.EXPECT().PickClient().Return(s.nodeID, in)
workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true) workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true)
@ -931,7 +932,7 @@ func (s *taskSchedulerSuite) Test_analyzeTaskFailCase() {
ctx := context.Background() ctx := context.Background()
catalog := catalogmocks.NewDataCoordCatalog(s.T()) catalog := catalogmocks.NewDataCoordCatalog(s.T())
workerManager := NewMockWorkerManager(s.T()) workerManager := session.NewMockWorkerManager(s.T())
mt := createMeta(catalog, mt := createMeta(catalog,
&analyzeMeta{ &analyzeMeta{
@ -988,7 +989,7 @@ func (s *taskSchedulerSuite) Test_analyzeTaskFailCase() {
in := mocks.NewMockIndexNodeClient(s.T()) in := mocks.NewMockIndexNodeClient(s.T())
workerManager := NewMockWorkerManager(s.T()) workerManager := session.NewMockWorkerManager(s.T())
mt := createMeta(catalog, s.createAnalyzeMeta(catalog), &indexMeta{ mt := createMeta(catalog, s.createAnalyzeMeta(catalog), &indexMeta{
RWMutex: sync.RWMutex{}, RWMutex: sync.RWMutex{},
@ -1222,7 +1223,7 @@ func (s *taskSchedulerSuite) Test_indexTaskFailCase() {
catalog := catalogmocks.NewDataCoordCatalog(s.T()) catalog := catalogmocks.NewDataCoordCatalog(s.T())
in := mocks.NewMockIndexNodeClient(s.T()) in := mocks.NewMockIndexNodeClient(s.T())
workerManager := NewMockWorkerManager(s.T()) workerManager := session.NewMockWorkerManager(s.T())
mt := createMeta(catalog, mt := createMeta(catalog,
&analyzeMeta{ &analyzeMeta{
@ -1383,7 +1384,7 @@ func (s *taskSchedulerSuite) Test_indexTaskWithMvOptionalScalarField() {
catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil) catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil)
in := mocks.NewMockIndexNodeClient(s.T()) in := mocks.NewMockIndexNodeClient(s.T())
workerManager := NewMockWorkerManager(s.T()) workerManager := session.NewMockWorkerManager(s.T())
workerManager.EXPECT().PickClient().Return(s.nodeID, in) workerManager.EXPECT().PickClient().Return(s.nodeID, in)
workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true) workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true)