diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 9b17263a6b..a41f63cfcb 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -59,7 +59,6 @@ import ( "github.com/milvus-io/milvus/pkg/v2/kv" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" - "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/expr" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" @@ -497,11 +496,6 @@ func (s *Server) startQueryCoord() error { go s.handleNodeUpLoop() go s.watchNodes(revision) - // check replica changes after restart - s.checkLoadConfigChanges(s.ctx) - // watch load config changes - s.watchLoadConfigChanges() - // check whether old node exist, if yes suspend auto balance until all old nodes down s.updateBalanceConfigLoop(s.ctx) @@ -513,6 +507,9 @@ func (s *Server) startQueryCoord() error { s.afterStart() s.UpdateStateCode(commonpb.StateCode_Healthy) sessionutil.SaveServerInfo(typeutil.MixCoordRole, s.session.GetServerID()) + // check replica changes after restart + // Note: this should be called after start progress is done + s.watchLoadConfigChanges() return nil } @@ -843,7 +840,12 @@ func (s *Server) updateBalanceConfig() bool { return false } -func (s *Server) checkLoadConfigChanges(ctx context.Context) { +func (s *Server) applyLoadConfigChanges(ctx context.Context, newReplicaNum int32, newRGs []string) { + if newReplicaNum <= 0 && len(newRGs) == 0 { + log.Info("invalid cluster level load config, skip it", zap.Int32("replica_num", newReplicaNum), zap.Strings("resource_groups", newRGs)) + return + } + // try to check load config changes after restart, and try to update replicas collectionIDs := s.meta.GetAll(ctx) collectionIDs = lo.Filter(collectionIDs, func(collectionID int64, _ int) bool { @@ -854,73 +856,44 @@ func (s *Server) checkLoadConfigChanges(ctx context.Context) { } return true }) - replicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsUint32() - rgs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings() + + if len(collectionIDs) == 0 { + log.Info("no collection to update load config, skip it") + return + } + log.Info("apply load config changes", zap.Int64s("collectionIDs", collectionIDs), - zap.Int32("replicaNum", int32(replicaNum)), - zap.Strings("resourceGroups", rgs)) - s.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{ - CollectionIDs: collectionIDs, - ReplicaNumber: int32(replicaNum), - ResourceGroups: rgs, - }) + zap.Int32("replicaNum", newReplicaNum), + zap.Strings("resourceGroups", newRGs)) + err := s.updateLoadConfig(ctx, collectionIDs, newReplicaNum, newRGs) + if err != nil { + log.Warn("failed to update load config", zap.Error(err)) + } } func (s *Server) watchLoadConfigChanges() { + // first apply load config change from params + replicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsUint32() + rgs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings() + s.applyLoadConfigChanges(s.ctx, int32(replicaNum), rgs) + log := log.Ctx(s.ctx) replicaNumHandler := config.NewHandler("watchReplicaNumberChanges", func(e *config.Event) { log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType)) - - collectionIDs := s.meta.GetAll(s.ctx) - if len(collectionIDs) == 0 { - log.Warn("no collection loaded, skip to trigger update load config") - return - } - collectionIDs = lo.Filter(collectionIDs, func(collectionID int64, _ int) bool { - collection := s.meta.GetCollection(s.ctx, collectionID) - if collection.UserSpecifiedReplicaMode { - log.Info("collection is user specified replica mode, skip update load config", zap.Int64("collectionID", collectionID)) - return false - } - return true - }) - replicaNum, err := strconv.ParseInt(e.Value, 10, 64) if err != nil { log.Warn("invalid cluster level load config, skip it", zap.String("key", e.Key), zap.String("value", e.Value)) return } - if replicaNum <= 0 { - log.Info("invalid cluster level load config, skip it", zap.Int64("replica_num", replicaNum)) - return - } rgs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings() - s.UpdateLoadConfig(s.ctx, &querypb.UpdateLoadConfigRequest{ - CollectionIDs: collectionIDs, - ReplicaNumber: int32(replicaNum), - ResourceGroups: rgs, - }) + s.applyLoadConfigChanges(s.ctx, int32(replicaNum), rgs) }) paramtable.Get().Watch(paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.Key, replicaNumHandler) rgHandler := config.NewHandler("watchResourceGroupChanges", func(e *config.Event) { log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType)) - collectionIDs := s.meta.GetAll(s.ctx) - if len(collectionIDs) == 0 { - log.Warn("no collection loaded, skip to trigger update load config") - return - } - collectionIDs = lo.Filter(collectionIDs, func(collectionID int64, _ int) bool { - collection := s.meta.GetCollection(s.ctx, collectionID) - if collection.UserSpecifiedReplicaMode { - log.Info("collection is user specified replica mode, skip update load config", zap.Int64("collectionID", collectionID)) - return false - } - return true - }) - if len(e.Value) == 0 { log.Warn("invalid cluster level load config, skip it", zap.String("key", e.Key), zap.String("value", e.Value)) return @@ -928,17 +901,8 @@ func (s *Server) watchLoadConfigChanges() { rgs := strings.Split(e.Value, ",") rgs = lo.Map(rgs, func(rg string, _ int) string { return strings.TrimSpace(rg) }) - if len(rgs) == 0 { - log.Info("invalid cluster level load config, skip it", zap.Strings("resource_groups", rgs)) - return - } - replicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsInt64() - s.UpdateLoadConfig(s.ctx, &querypb.UpdateLoadConfigRequest{ - CollectionIDs: collectionIDs, - ReplicaNumber: int32(replicaNum), - ResourceGroups: rgs, - }) + s.applyLoadConfigChanges(s.ctx, int32(replicaNum), rgs) }) paramtable.Get().Watch(paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.Key, rgHandler) } diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index f725a3cbd3..6502458bca 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -437,8 +437,8 @@ func (suite *ServerSuite) TestUpdateAutoBalanceConfigLoop() { }) } -func TestCheckLoadConfigChanges(t *testing.T) { - mockey.PatchConvey("TestCheckLoadConfigChanges", t, func() { +func TestApplyLoadConfigChanges(t *testing.T) { + mockey.PatchConvey("TestApplyLoadConfigChanges", t, func() { ctx := context.Background() // Create mock server @@ -483,23 +483,29 @@ func TestCheckLoadConfigChanges(t *testing.T) { // Mock UpdateLoadConfig to capture the call var updateLoadConfigCalled bool - var capturedRequest *querypb.UpdateLoadConfigRequest - mockey.Mock((*Server).UpdateLoadConfig).To(func(s *Server, ctx context.Context, req *querypb.UpdateLoadConfigRequest) (*commonpb.Status, error) { + var capturedCollectionIDs []int64 + var capturedReplicaNum int32 + var capturedRGs []string + mockey.Mock((*Server).updateLoadConfig).To(func(s *Server, ctx context.Context, collectionIDs []int64, newReplicaNum int32, newRGs []string) error { updateLoadConfigCalled = true - capturedRequest = req - return merr.Success(), nil + capturedCollectionIDs = collectionIDs + capturedReplicaNum = newReplicaNum + capturedRGs = newRGs + return nil }).Build() - // Call checkLoadConfigChanges - testServer.checkLoadConfigChanges(ctx) + replicaNum := paramtable.Get().QueryCoordCfg.ClusterLevelLoadReplicaNumber.GetAsUint32() + rgs := paramtable.Get().QueryCoordCfg.ClusterLevelLoadResourceGroups.GetAsStrings() + // Call applyLoadConfigChanges + testServer.applyLoadConfigChanges(ctx, int32(replicaNum), rgs) // Verify UpdateLoadConfig was called assert.True(t, updateLoadConfigCalled, "UpdateLoadConfig should be called") // Verify that only collections with IsUserSpecifiedReplicaMode = false are included - assert.Equal(t, []int64{1001}, capturedRequest.CollectionIDs, "Only collections with IsUserSpecifiedReplicaMode = false should be included") - assert.Equal(t, int32(2), capturedRequest.ReplicaNumber, "ReplicaNumber should match cluster level config") - assert.Equal(t, []string{"default"}, capturedRequest.ResourceGroups, "ResourceGroups should match cluster level config") + assert.Equal(t, []int64{1001}, capturedCollectionIDs, "Only collections with IsUserSpecifiedReplicaMode = false should be included") + assert.Equal(t, int32(2), capturedReplicaNum, "ReplicaNumber should match cluster level config") + assert.Equal(t, []string{"default"}, capturedRGs, "ResourceGroups should match cluster level config") }) } diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 4617028591..62db8ad3dd 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -25,7 +25,6 @@ import ( "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/sync/errgroup" - "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -1186,8 +1185,21 @@ func (s *Server) UpdateLoadConfig(ctx context.Context, req *querypb.UpdateLoadCo return merr.Status(errors.Wrap(err, msg)), nil } - jobs := make([]job.Job, 0, len(req.GetCollectionIDs())) - for _, collectionID := range req.GetCollectionIDs() { + err := s.updateLoadConfig(ctx, req.GetCollectionIDs(), req.GetReplicaNumber(), req.GetResourceGroups()) + if err != nil { + msg := "failed to update load config" + log.Warn(msg, zap.Error(err)) + return merr.Status(errors.Wrap(err, msg)), nil + } + + log.Info("update load config request finished") + + return merr.Success(), nil +} + +func (s *Server) updateLoadConfig(ctx context.Context, collectionIDs []int64, newReplicaNum int32, newRGs []string) error { + jobs := make([]job.Job, 0, len(collectionIDs)) + for _, collectionID := range collectionIDs { collection := s.meta.GetCollection(ctx, collectionID) if collection == nil { err := merr.WrapErrCollectionNotLoaded(collectionID) @@ -1196,13 +1208,16 @@ func (s *Server) UpdateLoadConfig(ctx context.Context, req *querypb.UpdateLoadCo } collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(ctx, collection.GetCollectionID()).Collect() - left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups()) + left, right := lo.Difference(collectionUsedRG, newRGs) rgChanged := len(left) > 0 || len(right) > 0 - replicaChanged := collection.GetReplicaNumber() != req.GetReplicaNumber() + replicaChanged := collection.GetReplicaNumber() != newReplicaNum - subReq := proto.Clone(req).(*querypb.UpdateLoadConfigRequest) - subReq.CollectionIDs = []int64{collectionID} - if len(req.ResourceGroups) == 0 { + subReq := &querypb.UpdateLoadConfigRequest{ + CollectionIDs: []int64{collectionID}, + ReplicaNumber: newReplicaNum, + ResourceGroups: newRGs, + } + if len(subReq.GetResourceGroups()) == 0 { subReq.ResourceGroups = collectionUsedRG rgChanged = false } @@ -1239,14 +1254,7 @@ func (s *Server) UpdateLoadConfig(ctx context.Context, req *querypb.UpdateLoadCo } } - if err != nil { - msg := "failed to update load config" - log.Warn(msg, zap.Error(err)) - return merr.Status(errors.Wrap(err, msg)), nil - } - log.Info("update load config request finished") - - return merr.Success(), nil + return err } func (s *Server) ListLoadedSegments(ctx context.Context, req *querypb.ListLoadedSegmentsRequest) (*querypb.ListLoadedSegmentsResponse, error) {