diff --git a/internal/datanode/broker/broker.go b/internal/datanode/broker/broker.go index 234d62dd7b..404981a8a0 100644 --- a/internal/datanode/broker/broker.go +++ b/internal/datanode/broker/broker.go @@ -22,13 +22,15 @@ type coordBroker struct { *dataCoordBroker } -func NewCoordBroker(rc types.RootCoordClient, dc types.DataCoordClient) Broker { +func NewCoordBroker(rc types.RootCoordClient, dc types.DataCoordClient, serverID int64) Broker { return &coordBroker{ rootCoordBroker: &rootCoordBroker{ - client: rc, + client: rc, + serverID: serverID, }, dataCoordBroker: &dataCoordBroker{ - client: dc, + client: dc, + serverID: serverID, }, } } diff --git a/internal/datanode/broker/datacoord.go b/internal/datanode/broker/datacoord.go index e3e57cd839..0b6ad491d7 100644 --- a/internal/datanode/broker/datacoord.go +++ b/internal/datanode/broker/datacoord.go @@ -14,18 +14,18 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) type dataCoordBroker struct { - client types.DataCoordClient + client types.DataCoordClient + serverID int64 } func (dc *dataCoordBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error) { req := &datapb.AssignSegmentIDRequest{ - NodeID: paramtable.GetNodeID(), + NodeID: dc.serverID, PeerRole: typeutil.ProxyRole, SegmentIDRequests: reqs, } @@ -48,7 +48,7 @@ func (dc *dataCoordBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.Dat req := &datapb.ReportDataNodeTtMsgsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(dc.serverID), ), Msgs: msgs, } @@ -69,7 +69,7 @@ func (dc *dataCoordBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int6 infoResp, err := dc.client.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(dc.serverID), ), SegmentIDs: segmentIDs, IncludeUnHealthy: true, @@ -96,7 +96,7 @@ func (dc *dataCoordBroker) UpdateChannelCheckpoint(ctx context.Context, channelN req := &datapb.UpdateChannelCheckpointRequest{ Base: commonpbutil.NewMsgBase( - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(dc.serverID), ), VChannel: channelName, Position: cp, diff --git a/internal/datanode/broker/datacoord_test.go b/internal/datanode/broker/datacoord_test.go index b4564aba38..ab773c8d2c 100644 --- a/internal/datanode/broker/datacoord_test.go +++ b/internal/datanode/broker/datacoord_test.go @@ -33,7 +33,7 @@ func (s *dataCoordSuite) SetupSuite() { func (s *dataCoordSuite) SetupTest() { s.dc = mocks.NewMockDataCoordClient(s.T()) - s.broker = NewCoordBroker(nil, s.dc) + s.broker = NewCoordBroker(nil, s.dc, 1) } func (s *dataCoordSuite) resetMock() { diff --git a/internal/datanode/broker/rootcoord.go b/internal/datanode/broker/rootcoord.go index 47129f8487..de41bd5865 100644 --- a/internal/datanode/broker/rootcoord.go +++ b/internal/datanode/broker/rootcoord.go @@ -13,12 +13,12 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) type rootCoordBroker struct { - client types.RootCoordClient + client types.RootCoordClient + serverID int64 } func (rc *rootCoordBroker) DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, timestamp typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error) { @@ -29,7 +29,7 @@ func (rc *rootCoordBroker) DescribeCollection(ctx context.Context, collectionID req := &milvuspb.DescribeCollectionRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(rc.serverID), ), // please do not specify the collection name alone after database feature. CollectionID: collectionID, @@ -89,7 +89,7 @@ func (rc *rootCoordBroker) AllocTimestamp(ctx context.Context, num uint32) (uint req := &rootcoordpb.AllocTimestampRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(rc.serverID), ), Count: num, } diff --git a/internal/datanode/broker/rootcoord_test.go b/internal/datanode/broker/rootcoord_test.go index e08279fe2f..383fa6f290 100644 --- a/internal/datanode/broker/rootcoord_test.go +++ b/internal/datanode/broker/rootcoord_test.go @@ -33,7 +33,7 @@ func (s *rootCoordSuite) SetupSuite() { func (s *rootCoordSuite) SetupTest() { s.rc = mocks.NewMockRootCoordClient(s.T()) - s.broker = NewCoordBroker(s.rc, nil) + s.broker = NewCoordBroker(s.rc, nil, 1) } func (s *rootCoordSuite) resetMock() { diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 859e190b34..fe212e25f5 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -26,12 +26,12 @@ import ( "math/rand" "os" "sync" - "sync/atomic" "syscall" "time" "github.com/cockroachdb/errors" clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -84,6 +84,7 @@ var Params *paramtable.ComponentParam = paramtable.Get() // `segmentCache` stores all flushing and flushed segments. type DataNode struct { ctx context.Context + serverID int64 cancel context.CancelFunc Role string stateCode atomic.Value // commonpb.StateCode_Initializing @@ -127,7 +128,7 @@ type DataNode struct { } // NewDataNode will return a DataNode with abnormal state. -func NewDataNode(ctx context.Context, factory dependency.Factory) *DataNode { +func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64) *DataNode { rand.Seed(time.Now().UnixNano()) ctx2, cancel2 := context.WithCancel(ctx) node := &DataNode{ @@ -138,6 +139,7 @@ func NewDataNode(ctx context.Context, factory dependency.Factory) *DataNode { rootCoord: nil, dataCoord: nil, factory: factory, + serverID: serverID, segmentCache: newCache(), compactionExecutor: newCompactionExecutor(), @@ -189,9 +191,10 @@ func (node *DataNode) SetDataCoordClient(ds types.DataCoordClient) error { // Register register datanode to etcd func (node *DataNode) Register() error { + log.Debug("node begin to register to etcd", zap.String("serverName", node.session.ServerName), zap.Int64("ServerID", node.session.ServerID)) node.session.Register() - metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.DataNodeRole).Inc() + metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.DataNodeRole).Inc() log.Info("DataNode Register Finished") // Start liveness check node.session.LivenessCheck(node.ctx, func() { @@ -199,7 +202,7 @@ func (node *DataNode) Register() error { if err := node.Stop(); err != nil { log.Fatal("failed to stop server", zap.Error(err)) } - metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.DataNodeRole).Dec() + metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.DataNodeRole).Dec() // manually send signal to starter goroutine if node.session.TriggerKill { if p, err := os.FindProcess(os.Getpid()); err == nil { @@ -232,6 +235,10 @@ func (node *DataNode) initRateCollector() error { return nil } +func (node *DataNode) GetNodeID() int64 { + return node.serverID +} + func (node *DataNode) Init() error { var initError error node.initOnce.Do(func() { @@ -244,24 +251,24 @@ func (node *DataNode) Init() error { return } - node.broker = broker.NewCoordBroker(node.rootCoord, node.dataCoord) + node.broker = broker.NewCoordBroker(node.rootCoord, node.dataCoord, node.GetNodeID()) err := node.initRateCollector() if err != nil { - log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", paramtable.GetNodeID()), zap.Error(err)) + log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", node.GetNodeID()), zap.Error(err)) initError = err return } - log.Info("DataNode server init rateCollector done", zap.Int64("node ID", paramtable.GetNodeID())) + log.Info("DataNode server init rateCollector done", zap.Int64("node ID", node.GetNodeID())) - node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.DataNodeRole, paramtable.GetNodeID()) - log.Info("DataNode server init dispatcher client done", zap.Int64("node ID", paramtable.GetNodeID())) + node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.DataNodeRole, node.GetNodeID()) + log.Info("DataNode server init dispatcher client done", zap.Int64("node ID", node.GetNodeID())) - alloc, err := allocator.New(context.Background(), node.rootCoord, paramtable.GetNodeID()) + alloc, err := allocator.New(context.Background(), node.rootCoord, node.GetNodeID()) if err != nil { log.Error("failed to create id allocator", zap.Error(err), - zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", paramtable.GetNodeID())) + zap.String("role", typeutil.DataNodeRole), zap.Int64("DataNode ID", node.GetNodeID())) initError = err return } @@ -292,7 +299,7 @@ func (node *DataNode) Init() error { node.channelCheckpointUpdater = newChannelCheckpointUpdater(node) - log.Info("init datanode done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", node.address)) + log.Info("init datanode done", zap.Int64("nodeID", node.GetNodeID()), zap.String("Address", node.address)) }) return initError } @@ -354,7 +361,7 @@ func (node *DataNode) Start() error { Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO), commonpbutil.WithMsgID(0), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(node.GetNodeID()), ), Count: 1, }) diff --git a/internal/datanode/data_sync_service.go b/internal/datanode/data_sync_service.go index 968ddb877a..0a27a2f5a5 100644 --- a/internal/datanode/data_sync_service.go +++ b/internal/datanode/data_sync_service.go @@ -40,7 +40,6 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/conc" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -351,7 +350,7 @@ func getServiceWithChannel(initCtx context.Context, node *DataNode, info *datapb resendTTCh = make(chan resendTTMsg, 100) ) - node.writeBufferManager.Register(channelName, metacache, storageV2Cache, writebuffer.WithMetaWriter(syncmgr.BrokerMetaWriter(node.broker)), writebuffer.WithIDAllocator(node.allocator)) + node.writeBufferManager.Register(channelName, metacache, storageV2Cache, writebuffer.WithMetaWriter(syncmgr.BrokerMetaWriter(node.broker, config.serverID)), writebuffer.WithIDAllocator(node.allocator)) ctx, cancel := context.WithCancel(node.ctx) ds := &dataSyncService{ ctx: ctx, @@ -410,7 +409,7 @@ func getServiceWithChannel(initCtx context.Context, node *DataNode, info *datapb } m.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()}) - metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() + metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(config.serverID)).Inc() log.Info("datanode AsProducer", zap.String("TimeTickChannelName", Params.CommonCfg.DataCoordTimeTick.GetValue())) m.EnableProduce(true) diff --git a/internal/datanode/event_manager.go b/internal/datanode/event_manager.go index 464a0c875a..93479774e6 100644 --- a/internal/datanode/event_manager.go +++ b/internal/datanode/event_manager.go @@ -33,7 +33,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/logutil" - "github.com/milvus-io/milvus/pkg/util/paramtable" ) const retryWatchInterval = 20 * time.Second @@ -93,7 +92,7 @@ func (node *DataNode) StartWatchChannels(ctx context.Context) { // serves the corner case for etcd connection lost and missing some events func (node *DataNode) checkWatchedList() error { // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} - prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID())) + prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.serverID)) keys, values, err := node.watchKv.LoadWithPrefix(prefix) if err != nil { return err diff --git a/internal/datanode/flow_graph_manager.go b/internal/datanode/flow_graph_manager.go index b1832f3884..7efc1d3eee 100644 --- a/internal/datanode/flow_graph_manager.go +++ b/internal/datanode/flow_graph_manager.go @@ -62,6 +62,7 @@ func (fm *fgManagerImpl) AddFlowgraph(ds *dataSyncService) { func (fm *fgManagerImpl) AddandStartWithEtcdTickler(dn *DataNode, vchan *datapb.VchannelInfo, schema *schemapb.CollectionSchema, tickler *etcdTickler) error { log := log.With(zap.String("channel", vchan.GetChannelName())) + log.Warn(fmt.Sprintf("debug AddandStartWithEtcdTickler %d", dn.GetNodeID())) if fm.flowgraphs.Contain(vchan.GetChannelName()) { log.Warn("try to add an existed DataSyncService") return nil diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index 854c593149..34474e5439 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -83,7 +83,7 @@ var segID2SegInfo = map[int64]*datapb.SegmentInfo{ func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode { factory := dependency.NewDefaultFactory(true) - node := NewDataNode(ctx, factory) + node := NewDataNode(ctx, factory, 1) node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID()) diff --git a/internal/datanode/services.go b/internal/datanode/services.go index 44316023ea..7432be8cc0 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -94,13 +94,13 @@ func (node *DataNode) GetComponentStates(ctx context.Context, req *milvuspb.GetC // So if receiving calls to flush segment A, DataNode should guarantee the segment to be flushed. func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { metrics.DataNodeFlushReqCounter.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), + fmt.Sprint(node.GetNodeID()), metrics.TotalLabel).Inc() log := log.Ctx(ctx) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.FlushSegments failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.FlushSegments failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return merr.Status(err), nil } @@ -111,6 +111,7 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen zap.Int64("serverID", serverID), ) + log.Info(fmt.Sprintf("debug by FlushSegments:%v:%v", serverID, req.GetBase().GetTargetID())) return merr.Status(merr.WrapErrNodeNotMatch(req.GetBase().GetTargetID(), serverID)), nil } @@ -133,7 +134,7 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen log.Info("sending segments to WriteBuffer Manager") metrics.DataNodeFlushReqCounter.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), + fmt.Sprint(node.GetNodeID()), metrics.SuccessLabel).Inc() return merr.Success(), nil } @@ -166,7 +167,7 @@ func (node *DataNode) GetStatisticsChannel(ctx context.Context, req *internalpb. func (node *DataNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { log.Debug("DataNode.ShowConfigurations", zap.String("pattern", req.Pattern)) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.ShowConfigurations failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.ShowConfigurations failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return &internalpb.ShowConfigurationsResponse{ Status: merr.Status(err), @@ -191,7 +192,7 @@ func (node *DataNode) ShowConfigurations(ctx context.Context, req *internalpb.Sh // GetMetrics return datanode metrics func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.GetMetrics failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.GetMetrics failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return &milvuspb.GetMetricsResponse{ Status: merr.Status(err), @@ -201,7 +202,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe metricType, err := metricsinfo.ParseMetricType(req.Request) if err != nil { log.Warn("DataNode.GetMetrics failed to parse metric type", - zap.Int64("nodeID", paramtable.GetNodeID()), + zap.Int64("nodeID", node.GetNodeID()), zap.String("req", req.Request), zap.Error(err)) @@ -213,7 +214,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe if metricType == metricsinfo.SystemInfoMetrics { systemInfoMetrics, err := node.getSystemInfoMetrics(ctx, req) if err != nil { - log.Warn("DataNode GetMetrics failed", zap.Int64("nodeID", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode GetMetrics failed", zap.Int64("nodeID", node.GetNodeID()), zap.Error(err)) return &milvuspb.GetMetricsResponse{ Status: merr.Status(err), }, nil @@ -223,7 +224,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe } log.RatedWarn(60, "DataNode.GetMetrics failed, request metric type is not implemented yet", - zap.Int64("nodeID", paramtable.GetNodeID()), + zap.Int64("nodeID", node.GetNodeID()), zap.String("req", req.Request), zap.String("metric_type", metricType)) @@ -237,7 +238,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe func (node *DataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) { log := log.Ctx(ctx).With(zap.Int64("planID", req.GetPlanID())) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.Compaction failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.Compaction failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return merr.Status(err), nil } @@ -307,7 +308,7 @@ func (node *DataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan // return status of all compaction plans func (node *DataNode) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) { if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.GetCompactionState failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.GetCompactionState failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return &datapb.CompactionStateResponse{ Status: merr.Status(err), }, nil @@ -330,7 +331,7 @@ func (node *DataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegments ) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.SyncSegments failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.SyncSegments failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return merr.Status(err), nil } @@ -366,7 +367,7 @@ func (node *DataNode) NotifyChannelOperation(ctx context.Context, req *datapb.Ch zap.Int("operation count", len(req.GetInfos()))) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.NotifyChannelOperation failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.NotifyChannelOperation failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return merr.Status(err), nil } @@ -389,7 +390,7 @@ func (node *DataNode) CheckChannelOperationProgress(ctx context.Context, req *da log.Info("DataNode receives CheckChannelOperationProgress") if err := merr.CheckHealthy(node.GetStateCode()); err != nil { - log.Warn("DataNode.CheckChannelOperationProgress failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) + log.Warn("DataNode.CheckChannelOperationProgress failed", zap.Int64("nodeId", node.GetNodeID()), zap.Error(err)) return &datapb.ChannelOperationProgressResponse{ Status: merr.Status(err), }, nil @@ -406,7 +407,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) zap.String("database name", req.GetImportTask().GetDatabaseName()), zap.Strings("channel names", req.GetImportTask().GetChannelNames()), zap.Int64s("working dataNodes", req.WorkingNodes), - zap.Int64("node ID", paramtable.GetNodeID()), + zap.Int64("node ID", node.GetNodeID()), } log.Info("DataNode receive import request", logFields...) defer func() { @@ -416,7 +417,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) importResult := &rootcoordpb.ImportResult{ Status: merr.Success(), TaskId: req.GetImportTask().TaskId, - DatanodeId: paramtable.GetNodeID(), + DatanodeId: node.GetNodeID(), State: commonpb.ImportState_ImportStarted, Segments: make([]int64, 0), AutoIds: make([]int64, 0), @@ -513,7 +514,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) } func (node *DataNode) FlushChannels(ctx context.Context, req *datapb.FlushChannelsRequest) (*commonpb.Status, error) { - log := log.Ctx(ctx).With(zap.Int64("nodeId", paramtable.GetNodeID()), + log := log.Ctx(ctx).With(zap.Int64("nodeId", node.GetNodeID()), zap.Time("flushTs", tsoutil.PhysicalTime(req.GetFlushTs())), zap.Strings("channels", req.GetChannels())) @@ -557,7 +558,7 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor return nil }, retry.Attempts(getFlowGraphServiceAttempts)) if err != nil { - logFields = append(logFields, zap.Int64("node ID", paramtable.GetNodeID())) + logFields = append(logFields, zap.Int64("node ID", node.GetNodeID())) log.Error("channel not found in current DataNode", logFields...) return &datapb.AddImportSegmentResponse{ Status: &commonpb.Status{ @@ -660,7 +661,7 @@ func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil importResult := &rootcoordpb.ImportResult{ Status: merr.Success(), TaskId: req.GetImportTask().TaskId, - DatanodeId: paramtable.GetNodeID(), + DatanodeId: node.GetNodeID(), State: commonpb.ImportState_ImportStarted, Segments: []int64{segmentID}, AutoIds: make([]int64, 0), @@ -732,7 +733,7 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo err := node.broker.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithTimeStamp(ts), // Pass current timestamp downstream. - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(node.GetNodeID()), ), SegmentId: segmentID, ChannelName: targetChName, @@ -742,7 +743,7 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo SaveBinlogPathReq: &datapb.SaveBinlogPathsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithTimeStamp(ts), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(node.GetNodeID()), ), SegmentID: segmentID, CollectionID: req.GetImportTask().GetCollectionId(), diff --git a/internal/datanode/stats_updater.go b/internal/datanode/stats_updater.go index cc44fff208..2f25c88136 100644 --- a/internal/datanode/stats_updater.go +++ b/internal/datanode/stats_updater.go @@ -12,7 +12,6 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) @@ -63,7 +62,7 @@ func (u *mqStatsUpdater) send(ts Timestamp, segmentIDs []int64) error { Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), commonpbutil.WithTimeStamp(ts), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(u.config.serverID), ), ChannelName: u.config.vChannelName, Timestamp: ts, diff --git a/internal/datanode/syncmgr/meta_writer.go b/internal/datanode/syncmgr/meta_writer.go index 4d506be2e5..0e82f6cfe6 100644 --- a/internal/datanode/syncmgr/meta_writer.go +++ b/internal/datanode/syncmgr/meta_writer.go @@ -13,7 +13,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" ) @@ -25,14 +24,16 @@ type MetaWriter interface { } type brokerMetaWriter struct { - broker broker.Broker - opts []retry.Option + broker broker.Broker + opts []retry.Option + serverID int64 } -func BrokerMetaWriter(broker broker.Broker, opts ...retry.Option) MetaWriter { +func BrokerMetaWriter(broker broker.Broker, serverID int64, opts ...retry.Option) MetaWriter { return &brokerMetaWriter{ - broker: broker, - opts: opts, + broker: broker, + serverID: serverID, + opts: opts, } } @@ -82,7 +83,7 @@ func (b *brokerMetaWriter) UpdateSync(pack *SyncTask) error { Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(0), commonpbutil.WithMsgID(0), - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(b.serverID), ), SegmentID: pack.segmentID, CollectionID: pack.collectionID, @@ -165,7 +166,7 @@ func (b *brokerMetaWriter) UpdateSyncV2(pack *SyncTaskV2) error { req := &datapb.SaveBinlogPathsRequest{ Base: commonpbutil.NewMsgBase( - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(b.serverID), ), SegmentID: pack.segmentID, CollectionID: pack.collectionID, @@ -214,7 +215,7 @@ func (b *brokerMetaWriter) DropChannel(channelName string) error { err := retry.Do(context.Background(), func() error { status, err := b.broker.DropVirtualChannel(context.Background(), &datapb.DropVirtualChannelRequest{ Base: commonpbutil.NewMsgBase( - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithSourceID(b.serverID), ), ChannelName: channelName, }) diff --git a/internal/datanode/syncmgr/meta_writer_test.go b/internal/datanode/syncmgr/meta_writer_test.go index 23d54d9be4..fc1d921b70 100644 --- a/internal/datanode/syncmgr/meta_writer_test.go +++ b/internal/datanode/syncmgr/meta_writer_test.go @@ -30,7 +30,7 @@ func (s *MetaWriterSuite) SetupSuite() { func (s *MetaWriterSuite) SetupTest() { s.broker = broker.NewMockBroker(s.T()) s.metacache = metacache.NewMockMetaCache(s.T()) - s.writer = BrokerMetaWriter(s.broker, retry.Attempts(1)) + s.writer = BrokerMetaWriter(s.broker, 1, retry.Attempts(1)) } func (s *MetaWriterSuite) TestNormalSave() { diff --git a/internal/datanode/syncmgr/sync_manager_test.go b/internal/datanode/syncmgr/sync_manager_test.go index 5ddb5e332a..ec416fca72 100644 --- a/internal/datanode/syncmgr/sync_manager_test.go +++ b/internal/datanode/syncmgr/sync_manager_test.go @@ -160,7 +160,7 @@ func (s *SyncManagerSuite) TestSubmit() { manager, err := NewSyncManager(s.chunkManager, s.allocator) s.NoError(err) task := s.getSuiteSyncTask() - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithTimeRange(50, 100) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, @@ -192,7 +192,7 @@ func (s *SyncManagerSuite) TestCompacted() { manager, err := NewSyncManager(s.chunkManager, s.allocator) s.NoError(err) task := s.getSuiteSyncTask() - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithTimeRange(50, 100) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, @@ -235,7 +235,7 @@ func (s *SyncManagerSuite) TestBlock() { go func() { task := s.getSuiteSyncTask() - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithTimeRange(50, 100) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, diff --git a/internal/datanode/syncmgr/task_test.go b/internal/datanode/syncmgr/task_test.go index 13b8318734..d03aa278d7 100644 --- a/internal/datanode/syncmgr/task_test.go +++ b/internal/datanode/syncmgr/task_test.go @@ -190,7 +190,7 @@ func (s *SyncTaskSuite) TestRunNormal() { s.Run("without_data", func() { task := s.getSuiteSyncTask() - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithTimeRange(50, 100) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, @@ -205,7 +205,7 @@ func (s *SyncTaskSuite) TestRunNormal() { s.Run("with_insert_delete_cp", func() { task := s.getSuiteSyncTask() task.WithTimeRange(50, 100) - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, MsgID: []byte{1, 2, 3, 4}, @@ -223,7 +223,7 @@ func (s *SyncTaskSuite) TestRunNormal() { s.Run("with_statslog", func() { task := s.getSuiteSyncTask() task.WithTimeRange(50, 100) - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, MsgID: []byte{1, 2, 3, 4}, @@ -246,7 +246,7 @@ func (s *SyncTaskSuite) TestRunNormal() { s.Run("with_delta_data", func() { task := s.getSuiteSyncTask() task.WithTimeRange(50, 100) - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, MsgID: []byte{1, 2, 3, 4}, @@ -278,7 +278,7 @@ func (s *SyncTaskSuite) TestRunL0Segment() { Value: []byte("test_data"), } task.WithTimeRange(50, 100) - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, MsgID: []byte{1, 2, 3, 4}, @@ -315,7 +315,7 @@ func (s *SyncTaskSuite) TestCompactToNull() { s.metacache.EXPECT().GetSegmentByID(s.segmentID).Return(seg, true) task := s.getSuiteSyncTask() - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithTimeRange(50, 100) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, @@ -379,7 +379,7 @@ func (s *SyncTaskSuite) TestRunError() { s.broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(errors.New("mocked")) task := s.getSuiteSyncTask() - task.WithMetaWriter(BrokerMetaWriter(s.broker, retry.Attempts(1))) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1, retry.Attempts(1))) task.WithTimeRange(50, 100) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, diff --git a/internal/datanode/syncmgr/taskv2_test.go b/internal/datanode/syncmgr/taskv2_test.go index ea29dba8d9..9367689ed1 100644 --- a/internal/datanode/syncmgr/taskv2_test.go +++ b/internal/datanode/syncmgr/taskv2_test.go @@ -221,7 +221,7 @@ func (s *SyncTaskSuiteV2) TestRunNormal() { s.Run("without_insert_delete", func() { task := s.getSuiteSyncTask() - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithTimeRange(50, 100) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, @@ -236,7 +236,7 @@ func (s *SyncTaskSuiteV2) TestRunNormal() { s.Run("with_insert_delete_cp", func() { task := s.getSuiteSyncTask() task.WithTimeRange(50, 100) - task.WithMetaWriter(BrokerMetaWriter(s.broker)) + task.WithMetaWriter(BrokerMetaWriter(s.broker, 1)) task.WithCheckpoint(&msgpb.MsgPosition{ ChannelName: s.channelName, MsgID: []byte{1, 2, 3, 4}, diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index f94dbe3e86..7458925a45 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -43,10 +43,11 @@ type Client struct { grpcClient grpcclient.GrpcClient[datapb.DataNodeClient] sess *sessionutil.Session addr string + serverID int64 } // NewClient creates a client for DataNode. -func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) { +func NewClient(ctx context.Context, addr string, serverID int64) (*Client, error) { if addr == "" { return nil, fmt.Errorf("address is empty") } @@ -61,12 +62,13 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) addr: addr, grpcClient: grpcclient.NewClientBase[datapb.DataNodeClient](config, "milvus.proto.data.DataNode"), sess: sess, + serverID: serverID, } // node shall specify node id - client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.DataNodeRole, nodeID)) + client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.DataNodeRole, serverID)) client.grpcClient.SetGetAddrFunc(client.getAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) - client.grpcClient.SetNodeID(nodeID) + client.grpcClient.SetNodeID(serverID) client.grpcClient.SetSession(sess) return client, nil @@ -120,7 +122,7 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) { return client.WatchDmChannels(ctx, req) }) @@ -142,7 +144,7 @@ func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsReq req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) { return client.FlushSegments(ctx, req) }) @@ -153,7 +155,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*internalpb.ShowConfigurationsResponse, error) { return client.ShowConfigurations(ctx, req) }) @@ -164,7 +166,7 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*milvuspb.GetMetricsResponse, error) { return client.GetMetrics(ctx, req) }) @@ -181,7 +183,7 @@ func (c *Client) GetCompactionState(ctx context.Context, req *datapb.CompactionS req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.CompactionStateResponse, error) { return client.GetCompactionState(ctx, req) }) @@ -192,7 +194,7 @@ func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest, opts req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) { return client.Import(ctx, req) }) @@ -202,7 +204,7 @@ func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegme req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.ResendSegmentStatsResponse, error) { return client.ResendSegmentStats(ctx, req) }) @@ -213,7 +215,7 @@ func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegm req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.serverID)) return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.AddImportSegmentResponse, error) { return client.AddImportSegment(ctx, req) }) diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 9aa1ce2535..e09435acaa 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -90,7 +90,8 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) }, } - s.datanode = dn.NewDataNode(s.ctx, s.factory) + s.serverID.Store(paramtable.GetNodeID()) + s.datanode = dn.NewDataNode(s.ctx, s.factory, s.serverID.Load()) return s, nil } @@ -246,6 +247,7 @@ func (s *Server) init() error { s.SetEtcdClient(s.etcdCli) s.datanode.SetAddress(Params.GetAddress()) log.Info("DataNode address", zap.String("address", Params.IP+":"+strconv.Itoa(Params.Port.GetAsInt()))) + log.Info("DataNode serverID", zap.Int64("serverID", s.serverID.Load())) err = s.startGrpc() if err != nil { diff --git a/internal/distributed/datanode/service_test.go b/internal/distributed/datanode/service_test.go index 88110e4d19..124bde6b25 100644 --- a/internal/distributed/datanode/service_test.go +++ b/internal/distributed/datanode/service_test.go @@ -91,6 +91,10 @@ func (m *MockDataNode) GetAddress() string { return "" } +func (m *MockDataNode) GetNodeID() int64 { + return 2 +} + func (m *MockDataNode) SetRootCoordClient(rc types.RootCoordClient) error { return m.err } diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index 863df81a90..c1f7af3b5a 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -41,6 +41,7 @@ type Client struct { grpcClient grpcclient.GrpcClient[querypb.QueryNodeClient] addr string sess *sessionutil.Session + nodeID int64 } // NewClient creates a new QueryNode client. @@ -59,6 +60,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) addr: addr, grpcClient: grpcclient.NewClientBase[querypb.QueryNodeClient](config, "milvus.proto.query.QueryNode"), sess: sess, + nodeID: nodeID, } // node shall specify node id client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.QueryNodeRole, nodeID)) @@ -122,7 +124,7 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChanne req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.WatchDmChannels(ctx, req) }) @@ -133,7 +135,7 @@ func (c *Client) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannel req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.UnsubDmChannel(ctx, req) }) @@ -144,7 +146,7 @@ func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequ req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.LoadSegments(ctx, req) }) @@ -155,7 +157,7 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.ReleaseCollection(ctx, req) }) @@ -166,7 +168,7 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.LoadPartitions(ctx, req) }) @@ -177,7 +179,7 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.ReleasePartitions(ctx, req) }) @@ -188,7 +190,7 @@ func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmen req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.ReleaseSegments(ctx, req) }) @@ -253,7 +255,7 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.GetSegmentInfoResponse, error) { return client.GetSegmentInfo(ctx, req) }) @@ -264,7 +266,7 @@ func (c *Client) SyncReplicaSegments(ctx context.Context, req *querypb.SyncRepli req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.SyncReplicaSegments(ctx, req) }) @@ -275,7 +277,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*internalpb.ShowConfigurationsResponse, error) { return client.ShowConfigurations(ctx, req) }) @@ -286,7 +288,7 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*milvuspb.GetMetricsResponse, error) { return client.GetMetrics(ctx, req) }) @@ -302,7 +304,7 @@ func (c *Client) GetDataDistribution(ctx context.Context, req *querypb.GetDataDi req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.GetDataDistributionResponse, error) { return client.GetDataDistribution(ctx, req) }) @@ -312,7 +314,7 @@ func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID())) + commonpbutil.FillMsgBaseFromClient(c.nodeID)) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.SyncDistribution(ctx, req) }) @@ -323,7 +325,7 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()), + commonpbutil.FillMsgBaseFromClient(c.nodeID), ) return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*commonpb.Status, error) { return client.Delete(ctx, req) diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 2c02285d1b..c0b56c10c3 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -132,6 +132,7 @@ func (s *Server) init() error { log.Error("QueryNode init error: ", zap.Error(err)) return err } + s.serverID.Store(s.querynode.GetNodeID()) return nil } diff --git a/internal/distributed/querynode/service_test.go b/internal/distributed/querynode/service_test.go index 7565aa59e1..fc979387e6 100644 --- a/internal/distributed/querynode/service_test.go +++ b/internal/distributed/querynode/service_test.go @@ -91,6 +91,7 @@ func Test_NewServer(t *testing.T) { mockQN.EXPECT().SetAddress(mock.Anything).Maybe() mockQN.EXPECT().UpdateStateCode(mock.Anything).Maybe() mockQN.EXPECT().Init().Return(nil).Maybe() + mockQN.EXPECT().GetNodeID().Return(2).Maybe() server.querynode = mockQN t.Run("Run", func(t *testing.T) { @@ -285,6 +286,7 @@ func Test_Run(t *testing.T) { mockQN.EXPECT().SetAddress(mock.Anything).Maybe() mockQN.EXPECT().UpdateStateCode(mock.Anything).Maybe() mockQN.EXPECT().Init().Return(nil).Maybe() + mockQN.EXPECT().GetNodeID().Return(2).Maybe() server.querynode = mockQN err = server.Run() assert.Error(t, err) diff --git a/internal/mocks/mock_datanode.go b/internal/mocks/mock_datanode.go index e9f4a5d7d6..21e3762e40 100644 --- a/internal/mocks/mock_datanode.go +++ b/internal/mocks/mock_datanode.go @@ -568,6 +568,47 @@ func (_c *MockDataNode_GetMetrics_Call) RunAndReturn(run func(context.Context, * return _c } +// GetNodeID provides a mock function with given fields: +func (_m *MockDataNode) GetNodeID() int64 { + ret := _m.Called() + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockDataNode_GetNodeID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeID' +type MockDataNode_GetNodeID_Call struct { + *mock.Call +} + +// GetNodeID is a helper method to define mock.On call +func (_e *MockDataNode_Expecter) GetNodeID() *MockDataNode_GetNodeID_Call { + return &MockDataNode_GetNodeID_Call{Call: _e.mock.On("GetNodeID")} +} + +func (_c *MockDataNode_GetNodeID_Call) Run(run func()) *MockDataNode_GetNodeID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockDataNode_GetNodeID_Call) Return(_a0 int64) *MockDataNode_GetNodeID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDataNode_GetNodeID_Call) RunAndReturn(run func() int64) *MockDataNode_GetNodeID_Call { + _c.Call.Return(run) + return _c +} + // GetStateCode provides a mock function with given fields: func (_m *MockDataNode) GetStateCode() commonpb.StateCode { ret := _m.Called() diff --git a/internal/mocks/mock_querynode.go b/internal/mocks/mock_querynode.go index c96f9b3e2b..e6288f04b1 100644 --- a/internal/mocks/mock_querynode.go +++ b/internal/mocks/mock_querynode.go @@ -291,6 +291,47 @@ func (_c *MockQueryNode_GetMetrics_Call) RunAndReturn(run func(context.Context, return _c } +// GetNodeID provides a mock function with given fields: +func (_m *MockQueryNode) GetNodeID() int64 { + ret := _m.Called() + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockQueryNode_GetNodeID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeID' +type MockQueryNode_GetNodeID_Call struct { + *mock.Call +} + +// GetNodeID is a helper method to define mock.On call +func (_e *MockQueryNode_Expecter) GetNodeID() *MockQueryNode_GetNodeID_Call { + return &MockQueryNode_GetNodeID_Call{Call: _e.mock.On("GetNodeID")} +} + +func (_c *MockQueryNode_GetNodeID_Call) Run(run func()) *MockQueryNode_GetNodeID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockQueryNode_GetNodeID_Call) Return(_a0 int64) *MockQueryNode_GetNodeID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryNode_GetNodeID_Call) RunAndReturn(run func() int64) *MockQueryNode_GetNodeID_Call { + _c.Call.Return(run) + return _c +} + // GetSegmentInfo provides a mock function with given fields: _a0, _a1 func (_m *MockQueryNode) GetSegmentInfo(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index d4afc8df1c..f431868e52 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -37,7 +37,6 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" ) @@ -184,10 +183,10 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque ) var err error - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc() defer func() { if err != nil { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc() } }() @@ -244,13 +243,13 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque )) latency := tr.ElapseSpan() - metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.Leader).Inc() + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.Leader).Inc() return resp, nil } func (node *QueryNode) queryChannelStream(ctx context.Context, req *querypb.QueryRequest, channel string, srv streamrpc.QueryStreamServer) error { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc() msgID := req.Req.Base.GetMsgID() log := log.Ctx(ctx).With( zap.Int64("msgID", msgID), @@ -262,7 +261,7 @@ func (node *QueryNode) queryChannelStream(ctx context.Context, req *querypb.Quer var err error defer func() { if err != nil { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc() } }() @@ -344,10 +343,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq defer node.lifetime.Done() var err error - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.Leader).Inc() defer func() { if err != nil { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.Leader).Inc() } }() @@ -394,10 +393,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq // update metric to prometheus latency := tr.ElapseSpan() - metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.Leader).Inc() - metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq())) - metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk())) + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.Leader).Inc() + metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(req.Req.GetNq())) + metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(req.Req.GetTopk())) return resp, nil } @@ -415,10 +414,10 @@ func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.Hyb defer node.lifetime.Done() var err error - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc() defer func() { if err != nil { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc() } }() @@ -449,11 +448,11 @@ func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.Hyb // update metric to prometheus latency := tr.ElapseSpan() - metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc() + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc() for _, searchReq := range req.GetReq().GetReqs() { - metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetNq())) - metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetTopk())) + metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetNq())) + metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(node.GetNodeID())).Observe(float64(searchReq.GetTopk())) } return result, nil } diff --git a/internal/querynodev2/metrics_info.go b/internal/querynodev2/metrics_info.go index d3bbc0527b..e2f1b9b5b0 100644 --- a/internal/querynodev2/metrics_info.go +++ b/internal/querynodev2/metrics_info.go @@ -114,7 +114,7 @@ func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error return seg.MemSize() }) totalGrowingSize += size - metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), + metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(node.GetNodeID()), fmt.Sprint(collection), segments.SegmentTypeGrowing.String()).Set(float64(size)) } @@ -126,7 +126,7 @@ func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error size := lo.SumBy(segs, func(seg segments.Segment) int64 { return seg.MemSize() }) - metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), + metrics.QueryNodeEntitiesSize.WithLabelValues(fmt.Sprint(node.GetNodeID()), fmt.Sprint(collection), segments.SegmentTypeSealed.String()).Set(float64(size)) } @@ -148,7 +148,7 @@ func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error QueryQueue: qqms, GrowingSegmentsSize: totalGrowingSize, Effect: metricsinfo.NodeEffect{ - NodeID: paramtable.GetNodeID(), + NodeID: node.GetNodeID(), CollectionIDs: collections.Collect(), }, }, nil @@ -163,7 +163,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, if err != nil { return &milvuspb.GetMetricsResponse{ Status: merr.Status(err), - ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()), + ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, node.GetNodeID()), }, nil } hardwareInfos := metricsinfo.HardwareMetrics{ @@ -179,7 +179,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, nodeInfos := metricsinfo.QueryNodeInfos{ BaseComponentInfos: metricsinfo.BaseComponentInfos{ - Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()), + Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, node.GetNodeID()), HardwareInfos: hardwareInfos, SystemInfo: metricsinfo.DeployMetrics{}, CreatedTime: paramtable.GetCreateTime().String(), @@ -199,13 +199,13 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, return &milvuspb.GetMetricsResponse{ Status: merr.Status(err), Response: "", - ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()), + ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, node.GetNodeID()), }, nil } return &milvuspb.GetMetricsResponse{ Status: merr.Success(), Response: resp, - ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()), + ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, node.GetNodeID()), }, nil } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index af1d76630b..aa69a535a3 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -105,6 +105,7 @@ type QueryNode struct { subscribingChannels *typeutil.ConcurrentSet[string] unsubscribingChannels *typeutil.ConcurrentSet[string] delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator] + serverID int64 // segment loader loader segments.Loader @@ -156,7 +157,8 @@ func (node *QueryNode) initSession() error { node.session.Init(typeutil.QueryNodeRole, node.address, false, true) sessionutil.SaveServerInfo(typeutil.QueryNodeRole, node.session.ServerID) paramtable.SetNodeID(node.session.ServerID) - log.Info("QueryNode init session", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("node address", node.session.Address)) + node.serverID = node.session.ServerID + log.Info("QueryNode init session", zap.Int64("nodeID", node.GetNodeID()), zap.String("node address", node.session.Address)) return nil } @@ -164,13 +166,13 @@ func (node *QueryNode) initSession() error { func (node *QueryNode) Register() error { node.session.Register() // start liveness check - metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.QueryNodeRole).Inc() + metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.QueryNodeRole).Inc() node.session.LivenessCheck(node.ctx, func() { - log.Error("Query Node disconnected from etcd, process will exit", zap.Int64("Server Id", paramtable.GetNodeID())) + log.Error("Query Node disconnected from etcd, process will exit", zap.Int64("Server Id", node.GetNodeID())) if err := node.Stop(); err != nil { log.Fatal("failed to stop server", zap.Error(err)) } - metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.QueryNodeRole).Dec() + metrics.NumNodes.WithLabelValues(fmt.Sprint(node.GetNodeID()), typeutil.QueryNodeRole).Dec() // manually send signal to starter goroutine if node.session.TriggerKill { if p, err := os.FindProcess(os.Getpid()); err == nil { @@ -263,6 +265,10 @@ func getIndexEngineVersion() (minimal, current int32) { return int32(cMinimal), int32(cCurrent) } +func (node *QueryNode) GetNodeID() int64 { + return node.serverID +} + func (node *QueryNode) CloseSegcore() { // safe stop initcore.CleanRemoteChunkManager() @@ -301,7 +307,7 @@ func (node *QueryNode) Init() error { initError = err return } - metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(localUsedSize / 1024 / 1024)) + metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(float64(localUsedSize / 1024 / 1024)) node.chunkManager, err = node.factory.NewPersistentStorageChunkManager(node.ctx) if err != nil { @@ -317,7 +323,7 @@ func (node *QueryNode) Init() error { log.Info("queryNode init scheduler", zap.String("policy", schedulePolicy)) node.clusterManager = cluster.NewWorkerManager(func(ctx context.Context, nodeID int64) (cluster.Worker, error) { - if nodeID == paramtable.GetNodeID() { + if nodeID == node.GetNodeID() { return NewLocalWorker(node), nil } @@ -350,7 +356,7 @@ func (node *QueryNode) Init() error { } else { node.loader = segments.NewLoader(node.manager, node.chunkManager) } - node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, paramtable.GetNodeID()) + node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, node.GetNodeID()) // init pipeline manager node.pipelineManager = pipeline.NewManager(node.manager, node.tSafeManager, node.dispClient, node.delegators) @@ -373,7 +379,7 @@ func (node *QueryNode) Init() error { } log.Info("query node init successfully", - zap.Int64("queryNodeID", paramtable.GetNodeID()), + zap.Int64("queryNodeID", node.GetNodeID()), zap.String("Address", node.address), ) }) @@ -392,9 +398,9 @@ func (node *QueryNode) Start() error { mmapEnabled := len(mmapDirPath) > 0 node.UpdateStateCode(commonpb.StateCode_Healthy) - registry.GetInMemoryResolver().RegisterQueryNode(paramtable.GetNodeID(), node) + registry.GetInMemoryResolver().RegisterQueryNode(node.GetNodeID(), node) log.Info("query node start successfully", - zap.Int64("queryNodeID", paramtable.GetNodeID()), + zap.Int64("queryNodeID", node.GetNodeID()), zap.String("Address", node.address), zap.Bool("mmapEnabled", mmapEnabled), ) @@ -432,7 +438,7 @@ func (node *QueryNode) Stop() error { select { case <-timeoutCh: - log.Warn("migrate data timed out", zap.Int64("ServerID", paramtable.GetNodeID()), + log.Warn("migrate data timed out", zap.Int64("ServerID", node.GetNodeID()), zap.Int64s("sealedSegments", lo.Map(sealedSegments, func(s segments.Segment, i int) int64 { return s.ID() })), @@ -444,14 +450,14 @@ func (node *QueryNode) Stop() error { break outer case <-time.After(time.Second): - metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(len(sealedSegments))) - metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(channelNum)) + metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(float64(len(sealedSegments))) + metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(float64(channelNum)) } } metrics.StoppingBalanceNodeNum.WithLabelValues().Set(0) - metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(0) - metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(0) + metrics.StoppingBalanceSegmentNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(0) + metrics.StoppingBalanceChannelNum.WithLabelValues(fmt.Sprint(node.GetNodeID())).Set(0) } node.UpdateStateCode(commonpb.StateCode_Abnormal) diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index f0eebc6687..7a30455ed1 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -66,7 +66,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context, req *milvuspb.Get log.Debug("QueryNode current state", zap.Int64("NodeID", nodeID), zap.String("StateCode", code.String())) if node.session != nil && node.session.Registered() { - nodeID = paramtable.GetNodeID() + nodeID = node.GetNodeID() } info := &milvuspb.ComponentInfo{ NodeID: nodeID, @@ -112,7 +112,7 @@ func (node *QueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatis } defer node.lifetime.Done() - err := merr.CheckTargetID(req.GetReq().GetBase()) + err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase()) if err != nil { log.Warn("target ID check failed", zap.Error(err)) return &internalpb.GetStatisticsResponse{ @@ -200,7 +200,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", channel.GetChannelName()), - zap.Int64("currentNodeID", paramtable.GetNodeID()), + zap.Int64("currentNodeID", node.GetNodeID()), ) log.Info("received watch channel request", @@ -214,7 +214,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm defer node.lifetime.Done() // check target matches - if err := merr.CheckTargetID(req.GetBase()); err != nil { + if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil { return merr.Status(err), nil } @@ -347,7 +347,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannelName()), - zap.Int64("currentNodeID", paramtable.GetNodeID()), + zap.Int64("currentNodeID", node.GetNodeID()), ) log.Info("received unsubscribe channel request") @@ -359,7 +359,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC defer node.lifetime.Done() // check target matches - if err := merr.CheckTargetID(req.GetBase()); err != nil { + if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil { return merr.Status(err), nil } @@ -412,7 +412,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen zap.Int64("partitionID", segment.GetPartitionID()), zap.String("shard", segment.GetInsertChannel()), zap.Int64("segmentID", segment.GetSegmentID()), - zap.Int64("currentNodeID", paramtable.GetNodeID()), + zap.Int64("currentNodeID", node.GetNodeID()), ) log.Info("received load segments request", @@ -426,7 +426,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen defer node.lifetime.Done() // check target matches - if err := merr.CheckTargetID(req.GetBase()); err != nil { + if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil { return merr.Status(err), nil } @@ -529,7 +529,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release zap.Int64("collectionID", req.GetCollectionID()), zap.String("shard", req.GetShard()), zap.Int64s("segmentIDs", req.GetSegmentIDs()), - zap.Int64("currentNodeID", paramtable.GetNodeID()), + zap.Int64("currentNodeID", node.GetNodeID()), ) log.Info("received release segment request", @@ -544,7 +544,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release defer node.lifetime.Done() // check target matches - if err := merr.CheckTargetID(req.GetBase()); err != nil { + if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil { return merr.Status(err), nil } @@ -630,8 +630,8 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen DmChannel: segment.Shard(), PartitionID: segment.Partition(), CollectionID: segment.Collection(), - NodeID: paramtable.GetNodeID(), - NodeIds: []int64{paramtable.GetNodeID()}, + NodeID: node.GetNodeID(), + NodeIds: []int64{node.GetNodeID()}, MemSize: segment.MemSize(), NumRows: segment.InsertCount(), IndexName: indexName, @@ -669,10 +669,10 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe } defer node.lifetime.Done() - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.FromLeader).Inc() defer func() { if !merr.Ok(resp.GetStatus()) { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.FromLeader).Inc() } }() @@ -693,7 +693,7 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe return resp, nil } - task := tasks.NewSearchTask(searchCtx, collection, node.manager, req) + task := tasks.NewSearchTask(searchCtx, collection, node.manager, req, node.serverID) if err := node.scheduler.Add(task); err != nil { log.Warn("failed to search channel", zap.Error(err)) resp.Status = merr.Status(err) @@ -713,8 +713,8 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe )) latency := tr.ElapseSpan() - metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() resp = task.Result() resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds() @@ -750,7 +750,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( } defer node.lifetime.Done() - err := merr.CheckTargetID(req.GetReq().GetBase()) + err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase()) if err != nil { log.Warn("target ID check failed", zap.Error(err)) return &internalpb.SearchResults{ @@ -807,12 +807,12 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( return resp, nil } reduceLatency := tr.RecordSpan() - metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards). + metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards). Observe(float64(reduceLatency.Milliseconds())) collector.Rate.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq())) collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req))) - metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel). + metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.SearchLabel). Add(float64(proto.Size(req))) if result.GetCostAggregation() != nil { @@ -836,19 +836,19 @@ func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSear if err := node.lifetime.Add(merr.IsHealthy); err != nil { return &querypb.HybridSearchResult{ Base: &commonpb.MsgBase{ - SourceID: paramtable.GetNodeID(), + SourceID: node.GetNodeID(), }, Status: merr.Status(err), }, nil } defer node.lifetime.Done() - err := merr.CheckTargetID(req.GetReq().GetBase()) + err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase()) if err != nil { log.Warn("target ID check failed", zap.Error(err)) return &querypb.HybridSearchResult{ Base: &commonpb.MsgBase{ - SourceID: paramtable.GetNodeID(), + SourceID: node.GetNodeID(), }, Status: merr.Status(err), }, nil @@ -856,7 +856,7 @@ func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSear resp := &querypb.HybridSearchResult{ Base: &commonpb.MsgBase{ - SourceID: paramtable.GetNodeID(), + SourceID: node.GetNodeID(), }, Status: merr.Success(), } @@ -916,11 +916,11 @@ func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSear resp.ChannelsMvcc = channelsMvcc reduceLatency := tr.RecordSpan() - metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards). + metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards). Observe(float64(reduceLatency.Milliseconds())) collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req))) - metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.HybridSearchLabel). + metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.HybridSearchLabel). Add(float64(proto.Size(req))) if resp.GetCostAggregation() != nil { @@ -950,10 +950,10 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ } defer node.lifetime.Done() - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc() defer func() { if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc() } }() @@ -995,8 +995,8 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ // TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency latency := tr.ElapseSpan() - metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() result := task.Result() result.GetCostAggregation().ResponseTime = latency.Milliseconds() result.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ() @@ -1031,7 +1031,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i } defer node.lifetime.Done() - err := merr.CheckTargetID(req.GetReq().GetBase()) + err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase()) if err != nil { log.Warn("target ID check failed", zap.Error(err)) return &internalpb.RetrieveResults{ @@ -1080,12 +1080,12 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i }, nil } reduceLatency := tr.RecordSpan() - metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards). + metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards). Observe(float64(reduceLatency.Milliseconds())) if !req.FromShardLeader { collector.Rate.Add(metricsinfo.NQPerSecond, 1) - metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) + metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) } if ret.GetCostAggregation() != nil { @@ -1116,7 +1116,7 @@ func (node *QueryNode) QueryStream(req *querypb.QueryRequest, srv querypb.QueryN } defer node.lifetime.Done() - err := merr.CheckTargetID(req.GetReq().GetBase()) + err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase()) if err != nil { log.Warn("target ID check failed", zap.Error(err)) return err @@ -1151,7 +1151,7 @@ func (node *QueryNode) QueryStream(req *querypb.QueryRequest, srv querypb.QueryN } collector.Rate.Add(metricsinfo.NQPerSecond, 1) - metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) + metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) return nil } @@ -1170,10 +1170,10 @@ func (node *QueryNode) QueryStreamSegments(req *querypb.QueryRequest, srv queryp ) resp := &internalpb.RetrieveResults{} - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc() defer func() { if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc() } }() @@ -1207,8 +1207,8 @@ func (node *QueryNode) QueryStreamSegments(req *querypb.QueryRequest, srv queryp // TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency latency := tr.ElapseSpan() - metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) - metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() return nil } @@ -1221,7 +1221,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { if err := node.lifetime.Add(merr.IsHealthy); err != nil { log.Warn("QueryNode.ShowConfigurations failed", - zap.Int64("nodeId", paramtable.GetNodeID()), + zap.Int64("nodeId", node.GetNodeID()), zap.String("req", req.Pattern), zap.Error(err)) @@ -1251,7 +1251,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { if err := node.lifetime.Add(merr.IsHealthy); err != nil { log.Warn("QueryNode.GetMetrics failed", - zap.Int64("nodeId", paramtable.GetNodeID()), + zap.Int64("nodeId", node.GetNodeID()), zap.String("req", req.Request), zap.Error(err)) @@ -1265,7 +1265,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR metricType, err := metricsinfo.ParseMetricType(req.Request) if err != nil { log.Warn("QueryNode.GetMetrics failed to parse metric type", - zap.Int64("nodeId", paramtable.GetNodeID()), + zap.Int64("nodeId", node.GetNodeID()), zap.String("req", req.Request), zap.Error(err)) @@ -1278,7 +1278,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR queryNodeMetrics, err := getSystemInfoMetrics(ctx, req, node) if err != nil { log.Warn("QueryNode.GetMetrics failed", - zap.Int64("nodeId", paramtable.GetNodeID()), + zap.Int64("nodeId", node.GetNodeID()), zap.String("req", req.Request), zap.String("metricType", metricType), zap.Error(err)) @@ -1287,7 +1287,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR }, nil } log.RatedDebug(50, "QueryNode.GetMetrics", - zap.Int64("nodeID", paramtable.GetNodeID()), + zap.Int64("nodeID", node.GetNodeID()), zap.String("req", req.Request), zap.String("metricType", metricType), zap.Any("queryNodeMetrics", queryNodeMetrics)) @@ -1296,7 +1296,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR } log.Debug("QueryNode.GetMetrics failed, request metric type is not implemented yet", - zap.Int64("nodeID", paramtable.GetNodeID()), + zap.Int64("nodeID", node.GetNodeID()), zap.String("req", req.Request), zap.String("metricType", metricType)) @@ -1308,7 +1308,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { log := log.Ctx(ctx).With( zap.Int64("msgID", req.GetBase().GetMsgID()), - zap.Int64("nodeID", paramtable.GetNodeID()), + zap.Int64("nodeID", node.GetNodeID()), ) if err := node.lifetime.Add(merr.IsHealthy); err != nil { log.Warn("QueryNode.GetDataDistribution failed", @@ -1321,7 +1321,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get defer node.lifetime.Done() // check target matches - if err := merr.CheckTargetID(req.GetBase()); err != nil { + if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil { return &querypb.GetDataDistributionResponse{ Status: merr.Status(err), }, nil @@ -1393,7 +1393,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get return &querypb.GetDataDistributionResponse{ Status: merr.Success(), - NodeID: paramtable.GetNodeID(), + NodeID: node.GetNodeID(), Segments: segmentVersionInfos, Channels: channelVersionInfos, LeaderViews: leaderViews, @@ -1402,7 +1402,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) { log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), - zap.String("channel", req.GetChannel()), zap.Int64("currentNodeID", paramtable.GetNodeID())) + zap.String("channel", req.GetChannel()), zap.Int64("currentNodeID", node.GetNodeID())) // check node healthy if err := node.lifetime.Add(merr.IsHealthy); err != nil { return merr.Status(err), nil @@ -1410,7 +1410,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi defer node.lifetime.Done() // check target matches - if err := merr.CheckTargetID(req.GetBase()); err != nil { + if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil { return merr.Status(err), nil } @@ -1510,7 +1510,7 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) ( defer node.lifetime.Done() // check target matches - if err := merr.CheckTargetID(req.GetBase()); err != nil { + if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil { return merr.Status(err), nil } diff --git a/internal/querynodev2/tasks/task.go b/internal/querynodev2/tasks/task.go index ffb9df7900..791f8340b9 100644 --- a/internal/querynodev2/tasks/task.go +++ b/internal/querynodev2/tasks/task.go @@ -48,6 +48,7 @@ type SearchTask struct { originNqs []int64 others []*SearchTask notifier chan error + serverID int64 tr *timerecord.TimeRecorder scheduleSpan trace.Span @@ -57,6 +58,7 @@ func NewSearchTask(ctx context.Context, collection *segments.Collection, manager *segments.Manager, req *querypb.SearchRequest, + serverID int64, ) *SearchTask { ctx, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "schedule") return &SearchTask{ @@ -74,6 +76,7 @@ func NewSearchTask(ctx context.Context, notifier: make(chan error, 1), tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"), scheduleSpan: span, + serverID: serverID, } } @@ -83,13 +86,17 @@ func (t *SearchTask) Username() string { return t.req.Req.GetUsername() } +func (t *SearchTask) GetNodeID() int64 { + return t.serverID +} + func (t *SearchTask) IsGpuIndex() bool { return t.collection.IsGpuIndex() } func (t *SearchTask) PreExecute() error { // Update task wait time metric before execute - nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10) + nodeID := strconv.FormatInt(t.GetNodeID(), 10) inQueueDuration := t.tr.ElapseSpan() // Update in queue metric for prometheus. @@ -180,7 +187,7 @@ func (t *SearchTask) Execute() error { task.result = &internalpb.SearchResults{ Base: &commonpb.MsgBase{ - SourceID: paramtable.GetNodeID(), + SourceID: t.GetNodeID(), }, Status: merr.Success(), MetricType: metricType, @@ -211,7 +218,7 @@ func (t *SearchTask) Execute() error { } defer segments.DeleteSearchResultDataBlobs(blobs) metrics.QueryNodeReduceLatency.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), + fmt.Sprint(t.GetNodeID()), metrics.SearchLabel, metrics.ReduceSegments). Observe(float64(tr.RecordSpan().Milliseconds())) @@ -234,7 +241,7 @@ func (t *SearchTask) Execute() error { task.result = &internalpb.SearchResults{ Base: &commonpb.MsgBase{ - SourceID: paramtable.GetNodeID(), + SourceID: t.GetNodeID(), }, Status: merr.Success(), MetricType: metricType, @@ -294,9 +301,9 @@ func (t *SearchTask) Merge(other *SearchTask) bool { func (t *SearchTask) Done(err error) { if !t.merged { - metrics.QueryNodeSearchGroupSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.groupSize)) - metrics.QueryNodeSearchGroupNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.nq)) - metrics.QueryNodeSearchGroupTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.topk)) + metrics.QueryNodeSearchGroupSize.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.groupSize)) + metrics.QueryNodeSearchGroupNQ.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.nq)) + metrics.QueryNodeSearchGroupTopK.WithLabelValues(fmt.Sprint(t.GetNodeID())).Observe(float64(t.topk)) } t.notifier <- err for _, other := range t.others { diff --git a/internal/types/types.go b/internal/types/types.go index 3239ed8b81..f5852cba6f 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -75,6 +75,7 @@ type DataNodeComponent interface { SetAddress(address string) GetAddress() string + GetNodeID() int64 // SetEtcdClient set etcd client for DataNode SetEtcdClient(etcdClient *clientv3.Client) @@ -283,6 +284,7 @@ type QueryNodeComponent interface { SetAddress(address string) GetAddress() string + GetNodeID() int64 // SetEtcdClient set etcd client for QueryNode SetEtcdClient(etcdClient *clientv3.Client) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 4d4383d9c6..228c0500c9 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -293,9 +293,9 @@ func AnalyzeState(role string, nodeID int64, state *milvuspb.ComponentStates) er return nil } -func CheckTargetID(msg *commonpb.MsgBase) error { - if msg.GetTargetID() != paramtable.GetNodeID() { - return WrapErrNodeNotMatch(paramtable.GetNodeID(), msg.GetTargetID()) +func CheckTargetID(actualNodeID int64, msg *commonpb.MsgBase) error { + if msg.GetTargetID() != actualNodeID { + return WrapErrNodeNotMatch(actualNodeID, msg.GetTargetID()) } return nil diff --git a/tests/integration/datanode/datanode_test.go b/tests/integration/datanode/datanode_test.go new file mode 100644 index 0000000000..0fd620ce2f --- /dev/null +++ b/tests/integration/datanode/datanode_test.go @@ -0,0 +1,309 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datanode + +import ( + "context" + "fmt" + "math/rand" + "strconv" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/tests/integration" +) + +type DataNodeSuite struct { + integration.MiniClusterSuite + maxGoRoutineNum int + dim int + numCollections int + rowsPerCollection int + waitTimeInSec time.Duration + prefix string +} + +func (s *DataNodeSuite) setupParam() { + s.maxGoRoutineNum = 100 + s.dim = 128 + s.numCollections = 2 + s.rowsPerCollection = 100 + s.waitTimeInSec = time.Second * 1 +} + +func (s *DataNodeSuite) loadCollection(collectionName string) { + c := s.Cluster + dbName := "" + schema := integration.ConstructSchema(collectionName, s.dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + s.NoError(err) + + showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + + batchSize := 500000 + for start := 0; start < s.rowsPerCollection; start += batchSize { + rowNum := batchSize + if start+batchSize > s.rowsPerCollection { + rowNum = s.rowsPerCollection - start + } + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, s.dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.GetStatus())) + } + log.Info("=========================Data insertion finished=========================") + + // flush + flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName) + log.Info("=========================Data flush finished=========================") + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexFaissIvfFlat, metric.IP), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + s.NoError(err) + s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField) + log.Info("=========================Index created=========================") + + // load + loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + s.NoError(err) + s.WaitForLoad(context.TODO(), collectionName) + log.Info("=========================Collection loaded=========================") +} + +func (s *DataNodeSuite) checkCollections() bool { + req := &milvuspb.ShowCollectionsRequest{ + DbName: "", + TimeStamp: 0, // means now + } + resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req) + s.NoError(err) + s.Equal(len(resp.CollectionIds), s.numCollections) + notLoaded := 0 + loaded := 0 + for _, name := range resp.CollectionNames { + loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{ + DbName: "", + CollectionName: name, + }) + s.NoError(err) + if loadProgress.GetProgress() != int64(100) { + notLoaded++ + } else { + loaded++ + } + } + log.Info(fmt.Sprintf("loading status: %d/%d", loaded, len(resp.GetCollectionNames()))) + return notLoaded == 0 +} + +func (s *DataNodeSuite) search(collectionName string) { + c := s.Cluster + var err error + // Query + queryReq := &milvuspb.QueryRequest{ + Base: nil, + CollectionName: collectionName, + PartitionNames: nil, + Expr: "", + OutputFields: []string{"count(*)"}, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + } + queryResult, err := c.Proxy.Query(context.TODO(), queryReq) + s.NoError(err) + s.Equal(len(queryResult.FieldsData), 1) + numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0] + s.Equal(numEntities, int64(s.rowsPerCollection)) + + // Search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 10 + topk := 10 + roundDecimal := -1 + radius := 10 + + params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP) + params["radius"] = radius + searchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal) + + searchResult, _ := c.Proxy.Search(context.TODO(), searchReq) + + err = merr.Error(searchResult.GetStatus()) + s.NoError(err) +} + +func (s *DataNodeSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart int, wg *sync.WaitGroup) { + for idx := 0; idx < collectionBatchSize; idx++ { + collectionName := prefix + "_" + strconv.Itoa(idxStart+idx) + s.loadCollection(collectionName) + } + wg.Done() +} + +func (s *DataNodeSuite) setupData() { + // Add the second data node + s.Cluster.AddDataNode() + goRoutineNum := s.maxGoRoutineNum + if goRoutineNum > s.numCollections { + goRoutineNum = s.numCollections + } + collectionBatchSize := s.numCollections / goRoutineNum + log.Info(fmt.Sprintf("=========================test with dim=%d, s.rowsPerCollection=%d, s.numCollections=%d, goRoutineNum=%d==================", s.dim, s.rowsPerCollection, s.numCollections, goRoutineNum)) + log.Info("=========================Start to inject data=========================") + s.prefix = "TestDataNodeUtil" + funcutil.GenRandomStr() + searchName := s.prefix + "_0" + wg := sync.WaitGroup{} + for idx := 0; idx < goRoutineNum; idx++ { + wg.Add(1) + go s.insertBatchCollections(s.prefix, collectionBatchSize, idx*collectionBatchSize, &wg) + } + wg.Wait() + log.Info("=========================Data injection finished=========================") + s.checkCollections() + log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName)) + s.search(searchName) + log.Info("=========================Search finished=========================") + time.Sleep(s.waitTimeInSec) + s.checkCollections() + log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName)) + s.search(searchName) + log.Info("=========================Search2 finished=========================") + s.checkAllCollectionsReady() +} + +func (s *DataNodeSuite) checkAllCollectionsReady() { + goRoutineNum := s.maxGoRoutineNum + if goRoutineNum > s.numCollections { + goRoutineNum = s.numCollections + } + collectionBatchSize := s.numCollections / goRoutineNum + for i := 0; i < goRoutineNum; i++ { + for idx := 0; idx < collectionBatchSize; idx++ { + collectionName := s.prefix + "_" + strconv.Itoa(i*collectionBatchSize+idx) + s.search(collectionName) + queryReq := &milvuspb.QueryRequest{ + CollectionName: collectionName, + Expr: "", + OutputFields: []string{"count(*)"}, + } + _, err := s.Cluster.Proxy.Query(context.TODO(), queryReq) + s.NoError(err) + } + } +} + +func (s *DataNodeSuite) checkQNRestarts(idx int) { + // Stop all data nodes + s.Cluster.StopAllDataNodes() + // Add new data nodes. + qn1 := s.Cluster.AddDataNode() + qn2 := s.Cluster.AddDataNode() + time.Sleep(s.waitTimeInSec) + cn := fmt.Sprintf("new_collection_r_%d", idx) + s.loadCollection(cn) + s.search(cn) + // Randomly stop one data node. + if rand.Intn(2) == 0 { + qn1.Stop() + } else { + qn2.Stop() + } + time.Sleep(s.waitTimeInSec) + cn = fmt.Sprintf("new_collection_x_%d", idx) + s.loadCollection(cn) + s.search(cn) +} + +func (s *DataNodeSuite) TestSwapQN() { + s.setupParam() + s.setupData() + // Test case with new data nodes added + s.Cluster.AddDataNode() + s.Cluster.AddDataNode() + time.Sleep(s.waitTimeInSec) + cn := "new_collection_a" + s.loadCollection(cn) + s.search(cn) + + // Test case with all data nodes replaced + for idx := 0; idx < 5; idx++ { + s.checkQNRestarts(idx) + } +} + +func TestDataNodeUtil(t *testing.T) { + suite.Run(t, new(DataNodeSuite)) +} diff --git a/tests/integration/minicluster_v2.go b/tests/integration/minicluster_v2.go index 9a90bd7541..0f7328555e 100644 --- a/tests/integration/minicluster_v2.go +++ b/tests/integration/minicluster_v2.go @@ -26,6 +26,7 @@ import ( "github.com/cockroachdb/errors" clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -118,13 +119,20 @@ type MiniClusterV2 struct { IndexNode *grpcindexnode.Server MetaWatcher MetaWatcher + ptmu sync.Mutex + querynodes []*grpcquerynode.Server + qnid atomic.Int64 + datanodes []*grpcdatanode.Server + dnid atomic.Int64 } type OptionV2 func(cluster *MiniClusterV2) func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, error) { cluster := &MiniClusterV2{ - ctx: ctx, + ctx: ctx, + qnid: *atomic.NewInt64(10000), + dnid: *atomic.NewInt64(20000), } paramtable.Init() cluster.params = DefaultParams() @@ -238,6 +246,62 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, return cluster, nil } +func (cluster *MiniClusterV2) AddQueryNode() *grpcquerynode.Server { + cluster.ptmu.Lock() + defer cluster.ptmu.Unlock() + cluster.qnid.Inc() + id := cluster.qnid.Load() + oid := paramtable.GetNodeID() + log.Info(fmt.Sprintf("adding extra querynode with id:%d", id)) + paramtable.SetNodeID(id) + node, err := grpcquerynode.NewServer(context.TODO(), cluster.factory) + if err != nil { + return nil + } + err = node.Run() + if err != nil { + return nil + } + paramtable.SetNodeID(oid) + + req := &milvuspb.GetComponentStatesRequest{} + resp, err := node.GetComponentStates(context.TODO(), req) + if err != nil { + return nil + } + log.Info(fmt.Sprintf("querynode %d ComponentStates:%v", id, resp)) + cluster.querynodes = append(cluster.querynodes, node) + return node +} + +func (cluster *MiniClusterV2) AddDataNode() *grpcdatanode.Server { + cluster.ptmu.Lock() + defer cluster.ptmu.Unlock() + cluster.qnid.Inc() + id := cluster.qnid.Load() + oid := paramtable.GetNodeID() + log.Info(fmt.Sprintf("adding extra datanode with id:%d", id)) + paramtable.SetNodeID(id) + node, err := grpcdatanode.NewServer(context.TODO(), cluster.factory) + if err != nil { + return nil + } + err = node.Run() + if err != nil { + return nil + } + paramtable.SetNodeID(oid) + + req := &milvuspb.GetComponentStatesRequest{} + resp, err := node.GetComponentStates(context.TODO(), req) + if err != nil { + return nil + } + log.Info(fmt.Sprintf("datanode %d ComponentStates:%v", id, resp)) + cluster.datanodes = append(cluster.datanodes, node) + return node +} + func (cluster *MiniClusterV2) Start() error { log.Info("mini cluster start") err := cluster.RootCoord.Run() @@ -301,10 +365,8 @@ func (cluster *MiniClusterV2) Stop() error { cluster.Proxy.Stop() log.Info("mini cluster proxy stopped") - cluster.DataNode.Stop() - log.Info("mini cluster dataNode stopped") - cluster.QueryNode.Stop() - log.Info("mini cluster queryNode stopped") + cluster.StopAllDataNodes() + cluster.StopAllQueryNodes() cluster.IndexNode.Stop() log.Info("mini cluster indexNode stopped") @@ -323,6 +385,26 @@ func (cluster *MiniClusterV2) Stop() error { return nil } +func (cluster *MiniClusterV2) StopAllQueryNodes() { + cluster.QueryNode.Stop() + log.Info("mini cluster main queryNode stopped") + numExtraQN := len(cluster.querynodes) + for _, node := range cluster.querynodes { + node.Stop() + } + log.Info(fmt.Sprintf("mini cluster stoped %d extra querynode", numExtraQN)) +} + +func (cluster *MiniClusterV2) StopAllDataNodes() { + cluster.DataNode.Stop() + log.Info("mini cluster main dataNode stopped") + numExtraQN := len(cluster.datanodes) + for _, node := range cluster.datanodes { + node.Stop() + } + log.Info(fmt.Sprintf("mini cluster stoped %d extra datanode", numExtraQN)) +} + func (cluster *MiniClusterV2) GetContext() context.Context { return cluster.ctx } diff --git a/tests/integration/querynode/querynode_test.go b/tests/integration/querynode/querynode_test.go new file mode 100644 index 0000000000..558b632182 --- /dev/null +++ b/tests/integration/querynode/querynode_test.go @@ -0,0 +1,305 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querynode + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/tests/integration" +) + +type QueryNodeSuite struct { + integration.MiniClusterSuite + maxGoRoutineNum int + dim int + numCollections int + rowsPerCollection int + waitTimeInSec time.Duration + prefix string +} + +func (s *QueryNodeSuite) setupParam() { + s.maxGoRoutineNum = 100 + s.dim = 128 + s.numCollections = 2 + s.rowsPerCollection = 100 + s.waitTimeInSec = time.Second * 10 +} + +func (s *QueryNodeSuite) loadCollection(collectionName string, dim int) { + c := s.Cluster + dbName := "" + schema := integration.ConstructSchema(collectionName, dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + s.NoError(err) + + showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + + batchSize := 500000 + for start := 0; start < s.rowsPerCollection; start += batchSize { + rowNum := batchSize + if start+batchSize > s.rowsPerCollection { + rowNum = s.rowsPerCollection - start + } + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.GetStatus())) + } + log.Info("=========================Data insertion finished=========================") + + // flush + flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName) + log.Info("=========================Data flush finished=========================") + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + s.NoError(err) + s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField) + log.Info("=========================Index created=========================") + + // load + loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + s.NoError(err) + s.WaitForLoad(context.TODO(), collectionName) + log.Info("=========================Collection loaded=========================") +} + +func (s *QueryNodeSuite) checkCollections() bool { + req := &milvuspb.ShowCollectionsRequest{ + DbName: "", + TimeStamp: 0, // means now + } + resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req) + s.NoError(err) + s.Equal(len(resp.CollectionIds), s.numCollections) + notLoaded := 0 + loaded := 0 + for _, name := range resp.CollectionNames { + loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{ + DbName: "", + CollectionName: name, + }) + s.NoError(err) + if loadProgress.GetProgress() != int64(100) { + notLoaded++ + } else { + loaded++ + } + } + log.Info(fmt.Sprintf("loading status: %d/%d", loaded, len(resp.GetCollectionNames()))) + return notLoaded == 0 +} + +func (s *QueryNodeSuite) search(collectionName string, dim int) { + c := s.Cluster + var err error + // Query + queryReq := &milvuspb.QueryRequest{ + Base: nil, + CollectionName: collectionName, + PartitionNames: nil, + Expr: "", + OutputFields: []string{"count(*)"}, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + } + queryResult, err := c.Proxy.Query(context.TODO(), queryReq) + s.NoError(err) + s.Equal(len(queryResult.FieldsData), 1) + numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0] + s.Equal(numEntities, int64(s.rowsPerCollection)) + + // Search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 10 + topk := 10 + roundDecimal := -1 + radius := 10 + + params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP) + params["radius"] = radius + searchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal) + + searchResult, _ := c.Proxy.Search(context.TODO(), searchReq) + + err = merr.Error(searchResult.GetStatus()) + s.NoError(err) +} + +func (s *QueryNodeSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart, dim int, wg *sync.WaitGroup) { + for idx := 0; idx < collectionBatchSize; idx++ { + collectionName := s.prefix + "_" + strconv.Itoa(idxStart+idx) + s.loadCollection(collectionName, dim) + } + wg.Done() +} + +func (s *QueryNodeSuite) setupData() { + // Add the second query node + s.Cluster.AddQueryNode() + goRoutineNum := s.maxGoRoutineNum + if goRoutineNum > s.numCollections { + goRoutineNum = s.numCollections + } + collectionBatchSize := s.numCollections / goRoutineNum + log.Info(fmt.Sprintf("=========================test with s.dim=%d, s.rowsPerCollection=%d, s.numCollections=%d, goRoutineNum=%d==================", s.dim, s.rowsPerCollection, s.numCollections, goRoutineNum)) + log.Info("=========================Start to inject data=========================") + s.prefix = "TestQueryNodeUtil" + funcutil.GenRandomStr() + searchName := s.prefix + "_0" + wg := sync.WaitGroup{} + for idx := 0; idx < goRoutineNum; idx++ { + wg.Add(1) + go s.insertBatchCollections(s.prefix, collectionBatchSize, idx*collectionBatchSize, s.dim, &wg) + } + wg.Wait() + log.Info("=========================Data injection finished=========================") + s.checkCollections() + log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName)) + s.search(searchName, s.dim) + log.Info("=========================Search finished=========================") + time.Sleep(s.waitTimeInSec) + s.checkCollections() + log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName)) + s.search(searchName, s.dim) + log.Info("=========================Search2 finished=========================") + s.checkAllCollectionsReady() +} + +func (s *QueryNodeSuite) checkAllCollectionsReady() { + goRoutineNum := s.maxGoRoutineNum + if goRoutineNum > s.numCollections { + goRoutineNum = s.numCollections + } + collectionBatchSize := s.numCollections / goRoutineNum + for i := 0; i < goRoutineNum; i++ { + for idx := 0; idx < collectionBatchSize; idx++ { + collectionName := s.prefix + "_" + strconv.Itoa(i*collectionBatchSize+idx) + s.search(collectionName, s.dim) + queryReq := &milvuspb.QueryRequest{ + CollectionName: collectionName, + Expr: "", + OutputFields: []string{"count(*)"}, + } + _, err := s.Cluster.Proxy.Query(context.TODO(), queryReq) + s.NoError(err) + } + } +} + +func (s *QueryNodeSuite) checkQNRestarts() { + // Stop all query nodes + s.Cluster.StopAllQueryNodes() + // Add new Query nodes. + s.Cluster.AddQueryNode() + s.Cluster.AddQueryNode() + + time.Sleep(s.waitTimeInSec) + for i := 0; i < 1000; i++ { + time.Sleep(s.waitTimeInSec) + if s.checkCollections() { + break + } + } + s.checkAllCollectionsReady() +} + +func (s *QueryNodeSuite) TestSwapQN() { + s.setupParam() + s.setupData() + // Test case with one query node stopped + s.Cluster.QueryNode.Stop() + time.Sleep(s.waitTimeInSec) + s.checkAllCollectionsReady() + // Test case with new Query nodes added + s.Cluster.AddQueryNode() + s.Cluster.AddQueryNode() + time.Sleep(s.waitTimeInSec) + s.checkAllCollectionsReady() + + // Test case with all query nodes replaced + for idx := 0; idx < 2; idx++ { + s.checkQNRestarts() + } +} + +func TestQueryNodeUtil(t *testing.T) { + suite.Run(t, new(QueryNodeSuite)) +}