diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index 3dd7188974..848f57dace 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -202,10 +202,10 @@ func (s *Server) Stop() error { return err } } - s.loopCancel() if s.indexnode != nil { s.indexnode.Stop() } + s.loopCancel() if s.etcdCli != nil { defer s.etcdCli.Close() } diff --git a/internal/indexcoord/index_coord.go b/internal/indexcoord/index_coord.go index 8cffdb30d6..a9970ae85c 100644 --- a/internal/indexcoord/index_coord.go +++ b/internal/indexcoord/index_coord.go @@ -1084,6 +1084,10 @@ func (i *IndexCoord) watchNodeLoop() { } }() i.metricsCacheManager.InvalidateSystemInfoMetrics() + case sessionutil.SessionUpdateEvent: + serverID := event.Session.ServerID + log.Info("IndexCoord watchNodeLoop SessionUpdateEvent", zap.Int64("serverID", serverID)) + i.nodeManager.StoppingNode(serverID) case sessionutil.SessionDelEvent: serverID := event.Session.ServerID log.Info("IndexCoord watchNodeLoop SessionDelEvent", zap.Int64("serverID", serverID)) diff --git a/internal/indexcoord/index_coord_test.go b/internal/indexcoord/index_coord_test.go index a76130adae..44eb8fcf54 100644 --- a/internal/indexcoord/index_coord_test.go +++ b/internal/indexcoord/index_coord_test.go @@ -28,7 +28,6 @@ import ( "time" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/common" @@ -44,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" ) @@ -517,6 +517,81 @@ func testIndexCoord(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) + t.Run("WatchNodeState", func(t *testing.T) { + allClients := ic.nodeManager.GetAllClients() + nodeSession := sessionutil.NewSession(context.Background(), Params.EtcdCfg.MetaRootPath, etcdCli) + originNodeID := Params.IndexCoordCfg.GetNodeID() + defer func() { + Params.IndexCoordCfg.SetNodeID(originNodeID) + }() + Params.IndexCoordCfg.SetNodeID(100) + nodeSession.Init(typeutil.IndexNodeRole, "127.0.0.1:11111", false, true) + nodeSession.Register() + + addNodeChan := make(chan struct{}) + go func() { + for { + time.Sleep(200 * time.Millisecond) + if len(ic.nodeManager.GetAllClients()) > len(allClients) { + close(addNodeChan) + break + } + } + }() + select { + case <-addNodeChan: + case <-time.After(10 * time.Second): + assert.Fail(t, "fail to add node") + } + var newNodeID UniqueID = -1 + for id := range ic.nodeManager.GetAllClients() { + if _, ok := allClients[id]; !ok { + newNodeID = id + break + } + } + + nodeSession.GoingStop() + stoppingNodeChan := make(chan struct{}) + go func() { + for { + time.Sleep(200 * time.Millisecond) + ic.nodeManager.lock.RLock() + _, ok := ic.nodeManager.stoppingNodes[newNodeID] + ic.nodeManager.lock.RUnlock() + if ok { + close(stoppingNodeChan) + break + } + } + }() + select { + case <-stoppingNodeChan: + case <-time.After(10 * time.Second): + assert.Fail(t, "fail to stop node") + } + + nodeSession.Revoke(time.Second) + deleteNodeChan := make(chan struct{}) + go func() { + for { + time.Sleep(200 * time.Millisecond) + ic.nodeManager.lock.RLock() + _, ok := ic.nodeManager.stoppingNodes[newNodeID] + ic.nodeManager.lock.RUnlock() + if !ok { + close(deleteNodeChan) + break + } + } + }() + select { + case <-deleteNodeChan: + case <-time.After(10 * time.Second): + assert.Fail(t, "fail to stop node") + } + }) + // Stop IndexCoord err = ic.Stop() assert.NoError(t, err) diff --git a/internal/indexcoord/node_manager.go b/internal/indexcoord/node_manager.go index fbc2c358e8..0942edfc63 100644 --- a/internal/indexcoord/node_manager.go +++ b/internal/indexcoord/node_manager.go @@ -34,16 +34,18 @@ import ( // NodeManager is used by IndexCoord to manage the client of IndexNode. type NodeManager struct { - nodeClients map[UniqueID]types.IndexNode - pq *PriorityQueue - lock sync.RWMutex - ctx context.Context + nodeClients map[UniqueID]types.IndexNode + stoppingNodes map[UniqueID]struct{} + pq *PriorityQueue + lock sync.RWMutex + ctx context.Context } // NewNodeManager is used to create a new NodeManager. func NewNodeManager(ctx context.Context) *NodeManager { return &NodeManager{ - nodeClients: make(map[UniqueID]types.IndexNode), + nodeClients: make(map[UniqueID]types.IndexNode), + stoppingNodes: make(map[UniqueID]struct{}), pq: &PriorityQueue{ policy: PeekClientV1, }, @@ -73,11 +75,19 @@ func (nm *NodeManager) RemoveNode(nodeID UniqueID) { log.Info("IndexCoord", zap.Any("Remove node with ID", nodeID)) nm.lock.Lock() delete(nm.nodeClients, nodeID) + delete(nm.stoppingNodes, nodeID) nm.lock.Unlock() nm.pq.Remove(nodeID) metrics.IndexCoordIndexNodeNum.WithLabelValues().Dec() } +func (nm *NodeManager) StoppingNode(nodeID UniqueID) { + log.Info("IndexCoord", zap.Any("Stopping node with ID", nodeID)) + nm.lock.Lock() + defer nm.lock.Unlock() + nm.stoppingNodes[nodeID] = struct{}{} +} + // AddNode adds the client of IndexNode. func (nm *NodeManager) AddNode(nodeID UniqueID, address string) error { @@ -224,7 +234,9 @@ func (nm *NodeManager) GetAllClients() map[UniqueID]types.IndexNode { allClients := make(map[UniqueID]types.IndexNode, len(nm.nodeClients)) for nodeID, client := range nm.nodeClients { - allClients[nodeID] = client + if _, ok := nm.stoppingNodes[nodeID]; !ok { + allClients[nodeID] = client + } } return allClients diff --git a/internal/indexcoord/node_manager_test.go b/internal/indexcoord/node_manager_test.go index d7868b9cc6..790cea3cb7 100644 --- a/internal/indexcoord/node_manager_test.go +++ b/internal/indexcoord/node_manager_test.go @@ -249,3 +249,18 @@ func TestNodeManager_ClientSupportDisk(t *testing.T) { assert.False(t, support) }) } + +func TestNodeManager_StoppingNode(t *testing.T) { + nm := NewNodeManager(context.Background()) + err := nm.AddNode(1, "indexnode-1") + assert.NoError(t, err) + assert.Equal(t, 1, len(nm.GetAllClients())) + + nm.StoppingNode(1) + assert.Equal(t, 0, len(nm.GetAllClients())) + assert.Equal(t, 1, len(nm.stoppingNodes)) + + nm.RemoveNode(1) + assert.Equal(t, 0, len(nm.GetAllClients())) + assert.Equal(t, 0, len(nm.stoppingNodes)) +} diff --git a/internal/indexnode/indexnode.go b/internal/indexnode/indexnode.go index 1ed2541cad..c715082593 100644 --- a/internal/indexnode/indexnode.go +++ b/internal/indexnode/indexnode.go @@ -39,15 +39,13 @@ import ( "time" "unsafe" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/commonpbutil" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/hardware" "github.com/milvus-io/milvus/internal/util/initcore" @@ -55,6 +53,8 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" ) // TODO add comments @@ -84,7 +84,8 @@ type IndexNode struct { sched *TaskScheduler - once sync.Once + once sync.Once + stopOnce sync.Once factory dependency.Factory storageFactory StorageFactory @@ -229,22 +230,34 @@ func (i *IndexNode) Start() error { // Stop closes the server. func (i *IndexNode) Stop() error { - // https://github.com/milvus-io/milvus/issues/12282 - i.UpdateStateCode(commonpb.StateCode_Abnormal) - // cleanup all running tasks - deletedTasks := i.deleteAllTasks() - for _, task := range deletedTasks { - if task.cancel != nil { - task.cancel() + i.stopOnce.Do(func() { + i.UpdateStateCode(commonpb.StateCode_Stopping) + log.Info("Index node stopping") + err := i.session.GoingStop() + if err != nil { + log.Warn("session fail to go stopping state", zap.Error(err)) + } else { + i.waitTaskFinish() } - } - i.loopCancel() - if i.sched != nil { - i.sched.Close() - } - i.session.Revoke(time.Second) - log.Debug("Index node stopped.") + // https://github.com/milvus-io/milvus/issues/12282 + i.UpdateStateCode(commonpb.StateCode_Abnormal) + log.Info("Index node abnormal") + // cleanup all running tasks + deletedTasks := i.deleteAllTasks() + for _, task := range deletedTasks { + if task.cancel != nil { + task.cancel() + } + } + i.loopCancel() + if i.sched != nil { + i.sched.Close() + } + i.session.Revoke(time.Second) + + log.Info("Index node stopped.") + }) return nil } @@ -258,86 +271,6 @@ func (i *IndexNode) SetEtcdClient(client *clientv3.Client) { i.etcdCli = client } -func (i *IndexNode) isHealthy() bool { - code := i.stateCode.Load().(commonpb.StateCode) - return code == commonpb.StateCode_Healthy -} - -//// BuildIndex receives request from IndexCoordinator to build an index. -//// Index building is asynchronous, so when an index building request comes, IndexNode records the task and returns. -//func (i *IndexNode) BuildIndex(ctx context.Context, request *indexpb.BuildIndexRequest) (*commonpb.Status, error) { -// if i.stateCode.Load().(commonpb.StateCode) != commonpb.StateCode_Healthy { -// return &commonpb.Status{ -// ErrorCode: commonpb.ErrorCode_UnexpectedError, -// Reason: "state code is not healthy", -// }, nil -// } -// log.Info("IndexNode building index ...", -// zap.Int64("clusterID", request.ClusterID), -// zap.Int64("IndexBuildID", request.IndexBuildID), -// zap.Int64("Version", request.IndexVersion), -// zap.Int("binlog paths num", len(request.DataPaths)), -// zap.Any("TypeParams", request.TypeParams), -// zap.Any("IndexParams", request.IndexParams)) -// -// sp, ctx2 := trace.StartSpanFromContextWithOperationName(i.loopCtx, "IndexNode-CreateIndex") -// defer sp.Finish() -// sp.SetTag("IndexBuildID", strconv.FormatInt(request.IndexBuildID, 10)) -// metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(Params.IndexNodeCfg.GetNodeID(), 10), metrics.TotalLabel).Inc() -// -// t := &IndexBuildTask{ -// BaseTask: BaseTask{ -// ctx: ctx2, -// done: make(chan error), -// }, -// req: request, -// cm: i.chunkManager, -// etcdKV: i.etcdKV, -// nodeID: Params.IndexNodeCfg.GetNodeID(), -// serializedSize: 0, -// } -// -// ret := &commonpb.Status{ -// ErrorCode: commonpb.ErrorCode_Success, -// } -// -// err := i.sched.IndexBuildQueue.Enqueue(t) -// if err != nil { -// log.Warn("IndexNode failed to schedule", zap.Int64("indexBuildID", request.IndexBuildID), zap.Error(err)) -// ret.ErrorCode = commonpb.ErrorCode_UnexpectedError -// ret.Reason = err.Error() -// metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(Params.IndexNodeCfg.GetNodeID(), 10), metrics.FailLabel).Inc() -// return ret, nil -// } -// log.Info("IndexNode successfully scheduled", zap.Int64("indexBuildID", request.IndexBuildID)) -// -// metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(Params.IndexNodeCfg.GetNodeID(), 10), metrics.SuccessLabel).Inc() -// return ret, nil -//} -// -//// GetTaskSlots gets how many task the IndexNode can still perform. -//func (i *IndexNode) GetTaskSlots(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { -// if i.stateCode.Load().(commonpb.StateCode) != commonpb.StateCode_Healthy { -// return &indexpb.GetTaskSlotsResponse{ -// Status: &commonpb.Status{ -// ErrorCode: commonpb.ErrorCode_UnexpectedError, -// Reason: "state code is not healthy", -// }, -// }, nil -// } -// -// log.Info("IndexNode GetTaskSlots received") -// ret := &indexpb.GetTaskSlotsResponse{ -// Status: &commonpb.Status{ -// ErrorCode: commonpb.ErrorCode_Success, -// }, -// } -// -// ret.Slots = int64(i.sched.GetTaskSlots()) -// log.Info("IndexNode GetTaskSlots success", zap.Int64("slots", ret.Slots)) -// return ret, nil -//} - // GetComponentStates gets the component states of IndexNode. func (i *IndexNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { log.Debug("get IndexNode components states ...") @@ -394,7 +327,7 @@ func (i *IndexNode) GetNodeID() int64 { // ShowConfigurations returns the configurations of indexNode matching req.Pattern func (i *IndexNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - if !i.isHealthy() { + if !commonpbutil.IsHealthyOrStopping(i.stateCode) { log.Warn("IndexNode.ShowConfigurations failed", zap.Int64("nodeId", Params.IndexNodeCfg.GetNodeID()), zap.String("req", req.Pattern), diff --git a/internal/indexnode/indexnode_service.go b/internal/indexnode/indexnode_service.go index b3f53b0b75..18b703bd65 100644 --- a/internal/indexnode/indexnode_service.go +++ b/internal/indexnode/indexnode_service.go @@ -1,3 +1,19 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package indexnode import ( @@ -6,22 +22,22 @@ import ( "strconv" "github.com/golang/protobuf/proto" - "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/util/commonpbutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/trace" + "go.uber.org/zap" ) func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { - stateCode := i.stateCode.Load().(commonpb.StateCode) - if stateCode != commonpb.StateCode_Healthy { + if !commonpbutil.IsHealthy(i.stateCode) { + stateCode := i.stateCode.Load().(commonpb.StateCode) log.Ctx(ctx).Warn("index node not ready", zap.Int32("state", int32(stateCode)), zap.String("ClusterID", req.ClusterID), zap.Int64("IndexBuildID", req.BuildID)) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -94,8 +110,8 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest } func (i *IndexNode) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { - stateCode := i.stateCode.Load().(commonpb.StateCode) - if stateCode != commonpb.StateCode_Healthy { + if !commonpbutil.IsHealthyOrStopping(i.stateCode) { + stateCode := i.stateCode.Load().(commonpb.StateCode) log.Ctx(ctx).Warn("index node not ready", zap.Int32("state", int32(stateCode)), zap.String("ClusterID", req.ClusterID)) return &indexpb.QueryJobsResponse{ Status: &commonpb.Status{ @@ -144,9 +160,9 @@ func (i *IndexNode) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest } func (i *IndexNode) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) (*commonpb.Status, error) { - log.Ctx(ctx).Debug("drop index build jobs", zap.String("ClusterID", req.ClusterID), zap.Int64s("IndexBuildIDs", req.BuildIDs)) - stateCode := i.stateCode.Load().(commonpb.StateCode) - if stateCode != commonpb.StateCode_Healthy { + log.Ctx(ctx).Info("drop index build jobs", zap.String("ClusterID", req.ClusterID), zap.Int64s("IndexBuildIDs", req.BuildIDs)) + if !commonpbutil.IsHealthyOrStopping(i.stateCode) { + stateCode := i.stateCode.Load().(commonpb.StateCode) log.Ctx(ctx).Warn("index node not ready", zap.Int32("state", int32(stateCode)), zap.String("ClusterID", req.ClusterID)) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -172,8 +188,8 @@ func (i *IndexNode) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) } func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - stateCode := i.stateCode.Load().(commonpb.StateCode) - if stateCode != commonpb.StateCode_Healthy { + if !commonpbutil.IsHealthyOrStopping(i.stateCode) { + stateCode := i.stateCode.Load().(commonpb.StateCode) log.Ctx(ctx).Warn("index node not ready", zap.Int32("state", int32(stateCode))) return &indexpb.GetJobStatsResponse{ Status: &commonpb.Status{ @@ -211,7 +227,7 @@ func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsReq // GetMetrics gets the metrics info of IndexNode. // TODO(dragondriver): cache the Metrics and set a retention to the cache func (i *IndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - if !i.isHealthy() { + if !commonpbutil.IsHealthyOrStopping(i.stateCode) { log.Ctx(ctx).Warn("IndexNode.GetMetrics failed", zap.Int64("nodeID", i.GetNodeID()), zap.String("req", req.Request), diff --git a/internal/indexnode/indexnode_service_test.go b/internal/indexnode/indexnode_service_test.go index 781e597ee9..5f544b9568 100644 --- a/internal/indexnode/indexnode_service_test.go +++ b/internal/indexnode/indexnode_service_test.go @@ -8,14 +8,13 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/internal/util/metautil" - - "github.com/stretchr/testify/assert" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/util/metautil" "github.com/milvus-io/milvus/internal/util/metricsinfo" + "github.com/stretchr/testify/assert" ) func genStorageConfig() *indexpb.StorageConfig { @@ -358,6 +357,10 @@ func TestAbnormalIndexNode(t *testing.T) { metricsResp, err := in.GetMetrics(ctx, &milvuspb.GetMetricsRequest{}) assert.Nil(t, err) assert.Equal(t, metricsResp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + + configurationResp, err := in.ShowConfigurations(ctx, &internalpb.ShowConfigurationsRequest{}) + assert.Nil(t, err) + assert.Equal(t, configurationResp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) } func TestGetMetrics(t *testing.T) { diff --git a/internal/indexnode/indexnode_test.go b/internal/indexnode/indexnode_test.go index 08e2852f71..927a2e9b5b 100644 --- a/internal/indexnode/indexnode_test.go +++ b/internal/indexnode/indexnode_test.go @@ -20,6 +20,7 @@ import ( "context" "os" "testing" + "time" "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/stretchr/testify/assert" @@ -480,6 +481,7 @@ func TestComponentState(t *testing.T) { assert.Equal(t, state.Status.ErrorCode, commonpb.ErrorCode_Success) assert.Equal(t, state.State.StateCode, commonpb.StateCode_Healthy) + assert.Nil(t, in.Stop()) assert.Nil(t, in.Stop()) state, err = in.GetComponentStates(ctx) assert.Nil(t, err) @@ -518,6 +520,41 @@ func TestGetStatisticChannel(t *testing.T) { assert.Equal(t, ret.Status.ErrorCode, commonpb.ErrorCode_Success) } +func TestIndexTaskWhenStoppingNode(t *testing.T) { + var ( + factory = &mockFactory{ + chunkMgr: &mockChunkmgr{}, + } + ctx = context.TODO() + ) + Params.Init() + in, err := NewIndexNode(ctx, factory) + assert.Nil(t, err) + + in.loadOrStoreTask("cluster-1", 1, &taskInfo{ + state: commonpb.IndexState_InProgress, + }) + in.loadOrStoreTask("cluster-2", 2, &taskInfo{ + state: commonpb.IndexState_Finished, + }) + + assert.True(t, in.hasInProgressTask()) + go func() { + time.Sleep(2 * time.Second) + in.storeTaskState("cluster-1", 1, commonpb.IndexState_Finished, "") + }() + noTaskChan := make(chan struct{}) + go func() { + in.waitTaskFinish() + close(noTaskChan) + }() + select { + case <-noTaskChan: + case <-time.After(5 * time.Second): + assert.Fail(t, "timeout task chan") + } +} + func TestInitErr(t *testing.T) { // var ( // factory = &mockFactory{} diff --git a/internal/indexnode/taskinfo_ops.go b/internal/indexnode/taskinfo_ops.go index 69b1745fcd..e442ae879c 100644 --- a/internal/indexnode/taskinfo_ops.go +++ b/internal/indexnode/taskinfo_ops.go @@ -1,13 +1,14 @@ package indexnode import ( - "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/internal/common" - "go.uber.org/zap" + "time" + "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/indexpb" + "go.uber.org/zap" ) func (i *IndexNode) loadOrStoreTask(ClusterID string, buildID UniqueID, info *taskInfo) *taskInfo { @@ -91,3 +92,40 @@ func (i *IndexNode) deleteAllTasks() []*taskInfo { } return deleted } + +func (i *IndexNode) hasInProgressTask() bool { + i.stateLock.Lock() + defer i.stateLock.Unlock() + for _, info := range i.tasks { + if info.state == commonpb.IndexState_InProgress { + return true + } + } + return false +} + +func (i *IndexNode) waitTaskFinish() { + if !i.hasInProgressTask() { + return + } + + gracefulTimeout := Params.IndexNodeCfg.GracefulStopTimeout + timer := time.NewTimer(time.Duration(gracefulTimeout) * time.Second) + + for { + select { + case <-time.Tick(time.Second): + if !i.hasInProgressTask() { + return + } + case <-timer.C: + log.Warn("timeout, the index node has some progress task") + for _, info := range i.tasks { + if info.state == commonpb.IndexState_InProgress { + log.Warn("progress task", zap.Any("info", info)) + } + } + return + } + } +} diff --git a/internal/util/commonpbutil/commonpbutil.go b/internal/util/commonpbutil/commonpbutil.go index edaf9f0555..3c02136148 100644 --- a/internal/util/commonpbutil/commonpbutil.go +++ b/internal/util/commonpbutil/commonpbutil.go @@ -17,6 +17,7 @@ package commonpbutil import ( + "sync/atomic" "time" "github.com/milvus-io/milvus-proto/go-api/commonpb" @@ -100,3 +101,19 @@ func UpdateMsgBase(msgBase *commonpb.MsgBase, options ...MsgBaseOptions) *common } return msgBaseRt } + +func IsHealthy(stateCode atomic.Value) bool { + code, ok := stateCode.Load().(commonpb.StateCode) + if !ok { + return false + } + return code == commonpb.StateCode_Healthy +} + +func IsHealthyOrStopping(stateCode atomic.Value) bool { + code, ok := stateCode.Load().(commonpb.StateCode) + if !ok { + return false + } + return code == commonpb.StateCode_Healthy || code == commonpb.StateCode_Stopping +} diff --git a/internal/util/commonpbutil/commonpbutil_test.go b/internal/util/commonpbutil/commonpbutil_test.go new file mode 100644 index 0000000000..5a73cc5b1a --- /dev/null +++ b/internal/util/commonpbutil/commonpbutil_test.go @@ -0,0 +1,79 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package commonpbutil + +import ( + "sync/atomic" + "testing" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/stretchr/testify/assert" +) + +func TestIsHealthy(t *testing.T) { + { + v := atomic.Value{} + v.Store(1) + assert.False(t, IsHealthy(v)) + } + + { + v := atomic.Value{} + v.Store(commonpb.StateCode_Abnormal) + assert.False(t, IsHealthy(v)) + } + + { + v := atomic.Value{} + v.Store(commonpb.StateCode_Stopping) + assert.False(t, IsHealthy(v)) + } + + { + v := atomic.Value{} + v.Store(commonpb.StateCode_Healthy) + assert.True(t, IsHealthy(v)) + } +} + +func TestIsHealthyOrStopping(t *testing.T) { + { + v := atomic.Value{} + v.Store(1) + assert.False(t, IsHealthyOrStopping(v)) + } + + { + v := atomic.Value{} + v.Store(commonpb.StateCode_Abnormal) + assert.False(t, IsHealthyOrStopping(v)) + } + + { + v := atomic.Value{} + v.Store(commonpb.StateCode_Stopping) + assert.True(t, IsHealthyOrStopping(v)) + } + + { + v := atomic.Value{} + v.Store(commonpb.StateCode_Healthy) + assert.True(t, IsHealthyOrStopping(v)) + } +} diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index f2c1c924f0..c82cce539c 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -1661,6 +1661,8 @@ type indexNodeConfig struct { EnableDisk bool DiskCapacityLimit int64 MaxDiskUsagePercentage float64 + + GracefulStopTimeout int64 } func (p *indexNodeConfig) init(base *BaseTable) { @@ -1670,6 +1672,7 @@ func (p *indexNodeConfig) init(base *BaseTable) { p.initEnableDisk() p.initDiskCapacity() p.initMaxDiskUsagePercentage() + p.initGracefulStopTimeout() } // InitAlias initializes an alias for the IndexNode role. @@ -1729,3 +1732,13 @@ func (p *indexNodeConfig) initMaxDiskUsagePercentage() { } p.MaxDiskUsagePercentage = float64(maxDiskUsagePercentage) / 100 } + +func (p *indexNodeConfig) initGracefulStopTimeout() { + timeout := p.Base.LoadWithDefault2([]string{"indexNode.gracefulStopTimeout", "common.gracefulStopTimeout"}, + strconv.FormatInt(DefaultGracefulStopTimeout, 10)) + var err error + p.GracefulStopTimeout, err = strconv.ParseInt(timeout, 10, 64) + if err != nil { + panic(err) + } +} diff --git a/internal/util/paramtable/component_param_test.go b/internal/util/paramtable/component_param_test.go index 7f911a041d..17de238b8a 100644 --- a/internal/util/paramtable/component_param_test.go +++ b/internal/util/paramtable/component_param_test.go @@ -70,12 +70,15 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, Params.GracefulStopTimeout, int64(DefaultGracefulStopTimeout)) assert.Equal(t, CParams.QueryNodeCfg.GracefulStopTimeout, Params.GracefulStopTimeout) + assert.Equal(t, CParams.IndexNodeCfg.GracefulStopTimeout, Params.GracefulStopTimeout) t.Logf("default grafeful stop timeout = %d", Params.GracefulStopTimeout) Params.Base.Save("common.gracefulStopTimeout", "50") Params.initGracefulStopTimeout() assert.Equal(t, Params.GracefulStopTimeout, int64(50)) CParams.QueryNodeCfg.initGracefulStopTimeout() assert.Equal(t, CParams.QueryNodeCfg.GracefulStopTimeout, int64(50)) + CParams.IndexNodeCfg.initGracefulStopTimeout() + assert.Equal(t, CParams.IndexNodeCfg.GracefulStopTimeout, int64(50)) // -- proxy -- assert.Equal(t, Params.ProxySubName, "by-dev-proxy") @@ -379,5 +382,9 @@ func TestComponentParam(t *testing.T) { Params.UpdatedTime = time.Now() t.Logf("UpdatedTime: %v", Params.UpdatedTime) + + Params.Base.Save("indexNode.gracefulStopTimeout", "100") + Params.initGracefulStopTimeout() + assert.Equal(t, Params.GracefulStopTimeout, int64(100)) }) }