diff --git a/internal/proxy/task_delete_streaming.go b/internal/proxy/task_delete_streaming.go index 8c37ea3f33..85b3ee75f2 100644 --- a/internal/proxy/task_delete_streaming.go +++ b/internal/proxy/task_delete_streaming.go @@ -69,7 +69,7 @@ func (dt *deleteTaskByStreamingService) Execute(ctx context.Context) (err error) zap.Duration("prepare duration", dt.tr.RecordSpan())) resp := streaming.WAL().AppendMessages(ctx, msgs...) - if resp.UnwrapFirstError(); err != nil { + if err := resp.UnwrapFirstError(); err != nil { log.Ctx(ctx).Warn("append messages to wal failed", zap.Error(err)) return err } diff --git a/internal/querycoordv2/balance/channel_level_score_balancer.go b/internal/querycoordv2/balance/channel_level_score_balancer.go index d9b13b9622..8393bf5d81 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer.go @@ -263,6 +263,7 @@ func (b *ChannelLevelScoreBalancer) genChannelPlan(ctx context.Context, replica channelsToMove := make([]*meta.DmChannel, 0) for _, node := range onlineNodes { channels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(node)) + channels = sortIfChannelAtWALLocated(channels) if len(channels) <= average { nodeWithLessChannel = append(nodeWithLessChannel, node) diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 6f172910d3..a2076f4851 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -366,6 +366,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(ctx context.Context, br *balanceR channelsToMove := make([]*meta.DmChannel, 0) for _, node := range rwNodes { channels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(node)) + channels = sortIfChannelAtWALLocated(channels) if len(channels) <= average { nodeWithLessChannel = append(nodeWithLessChannel, node) diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 97b114d76e..d4564cf744 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -661,6 +661,7 @@ func (b *ScoreBasedBalancer) genChannelPlan(ctx context.Context, br *balanceRepo br.AddRecord(StrRecordf("node %d skip balance since current score(%f) lower than assigned one (%f)", node, currentScore, assignedScore)) continue } + channels = sortIfChannelAtWALLocated(channels) for _, ch := range channels { channelScore := b.calculateChannelScore(ch, replica.GetCollectionID()) diff --git a/internal/querycoordv2/balance/utils.go b/internal/querycoordv2/balance/utils.go index aa01032a1f..095ace6512 100644 --- a/internal/querycoordv2/balance/utils.go +++ b/internal/querycoordv2/balance/utils.go @@ -19,12 +19,15 @@ package balance import ( "context" "fmt" + "sort" "time" "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" ) @@ -204,3 +207,24 @@ func PrintCurrentReplicaDist(replica *meta.Replica, log.Info(distInfo) } + +// sortIfChannelAtWALLocated sorts the channels by the weight of the node where the WAL is located. +// put the channel at the node where the WAL is located to the tail of the channels. +func sortIfChannelAtWALLocated(channels []*meta.DmChannel) []*meta.DmChannel { + if !streamingutil.IsStreamingServiceEnabled() { + return channels + } + weighter := func(ch *meta.DmChannel) int { + nodeID := snmanager.StaticStreamingNodeManager.GetWALLocated(ch.GetChannelName()) + if ch.Node == nodeID { + // if node is the node where the WAL is located, put it to the tail of the channels. + // so assign 1 to the node the WAL is located. + return 1 + } + return 0 + } + sort.Slice(channels, func(i, j int) bool { + return weighter(channels[i]) < weighter(channels[j]) + }) + return channels +} diff --git a/internal/querycoordv2/meta/replica.go b/internal/querycoordv2/meta/replica.go index 21fbdd2032..fc644034d4 100644 --- a/internal/querycoordv2/meta/replica.go +++ b/internal/querycoordv2/meta/replica.go @@ -107,6 +107,8 @@ func NewReplicaWithPriority(replica *querypb.Replica, priority commonpb.LoadPrio replicaPB: proto.Clone(replica).(*querypb.Replica), rwNodes: typeutil.NewUniqueSet(replica.Nodes...), roNodes: typeutil.NewUniqueSet(replica.RoNodes...), + rwSQNodes: typeutil.NewUniqueSet(replica.RwSqNodes...), + roSQNodes: typeutil.NewUniqueSet(replica.RoSqNodes...), loadPriority: priority, } } @@ -212,7 +214,7 @@ func (replica *Replica) NodesCount() int { // Contains checks if the node is in rw nodes of the replica. func (replica *Replica) Contains(node int64) bool { - return replica.ContainRONode(node) || replica.ContainRWNode(node) || replica.ContainSQNode(node) || replica.ContainRWSQNode(node) + return replica.ContainRONode(node) || replica.ContainRWNode(node) || replica.ContainSQNode(node) } // ContainRONode checks if the node is in ro nodes of the replica. diff --git a/internal/querycoordv2/meta/replica_manager.go b/internal/querycoordv2/meta/replica_manager.go index 586609304f..78b61fc35d 100644 --- a/internal/querycoordv2/meta/replica_manager.go +++ b/internal/querycoordv2/meta/replica_manager.go @@ -133,7 +133,10 @@ func (m *ReplicaManager) Recover(ctx context.Context, collections []int64) error log.Info("recover replica", zap.Int64("collectionID", replica.GetCollectionID()), zap.Int64("replicaID", replica.GetID()), - zap.Int64s("nodes", replica.GetNodes()), + zap.Int64s("rwNodes", replica.GetNodes()), + zap.Int64s("roNodes", replica.GetRoNodes()), + zap.Int64s("rwSQNodes", replica.GetRwSqNodes()), + zap.Int64s("roSQNodes", replica.GetRoNodes()), ) } else { err := m.catalog.ReleaseReplica(ctx, replica.GetCollectionID(), replica.GetID()) @@ -477,6 +480,10 @@ func (m *ReplicaManager) RecoverNodesInCollection(ctx context.Context, collectio zap.Int64s("newIncomingNodes", incomingNode), zap.Bool("enableChannelExclusiveMode", mutableReplica.IsChannelExclusiveModeEnabled()), zap.Any("channelNodeInfos", mutableReplica.replicaPB.GetChannelNodeInfos()), + zap.Int64s("rwNodes", mutableReplica.GetRWNodes()), + zap.Int64s("roNodes", mutableReplica.GetRONodes()), + zap.Int64s("rwSQNodes", mutableReplica.GetRWSQNodes()), + zap.Int64s("roSQNodes", mutableReplica.GetROSQNodes()), ) modifiedReplicas = append(modifiedReplicas, mutableReplica.IntoReplica()) }) @@ -636,7 +643,12 @@ func (m *ReplicaManager) RecoverSQNodesInCollection(ctx context.Context, collect zap.Int64("replicaID", assignment.GetReplicaID()), zap.Int64s("newRONodes", roNodes), zap.Int64s("roToRWNodes", recoverableNodes), - zap.Int64s("newIncomingNodes", incomingNode)) + zap.Int64s("newIncomingNodes", incomingNode), + zap.Int64s("rwNodes", mutableReplica.GetRWNodes()), + zap.Int64s("roNodes", mutableReplica.GetRONodes()), + zap.Int64s("rwSQNodes", mutableReplica.GetRWSQNodes()), + zap.Int64s("roSQNodes", mutableReplica.GetROSQNodes()), + ) modifiedReplicas = append(modifiedReplicas, mutableReplica.IntoReplica()) }) return m.put(ctx, modifiedReplicas...) diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index d17486da01..0b640a4f22 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -31,10 +31,12 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/coordinator/snmanager" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" @@ -237,6 +239,12 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { log.Warn(msg, zap.Error(err)) return err } + + if err := ex.checkIfShardLeaderIsStreamingNode(view); err != nil { + log.Warn("shard leader is not a streamingnode, skip load segment", zap.Error(err)) + return err + } + log = log.With(zap.Int64("shardLeader", view.Node)) // NOTE: for balance segment task, expected load and release execution on the same shard leader @@ -259,6 +267,26 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { return nil } +// checkIfShardLeaderIsStreamingNode checks if the shard leader is a streamingnode. +// Because the L0 management at 2.6 and 2.5 is different, so when upgrading mixcoord, +// the new mixcoord will make a wrong plan when balancing a segment from one query node to another by 2.5 delegator. +// We need to balance the 2.5 delegator to 2.6 delegator before balancing any segment by 2.6 mixcoord. +func (ex *Executor) checkIfShardLeaderIsStreamingNode(view *meta.DmChannel) error { + if !streamingutil.IsStreamingServiceEnabled() { + return nil + } + + node := ex.nodeMgr.Get(view.Node) + if node == nil { + return merr.WrapErrServiceInternal(fmt.Sprintf("node %d is not found", view.Node)) + } + nodes := snmanager.StaticStreamingNodeManager.GetStreamingQueryNodeIDs() + if !nodes.Contain(view.Node) { + return merr.WrapErrServiceInternal(fmt.Sprintf("channel %s at node %d is not working at streamingnode, skip load segment", view.GetChannelName(), view.Node)) + } + return nil +} + func (ex *Executor) releaseSegment(task *SegmentTask, step int) { defer ex.removeTask(task, step) startTs := time.Now() diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index f2c6119267..29905308e9 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -647,6 +647,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) }) log.Debug("query segments...", + zap.Uint64("mvcc", req.GetReq().GetMvccTimestamp()), zap.Int("sealedNum", sealedNum), zap.Int("growingNum", len(growing)), ) diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index cdcf672c3f..3a327ce6b5 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -74,12 +74,14 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s countRet := result.GetFieldsData()[0].GetScalars().GetLongData().GetData()[0] if allRetrieveCount != countRet { log.Debug("count segment done with delete", + zap.Uint64("mvcc", req.GetReq().GetMvccTimestamp()), zap.String("channel", s.LoadInfo().GetInsertChannel()), zap.Int64("segmentID", s.ID()), zap.Int64("allRetrieveCount", allRetrieveCount), zap.Int64("countRet", countRet)) } else { log.Debug("count segment done", + zap.Uint64("mvcc", req.GetReq().GetMvccTimestamp()), zap.String("channel", s.LoadInfo().GetInsertChannel()), zap.Int64("segmentID", s.ID()), zap.Int64("allRetrieveCount", allRetrieveCount), diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 2e846e5665..b9af7114b7 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -521,6 +521,7 @@ func (s *LocalSegment) ResetIndexesLazyLoad(lazyState bool) { func (s *LocalSegment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) { log := log.Ctx(ctx).WithLazy( + zap.Uint64("mvcc", searchReq.MVCC()), zap.Int64("collectionID", s.Collection()), zap.Int64("segmentID", s.ID()), zap.String("segmentType", s.segmentType.String()), @@ -572,7 +573,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) zap.Int64("collectionID", s.Collection()), zap.Int64("partitionID", s.Partition()), zap.Int64("segmentID", s.ID()), - zap.Int64("msgID", plan.MsgID()), + zap.Uint64("mvcc", plan.Timestamp), zap.String("segmentType", s.segmentType.String()), ) diff --git a/internal/util/segcore/plan.go b/internal/util/segcore/plan.go index 6bfdfdc9be..c71aaa589c 100644 --- a/internal/util/segcore/plan.go +++ b/internal/util/segcore/plan.go @@ -135,6 +135,10 @@ func (req *SearchRequest) GetNumOfQuery() int64 { return int64(numQueries) } +func (req *SearchRequest) MVCC() typeutil.Timestamp { + return req.mvccTimestamp +} + func (req *SearchRequest) Plan() *SearchPlan { return req.plan } diff --git a/tests/integration/balance/balance_test.go b/tests/integration/balance/balance_test.go index f40ab16479..23a314872b 100644 --- a/tests/integration/balance/balance_test.go +++ b/tests/integration/balance/balance_test.go @@ -185,9 +185,7 @@ func (s *BalanceTestSuit) TestBalanceOnSingleReplica() { s.NoError(err) s.True(merr.Ok(resp.GetStatus())) log.Info("balance on single replica", zap.Int("channel", len(resp2.Channels)), zap.Int("segments", len(resp.Segments))) - // TODO: https://github.com/milvus-io/milvus/issues/42966 - // return len(resp2.Channels) == 1 && len(resp.Segments) == 2 - return len(resp.Segments) == 2 + return len(resp2.Channels) == 1 && len(resp.Segments) == 2 }, 30*time.Second, 1*time.Second) // check total segment number and total channel number @@ -372,7 +370,13 @@ func (s *BalanceTestSuit) TestConcurrentBalanceChannelAndSegment() { resp, err := qn.MustGetClient(ctx).GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{}) s.NoError(err) s.True(merr.Ok(resp.GetStatus())) - log.Info("segments on query node", zap.Int64("nodeID", qn.GetNodeID()), zap.Int("channel", len(resp.Channels)), zap.Int("segments", len(resp.Segments))) + log.Info("segments on query node before balance", zap.Int64("nodeID", qn.GetNodeID()), zap.Int("channel", len(resp.Channels)), zap.Int("segments", len(resp.Segments))) + } + for _, sn := range s.Cluster.GetAllStreamingNodes() { + resp, err := sn.MustGetClient(ctx).GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{}) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + log.Info("channel on streaming node before balance", zap.Int64("nodeID", sn.GetNodeID()), zap.Int("channel", len(resp.Channels)), zap.Int("segments", len(resp.Segments))) } // then we add 1 query node, expected segment and channel will be move to new query node concurrently @@ -387,6 +391,12 @@ func (s *BalanceTestSuit) TestConcurrentBalanceChannelAndSegment() { s.True(merr.Ok(resp.GetStatus())) log.Info("segments on query node", zap.Int64("nodeID", qn.GetNodeID()), zap.Int("channel", len(resp.Channels)), zap.Int("segments", len(resp.Segments))) } + for _, sn := range s.Cluster.GetAllStreamingNodes() { + resp, err := sn.MustGetClient(ctx).GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{}) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + log.Info("channel on streaming node", zap.Int64("nodeID", sn.GetNodeID()), zap.Int("channel", len(resp.Channels)), zap.Int("segments", len(resp.Segments))) + } resp, err := qn1.MustGetClient(ctx).GetDataDistribution(ctx, &querypb.GetDataDistributionRequest{}) s.NoError(err) @@ -394,9 +404,7 @@ func (s *BalanceTestSuit) TestConcurrentBalanceChannelAndSegment() { s.NoError(err) s.True(merr.Ok(resp.GetStatus())) log.Info("concurrent balance channel and segment", zap.Int("channel1", len(resp2.Channels)), zap.Int("segments1", len(resp.Segments))) - // TODO: https://github.com/milvus-io/milvus/issues/42966 - // return len(resp2.Channels) == 2 && len(resp.Segments) >= 20 - return len(resp.Segments) >= 20 + return len(resp2.Channels) == 2 && len(resp.Segments) >= 20 }, 30*time.Second, 1*time.Second) // expected concurrent balance will execute successfully, shard serviceable won't be broken diff --git a/tests/integration/coorddownsearch/search_after_coord_down_test.go b/tests/integration/coorddownsearch/search_after_coord_down_test.go index 87e9f99121..b138852ced 100644 --- a/tests/integration/coorddownsearch/search_after_coord_down_test.go +++ b/tests/integration/coorddownsearch/search_after_coord_down_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metric" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/tests/integration" ) @@ -52,6 +53,11 @@ const ( var searchCollectionName = "" +func (s *CoordDownSearch) SetupSuite() { + s.WithMilvusConfig(paramtable.Get().LogCfg.Level.Key, "debug") + s.MiniClusterSuite.SetupSuite() +} + func (s *CoordDownSearch) loadCollection(collectionName string, dim int) { c := s.Cluster dbName := ""