mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
fix: add load config watcher to avoid load config modification lost (#46784)
issue: #46778 Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
5f2e430941
commit
56e82c78e1
157
internal/querycoordv2/load_config_watcher.go
Normal file
157
internal/querycoordv2/load_config_watcher.go
Normal file
@ -0,0 +1,157 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querycoordv2
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
// NewLoadConfigWatcher creates a new load config watcher.
|
||||
func NewLoadConfigWatcher(s *Server) *LoadConfigWatcher {
|
||||
w := &LoadConfigWatcher{
|
||||
triggerCh: make(chan struct{}, 10),
|
||||
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
|
||||
s: s,
|
||||
}
|
||||
w.SetLogger(log.With(log.FieldModule(typeutil.QueryCoordRole), log.FieldComponent("load_config_watcher")))
|
||||
go w.background()
|
||||
return w
|
||||
}
|
||||
|
||||
// LoadConfigWatcher is a watcher for load config changes.
|
||||
type LoadConfigWatcher struct {
|
||||
log.Binder
|
||||
triggerCh chan struct{}
|
||||
notifier *syncutil.AsyncTaskNotifier[struct{}]
|
||||
s *Server
|
||||
|
||||
previousReplicaNum int32
|
||||
previousRGs []string
|
||||
}
|
||||
|
||||
// Trigger triggers a load config change.
|
||||
func (w *LoadConfigWatcher) Trigger() {
|
||||
select {
|
||||
case <-w.notifier.Context().Done():
|
||||
case w.triggerCh <- struct{}{}:
|
||||
}
|
||||
}
|
||||
|
||||
// background is the background task for load config watcher.
|
||||
func (w *LoadConfigWatcher) background() {
|
||||
defer func() {
|
||||
w.notifier.Finish(struct{}{})
|
||||
w.Logger().Info("load config watcher stopped")
|
||||
}()
|
||||
w.Logger().Info("load config watcher started")
|
||||
|
||||
balanceTimer := typeutil.NewBackoffTimer(typeutil.BackoffTimerConfig{
|
||||
Default: time.Minute,
|
||||
Backoff: typeutil.BackoffConfig{
|
||||
InitialInterval: 10 * time.Millisecond,
|
||||
Multiplier: 2,
|
||||
MaxInterval: 10 * time.Minute,
|
||||
},
|
||||
})
|
||||
|
||||
for {
|
||||
nextTimer, _ := balanceTimer.NextTimer()
|
||||
select {
|
||||
case <-w.notifier.Context().Done():
|
||||
return
|
||||
case <-w.triggerCh:
|
||||
w.Logger().Info("load config watcher triggered")
|
||||
case <-nextTimer:
|
||||
}
|
||||
if err := w.applyLoadConfigChanges(); err != nil {
|
||||
balanceTimer.EnableBackoff()
|
||||
continue
|
||||
}
|
||||
balanceTimer.DisableBackoff()
|
||||
}
|
||||
}
|
||||
|
||||
// applyLoadConfigChanges applies the load config changes.
|
||||
func (w *LoadConfigWatcher) applyLoadConfigChanges() error {
|
||||
newReplicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsInt32()
|
||||
newRGs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings()
|
||||
|
||||
if newReplicaNum == 0 && len(newRGs) == 0 {
|
||||
// default cluster level load config, nothing to do for it.
|
||||
return nil
|
||||
}
|
||||
|
||||
if newReplicaNum <= 0 || len(newRGs) == 0 {
|
||||
w.Logger().Info("illegal cluster level load config, skip it", zap.Int32("replica_num", newReplicaNum), zap.Strings("resource_groups", newRGs))
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(newRGs) != 1 && len(newRGs) != int(newReplicaNum) {
|
||||
w.Logger().Info("illegal cluster level load config, skip it", zap.Int32("replica_num", newReplicaNum), zap.Strings("resource_groups", newRGs))
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := lo.Difference(w.previousRGs, newRGs)
|
||||
rgChanged := len(left) > 0 || len(right) > 0
|
||||
if w.previousReplicaNum == newReplicaNum && !rgChanged {
|
||||
w.Logger().Info("no need to update load config, skip it", zap.Int32("replica_num", newReplicaNum), zap.Strings("resource_groups", newRGs))
|
||||
return nil
|
||||
}
|
||||
|
||||
// try to check load config changes after restart, and try to update replicas
|
||||
collectionIDs := w.s.meta.GetAll(w.notifier.Context())
|
||||
collectionIDs = lo.Filter(collectionIDs, func(collectionID int64, _ int) bool {
|
||||
collection := w.s.meta.GetCollection(w.notifier.Context(), collectionID)
|
||||
if collection.UserSpecifiedReplicaMode {
|
||||
w.Logger().Info("collection is user specified replica mode, skip update load config", zap.Int64("collectionID", collectionID))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(collectionIDs) == 0 {
|
||||
w.Logger().Info("no collection to update load config, skip it")
|
||||
}
|
||||
|
||||
if err := w.s.updateLoadConfig(w.notifier.Context(), collectionIDs, newReplicaNum, newRGs); err != nil {
|
||||
w.Logger().Warn("failed to update load config", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
w.Logger().Info("apply load config changes",
|
||||
zap.Int64s("collectionIDs", collectionIDs),
|
||||
zap.Int32("previousReplicaNum", w.previousReplicaNum),
|
||||
zap.Strings("previousResourceGroups", w.previousRGs),
|
||||
zap.Int32("replicaNum", newReplicaNum),
|
||||
zap.Strings("resourceGroups", newRGs))
|
||||
w.previousReplicaNum = newReplicaNum
|
||||
w.previousRGs = newRGs
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the load config watcher.
|
||||
func (w *LoadConfigWatcher) Close() {
|
||||
w.notifier.Cancel()
|
||||
w.notifier.BlockUntilFinish()
|
||||
}
|
||||
@ -20,8 +20,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -137,6 +135,9 @@ type Server struct {
|
||||
// for balance streaming node request
|
||||
// now only used for run analyzer and validate analyzer
|
||||
nodeIdx atomic.Uint32
|
||||
|
||||
// load config watcher
|
||||
loadConfigWatcher *LoadConfigWatcher
|
||||
}
|
||||
|
||||
func NewQueryCoord(ctx context.Context) (*Server, error) {
|
||||
@ -552,6 +553,11 @@ func (s *Server) Stop() error {
|
||||
// job scheduler -> checker controller -> task scheduler -> dist controller -> cluster -> session
|
||||
// observers -> dist controller
|
||||
|
||||
if s.loadConfigWatcher != nil {
|
||||
log.Info("stop load config watcher...")
|
||||
s.loadConfigWatcher.Close()
|
||||
}
|
||||
|
||||
if s.jobScheduler != nil {
|
||||
log.Info("stop job scheduler...")
|
||||
s.jobScheduler.Stop()
|
||||
@ -877,70 +883,15 @@ func (s *Server) updateBalanceConfig() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) applyLoadConfigChanges(ctx context.Context, newReplicaNum int32, newRGs []string) {
|
||||
if newReplicaNum <= 0 && len(newRGs) == 0 {
|
||||
log.Info("invalid cluster level load config, skip it", zap.Int32("replica_num", newReplicaNum), zap.Strings("resource_groups", newRGs))
|
||||
return
|
||||
}
|
||||
|
||||
// try to check load config changes after restart, and try to update replicas
|
||||
collectionIDs := s.meta.GetAll(ctx)
|
||||
collectionIDs = lo.Filter(collectionIDs, func(collectionID int64, _ int) bool {
|
||||
collection := s.meta.GetCollection(ctx, collectionID)
|
||||
if collection.UserSpecifiedReplicaMode {
|
||||
log.Info("collection is user specified replica mode, skip update load config", zap.Int64("collectionID", collectionID))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(collectionIDs) == 0 {
|
||||
log.Info("no collection to update load config, skip it")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("apply load config changes",
|
||||
zap.Int64s("collectionIDs", collectionIDs),
|
||||
zap.Int32("replicaNum", newReplicaNum),
|
||||
zap.Strings("resourceGroups", newRGs))
|
||||
err := s.updateLoadConfig(ctx, collectionIDs, newReplicaNum, newRGs)
|
||||
if err != nil {
|
||||
log.Warn("failed to update load config", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) watchLoadConfigChanges() {
|
||||
// first apply load config change from params
|
||||
replicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsUint32()
|
||||
rgs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings()
|
||||
s.applyLoadConfigChanges(s.ctx, int32(replicaNum), rgs)
|
||||
w := NewLoadConfigWatcher(s)
|
||||
s.loadConfigWatcher = w
|
||||
w.Trigger()
|
||||
|
||||
log := log.Ctx(s.ctx)
|
||||
replicaNumHandler := config.NewHandler("watchReplicaNumberChanges", func(e *config.Event) {
|
||||
log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType))
|
||||
replicaNum, err := strconv.ParseInt(e.Value, 10, 64)
|
||||
if err != nil {
|
||||
log.Warn("invalid cluster level load config, skip it", zap.String("key", e.Key), zap.String("value", e.Value))
|
||||
return
|
||||
}
|
||||
rgs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings()
|
||||
|
||||
s.applyLoadConfigChanges(s.ctx, int32(replicaNum), rgs)
|
||||
})
|
||||
replicaNumHandler := config.NewHandler("watchReplicaNumberChanges", func(e *config.Event) { w.Trigger() })
|
||||
paramtable.Get().Watch(paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.Key, replicaNumHandler)
|
||||
|
||||
rgHandler := config.NewHandler("watchResourceGroupChanges", func(e *config.Event) {
|
||||
log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType))
|
||||
if len(e.Value) == 0 {
|
||||
log.Warn("invalid cluster level load config, skip it", zap.String("key", e.Key), zap.String("value", e.Value))
|
||||
return
|
||||
}
|
||||
|
||||
rgs := strings.Split(e.Value, ",")
|
||||
rgs = lo.Map(rgs, func(rg string, _ int) string { return strings.TrimSpace(rg) })
|
||||
replicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsInt64()
|
||||
s.applyLoadConfigChanges(s.ctx, int32(replicaNum), rgs)
|
||||
})
|
||||
rgHandler := config.NewHandler("watchResourceGroupChanges", func(e *config.Event) { w.Trigger() })
|
||||
paramtable.Get().Watch(paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.Key, rgHandler)
|
||||
}
|
||||
|
||||
|
||||
@ -58,6 +58,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/tikv"
|
||||
)
|
||||
|
||||
@ -420,10 +421,10 @@ func TestApplyLoadConfigChanges(t *testing.T) {
|
||||
}).Build()
|
||||
|
||||
// Mock paramtable.ParamItem.GetAsUint32() for ClusterLevelLoadReplicaNumber
|
||||
mockey.Mock((*paramtable.ParamItem).GetAsUint32).Return(uint32(2)).Build()
|
||||
mockey.Mock((*paramtable.ParamItem).GetAsInt32).Return(int32(2)).Build()
|
||||
|
||||
// Mock paramtable.ParamItem.GetAsStrings() for ClusterLevelLoadResourceGroups
|
||||
mockey.Mock((*paramtable.ParamItem).GetAsStrings).Return([]string{"default"}).Build()
|
||||
mockey.Mock((*paramtable.ParamItem).GetAsStrings).Return([]string{"default", "rg1"}).Build()
|
||||
|
||||
// Mock UpdateLoadConfig to capture the call
|
||||
var updateLoadConfigCalled bool
|
||||
@ -438,10 +439,13 @@ func TestApplyLoadConfigChanges(t *testing.T) {
|
||||
return nil
|
||||
}).Build()
|
||||
|
||||
replicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsUint32()
|
||||
rgs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings()
|
||||
// Call applyLoadConfigChanges
|
||||
testServer.applyLoadConfigChanges(ctx, int32(replicaNum), rgs)
|
||||
watcher := &LoadConfigWatcher{
|
||||
s: testServer,
|
||||
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
|
||||
triggerCh: make(chan struct{}, 10),
|
||||
}
|
||||
watcher.applyLoadConfigChanges()
|
||||
|
||||
// Verify UpdateLoadConfig was called
|
||||
assert.True(t, updateLoadConfigCalled, "UpdateLoadConfig should be called")
|
||||
@ -449,7 +453,12 @@ func TestApplyLoadConfigChanges(t *testing.T) {
|
||||
// Verify that only collections with IsUserSpecifiedReplicaMode = false are included
|
||||
assert.Equal(t, []int64{1001}, capturedCollectionIDs, "Only collections with IsUserSpecifiedReplicaMode = false should be included")
|
||||
assert.Equal(t, int32(2), capturedReplicaNum, "ReplicaNumber should match cluster level config")
|
||||
assert.Equal(t, []string{"default"}, capturedRGs, "ResourceGroups should match cluster level config")
|
||||
assert.Equal(t, []string{"default", "rg1"}, capturedRGs, "ResourceGroups should match cluster level config")
|
||||
|
||||
watcher = NewLoadConfigWatcher(testServer)
|
||||
watcher.Trigger()
|
||||
time.Sleep(1 * time.Second)
|
||||
watcher.Close()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ func AssignReplica(ctx context.Context, m *meta.Meta, resourceGroups []string, r
|
||||
return nil, merr.WrapErrParameterInvalidMsg("replica=[%d] resource group=[%s], resource group num can only be 0, 1 or same as replica number", replicaNumber, strings.Join(resourceGroups, ","))
|
||||
}
|
||||
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
if streamingutil.IsStreamingServiceEnabled() && checkNodeNum {
|
||||
streamingNodeCount := snmanager.StaticStreamingNodeManager.GetStreamingQueryNodeIDs().Len()
|
||||
if replicaNumber > int32(streamingNodeCount) {
|
||||
return nil, merr.WrapErrStreamingNodeNotEnough(streamingNodeCount, int(replicaNumber), fmt.Sprintf("when load %d replica count", replicaNumber))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user