diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 62c2cda1f0..2d72607e9d 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -74,6 +74,10 @@ queryCoord: address: localhost port: 19531 autoHandoff: true + autoBalance: false + overloadedMemoryThresholdPercentage: 90 + balanceIntervalSeconds: 60 + memoryUsageMaxDifferencePercentage: 30 grpc: serverMaxRecvSize: 2147483647 # math.MaxInt32 diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index 00f7d374ba..15705e9055 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -25,6 +25,7 @@ import ( etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -59,6 +60,7 @@ type Cluster interface { releasePartitions(ctx context.Context, nodeID int64, in *querypb.ReleasePartitionsRequest) error getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) getSegmentInfoByNode(ctx context.Context, nodeID int64, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) + getSegmentInfoByID(ctx context.Context, segmentID UniqueID) (*querypb.SegmentInfo, error) registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error getNodeInfoByID(nodeID int64) (Node, error) @@ -69,7 +71,7 @@ type Cluster interface { offlineNodes() (map[int64]Node, error) hasNode(nodeID int64) bool - allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error + allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error getSessionVersion() int64 @@ -381,6 +383,37 @@ func (c *queryNodeCluster) releasePartitions(ctx context.Context, nodeID int64, return fmt.Errorf("ReleasePartitions: can't find query node by nodeID, nodeID = %d", nodeID) } +func (c *queryNodeCluster) getSegmentInfoByID(ctx context.Context, segmentID UniqueID) (*querypb.SegmentInfo, error) { + c.RLock() + defer c.RUnlock() + + segmentInfo, err := c.clusterMeta.getSegmentInfoByID(segmentID) + if err != nil { + return nil, err + } + if node, ok := c.nodes[segmentInfo.NodeID]; ok { + res, err := node.getSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_SegmentInfo, + }, + CollectionID: segmentInfo.CollectionID, + }) + if err != nil { + return nil, err + } + if res != nil { + for _, info := range res.Infos { + if info.SegmentID == segmentID { + return info, nil + } + } + } + return nil, fmt.Errorf("updateSegmentInfo: can't find segment %d on query node %d", segmentID, segmentInfo.NodeID) + } + + return nil, fmt.Errorf("updateSegmentInfo: can't find query node by nodeID, nodeID = %d", segmentInfo.NodeID) +} + func (c *queryNodeCluster) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) { c.RLock() defer c.RUnlock() @@ -650,8 +683,8 @@ func (c *queryNodeCluster) getCollectionInfosByID(ctx context.Context, nodeID in return nil } -func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error { - return c.segmentAllocator(ctx, reqs, c, wait, excludeNodeIDs) +func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error { + return c.segmentAllocator(ctx, reqs, c, wait, excludeNodeIDs, includeNodeIDs) } func (c *queryNodeCluster) allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error { diff --git a/internal/querycoord/mock_querynode_server_test.go b/internal/querycoord/mock_querynode_server_test.go index 7fc1043f93..2e9734e2d3 100644 --- a/internal/querycoord/mock_querynode_server_test.go +++ b/internal/querycoord/mock_querynode_server_test.go @@ -32,9 +32,11 @@ import ( ) const ( - defaultTotalmemPerNode = 6000000000 + defaultTotalmemPerNode = 6000000 ) +var GlobalSegmentInfos = make(map[UniqueID]*querypb.SegmentInfo) + type queryNodeServerMock struct { querypb.QueryNodeServer ctx context.Context @@ -58,9 +60,9 @@ type queryNodeServerMock struct { getSegmentInfos func() (*querypb.GetSegmentInfoResponse, error) getMetrics func() (*milvuspb.GetMetricsResponse, error) - totalMem uint64 - memUsage uint64 - memUsageRate float64 + segmentInfos map[UniqueID]*querypb.SegmentInfo + + totalMem uint64 } func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock { @@ -81,9 +83,9 @@ func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock { getSegmentInfos: returnSuccessGetSegmentInfoResult, getMetrics: returnSuccessGetMetricsResult, - totalMem: defaultTotalmemPerNode, - memUsage: uint64(0), - memUsageRate: float64(0), + segmentInfos: GlobalSegmentInfos, + + totalMem: defaultTotalmemPerNode, } } @@ -194,12 +196,19 @@ func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.Lo if err != nil { return returnFailedResult() } - totalNumRow := int64(0) for _, info := range req.Infos { - totalNumRow += info.NumOfRows + segmentInfo := &querypb.SegmentInfo{ + SegmentID: info.SegmentID, + PartitionID: info.PartitionID, + CollectionID: info.CollectionID, + NodeID: qs.queryNodeID, + SegmentState: querypb.SegmentState_sealed, + MemSize: info.NumOfRows * int64(sizePerRecord), + NumRows: info.NumOfRows, + } + qs.segmentInfos[info.SegmentID] = segmentInfo } - qs.memUsage += uint64(totalNumRow) * uint64(sizePerRecord) - qs.memUsageRate = float64(qs.memUsage) / float64(qs.totalMem) + return qs.loadSegment() } @@ -215,8 +224,19 @@ func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb return qs.releaseSegments() } -func (qs *queryNodeServerMock) GetSegmentInfo(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - return qs.getSegmentInfos() +func (qs *queryNodeServerMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { + segmentInfos := make([]*querypb.SegmentInfo, 0) + for _, info := range qs.segmentInfos { + if info.CollectionID == req.CollectionID && info.NodeID == qs.queryNodeID { + segmentInfos = append(segmentInfos, info) + } + } + + res, err := qs.getSegmentInfos() + if err == nil { + res.Infos = segmentInfos + } + return res, err } func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { @@ -227,13 +247,20 @@ func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.Get if response.Status.ErrorCode != commonpb.ErrorCode_Success { return nil, errors.New("query node do task failed") } + + totalMemUsage := uint64(0) + for _, info := range qs.segmentInfos { + if info.NodeID == qs.queryNodeID { + totalMemUsage += uint64(info.MemSize) + } + } nodeInfos := metricsinfo.QueryNodeInfos{ BaseComponentInfos: metricsinfo.BaseComponentInfos{ Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, qs.queryNodeID), HardwareInfos: metricsinfo.HardwareMetrics{ IP: qs.queryNodeIP, Memory: qs.totalMem, - MemoryUsage: qs.memUsage, + MemoryUsage: totalMemUsage, }, Type: typeutil.QueryNodeRole, ID: qs.queryNodeID, diff --git a/internal/querycoord/param_table.go b/internal/querycoord/param_table.go index 4353e37419..6911cfa54b 100644 --- a/internal/querycoord/param_table.go +++ b/internal/querycoord/param_table.go @@ -70,6 +70,12 @@ type ParamTable struct { //---- Handoff --- AutoHandoff bool + + //---- Balance --- + AutoBalance bool + OverloadedMemoryThresholdPercentage float64 + BalanceIntervalSeconds int64 + MemoryUsageMaxDifferencePercentage float64 } // Params are variables of the ParamTable type @@ -117,6 +123,12 @@ func (p *ParamTable) Init() { p.initDmlChannelName() p.initDeltaChannelName() + + //---- Balance --- + p.initAutoBalance() + p.initOverloadedMemoryThresholdPercentage() + p.initBalanceIntervalSeconds() + p.initMemoryUsageMaxDifferencePercentage() } func (p *ParamTable) initQueryCoordAddress() { @@ -271,6 +283,42 @@ func (p *ParamTable) initAutoHandoff() { } } +func (p *ParamTable) initAutoBalance() { + balanceStr := p.LoadWithDefault("queryCoord.autoBalance", "false") + autoBalance, err := strconv.ParseBool(balanceStr) + if err != nil { + panic(err) + } + p.AutoBalance = autoBalance +} + +func (p *ParamTable) initOverloadedMemoryThresholdPercentage() { + overloadedMemoryThresholdPercentage := p.LoadWithDefault("queryCoord.overloadedMemoryThresholdPercentage", "90") + thresholdPercentage, err := strconv.ParseInt(overloadedMemoryThresholdPercentage, 10, 64) + if err != nil { + panic(err) + } + p.OverloadedMemoryThresholdPercentage = float64(thresholdPercentage) / 100 +} + +func (p *ParamTable) initBalanceIntervalSeconds() { + balanceInterval := p.LoadWithDefault("queryCoord.balanceIntervalSeconds", "60") + interval, err := strconv.ParseInt(balanceInterval, 10, 64) + if err != nil { + panic(err) + } + p.BalanceIntervalSeconds = interval +} + +func (p *ParamTable) initMemoryUsageMaxDifferencePercentage() { + maxDiff := p.LoadWithDefault("queryCoord.memoryUsageMaxDifferencePercentage", "30") + diffPercentage, err := strconv.ParseInt(maxDiff, 10, 64) + if err != nil { + panic(err) + } + p.MemoryUsageMaxDifferencePercentage = float64(diffPercentage) / 100 +} + func (p *ParamTable) initDmlChannelName() { config, err := p.Load("msgChannel.chanNamePrefix.rootCoordDml") if err != nil { diff --git a/internal/querycoord/query_coord.go b/internal/querycoord/query_coord.go index 9a5ed25187..19d59fdc22 100644 --- a/internal/querycoord/query_coord.go +++ b/internal/querycoord/query_coord.go @@ -14,6 +14,9 @@ package querycoord import ( "context" "errors" + "math" + "sort" + "fmt" "math/rand" "strconv" @@ -184,6 +187,9 @@ func (qc *QueryCoord) Start() error { qc.loopWg.Add(1) go qc.watchHandoffSegmentLoop() + qc.loopWg.Add(1) + go qc.loadBalanceSegmentLoop() + go qc.session.LivenessCheck(qc.loopCtx, func() { log.Error("Query Coord disconnected from etcd, process will exit", zap.Int64("Server Id", qc.session.ServerID)) if err := qc.Stop(); err != nil { @@ -563,3 +569,179 @@ func (qc *QueryCoord) processHandoffAfterIndexDone(ctx context.Context, indexedC } } } + +func (qc *QueryCoord) loadBalanceSegmentLoop() { + ctx, cancel := context.WithCancel(qc.loopCtx) + defer cancel() + defer qc.loopWg.Done() + log.Debug("query coordinator start load balance segment loop") + + timer := time.NewTicker(time.Duration(Params.BalanceIntervalSeconds) * time.Second) + + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + onlineNodes, err := qc.cluster.onlineNodes() + if err != nil { + log.Warn("loadBalanceSegmentLoop: there are no online query node to balance") + continue + } + // get mem info of online nodes from cluster + nodeID2MemUsageRate := make(map[int64]float64) + nodeID2MemUsage := make(map[int64]uint64) + nodeID2TotalMem := make(map[int64]uint64) + nodeID2SegmentInfos := make(map[int64]map[UniqueID]*querypb.SegmentInfo) + onlineNodeIDs := make([]int64, 0) + for nodeID := range onlineNodes { + nodeInfo, err := qc.cluster.getNodeInfoByID(nodeID) + if err != nil { + log.Warn("loadBalanceSegmentLoop: get node info from query node failed", zap.Int64("nodeID", nodeID), zap.Error(err)) + delete(onlineNodes, nodeID) + continue + } + + updateSegmentInfoDone := true + leastSegmentInfos := make(map[UniqueID]*querypb.SegmentInfo) + segmentInfos := qc.meta.getSegmentInfosByNode(nodeID) + for _, segmentInfo := range segmentInfos { + leastInfo, err := qc.cluster.getSegmentInfoByID(ctx, segmentInfo.SegmentID) + if err != nil { + log.Warn("loadBalanceSegmentLoop: get segment info from query node failed", zap.Int64("nodeID", nodeID), zap.Error(err)) + delete(onlineNodes, nodeID) + updateSegmentInfoDone = false + break + } + leastSegmentInfos[segmentInfo.SegmentID] = leastInfo + } + if updateSegmentInfoDone { + nodeID2MemUsageRate[nodeID] = nodeInfo.(*queryNode).memUsageRate + nodeID2MemUsage[nodeID] = nodeInfo.(*queryNode).memUsage + nodeID2TotalMem[nodeID] = nodeInfo.(*queryNode).totalMem + onlineNodeIDs = append(onlineNodeIDs, nodeID) + nodeID2SegmentInfos[nodeID] = leastSegmentInfos + } + } + log.Debug("loadBalanceSegmentLoop: memory usage rage of all online query node", zap.Any("mem rate", nodeID2MemUsageRate)) + if len(onlineNodeIDs) <= 1 { + log.Warn("loadBalanceSegmentLoop: there are too few online query nodes to balance", zap.Int64s("onlineNodeIDs", onlineNodeIDs)) + continue + } + + // check which nodes need balance and determine which segments on these nodes need to be migrated to other nodes + memoryInsufficient := false + loadBalanceTasks := make([]*loadBalanceTask, 0) + for { + var selectedSegmentInfo *querypb.SegmentInfo = nil + sort.Slice(onlineNodeIDs, func(i, j int) bool { + return nodeID2MemUsageRate[onlineNodeIDs[i]] > nodeID2MemUsageRate[onlineNodeIDs[j]] + }) + + // the memoryUsageRate of the sourceNode is higher than other query node + sourceNodeID := onlineNodeIDs[0] + dstNodeID := onlineNodeIDs[len(onlineNodeIDs)-1] + memUsageRateDiff := nodeID2MemUsageRate[sourceNodeID] - nodeID2MemUsageRate[dstNodeID] + // if memoryUsageRate of source node is greater then 90%, and the max memUsageDiff is greater than 30% + // then migrate the segments on source node to other query nodes + if nodeID2MemUsageRate[sourceNodeID] > Params.OverloadedMemoryThresholdPercentage || + memUsageRateDiff > Params.MemoryUsageMaxDifferencePercentage { + segmentInfos := nodeID2SegmentInfos[sourceNodeID] + // select the segment that needs balance on the source node + selectedSegmentInfo, err = chooseSegmentToBalance(sourceNodeID, dstNodeID, segmentInfos, nodeID2MemUsage, nodeID2TotalMem, nodeID2MemUsageRate) + if err == nil && selectedSegmentInfo != nil { + req := &querypb.LoadBalanceRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadBalanceSegments, + }, + BalanceReason: querypb.TriggerCondition_loadBalance, + SourceNodeIDs: []UniqueID{sourceNodeID}, + DstNodeIDs: []UniqueID{dstNodeID}, + SealedSegmentIDs: []UniqueID{selectedSegmentInfo.SegmentID}, + } + baseTask := newBaseTask(qc.loopCtx, querypb.TriggerCondition_loadBalance) + balanceTask := &loadBalanceTask{ + baseTask: baseTask, + LoadBalanceRequest: req, + rootCoord: qc.rootCoordClient, + dataCoord: qc.dataCoordClient, + cluster: qc.cluster, + meta: qc.meta, + } + loadBalanceTasks = append(loadBalanceTasks, balanceTask) + nodeID2MemUsage[sourceNodeID] -= uint64(selectedSegmentInfo.MemSize) + nodeID2MemUsage[dstNodeID] += uint64(selectedSegmentInfo.MemSize) + nodeID2MemUsageRate[sourceNodeID] = float64(nodeID2MemUsage[sourceNodeID]) / float64(nodeID2TotalMem[sourceNodeID]) + nodeID2MemUsageRate[dstNodeID] = float64(nodeID2MemUsage[dstNodeID]) / float64(nodeID2TotalMem[dstNodeID]) + delete(nodeID2SegmentInfos[sourceNodeID], selectedSegmentInfo.SegmentID) + nodeID2SegmentInfos[dstNodeID][selectedSegmentInfo.SegmentID] = selectedSegmentInfo + continue + } + } + if err != nil { + // no enough memory on query nodes to balance, then notify proxy to stop insert + memoryInsufficient = true + } + // if memoryInsufficient == false + // all query node's memoryUsageRate is less than 90%, and the max memUsageDiff is less than 30% + // this balance loop is done + break + } + if !memoryInsufficient { + for _, t := range loadBalanceTasks { + qc.scheduler.Enqueue(t) + log.Debug("loadBalanceSegmentLoop: enqueue a loadBalance task", zap.Any("task", t)) + err = t.waitToFinish() + if err != nil { + // if failed, wait for next balance loop + // it may be that the collection/partition of the balanced segment has been released + // it also may be other abnormal errors + log.Error("loadBalanceSegmentLoop: balance task execute failed", zap.Any("task", t)) + } else { + log.Debug("loadBalanceSegmentLoop: balance task execute success", zap.Any("task", t)) + } + } + log.Debug("loadBalanceSegmentLoop: load balance Done in this loop", zap.Any("tasks", loadBalanceTasks)) + } else { + // no enough memory on query nodes to balance, then notify proxy to stop insert + //TODO:: xige-16 + log.Error("loadBalanceSegmentLoop: query node has insufficient memory, stop inserting data") + } + } + } +} + +func chooseSegmentToBalance(sourceNodeID int64, dstNodeID int64, + segmentInfos map[UniqueID]*querypb.SegmentInfo, + nodeID2MemUsage map[int64]uint64, + nodeID2TotalMem map[int64]uint64, + nodeID2MemUsageRate map[int64]float64) (*querypb.SegmentInfo, error) { + memoryInsufficient := true + minMemDiffPercentage := 1.0 + var selectedSegmentInfo *querypb.SegmentInfo = nil + for _, info := range segmentInfos { + dstNodeMemUsageAfterBalance := nodeID2MemUsage[dstNodeID] + uint64(info.MemSize) + dstNodeMemUsageRateAfterBalance := float64(dstNodeMemUsageAfterBalance) / float64(nodeID2TotalMem[dstNodeID]) + // if memUsageRate of dstNode is greater than OverloadedMemoryThresholdPercentage after balance, than can't balance + if dstNodeMemUsageRateAfterBalance < Params.OverloadedMemoryThresholdPercentage { + memoryInsufficient = false + sourceNodeMemUsageAfterBalance := nodeID2MemUsage[sourceNodeID] - uint64(info.MemSize) + sourceNodeMemUsageRateAfterBalance := float64(sourceNodeMemUsageAfterBalance) / float64(nodeID2TotalMem[sourceNodeID]) + // assume all query node has same memory capacity + // if the memUsageRateDiff between the two nodes does not become smaller after balance, there is no need for balance + diffBeforBalance := nodeID2MemUsageRate[sourceNodeID] - nodeID2MemUsageRate[dstNodeID] + diffAfterBalance := dstNodeMemUsageRateAfterBalance - sourceNodeMemUsageRateAfterBalance + if diffAfterBalance < diffBeforBalance { + if math.Abs(diffAfterBalance) < minMemDiffPercentage { + selectedSegmentInfo = info + } + } + } + } + + if memoryInsufficient { + return nil, errors.New("all query nodes has insufficient memory") + } + + return selectedSegmentInfo, nil +} diff --git a/internal/querycoord/query_coord_test.go b/internal/querycoord/query_coord_test.go index 8bf8c06f28..855ec8389f 100644 --- a/internal/querycoord/query_coord_test.go +++ b/internal/querycoord/query_coord_test.go @@ -43,6 +43,7 @@ func refreshParams() { Params.MetaRootPath = Params.MetaRootPath + suffix Params.DmlChannelPrefix = "Dml" Params.DeltaChannelPrefix = "delta" + GlobalSegmentInfos = make(map[UniqueID]*querypb.SegmentInfo) } func TestMain(m *testing.M) { @@ -490,3 +491,73 @@ func TestHandoffSegmentLoop(t *testing.T) { err = removeAllSession() assert.Nil(t, err) } + +func TestLoadBalanceSegmentLoop(t *testing.T) { + refreshParams() + Params.BalanceIntervalSeconds = 10 + baseCtx := context.Background() + + queryCoord, err := startQueryCoord(baseCtx) + assert.Nil(t, err) + queryCoord.cluster.(*queryNodeCluster).segmentAllocator = shuffleSegmentsToQueryNode + + queryNode1, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, queryNode1.queryNodeID) + + loadCollectionTask := genLoadCollectionTask(baseCtx, queryCoord) + err = queryCoord.scheduler.Enqueue(loadCollectionTask) + assert.Nil(t, err) + waitTaskFinalState(loadCollectionTask, taskExpired) + + partitionID := defaultPartitionID + for { + req := &querypb.LoadPartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadPartitions, + }, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{partitionID}, + Schema: genCollectionSchema(defaultCollectionID, false), + } + baseTask := newBaseTask(baseCtx, querypb.TriggerCondition_grpcRequest) + loadPartitionTask := &loadPartitionTask{ + baseTask: baseTask, + LoadPartitionsRequest: req, + dataCoord: queryCoord.dataCoordClient, + cluster: queryCoord.cluster, + meta: queryCoord.meta, + } + err = queryCoord.scheduler.Enqueue(loadPartitionTask) + assert.Nil(t, err) + waitTaskFinalState(loadPartitionTask, taskExpired) + nodeInfo, err := queryCoord.cluster.getNodeInfoByID(queryNode1.queryNodeID) + assert.Nil(t, err) + if nodeInfo.(*queryNode).memUsageRate >= Params.OverloadedMemoryThresholdPercentage { + break + } + partitionID++ + } + + queryNode2, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, queryNode2.queryNodeID) + + // if sealed has been balance to query node2, than balance work + for { + segmentInfos, err := queryCoord.cluster.getSegmentInfoByNode(baseCtx, queryNode2.queryNodeID, &querypb.GetSegmentInfoRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadBalanceSegments, + }, + CollectionID: defaultCollectionID, + }) + assert.Nil(t, err) + if len(segmentInfos) > 0 { + break + } + } + + queryCoord.Stop() + err = removeAllSession() + assert.Nil(t, err) +} diff --git a/internal/querycoord/segment_allocator.go b/internal/querycoord/segment_allocator.go index 5c6a626510..dd4af2a6bf 100644 --- a/internal/querycoord/segment_allocator.go +++ b/internal/querycoord/segment_allocator.go @@ -24,19 +24,17 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) -const MaxMemUsagePerNode = 0.9 - func defaultSegAllocatePolicy() SegmentAllocatePolicy { return shuffleSegmentsToQueryNodeV2 } // SegmentAllocatePolicy helper function definition to allocate Segment to queryNode -type SegmentAllocatePolicy func(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error +type SegmentAllocatePolicy func(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error // shuffleSegmentsToQueryNode shuffle segments to online nodes // returned are noded id for each segment, which satisfies: // len(returnedNodeIds) == len(segmentIDs) && segmentIDs[i] is assigned to returnedNodeIds[i] -func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error { +func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error { if len(reqs) == 0 { return nil } @@ -57,6 +55,10 @@ func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegment nodeID2NumSegemnt := make(map[int64]int) for nodeID := range availableNodes { + if len(includeNodeIDs) > 0 && !nodeIncluded(nodeID, includeNodeIDs) { + delete(availableNodes, nodeID) + continue + } numSegments, err := cluster.getNumSegments(nodeID) if err != nil { delete(availableNodes, nodeID) @@ -87,7 +89,7 @@ func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegment } } -func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error { +func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) error { // key = offset, value = segmentSize if len(reqs) == 0 { return nil @@ -118,6 +120,10 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme delete(availableNodes, id) } for nodeID := range availableNodes { + if len(includeNodeIDs) > 0 && !nodeIncluded(nodeID, includeNodeIDs) { + delete(availableNodes, nodeID) + continue + } // statistic nodeInfo, used memory, memory usage of every query node nodeInfo, err := cluster.getNodeInfoByID(nodeID) if err != nil { @@ -127,7 +133,7 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme } queryNodeInfo := nodeInfo.(*queryNode) // avoid allocate segment to node which memUsageRate is high - if queryNodeInfo.memUsageRate >= MaxMemUsagePerNode { + if queryNodeInfo.memUsageRate >= Params.OverloadedMemoryThresholdPercentage { log.Debug("shuffleSegmentsToQueryNodeV2: queryNode memUsageRate large than MaxMemUsagePerNode", zap.Int64("nodeID", nodeID), zap.Float64("current rate", queryNodeInfo.memUsageRate)) delete(availableNodes, nodeID) continue @@ -152,7 +158,7 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme for _, nodeID := range nodeIDSlice { memUsageAfterLoad := memUsage[nodeID] + uint64(sizeOfReq) memUsageRateAfterLoad := float64(memUsageAfterLoad) / float64(totalMem[nodeID]) - if memUsageRateAfterLoad > MaxMemUsagePerNode { + if memUsageRateAfterLoad > Params.OverloadedMemoryThresholdPercentage { continue } reqs[offset].DstNodeID = nodeID @@ -181,3 +187,13 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme } } } + +func nodeIncluded(nodeID int64, includeNodeIDs []int64) bool { + for _, id := range includeNodeIDs { + if id == nodeID { + return true + } + } + + return false +} diff --git a/internal/querycoord/segment_allocator_test.go b/internal/querycoord/segment_allocator_test.go index 65377a42d4..c40f42ebd9 100644 --- a/internal/querycoord/segment_allocator_test.go +++ b/internal/querycoord/segment_allocator_test.go @@ -70,7 +70,7 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) { reqs := []*querypb.LoadSegmentsRequest{firstReq, secondReq} t.Run("Test shuffleSegmentsWithoutQueryNode", func(t *testing.T) { - err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil) + err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil, nil) assert.NotNil(t, err) }) @@ -82,7 +82,7 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) { waitQueryNodeOnline(cluster, node1ID) t.Run("Test shuffleSegmentsToQueryNode", func(t *testing.T) { - err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil) + err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil, nil) assert.Nil(t, err) assert.Equal(t, node1ID, firstReq.DstNodeID) @@ -98,7 +98,7 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) { cluster.stopNode(node1ID) t.Run("Test shuffleSegmentsToQueryNodeV2", func(t *testing.T) { - err = shuffleSegmentsToQueryNodeV2(baseCtx, reqs, cluster, false, nil) + err = shuffleSegmentsToQueryNodeV2(baseCtx, reqs, cluster, false, nil, nil) assert.Nil(t, err) assert.Equal(t, node2ID, firstReq.DstNodeID) diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 8cfa158f28..56d760f85d 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -453,7 +453,7 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { } - internalTasks, err := assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil) + internalTasks, err := assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil, nil) if err != nil { log.Warn("loadCollectionTask: assign child task failed", zap.Int64("collectionID", collectionID)) lct.setResultInfo(err) @@ -783,7 +783,7 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { } } - internalTasks, err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil) + internalTasks, err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil, nil) if err != nil { log.Warn("loadPartitionTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) lpt.setResultInfo(err) @@ -1082,7 +1082,7 @@ func (lst *loadSegmentTask) reschedule(ctx context.Context) ([]task, error) { } lst.excludeNodeIDs = append(lst.excludeNodeIDs, lst.DstNodeID) //TODO:: wait or not according msgType - reScheduledTasks, err := assignInternalTask(ctx, collectionID, lst.getParentTask(), lst.meta, lst.cluster, loadSegmentReqs, nil, nil, false, lst.excludeNodeIDs) + reScheduledTasks, err := assignInternalTask(ctx, collectionID, lst.getParentTask(), lst.meta, lst.cluster, loadSegmentReqs, nil, nil, false, lst.excludeNodeIDs, nil) if err != nil { log.Error("loadSegment reschedule failed", zap.Int64s("excludeNodes", lst.excludeNodeIDs), zap.Error(err)) return nil, err @@ -1257,7 +1257,7 @@ func (wdt *watchDmChannelTask) reschedule(ctx context.Context) ([]task, error) { } wdt.excludeNodeIDs = append(wdt.excludeNodeIDs, wdt.NodeID) //TODO:: wait or not according msgType - reScheduledTasks, err := assignInternalTask(ctx, collectionID, wdt.parentTask, wdt.meta, wdt.cluster, nil, watchDmChannelReqs, nil, false, wdt.excludeNodeIDs) + reScheduledTasks, err := assignInternalTask(ctx, collectionID, wdt.parentTask, wdt.meta, wdt.cluster, nil, watchDmChannelReqs, nil, false, wdt.excludeNodeIDs, nil) if err != nil { log.Error("watchDmChannel reschedule failed", zap.Int64s("excludeNodes", wdt.excludeNodeIDs), zap.Error(err)) return nil, err @@ -1557,7 +1557,7 @@ func (ht *handoffTask) execute(ctx context.Context) error { ht.setResultInfo(err) return err } - internalTasks, err := assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil) + internalTasks, err := assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil, nil) if err != nil { log.Error("handoffTask: assign child task failed", zap.Any("segmentInfo", segmentInfo)) ht.setResultInfo(err) @@ -1774,7 +1774,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { } } - internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs) + internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs, lbt.DstNodeIDs) if err != nil { log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) lbt.setResultInfo(err) @@ -1925,7 +1925,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { } // TODO:: assignInternalTask with multi collection - internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs) + internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs, lbt.DstNodeIDs) if err != nil { log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) lbt.setResultInfo(err) @@ -2006,11 +2006,11 @@ func assignInternalTask(ctx context.Context, loadSegmentRequests []*querypb.LoadSegmentsRequest, watchDmChannelRequests []*querypb.WatchDmChannelsRequest, watchDeltaChannelRequests []*querypb.WatchDeltaChannelsRequest, - wait bool, excludeNodeIDs []int64) ([]task, error) { + wait bool, excludeNodeIDs []int64, includeNodeIDs []int64) ([]task, error) { sp, _ := trace.StartSpanFromContext(ctx) defer sp.Finish() internalTasks := make([]task, 0) - err := cluster.allocateSegmentsToQueryNode(ctx, loadSegmentRequests, wait, excludeNodeIDs) + err := cluster.allocateSegmentsToQueryNode(ctx, loadSegmentRequests, wait, excludeNodeIDs, includeNodeIDs) if err != nil { log.Error("assignInternalTask: assign segment to node failed", zap.Any("load segments requests", loadSegmentRequests)) return nil, err diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index a0cc19743b..e31b2cfa4e 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -694,7 +694,7 @@ func Test_AssignInternalTask(t *testing.T) { loadSegmentRequests = append(loadSegmentRequests, req) } - internalTasks, err := assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil) + internalTasks, err := assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil, nil) assert.Nil(t, err) assert.NotEqual(t, 1, len(internalTasks)) diff --git a/internal/querynode/param_table.go b/internal/querynode/param_table.go index fd3708b586..841e865954 100644 --- a/internal/querynode/param_table.go +++ b/internal/querynode/param_table.go @@ -87,6 +87,9 @@ type ParamTable struct { // recovery skipQueryChannelRecovery bool + + // memory limit + OverloadedMemoryThresholdPercentage float64 } // Params is a package scoped variable of type ParamTable. @@ -146,6 +149,7 @@ func (p *ParamTable) Init() { p.initRoleName() p.initSkipQueryChannelRecovery() + p.initOverloadedMemoryThresholdPercentage() } func (p *ParamTable) initCacheSize() { @@ -346,3 +350,12 @@ func (p *ParamTable) initRoleName() { func (p *ParamTable) initSkipQueryChannelRecovery() { p.skipQueryChannelRecovery = p.ParseBool("msgChannel.skipQueryChannelRecovery", false) } + +func (p *ParamTable) initOverloadedMemoryThresholdPercentage() { + overloadedMemoryThresholdPercentage := p.LoadWithDefault("queryCoord.overloadedMemoryThresholdPercentage", "90") + thresholdPercentage, err := strconv.ParseInt(overloadedMemoryThresholdPercentage, 10, 64) + if err != nil { + panic(err) + } + p.OverloadedMemoryThresholdPercentage = float64(thresholdPercentage) / 100 +} diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index d90d9eb486..8c8ba3dca0 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -525,7 +525,6 @@ func (loader *segmentLoader) estimateSegmentSize(segment *Segment, } func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentSizes map[UniqueID]int64) error { - const thresholdFactor = 0.9 usedMem, err := getUsedMemory() if err != nil { return err @@ -548,16 +547,16 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentSize zap.Any("usedMem", usedMem), zap.Any("segmentTotalSize", segmentTotalSize), zap.Any("currentSegmentSize", size), - zap.Any("thresholdFactor", thresholdFactor), + zap.Any("thresholdFactor", Params.OverloadedMemoryThresholdPercentage), ) - if int64(usedMem)+segmentTotalSize+size > int64(float64(totalMem)*thresholdFactor) { + if int64(usedMem)+segmentTotalSize+size > int64(float64(totalMem)*Params.OverloadedMemoryThresholdPercentage) { return errors.New(fmt.Sprintln("load segment failed, OOM if load, "+ "collectionID = ", collectionID, ", ", "usedMem = ", usedMem, ", ", "segmentTotalSize = ", segmentTotalSize, ", ", "currentSegmentSize = ", size, ", ", "totalMem = ", totalMem, ", ", - "thresholdFactor = ", thresholdFactor, + "thresholdFactor = ", Params.OverloadedMemoryThresholdPercentage, )) } }