From 66027790a2b453234079378033437d374b1c0bab Mon Sep 17 00:00:00 2001 From: congqixia Date: Sun, 29 Jan 2023 17:45:49 +0800 Subject: [PATCH] Implement detailed lifetime control for querynode (#21851) Signed-off-by: Congqi Xia --- internal/querynode/impl.go | 93 ++++++++-------------- internal/querynode/impl_test.go | 13 +-- internal/querynode/query_node.go | 13 ++- internal/util/lifetime/lifetime.go | 101 ++++++++++++++++++++++++ internal/util/lifetime/lifetime_test.go | 64 +++++++++++++++ 5 files changed, 209 insertions(+), 75 deletions(-) create mode 100644 internal/util/lifetime/lifetime.go create mode 100644 internal/util/lifetime/lifetime_test.go diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 8295d05df9..63b1806a74 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -47,13 +47,11 @@ import ( ) // isHealthy checks if QueryNode is healthy -func (node *QueryNode) isHealthy() bool { - code := node.stateCode.Load().(commonpb.StateCode) +func (node *QueryNode) isHealthy(code commonpb.StateCode) bool { return code == commonpb.StateCode_Healthy } -func (node *QueryNode) isHealthyOrStopping() bool { - code := node.stateCode.Load().(commonpb.StateCode) +func (node *QueryNode) isHealthyOrStopping(code commonpb.StateCode) bool { return code == commonpb.StateCode_Healthy || code == commonpb.StateCode_Stopping } @@ -64,15 +62,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon ErrorCode: commonpb.ErrorCode_Success, }, } - code, ok := node.stateCode.Load().(commonpb.StateCode) - if !ok { - errMsg := "unexpected error in type assertion" - stats.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: errMsg, - } - return stats, nil - } + code := node.lifetime.GetState() nodeID := common.NotRegisteredID if node.session != nil && node.session.Registered() { nodeID = node.GetSession().ServerID @@ -83,7 +73,7 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon StateCode: code, } stats.State = info - log.Debug("Get QueryNode component state done", zap.Any("stateCode", info.StateCode)) + log.Debug("Get QueryNode component state done", zap.String("stateCode", info.StateCode.String())) return stats, nil } @@ -171,12 +161,11 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que }, } - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { failRet.Status.Reason = msgQueryNodeIsUnhealthy(node.GetSession().ServerID) return failRet, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() traceID := trace.SpanFromContext(ctx).SpanContext().TraceID() log.Ctx(ctx).Debug("received GetStatisticRequest", @@ -301,7 +290,7 @@ func (node *QueryNode) getStatisticsWithDmlChannel(ctx context.Context, req *que func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { nodeID := node.GetSession().ServerID // check node healthy - if !node.isHealthy() { + if !node.lifetime.Add(node.isHealthy) { err := fmt.Errorf("query node %d is not ready", nodeID) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -309,8 +298,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC } return status, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() // check target matches if in.GetBase().GetTargetID() != nodeID { @@ -392,7 +380,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { // check node healthy nodeID := node.GetSession().ServerID - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { err := fmt.Errorf("query node %d is not ready", nodeID) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -400,8 +388,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC } return status, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() // check target matches if req.GetBase().GetTargetID() != nodeID { @@ -452,7 +439,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { nodeID := node.GetSession().ServerID // check node healthy - if !node.isHealthy() { + if !node.lifetime.Add(node.isHealthy) { err := fmt.Errorf("query node %d is not ready", nodeID) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -460,8 +447,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment } return status, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() // check target matches if in.GetBase().GetTargetID() != nodeID { @@ -538,7 +524,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment // ReleaseCollection clears all data related to this collection on the querynode func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -546,8 +532,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas } return status, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() dct := &releaseCollectionTask{ baseTask: baseTask{ @@ -586,7 +571,7 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas // ReleasePartitions clears all data related to this partition on the querynode func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -594,8 +579,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas } return status, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() dct := &releasePartitionsTask{ baseTask: baseTask{ @@ -635,7 +619,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas // ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { nodeID := node.GetSession().ServerID - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { err := fmt.Errorf("query node %d is not ready", nodeID) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -643,8 +627,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS } return status, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() // check target matches if in.GetBase().GetTargetID() != nodeID { @@ -684,7 +667,7 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS // GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ... func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID) res := &querypb.GetSegmentInfoResponse{ Status: &commonpb.Status{ @@ -694,8 +677,7 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen } return res, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() var segmentInfos []*querypb.SegmentInfo @@ -828,12 +810,11 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc() } }() - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID) return failRet, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() if node.queryShardService == nil { failRet.Status.Reason = "queryShardService is nil" @@ -979,12 +960,11 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel, metrics.FailLabel).Inc() } }() - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { failRet.Status.Reason = msgQueryNodeIsUnhealthy(nodeID) return failRet, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() log.Ctx(ctx).Debug("queryWithDmlChannel receives query request", zap.Bool("fromShardLeader", req.GetFromShardLeader()), @@ -1197,14 +1177,13 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i // SyncReplicaSegments syncs replica node & segments states func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { - if !node.isHealthy() { + if !node.lifetime.Add(node.isHealthy) { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: msgQueryNodeIsUnhealthy(node.GetSession().ServerID), }, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() log.Info("Received SyncReplicaSegments request", zap.String("vchannelName", req.GetVchannelName())) @@ -1225,7 +1204,7 @@ func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.Syn // ShowConfigurations returns the configurations of queryNode matching req.Pattern func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { nodeID := node.GetSession().ServerID - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { log.Warn("QueryNode.ShowConfigurations failed", zap.Int64("nodeId", nodeID), zap.String("req", req.Pattern), @@ -1239,8 +1218,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S Configuations: nil, }, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() configList := make([]*commonpb.KeyValuePair, 0) for key, value := range Params.GetComponentConfigurations("querynode", req.Pattern) { @@ -1263,7 +1241,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S // GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ... func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { nodeID := node.GetSession().ServerID - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { log.Ctx(ctx).Warn("QueryNode.GetMetrics failed", zap.Int64("nodeId", nodeID), zap.String("req", req.Request), @@ -1277,8 +1255,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR Response: "", }, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() metricType, err := metricsinfo.ParseMetricType(req.Request) if err != nil { @@ -1333,7 +1310,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get zap.Int64("msg-id", req.GetBase().GetMsgID()), zap.Int64("node-id", nodeID), ) - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { log.Warn("QueryNode.GetMetrics failed", zap.Error(errQueryNodeIsUnhealthy(nodeID))) @@ -1344,8 +1321,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get }, }, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() // check target matches if req.GetBase().GetTargetID() != nodeID { @@ -1426,7 +1402,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel())) nodeID := node.GetSession().ServerID // check node healthy - if !node.isHealthyOrStopping() { + if !node.lifetime.Add(node.isHealthyOrStopping) { err := fmt.Errorf("query node %d is not ready", nodeID) status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -1434,8 +1410,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi } return status, nil } - node.wg.Add(1) - defer node.wg.Done() + defer node.lifetime.Done() // check target matches if req.GetBase().GetTargetID() != nodeID { diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 483a3f302d..2fd211bb89 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -22,7 +22,6 @@ import ( "math/rand" "runtime" "sync" - "sync/atomic" "testing" "github.com/milvus-io/milvus-proto/go-api/commonpb" @@ -49,21 +48,18 @@ func TestImpl_GetComponentStates(t *testing.T) { assert.NoError(t, err) node.session.UpdateRegistered(true) + node.UpdateStateCode(commonpb.StateCode_Healthy) rsp, err := node.GetComponentStates(ctx) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + assert.Equal(t, commonpb.StateCode_Healthy, rsp.GetState().GetStateCode()) node.UpdateStateCode(commonpb.StateCode_Abnormal) rsp, err = node.GetComponentStates(ctx) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) - - node.stateCode = atomic.Value{} - node.stateCode.Store("invalid") - rsp, err = node.GetComponentStates(ctx) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) + assert.Equal(t, commonpb.StateCode_Abnormal, rsp.GetState().GetStateCode()) } func TestImpl_GetTimeTickChannel(t *testing.T) { @@ -519,8 +515,7 @@ func TestImpl_isHealthy(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - isHealthy := node.isHealthy() - assert.True(t, isHealthy) + assert.True(t, node.isHealthy(node.lifetime.GetState())) } func TestImpl_ShowConfigurations(t *testing.T) { diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 17dc91c7de..97763d402a 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -35,7 +35,6 @@ import ( "runtime" "runtime/debug" "sync" - "sync/atomic" "syscall" "time" "unsafe" @@ -50,6 +49,7 @@ import ( "github.com/milvus-io/milvus/internal/util/gc" "github.com/milvus-io/milvus/internal/util/hardware" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/internal/util/lifetime" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -83,10 +83,9 @@ type QueryNode struct { queryNodeLoopCtx context.Context queryNodeLoopCancel context.CancelFunc - wg sync.WaitGroup + lifetime lifetime.Lifetime[commonpb.StateCode] - stateCode atomic.Value - stopOnce sync.Once + stopOnce sync.Once //call once initOnce sync.Once @@ -143,11 +142,11 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) *QueryNode { queryNodeLoopCancel: cancel, factory: factory, IsStandAlone: os.Getenv(metricsinfo.DeployModeEnvKey) == metricsinfo.StandaloneDeployMode, + lifetime: lifetime.NewLifetime(commonpb.StateCode_Abnormal), } queryNode.tSafeReplica = newTSafeReplica() queryNode.scheduler = newTaskScheduler(ctx1, queryNode.tSafeReplica) - queryNode.UpdateStateCode(commonpb.StateCode_Abnormal) return queryNode } @@ -355,7 +354,7 @@ func (node *QueryNode) Stop() error { } node.UpdateStateCode(commonpb.StateCode_Abnormal) - node.wg.Wait() + node.lifetime.Wait() node.queryNodeLoopCancel() // close services @@ -383,7 +382,7 @@ func (node *QueryNode) Stop() error { // UpdateStateCode updata the state of query node, which can be initializing, healthy, and abnormal func (node *QueryNode) UpdateStateCode(code commonpb.StateCode) { - node.stateCode.Store(code) + node.lifetime.SetState(code) } // SetEtcdClient assigns parameter client to its member etcdCli diff --git a/internal/util/lifetime/lifetime.go b/internal/util/lifetime/lifetime.go new file mode 100644 index 0000000000..80f8db8e34 --- /dev/null +++ b/internal/util/lifetime/lifetime.go @@ -0,0 +1,101 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// package lifetime provides common component lifetime control logic. +package lifetime + +import ( + "sync" +) + +// Lifetime interface for lifetime control. +type Lifetime[T any] interface { + // SetState is the method to change lifetime state. + SetState(state T) + // GetState returns current state. + GetState() T + // Add records a task is running, returns false if the lifetime is not healthy. + Add(isHealthy IsHealthy[T]) bool + // Done records a task is done. + Done() + // Wait waits until all tasks are done. + Wait() +} + +// IsHealthy function type for lifetime healthy check. +type IsHealthy[T any] func(T) bool + +var _ Lifetime[any] = (*lifetime[any])(nil) + +// NewLifetime returns a new instance of Lifetime with init state and isHealthy logic. +func NewLifetime[T any](initState T) Lifetime[T] { + return &lifetime[T]{ + state: initState, + } +} + +// lifetime implementation of Lifetime. +// users shall not care about the internal fields of this struct. +type lifetime[T any] struct { + // wg is used for keeping record each running task. + wg sync.WaitGroup + // state is the "atomic" value to store component state. + state T + // mut is the rwmutex to control each task and state change event. + mut sync.RWMutex + // isHealthy is the method to check whether is legal to add a task. + isHealthy func(int32) bool +} + +// SetState is the method to change lifetime state. +func (l *lifetime[T]) SetState(state T) { + l.mut.Lock() + defer l.mut.Unlock() + + l.state = state +} + +// GetState returns current state. +func (l *lifetime[T]) GetState() T { + l.mut.RLock() + defer l.mut.RUnlock() + + return l.state +} + +// Add records a task is running, returns false if the lifetime is not healthy. +func (l *lifetime[T]) Add(isHealthy IsHealthy[T]) bool { + l.mut.RLock() + defer l.mut.RUnlock() + + // check lifetime healthy + if !isHealthy(l.state) { + return false + } + + l.wg.Add(1) + return true +} + +// Done records a task is done. +func (l *lifetime[T]) Done() { + l.wg.Done() +} + +// Wait waits until all tasks are done. +func (l *lifetime[T]) Wait() { + l.wg.Wait() +} diff --git a/internal/util/lifetime/lifetime_test.go b/internal/util/lifetime/lifetime_test.go new file mode 100644 index 0000000000..f964f56a98 --- /dev/null +++ b/internal/util/lifetime/lifetime_test.go @@ -0,0 +1,64 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lifetime + +import ( + "testing" + "time" + + "github.com/stretchr/testify/suite" +) + +type LifetimeSuite struct { + suite.Suite +} + +func (s *LifetimeSuite) TestNormal() { + l := NewLifetime[int32](0) + isHealthy := func(state int32) bool { return state == 0 } + + state := l.GetState() + s.EqualValues(0, state) + + s.True(l.Add(isHealthy)) + + l.SetState(1) + s.False(l.Add(isHealthy)) + + signal := make(chan struct{}) + go func() { + l.Wait() + close(signal) + }() + + select { + case <-signal: + s.FailNow("signal closed before all tasks done") + default: + } + + l.Done() + select { + case <-signal: + case <-time.After(time.Second): + s.FailNow("signal not closed after all tasks done") + } +} + +func TestLifetime(t *testing.T) { + suite.Run(t, new(LifetimeSuite)) +}